In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_nogatematrices import vlstm_fw_prepare_gate_preacts, vlstm_fw_nogatematrices_nostabilization, vlstm_fwbw_nogatematrices_nostabilization
from vlstm_full import vlstm_fw_torch



# vLSTM forward backward implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
temp_Ctilde = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_D = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_QK = torch.zeros((B, NH, S, S), dtype=DTYPE, device=DEVICE)
temp_N = torch.zeros((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
temp_B = torch.zeros((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 5, 5]))

In [6]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
          [-1.3068,  0.0000,    -inf,    -inf,    -inf],
          [-1.6387, -0.3319,  0.0000,    -inf,    -inf],
          [-2.7734, -1.4667, -1.1348,  0.0000,    -inf],
          [-3.8335, -2.5267, -2.1948, -1.0600,  0.0000]]]], device='cuda:0')

In [7]:
igate_mat

tensor([[[[ 0.0681,    -inf,    -inf,    -inf,    -inf],
          [ 0.0681,  1.1541,    -inf,    -inf,    -inf],
          [ 0.0681,  1.1541, -0.2651,    -inf,    -inf],
          [ 0.0681,  1.1541, -0.2651, -0.3021,    -inf],
          [ 0.0681,  1.1541, -0.2651, -0.3021,  0.8409]]]], device='cuda:0')

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward NOT stabilized without input & forget gate

### Torch Autograd

In [9]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)
temp_Ctilde_pt = temp_Ctilde.clone().detach().requires_grad_(True)
temp_D_pt = temp_D.clone().detach().requires_grad_(True)
temp_QK_pt = temp_QK.clone().detach().requires_grad_(True)
temp_N_pt = temp_N.clone().detach().requires_grad_(True)
temp_B_pt = temp_B.clone().detach().requires_grad_(True)

In [10]:
igate_mat_pt, fgate_mat_pt

(tensor([[[[ 0.0681,    -inf,    -inf,    -inf,    -inf],
           [ 0.0681,  1.1541,    -inf,    -inf,    -inf],
           [ 0.0681,  1.1541, -0.2651,    -inf,    -inf],
           [ 0.0681,  1.1541, -0.2651, -0.3021,    -inf],
           [ 0.0681,  1.1541, -0.2651, -0.3021,  0.8409]]]], device='cuda:0',
        requires_grad=True),
 tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
           [-1.3068,  0.0000,    -inf,    -inf,    -inf],
           [-1.6387, -0.3319,  0.0000,    -inf,    -inf],
           [-2.7734, -1.4667, -1.1348,  0.0000,    -inf],
           [-3.8335, -2.5267, -2.1948, -1.0600,  0.0000]]]], device='cuda:0',
        requires_grad=True))

In [11]:
retr_val_pt = vlstm_fw_nogatematrices_nostabilization(
    qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, 
    temp_Ctilde_pt,
    temp_D_pt, 
    temp_QK_pt,
    temp_N_pt,
    temp_B_pt,
    eps=EPS
)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [12]:
retr_val_pt.sum().backward()

In [13]:
temp_Ctilde_pt.grad

tensor([[[[-0.1808,  7.5565,  3.0299, -2.6199, -1.2011],
          [-4.1559, -0.2817, -2.5482, -5.3772, -4.6668],
          [-3.0533,  1.2428, -1.2705, -4.4075, -3.6198],
          [-0.1808,  7.5565,  3.0299, -2.6199, -1.2011],
          [-0.1808,  7.5565,  3.0299, -2.6199, -1.2011]]]], device='cuda:0')

In [14]:
qs_pt.grad

tensor([[[[ 0.0173,  0.0353, -0.0691,  0.0335, -0.0069,  0.0910],
          [-0.0636,  0.3762, -0.4657,  0.0243, -0.0847,  0.8219],
          [ 0.7051, -0.2831, -0.4671,  0.7171,  0.2757, -0.0789],
          [ 0.5619, -1.1695,  1.5253,  1.1713, -0.1283, -3.8332],
          [ 1.3477, -2.1956,  0.5541,  0.2074, -0.5199, -1.2573]]]],
       device='cuda:0')

