In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_nogatematrices import vlstm_fw_nogatematrices_torch, vlstm_fw_prepare_gate_preacts, vlstm_fwbw_nogatematrices_torch
from vlstm_full import vlstm_fw_torch



# vLSTM forward backward (NOGATES stabilized) implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 5, 5]))

In [6]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
          [-0.6975,  0.0000,    -inf,    -inf,    -inf],
          [-1.7448, -1.0474,  0.0000,    -inf,    -inf],
          [-1.9973, -1.2998, -0.2524,  0.0000,    -inf],
          [-3.8352, -3.1378, -2.0904, -1.8379,  0.0000]]]], device='cuda:0')

In [7]:
igate_mat

tensor([[[[-1.2596,    -inf,    -inf,    -inf,    -inf],
          [-1.2596,  0.3898,    -inf,    -inf,    -inf],
          [-1.2596,  0.3898, -0.1589,    -inf,    -inf],
          [-1.2596,  0.3898, -0.1589, -0.9753,    -inf],
          [-1.2596,  0.3898, -0.1589, -0.9753, -1.7398]]]], device='cuda:0')

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward Stabilized without input & forget gate

### Torch Autograd

In [9]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [10]:
igate_mat_pt, fgate_mat_pt

(tensor([[[[-1.2596,    -inf,    -inf,    -inf,    -inf],
           [-1.2596,  0.3898,    -inf,    -inf,    -inf],
           [-1.2596,  0.3898, -0.1589,    -inf,    -inf],
           [-1.2596,  0.3898, -0.1589, -0.9753,    -inf],
           [-1.2596,  0.3898, -0.1589, -0.9753, -1.7398]]]], device='cuda:0',
        requires_grad=True),
 tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
           [-0.6975,  0.0000,    -inf,    -inf,    -inf],
           [-1.7448, -1.0474,  0.0000,    -inf,    -inf],
           [-1.9973, -1.2998, -0.2524,  0.0000,    -inf],
           [-3.8352, -3.1378, -2.0904, -1.8379,  0.0000]]]], device='cuda:0',
        requires_grad=True))

In [11]:
retr_val_pt = vlstm_fw_nogatematrices_torch(qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [12]:
retr_val_pt.sum().backward()

In [13]:
qs_pt.grad

tensor([[[[-0.0600, -0.0271,  0.0342, -0.0270, -0.0670, -0.0976],
          [ 1.6059,  1.4116,  0.2235, -0.7741, -0.4435,  0.2759],
          [ 0.5116,  0.4880,  0.2219, -0.2664, -0.2346,  0.0729],
          [ 0.4261,  0.3242,  0.1623, -0.1879, -0.1271,  0.0533],
          [ 0.1224,  0.0346,  0.0641, -0.1076, -0.0618, -0.1150]]]],
       device='cuda:0')

In [14]:
ks_pt.grad

tensor([[[[ 5.2426e-02, -1.4183e-01,  6.9655e-02,  4.3799e-02, -3.3762e-02,
            1.7235e-02],
          [ 1.2053e+00, -1.7971e+00,  3.0814e-01, -8.8649e-01, -1.3736e-01,
           -3.1060e-01],
          [-2.2055e-02, -1.0415e-01, -1.8058e-02, -1.9704e-01, -1.6115e-02,
           -1.8045e-01],
          [-7.9026e-04,  2.7587e-02, -1.8638e-02,  4.7791e-02,  1.8510e-02,
            6.5603e-02],
          [ 7.0735e-02, -1.2208e-02,  1.1616e-01,  3.8113e-02, -4.9610e-02,
           -2.1831e-01]]]], device='cuda:0')

In [15]:
vs_pt.grad

tensor([[[[0.1283, 0.1283, 0.1283, 0.1283, 0.1283, 0.1283],
          [0.0766, 0.0766, 0.0766, 0.0766, 0.0766, 0.0766],
          [0.1690, 0.1690, 0.1690, 0.1690, 0.1690, 0.1690],
          [0.0397, 0.0397, 0.0397, 0.0397, 0.0397, 0.0397],
          [0.5019, 0.5019, 0.5019, 0.5019, 0.5019, 0.5019]]]], device='cuda:0')

In [16]:
fgate_mat_pt.grad

tensor([[[[-0.0247,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0263, -0.4703,  0.0000,  0.0000,  0.0000],
          [-0.0166,  0.2711,  0.0737,  0.0000,  0.0000],
          [-0.0301,  0.1106, -0.0775,  0.0102,  0.0000],
          [-0.0073, -0.0264, -0.0395, -0.0005, -0.4921]]]], device='cuda:0')

In [17]:
igate_mat_pt.grad

tensor([[[[-0.0247,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0263, -0.4703,  0.0000,  0.0000,  0.0000],
          [-0.0166,  0.2711,  0.0737,  0.0000,  0.0000],
          [-0.0301,  0.1106, -0.0775,  0.0102,  0.0000],
          [-0.0073, -0.0264, -0.0395, -0.0005, -0.4921]]]], device='cuda:0')

### Own backward

In [18]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [19]:
retr_val_obw = vlstm_fwbw_nogatematrices_torch(
    qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [20]:
retr_val_obw.sum().backward()

In [21]:
qs_obw.grad-qs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [22]:
ks_obw.grad-ks_pt.grad

tensor([[[[-3.7253e-09,  1.4901e-08, -7.4506e-09,  3.7253e-09,  3.7253e-09,
            3.7253e-09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  3.7253e-09, -1.4901e-08,  0.0000e+00,
            1.4901e-08],
          [-5.2387e-10,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.7253e-09,
            0.0000e+00],
          [ 0.0000e+00,  9.3132e-10,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -1.4901e-08]]]], device='cuda:0')

In [23]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [24]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[-1.8626e-09,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -1.4901e-08,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  5.9605e-08]]]],
       device='cuda:0')

### Do gradients match? 

In [25]:
qs_pt.grad - qs_obw.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [26]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_mat match: True
igate_mat match: True


## Forward without input & forget gate

In [27]:
retr_vals = vlstm_fw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals.shape

torch.Size([1, 1, 5, 6])

In [28]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals_fwbw.shape

torch.Size([1, 1, 5, 6])

### Check if it equals the full version:

In [29]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 5, 6])

In [30]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [31]:
retr_vals_fwbw - retr_vals_full

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')