In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_full import vlstm_fw_torch, vlstm_fwbw_torch



# vLSTM forward backward (FULL version stabilized) implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward Stabilized without input & forget gate

### Torch Autograd

In [6]:
fgate_preacts_pt = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_pt = igate_preacts.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [7]:
igate_preacts_pt, fgate_preacts_pt

(tensor([[[[ 1.7118],
           [-1.1365],
           [-0.4029],
           [-0.8879],
           [ 0.3948]]]], device='cuda:0', requires_grad=True),
 tensor([[[[ 1.2154],
           [ 0.9013],
           [-0.5247],
           [-1.4010],
           [ 0.8880]]]], device='cuda:0', requires_grad=True))

In [8]:
retr_val_pt = vlstm_fw_torch(qs_pt, ks_pt, vs_pt, igate_preacts_pt, fgate_preacts_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [9]:
retr_val_pt.sum().backward()

In [10]:
qs_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1298,  0.0637,  0.2674,  0.1264,  0.3614,  0.0627],
          [ 0.1513, -1.3366, -0.0301,  0.4553,  0.9171,  0.1421],
          [-0.1384,  0.4530, -0.0536,  0.0732,  0.5131,  0.4276],
          [-0.0937, -0.0493,  0.1340,  0.2466,  0.0728,  0.0415]]]],
       device='cuda:0')

In [11]:
ks_pt.grad

tensor([[[[-0.0134,  2.3862,  5.4911,  1.7537,  1.6089,  3.4426],
          [-0.5108,  0.5511,  0.6865,  0.5661,  0.5641,  0.7388],
          [-0.6127,  1.5099,  3.6775,  1.1587,  1.3098,  2.6819],
          [ 0.2164,  0.2657,  0.2599,  0.0459, -0.1289,  0.0800],
          [-0.0444,  0.0211,  0.0079, -0.0196, -0.0178,  0.0470]]]],
       device='cuda:0')

In [12]:
vs_pt.grad

tensor([[[[-4.3285, -4.3285, -4.3285, -4.3285, -4.3285, -4.3285],
          [ 0.1631,  0.1631,  0.1631,  0.1631,  0.1631,  0.1631],
          [ 1.2019,  1.2019,  1.2019,  1.2019,  1.2019,  1.2019],
          [-0.3575, -0.3575, -0.3575, -0.3575, -0.3575, -0.3575],
          [-0.8331, -0.8331, -0.8331, -0.8331, -0.8331, -0.8331]]]],
       device='cuda:0')

In [13]:
fgate_preacts_pt.grad

tensor([[[[ 2.1815e-07],
          [ 2.8293e+00],
          [ 4.8198e+00],
          [ 1.0089e-01],
          [-1.9572e-02]]]], device='cuda:0')

In [14]:
igate_preacts_pt.grad

tensor([[[[ 9.7972],
          [-2.1254],
          [-7.5460],
          [-0.1710],
          [ 0.0671]]]], device='cuda:0')

### Own backward

In [15]:
fgate_preacts_obw = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_obw = igate_preacts.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [16]:
retr_val_obw = vlstm_fwbw_torch(
    qs_obw, ks_obw, vs_obw, igate_preacts_obw, fgate_preacts_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [17]:
retr_val_obw.sum().backward()

In [18]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 1.1921e-07, -2.3842e-07, -1.0431e-07, -2.3842e-07, -5.9605e-08,
           -2.3842e-07],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-5.9605e-08, -8.5682e-08,  2.9802e-08,  1.0431e-07, -7.4506e-09,
           -6.7055e-08]]]], device='cuda:0')

In [19]:
ks_obw.grad-ks_pt.grad

tensor([[[[-1.8161e-07,  4.7684e-07,  4.7684e-07,  3.5763e-07,  4.7684e-07,
            7.1526e-07],
          [ 0.0000e+00,  5.9605e-08,  2.3842e-07,  0.0000e+00,  5.9605e-08,
            5.9605e-08],
          [-5.9605e-08,  2.3842e-07,  9.5367e-07,  0.0000e+00,  2.3842e-07,
            7.1526e-07],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1176e-08, -1.4901e-08,
            2.9802e-08],
          [-7.8231e-08,  3.9116e-08,  1.4901e-08, -3.7253e-08, -3.1665e-08,
            8.9407e-08]]]], device='cuda:0')

In [20]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [21]:
fgate_preacts_obw.grad-fgate_preacts_pt.grad

tensor([[[[-2.1815e-07],
          [ 0.0000e+00],
          [-4.7684e-07],
          [ 7.4506e-09],
          [ 1.3039e-08]]]], device='cuda:0')

In [22]:
igate_preacts_obw.grad-igate_preacts_pt.grad

tensor([[[[ 9.5367e-07],
          [-2.3842e-07],
          [-9.5367e-07],
          [-1.4901e-08],
          [-6.7055e-08]]]], device='cuda:0')

### Do gradients match? 

In [23]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-1.1921e-07,  2.3842e-07,  1.0431e-07,  2.3842e-07,  5.9605e-08,
            2.3842e-07],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 5.9605e-08,  8.5682e-08, -2.9802e-08, -1.0431e-07,  7.4506e-09,
            6.7055e-08]]]], device='cuda:0')

In [24]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgate_preacts_pt.grad, fgate_preacts_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igate_preacts_pt.grad, igate_preacts_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True
