In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_nogatematrices import vlstm_fw_prepare_gate_preacts, vlstm_fw_nogatematrices_nostabilization, vlstm_fwbw_nogatematrices_nostabilization
from vlstm_full import vlstm_fw_torch



# vLSTM forward backward implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
temp_Ctilde = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_D = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_QK = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_N = torch.zeros((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 5, 5]))

In [6]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
          [-0.5457,  0.0000,    -inf,    -inf,    -inf],
          [-1.2204, -0.6747,  0.0000,    -inf,    -inf],
          [-1.3282, -0.7825, -0.1078,  0.0000,    -inf],
          [-2.1286, -1.5828, -0.9082, -0.8003,  0.0000]]]], device='cuda:0')

In [7]:
igate_mat

tensor([[[[ 2.3054,    -inf,    -inf,    -inf,    -inf],
          [ 2.3054,  0.0847,    -inf,    -inf,    -inf],
          [ 2.3054,  0.0847, -1.0159,    -inf,    -inf],
          [ 2.3054,  0.0847, -1.0159, -1.3262,    -inf],
          [ 2.3054,  0.0847, -1.0159, -1.3262,  0.7103]]]], device='cuda:0')

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward NOT stabilized without input & forget gate

### Torch Autograd

In [9]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)
temp_Ctilde_pt = temp_Ctilde.clone().detach().requires_grad_(True)
temp_D_pt = temp_D.clone().detach().requires_grad_(True)
temp_QK_pt = temp_QK.clone().detach().requires_grad_(True)
temp_N_pt = temp_N.clone().detach().requires_grad_(True)

In [10]:
igate_mat_pt, fgate_mat_pt

(tensor([[[[ 2.3054,    -inf,    -inf,    -inf,    -inf],
           [ 2.3054,  0.0847,    -inf,    -inf,    -inf],
           [ 2.3054,  0.0847, -1.0159,    -inf,    -inf],
           [ 2.3054,  0.0847, -1.0159, -1.3262,    -inf],
           [ 2.3054,  0.0847, -1.0159, -1.3262,  0.7103]]]], device='cuda:0',
        requires_grad=True),
 tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
           [-0.5457,  0.0000,    -inf,    -inf,    -inf],
           [-1.2204, -0.6747,  0.0000,    -inf,    -inf],
           [-1.3282, -0.7825, -0.1078,  0.0000,    -inf],
           [-2.1286, -1.5828, -0.9082, -0.8003,  0.0000]]]], device='cuda:0',
        requires_grad=True))

In [11]:
retr_val_pt = vlstm_fw_nogatematrices_nostabilization(
    qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, 
    temp_Ctilde_pt,
    temp_D_pt, 
    temp_QK_pt,
    temp_N_pt,
    eps=EPS
)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [12]:
retr_val_pt.sum().backward()

In [13]:
temp_Ctilde_pt.grad

tensor([[[[ 0.0000, -0.0587,  0.2899, -0.1118,  0.1739],
          [ 0.0124, -0.0611,  0.3752, -0.1276,  0.2300],
          [-0.0260, -0.2134,  0.8997, -0.3831,  0.5293],
          [ 0.0133, -0.1824,  0.9803, -0.3597,  0.5934],
          [-0.0992, -0.3727,  1.2520, -0.6205,  0.7113]]]], device='cuda:0')

In [14]:
qs_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0016, -0.0482,  0.0024, -0.0603,  0.0300,  0.0020],
          [-0.0495, -0.1514, -0.1589, -0.1463, -0.0703, -0.1679],
          [-0.0568, -0.0998, -0.1180, -0.1954, -0.0650, -0.1147],
          [-0.1934, -0.0472, -1.1054, -0.4855,  0.3414, -0.2261]]]],
       device='cuda:0')

In [15]:
ks_pt.grad

tensor([[[[ 6.4270e-03,  1.5413e-02,  3.8854e-03,  1.6204e-02, -3.2575e-02,
            7.6880e-03],
          [ 1.9567e-02,  1.0672e-01,  1.0744e-02, -5.8428e-02,  3.0784e-02,
            6.6472e-02],
          [-1.1618e-01, -2.1264e-01, -1.9881e-04,  1.2807e-01, -8.7206e-02,
           -1.0195e-01],
          [ 2.4762e-02,  1.4069e-01, -1.0870e-02, -7.7597e-02,  3.6510e-03,
            3.2500e-02],
          [ 7.0642e-02, -1.4443e+00,  2.1128e-03,  5.6555e-01,  3.5326e-01,
           -4.5008e-01]]]], device='cuda:0')