In [15]:
ks_pt.grad

tensor([[[[ 1.0093, -0.0430, -0.3616,  0.6074,  0.4815,  0.3456],
          [-2.3925, -2.7342,  1.2138,  8.0466,  3.8282, -3.8922],
          [-0.0234,  0.2902,  0.3225,  1.0776,  0.4399,  0.2744],
          [ 0.7681,  0.4138, -0.6249, -2.7095, -1.1854,  0.7790],
          [ 0.0232,  0.0758, -0.8248, -0.7961,  0.4275,  1.3741]]]],
       device='cuda:0')

In [16]:
vs_pt.grad

tensor([[[[-0.0248, -0.0248, -0.0248, -0.0248, -0.0248, -0.0248],
          [ 0.2352,  0.2352,  0.2352,  0.2352,  0.2352,  0.2352],
          [ 0.4537,  0.4537,  0.4537,  0.4537,  0.4537,  0.4537],
          [-1.1271, -1.1271, -1.1271, -1.1271, -1.1271, -1.1271],
          [-0.0572, -0.0572, -0.0572, -0.0572, -0.0572, -0.0572]]]],
       device='cuda:0')

In [17]:
fgate_mat_pt.grad

tensor([[[[ 4.0874e-02, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-6.0341e-01,  6.0341e-01, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-5.6450e-01,  1.2945e+00, -7.2997e-01,  0.0000e+00,  0.0000e+00],
          [-1.7167e-03,  4.5362e+00,  1.8652e-01,  1.8357e+00, -0.0000e+00],
          [-2.9584e-03,  9.7654e-01,  2.2154e-01,  1.1172e+00,  6.8687e-02]]]],
       device='cuda:0')

In [18]:
igate_mat_pt.grad

tensor([[[[ 4.0874e-02, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-6.0341e-01,  6.0341e-01, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-5.6450e-01,  1.2945e+00, -7.2997e-01,  0.0000e+00,  0.0000e+00],
          [-1.7167e-03,  4.5362e+00,  1.8652e-01,  1.8357e+00, -0.0000e+00],
          [-2.9584e-03,  9.7654e-01,  2.2154e-01,  1.1172e+00,  6.8687e-02]]]],
       device='cuda:0')

### Own backward

In [19]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)
temp_Ctilde_obw = temp_Ctilde.clone().detach().requires_grad_(True)
temp_D_obw = temp_D.clone().detach().requires_grad_(True)
temp_QK_obw = temp_QK.clone().detach().requires_grad_(True)
temp_N_obw = temp_N.clone().detach().requires_grad_(True)
temp_B_obw = temp_B.clone().detach().requires_grad_(True)

In [20]:
retr_val_obw = vlstm_fwbw_nogatematrices_nostabilization(
    qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, 
    temp_Ctilde_obw,
    temp_D_obw, 
    temp_QK_obw,
    temp_N_obw,
    temp_B_obw,
    eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [21]:
retr_val_obw.sum().backward()

In [22]:
temp_N_obw.grad - temp_N_pt.grad

tensor([[[[ 0.0000e+00],
          [-4.7684e-07],
          [ 4.7684e-07],
          [ 0.0000e+00],
          [ 0.0000e+00]]]], device='cuda:0')

In [23]:
temp_B_obw.grad - temp_B_pt.grad

tensor([[[[0.0000e+00],
          [4.7684e-07],
          [4.7684e-07],
          [0.0000e+00],
          [0.0000e+00]]]], device='cuda:0')

In [24]:
temp_QK_obw.grad - temp_QK_pt.grad

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.1921e-07, 1.4901e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [5.9605e-08, 1.1921e-06, 3.5763e-07, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]],
       device='cuda:0')

In [25]:
temp_Ctilde_obw.grad - temp_Ctilde_pt.grad

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [4.7684e-07, 4.7684e-07, 4.7684e-07, 9.5367e-07, 4.7684e-07],
          [4.7684e-07, 4.7684e-07, 4.7684e-07, 4.7684e-07, 4.7684e-07],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]],
       device='cuda:0')

