In [3]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_nogatematrices import vlstm_fw_nogatematrices_torch, vlstm_fw_prepare_gate_preacts, vlstm_fwbw_nogatematrices_torch
from vlstm_full import vlstm_fw_torch



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# vLSTM forward backward implementation

In [4]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [5]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [6]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [7]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 5, 5]))

In [8]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
          [-0.9740,  0.0000,    -inf,    -inf,    -inf],
          [-1.9280, -0.9540,  0.0000,    -inf,    -inf],
          [-2.7621, -1.7881, -0.8342,  0.0000,    -inf],
          [-3.2144, -2.2404, -1.2864, -0.4523,  0.0000]]]], device='cuda:0')

In [9]:
igate_mat

tensor([[[[1.1734,   -inf,   -inf,   -inf,   -inf],
          [1.1734, 0.1441,   -inf,   -inf,   -inf],
          [1.1734, 0.1441, 0.8900,   -inf,   -inf],
          [1.1734, 0.1441, 0.8900, 0.8338,   -inf],
          [1.1734, 0.1441, 0.8900, 0.8338, 1.9555]]]], device='cuda:0')

In [10]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward Stabilized without input & forget gate

### Torch Autograd

In [11]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [27]:
igate_mat_pt, fgate_mat_pt

(tensor([[[[-0.9360,    -inf,    -inf],
           [-0.9360, -0.2426,    -inf],
           [-0.9360, -0.2426, -1.8288]]]], device='cuda:0', requires_grad=True),
 tensor([[[[ 0.0000,    -inf,    -inf],
           [-1.5091,  0.0000,    -inf],
           [-2.3670, -0.8579,  0.0000]]]], device='cuda:0', requires_grad=True))

In [12]:
retr_val_pt = vlstm_fw_nogatematrices_torch(qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [13]:
retr_val_pt.sum().backward()

In [14]:
qs_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.1898,  0.3535,  0.2526, -0.4513,  0.1192, -1.0204],
          [ 2.1230,  3.6409,  1.8615, -2.9232,  1.3874, -0.3905],
          [ 0.4781, -5.0274,  0.8095,  0.5433,  2.2278, -3.2875],
          [ 1.0696, -2.7739,  0.3917,  0.9173, -0.8761, -1.7354]]]],
       device='cuda:0')

In [15]:
ks_pt.grad

tensor([[[[ 0.3145, -0.2250,  0.2156,  0.2999,  0.2635, -0.9822],
          [ 0.3043, -0.2178,  0.2086,  0.2902,  0.2550, -0.9504],
          [ 1.3716, -0.3946, -1.0525,  0.9435, -0.3814, -3.8808],
          [ 0.3076, -2.2874,  4.8265,  0.7353, -0.4333, -2.3362],
          [ 0.4348, -2.9404,  3.5666, -1.7066,  0.1127,  0.1471]]]],
       device='cuda:0')

In [16]:
vs_pt.grad

tensor([[[[ 2.0732,  2.0732,  2.0732,  2.0732,  2.0732,  2.0732],
          [-1.4744, -1.4744, -1.4744, -1.4744, -1.4744, -1.4744],
          [-1.0673, -1.0673, -1.0673, -1.0673, -1.0673, -1.0673],
          [ 3.3890,  3.3890,  3.3890,  3.3890,  3.3890,  3.3890],
          [-2.6302, -2.6302, -2.6302, -2.6302, -2.6302, -2.6302]]]],
       device='cuda:0')

In [17]:
fgate_mat_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 2.2058, -1.2235,  0.0000,  0.0000,  0.0000],
          [ 0.5836, -0.1235, -1.6474,  0.0000,  0.0000],
          [ 0.2251, -0.0952, -0.5429,  3.0326,  0.0000],
          [ 0.1239, -0.0976,  0.2957,  6.0540, -1.5206]]]], device='cuda:0')

In [18]:
igate_mat_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 2.2058, -1.2235,  0.0000,  0.0000,  0.0000],
          [ 0.5836, -0.1235, -1.6474,  0.0000,  0.0000],
          [ 0.2251, -0.0952, -0.5429,  3.0326,  0.0000],
          [ 0.1239, -0.0976,  0.2957,  6.0540, -1.5206]]]], device='cuda:0')

### Own backward

In [19]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [20]:
retr_val_obw = vlstm_fwbw_nogatematrices_torch(
    qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [21]:
retr_val_obw.sum().backward()

In [22]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.2791, -0.2202,  0.2366,  0.2645,  0.4184, -0.5586],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0')

In [23]:
ks_obw.grad-ks_pt.grad

tensor([[[[-3.2545e-01,  5.9422e-01, -4.7766e-01, -1.5620e-01,  2.6975e-01,
            6.8648e-02],
          [ 2.9802e-08,  0.0000e+00,  1.4901e-08,  0.0000e+00, -2.9802e-08,
            1.1921e-07],
          [ 0.0000e+00, -8.9407e-08,  0.0000e+00, -1.1921e-07,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00, -2.3842e-07,  0.0000e+00, -5.9605e-08,  0.0000e+00,
            2.3842e-07],
          [ 0.0000e+00,  0.0000e+00, -2.3842e-07,  0.0000e+00, -7.4506e-09,
            1.4901e-08]]]], device='cuda:0')

In [24]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [25]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[-1.0212e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -2.3842e-07,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]],
       device='cuda:0')

### Do gradients match? 

In [26]:
qs_pt.grad - qs_obw.grad

tensor([[[[-0.2791,  0.2202, -0.2366, -0.2645, -0.4184,  0.5586],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0')

In [27]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: False
ks match: False
vs match: True
fgate_mat match: False
igate_mat match: False


### DEBUG

## Forward without input & forget gate

In [9]:
retr_vals = vlstm_fw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals.shape

torch.Size([1, 1, 3, 4])

In [10]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals_fwbw.shape

torch.Size([1, 1, 3, 4])

### Check if it equals the full version:

In [11]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 3, 4])

In [12]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')

In [13]:
retr_vals_fwbw - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')