In [16]:
vs_pt.grad

tensor([[[[-2.6701, -2.6701, -2.6701, -2.6701, -2.6701, -2.6701],
          [-0.3439, -0.3439, -0.3439, -0.3439, -0.3439, -0.3439],
          [ 0.0902,  0.0902,  0.0902,  0.0902,  0.0902,  0.0902],
          [ 0.1132,  0.1132,  0.1132,  0.1132,  0.1132,  0.1132],
          [-0.1894, -0.1894, -0.1894, -0.1894, -0.1894, -0.1894]]]],
       device='cuda:0')

In [17]:
fgate_mat_pt.grad

tensor([[[[-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0601,  0.0601,  0.0000, -0.0000, -0.0000],
          [-0.0391, -0.1287,  0.1678,  0.0000,  0.0000],
          [-0.0265,  0.0626,  0.0098, -0.0458, -0.0000],
          [ 0.0916,  0.1653,  0.0085, -0.0538, -0.2117]]]], device='cuda:0')

In [18]:
igate_mat_pt.grad

tensor([[[[-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0601,  0.0601,  0.0000, -0.0000, -0.0000],
          [-0.0391, -0.1287,  0.1678,  0.0000,  0.0000],
          [-0.0265,  0.0626,  0.0098, -0.0458, -0.0000],
          [ 0.0916,  0.1653,  0.0085, -0.0538, -0.2117]]]], device='cuda:0')

### Own backward

In [19]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)
temp_Ctilde_obw = temp_Ctilde.clone().detach().requires_grad_(True)
temp_D_obw = temp_D.clone().detach().requires_grad_(True)
temp_QK_obw = temp_QK.clone().detach().requires_grad_(True)
temp_N_obw = temp_N.clone().detach().requires_grad_(True)

In [20]:
retr_val_obw = vlstm_fwbw_nogatematrices_nostabilization(
    qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, 
    temp_Ctilde_obw,
    temp_D_obw, 
    temp_QK_obw,
    temp_N_obw,
    eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [21]:
retr_val_obw.sum().backward()

In [22]:
temp_N_obw.grad - temp_N_pt.grad

tensor([[[[ 1.4901e-08],
          [-5.9605e-08],
          [ 5.9605e-08],
          [ 0.0000e+00],
          [ 0.0000e+00]]]], device='cuda:0')

In [23]:
temp_QK_obw.grad - temp_QK_pt.grad

tensor([[[[-6.5122,  0.0000, -0.0000,  0.0000, -0.0000],
          [-3.4286, -0.8048, -0.0000,  0.0000, -0.0000],
          [-0.3922, -0.0931, -1.0860,  0.0000, -0.0000],
          [-0.2629, -0.0627, -0.9775, -0.4760, -0.0000],
          [ 0.2547,  0.0590, -0.4495, -0.1589, -7.8078]]]], device='cuda:0')

In [31]:
temp_Ctilde_obw.grad - temp_Ctilde_pt.grad

tensor([[[[-0.6494, -0.8135, -2.8804, -2.6048, -4.0636],
          [-0.5901, -0.7394, -2.8941, -2.5173, -4.0480],
          [-0.1325, -0.1679, -2.9994, -1.8426, -3.9281],
          [-0.0989, -0.1260, -3.0071, -1.7931, -3.9193],
          [ 0.2134,  0.2641, -3.0790, -1.3325, -3.8374]]]], device='cuda:0')

In [24]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 5.5456e-01,  3.1493e-01, -9.2367e-02,  2.2284e+00,  3.1072e-02,
           -8.0010e-01],
          [ 3.5906e-01, -1.8177e-03, -5.0489e-02,  1.2086e+00,  1.5092e-01,
           -5.1932e-01],
          [ 1.9192e-01,  2.0818e-01,  3.0078e-01,  5.0736e-01,  2.3261e-01,
            1.8953e-01],
          [ 2.1064e-01,  2.6126e-01,  3.2550e-01,  5.4577e-01,  2.3976e-01,
            1.8080e-01],
          [ 5.9204e-01,  1.0595e-01,  3.0582e+00,  1.4426e+00, -9.2328e-01,
            4.8500e-01]]]], device='cuda:0')