In [26]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 2.6822e-07, -2.6822e-07,  1.1921e-07,  2.9244e-07,  7.4506e-08,
           -4.7684e-07],
          [ 2.3842e-07, -2.6822e-07,  2.3842e-07,  2.3842e-07,  0.0000e+00,
           -5.9605e-07],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00]]]], device='cuda:0')

In [27]:
ks_obw.grad-ks_pt.grad

tensor([[[[ 0.0000e+00, -3.3528e-08,  0.0000e+00,  0.0000e+00, -5.9605e-08,
           -8.9407e-08],
          [-1.4305e-06,  0.0000e+00,  4.7684e-07, -1.9073e-06, -4.7684e-07,
           -2.3842e-07],
          [-7.2643e-08, -2.0862e-07, -2.9802e-08,  0.0000e+00, -2.9802e-08,
           -2.0862e-07],
          [-5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            5.9605e-08],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -5.9605e-08, -2.9802e-08,
           -1.1921e-07]]]], device='cuda:0')

In [28]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [29]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 5.9605e-08, -1.0133e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 5.9605e-08,  4.7684e-07,  2.3842e-07,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]],
       device='cuda:0')

### Do gradients match? 

In [30]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-2.6822e-07,  2.6822e-07, -1.1921e-07, -2.9244e-07, -7.4506e-08,
            4.7684e-07],
          [-2.3842e-07,  2.6822e-07, -2.3842e-07, -2.3842e-07,  0.0000e+00,
            5.9605e-07],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00]]]], device='cuda:0')

In [31]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_mat match: True
igate_mat match: True


### DEBUG

## Forward without input & forget gate

In [32]:
retr_vals = vlstm_fw_nogatematrices_nostabilization(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igate_mat,
    fgate_preact=fgate_mat,
    temp_Ctilde=temp_Ctilde,
    temp_D=temp_D,
    temp_QK=temp_QK,
    temp_N=temp_N,
    eps=EPS,
)
retr_vals.shape

TypeError: vlstm_fw_nogatematrices_nostabilization() missing 1 required positional argument: 'temp_B'

In [None]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_nostabilization(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igate_mat,
    fgate_preact=fgate_mat,
    temp_Ctilde=temp_Ctilde,
    temp_D=temp_D,
    temp_QK=temp_QK,
    temp_N=temp_N,
    eps=EPS,
)
retr_vals_fwbw.shape

torch.Size([1, 1, 5, 6])

### Check if it equals the full version:

In [None]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 5, 6])

In [None]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[ 8.9407e-08, -8.3447e-07,  9.5367e-07,  1.2517e-06,  1.4901e-07,
            7.7486e-07],
          [ 3.2783e-07, -5.3644e-07,  5.3644e-07,  8.9407e-07, -1.3411e-07,
            5.3644e-07],
          [-7.1526e-07,  6.5565e-07, -7.1526e-07, -1.1325e-06,  5.9605e-07,
           -7.7486e-07],
          [ 2.9802e-07, -7.1526e-07,  7.1526e-07,  1.4305e-06, -2.5332e-07,
            6.5565e-07],
          [ 4.4703e-07, -6.8545e-07,  8.9407e-07,  1.6689e-06, -5.9605e-07,
            3.8743e-07]]]], device='cuda:0')

In [None]:
retr_vals_fwbw - retr_vals_full

tensor([[[[ 8.9407e-08, -8.3447e-07,  9.5367e-07,  1.2517e-06,  1.4901e-07,
            7.7486e-07],
          [ 3.2783e-07, -5.3644e-07,  5.3644e-07,  8.9407e-07, -1.3411e-07,
            5.3644e-07],
          [-7.1526e-07,  6.5565e-07, -7.1526e-07, -1.1325e-06,  5.9605e-07,
           -7.7486e-07],
          [ 2.9802e-07, -7.1526e-07,  7.1526e-07,  1.4305e-06, -2.5332e-07,
            6.5565e-07],
          [ 4.4703e-07, -6.8545e-07,  8.9407e-07,  1.6689e-06, -5.9605e-07,
            3.8743e-07]]]], device='cuda:0')