In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_full import vlstm_fw_torch, vlstm_fwbw_torch



# vLSTM forward backward (FULL version stabilized) implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward Stabilized without input & forget gate

### Torch Autograd

In [6]:
fgate_preacts_pt = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_pt = igate_preacts.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [7]:
igate_preacts_pt, fgate_preacts_pt

(tensor([[[[-0.8153],
           [-0.8300],
           [-0.4147],
           [ 0.2378],
           [ 1.3356]]]], device='cuda:0', requires_grad=True),
 tensor([[[[ 1.5931],
           [-0.8049],
           [-1.0801],
           [-1.8037],
           [-1.1816]]]], device='cuda:0', requires_grad=True))

In [8]:
retr_val_pt = vlstm_fw_torch(qs_pt, ks_pt, vs_pt, igate_preacts_pt, fgate_preacts_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [9]:
retr_val_pt.sum().backward()

In [10]:
qs_pt.grad

tensor([[[[-7.4937e-03,  9.9400e-02,  4.8439e-02, -1.8549e-01, -2.4355e-01,
            1.3618e-01],
          [ 1.5646e-02,  4.5657e-02,  5.3954e-02, -6.1600e-02, -8.4806e-02,
            1.0692e-02],
          [-3.4843e-01, -3.0921e-04, -2.2357e-01,  2.5050e-01, -1.1053e-01,
            1.6818e-01],
          [ 7.0005e-03,  3.7986e-03,  1.5666e-02, -8.3835e-03,  6.8498e-04,
           -1.4984e-02],
          [-9.7439e-04,  3.1440e-04,  1.4727e-03,  4.9805e-05,  4.8324e-04,
           -1.9327e-03]]]], device='cuda:0')

In [11]:
ks_pt.grad

tensor([[[[-3.2740e-01,  2.6350e-01, -4.8446e-02,  6.9152e-02, -1.2329e-01,
           -1.0815e-01],
          [-2.4688e-02, -4.6997e-03,  1.4115e-02,  8.7711e-02, -3.2873e-02,
           -2.3681e-02],
          [ 1.0163e-01,  3.6865e-02, -1.3464e-01, -1.6148e-01, -3.7631e-01,
           -7.8293e-02],
          [ 3.1089e-03,  2.4774e-03, -1.8500e-04, -3.3440e-03,  1.9840e-03,
            6.8678e-03],
          [ 2.4160e-04,  8.4634e-05,  1.5980e-04, -2.8352e-04,  1.5039e-04,
            4.4040e-05]]]], device='cuda:0')

In [12]:
vs_pt.grad

tensor([[[[ 0.2637,  0.2637,  0.2637,  0.2637,  0.2637,  0.2637],
          [ 0.1878,  0.1878,  0.1878,  0.1878,  0.1878,  0.1878],
          [-0.2005, -0.2005, -0.2005, -0.2005, -0.2005, -0.2005],
          [-0.8794, -0.8794, -0.8794, -0.8794, -0.8794, -0.8794],
          [ 0.9068,  0.9068,  0.9068,  0.9068,  0.9068,  0.9068]]]],
       device='cuda:0')

In [13]:
fgate_preacts_pt.grad

tensor([[[[ 2.5175e-09],
          [-8.7534e-02],
          [-3.3391e-02],
          [-1.5502e-02],
          [-5.1984e-04]]]], device='cuda:0')

In [14]:
igate_preacts_pt.grad

tensor([[[[0.2038],
          [0.0288],
          [0.1226],
          [0.0174],
          [0.0007]]]], device='cuda:0')

### Own backward

In [15]:
fgate_preacts_obw = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_obw = igate_preacts.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [16]:
retr_val_obw = vlstm_fwbw_torch(
    qs_obw, ks_obw, vs_obw, igate_preacts_obw, fgate_preacts_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [17]:
retr_val_obw.sum().backward()

In [18]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-5.2620e-08, -4.6566e-10, -1.8626e-09,  2.7940e-08,  7.4506e-09,
           -2.7008e-08],
          [-6.9849e-09, -1.5716e-09, -3.6904e-08,  1.5621e-08, -3.4343e-08,
           -2.0373e-08]]]], device='cuda:0')

In [19]:
ks_obw.grad-ks_pt.grad

tensor([[[[ 2.9802e-08,  0.0000e+00,  0.0000e+00,  7.4506e-09,  1.4901e-08,
            0.0000e+00],
          [ 0.0000e+00,  4.6566e-10,  9.3132e-10,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 7.4506e-09,  3.7253e-09, -1.4901e-08,  0.0000e+00, -2.9802e-08,
            0.0000e+00],
          [ 2.2352e-08,  1.5134e-08,  3.0414e-09, -2.5146e-08,  1.3970e-08,
            3.7253e-08],
          [-3.7733e-08, -1.3220e-08, -2.4957e-08,  4.4267e-08, -2.3487e-08,
           -6.8794e-09]]]], device='cuda:0')

In [20]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [21]:
fgate_preacts_obw.grad-fgate_preacts_pt.grad

tensor([[[[-2.5175e-09],
          [-1.4901e-08],
          [ 0.0000e+00],
          [ 3.7253e-09],
          [-8.2073e-09]]]], device='cuda:0')

In [22]:
igate_preacts_obw.grad-igate_preacts_pt.grad

tensor([[[[-2.9802e-08],
          [-7.4506e-09],
          [ 7.4506e-09],
          [ 3.1665e-08],
          [-2.2870e-07]]]], device='cuda:0')

### Do gradients match? 

In [23]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 5.2620e-08,  4.6566e-10,  1.8626e-09, -2.7940e-08, -7.4506e-09,
            2.7008e-08],
          [ 6.9849e-09,  1.5716e-09,  3.6904e-08, -1.5621e-08,  3.4343e-08,
            2.0373e-08]]]], device='cuda:0')

In [24]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgate_preacts_pt.grad, fgate_preacts_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igate_preacts_pt.grad, igate_preacts_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True
