In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_fw_nogatematrices import vlstm_fw_nogatematrices_torch, vlstm_fw_prepare_gate_preacts, vlstm_fwbw_nogatematrices_torch
from vlstm_fw_full import vlstm_fw_torch



# vLSTM forward backward implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 3
NH = 1
DH = 4
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 3, 3]), torch.Size([1, 1, 3, 3]))

In [6]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf],
          [-0.4312,  0.0000,    -inf],
          [-1.1061, -0.6749,  0.0000]]]], device='cuda:0')

In [7]:
igate_mat

tensor([[[[ 1.2072,    -inf,    -inf],
          [ 1.2072, -1.2218,    -inf],
          [ 1.2072, -1.2218,  1.0656]]]], device='cuda:0')

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 3, 4])

## Backward without input & forget gate

### Torch Autograd

In [9]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [10]:
retr_val_pt = vlstm_fw_nogatematrices_torch(qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 3, 4])

In [11]:
retr_val_pt.sum().backward()

In [12]:
qs_pt.grad

tensor([[[[ 1.0449e+00,  1.2792e+00, -6.0733e-01, -2.4505e+00],
          [-8.6931e-03,  1.7846e-03, -5.4862e-03, -2.9291e-03],
          [-7.4326e-01,  6.5229e-01,  3.7502e+00, -1.6537e+00]]]],
       device='cuda:0')

In [13]:
ks_pt.grad

tensor([[[[ 2.9352, -2.2258,  1.2687,  1.8272],
          [ 0.1175, -0.0377,  0.0831,  0.0993],
          [ 2.2288, -0.7318,  1.8055,  2.1428]]]], device='cuda:0')

In [14]:
vs_pt.grad

tensor([[[[-0.2490, -0.2490, -0.2490, -0.2490],
          [-0.2112, -0.2112, -0.2112, -0.2112],
          [ 0.8581,  0.8581,  0.8581,  0.8581]]]], device='cuda:0')

In [15]:
fgate_mat_pt.grad

tensor([[[[ 1.2744e+00,  0.0000e+00,  0.0000e+00],
          [-5.5217e-04,  5.5217e-04,  0.0000e+00],
          [ 6.6943e-01,  3.3163e-01, -1.4477e+00]]]], device='cuda:0')

In [16]:
igate_mat_pt.grad

tensor([[[[ 1.2744e+00,  0.0000e+00,  0.0000e+00],
          [-5.5217e-04,  5.5217e-04,  0.0000e+00],
          [ 6.6943e-01,  3.3163e-01, -1.4477e+00]]]], device='cuda:0')

### Own backward

In [17]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
retr_val_obw = vlstm_fwbw_nogatematrices_torch(qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, eps=EPS)
retr_val_obw.shape

torch.Size([1, 1, 3, 4])

In [19]:
retr_val_obw.sum().backward()

In [20]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000],
          [-0.1905, -0.0257,  0.0898,  0.1187],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0')

In [21]:
ks_obw.grad-ks_pt.grad

tensor([[[[-0.0031, -0.0063, -0.0008,  0.0012],
          [-0.1771, -0.3572, -0.0455,  0.0685],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0')

In [22]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')

In [23]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[ 0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-4.3852e-03, -2.3707e-01,  0.0000e+00],
          [ 0.0000e+00,  2.9802e-08,  0.0000e+00]]]], device='cuda:0')

### Do gradients match? 

In [25]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1905,  0.0257, -0.0898, -0.1187],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0')

In [26]:
atol = 1e-4
rtol = 1e-4
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: False
ks match: False
vs match: True
fgate_mat match: False
igate_mat match: False


### DEBUG

## Forward without input & forget gate

In [9]:
retr_vals = vlstm_fw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals.shape

torch.Size([1, 1, 3, 4])

In [10]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals_fwbw.shape

torch.Size([1, 1, 3, 4])

### Check if it equals the full version:

In [11]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 3, 4])

In [12]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')

In [13]:
retr_vals_fwbw - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')