In [25]:
ks_obw.grad-ks_pt.grad

tensor([[[[-0.7835,  2.4752,  1.8734, -2.6522,  1.0793,  2.9452],
          [-0.1439,  0.2936,  0.0989, -0.1966,  0.1537,  0.3144],
          [ 0.3699,  0.4899,  0.0105, -0.3212,  0.3104,  0.2817],
          [ 0.1337,  0.4918, -0.0539, -0.3048,  0.0694,  0.0967],
          [-0.3811,  7.7917, -0.0114, -3.0511, -1.9058,  2.4281]]]],
       device='cuda:0')

In [26]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [27]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[ 4.7555,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 2.8712,  0.7280, -0.0000,  0.0000,  0.0000],
          [-0.1993, -0.1013, -0.5592,  0.0000, -0.0000],
          [ 0.1969,  0.0432, -0.0299, -0.2284,  0.0000],
          [-0.1971, -0.1172, -0.0210, -0.1156,  1.1420]]]], device='cuda:0')

### Do gradients match? 

In [28]:
qs_pt.grad - qs_obw.grad

tensor([[[[-5.5456e-01, -3.1493e-01,  9.2367e-02, -2.2284e+00, -3.1072e-02,
            8.0010e-01],
          [-3.5906e-01,  1.8177e-03,  5.0489e-02, -1.2086e+00, -1.5092e-01,
            5.1932e-01],
          [-1.9192e-01, -2.0818e-01, -3.0078e-01, -5.0736e-01, -2.3261e-01,
           -1.8953e-01],
          [-2.1064e-01, -2.6126e-01, -3.2550e-01, -5.4577e-01, -2.3976e-01,
           -1.8080e-01],
          [-5.9204e-01, -1.0595e-01, -3.0582e+00, -1.4426e+00,  9.2328e-01,
           -4.8500e-01]]]], device='cuda:0')

In [29]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: False
ks match: False
vs match: True
fgate_mat match: False
igate_mat match: False


### DEBUG

## Forward without input & forget gate

In [32]:
retr_vals = vlstm_fw_nogatematrices_nostabilization(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igate_mat,
    fgate_preact=fgate_mat,
    temp_Ctilde=temp_Ctilde,
    temp_D=temp_D,
    temp_QK=temp_QK,
    temp_N=temp_N,
    eps=EPS,
)
retr_vals.shape

torch.Size([1, 1, 5, 6])

In [35]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_nostabilization(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igate_mat,
    fgate_preact=fgate_mat,
    temp_Ctilde=temp_Ctilde,
    temp_D=temp_D,
    temp_QK=temp_QK,
    temp_N=temp_N,
    eps=EPS,
)
retr_vals_fwbw.shape

torch.Size([1, 1, 5, 6])

### Check if it equals the full version:

In [36]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 5, 6])

In [37]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[ 8.9407e-08, -8.3447e-07,  9.5367e-07,  1.2517e-06,  1.4901e-07,
            7.7486e-07],
          [ 3.2783e-07, -5.3644e-07,  5.3644e-07,  8.9407e-07, -1.3411e-07,
            5.3644e-07],
          [-7.1526e-07,  6.5565e-07, -7.1526e-07, -1.1325e-06,  5.9605e-07,
           -7.7486e-07],
          [ 2.9802e-07, -7.1526e-07,  7.1526e-07,  1.4305e-06, -2.5332e-07,
            6.5565e-07],
          [ 4.4703e-07, -6.8545e-07,  8.9407e-07,  1.6689e-06, -5.9605e-07,
            3.8743e-07]]]], device='cuda:0')

In [38]:
retr_vals_fwbw - retr_vals_full

tensor([[[[ 8.9407e-08, -8.3447e-07,  9.5367e-07,  1.2517e-06,  1.4901e-07,
            7.7486e-07],
          [ 3.2783e-07, -5.3644e-07,  5.3644e-07,  8.9407e-07, -1.3411e-07,
            5.3644e-07],
          [-7.1526e-07,  6.5565e-07, -7.1526e-07, -1.1325e-06,  5.9605e-07,
           -7.7486e-07],
          [ 2.9802e-07, -7.1526e-07,  7.1526e-07,  1.4305e-06, -2.5332e-07,
            6.5565e-07],
          [ 4.4703e-07, -6.8545e-07,  8.9407e-07,  1.6689e-06, -5.9605e-07,
            3.8743e-07]]]], device='cuda:0')