In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_full import vlstm_fw_torch, vlstm_fwbw_torch

# vLSTM forward backward (FULL version stabilized) implementation

In [2]:
DTYPE = torch.float64 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward Stabilized without input & forget gate

### Torch Autograd

In [6]:
fgate_preacts_pt = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_pt = igate_preacts.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [7]:
igate_preacts_pt, fgate_preacts_pt

(tensor([[[[ 0.4533],
           [-1.9751],
           [ 1.4485],
           [ 0.3025],
           [ 0.4049]]]], device='cuda:0', dtype=torch.float64,
        requires_grad=True),
 tensor([[[[ 0.2243],
           [ 0.4543],
           [-1.0214],
           [ 0.3320],
           [-0.3535]]]], device='cuda:0', dtype=torch.float64,
        requires_grad=True))

In [8]:
retr_val_pt = vlstm_fw_torch(qs_pt, ks_pt, vs_pt, igate_preacts_pt, fgate_preacts_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [9]:
retr_val_pt.sum().backward()

In [10]:
qs_pt.grad

tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1928, -0.0109, -0.1456, -0.1519,  0.1054,  0.0070],
          [-0.9843,  2.2921,  0.7373, -0.9059, -0.5809,  0.2165],
          [-0.0174,  0.0697, -0.0138,  0.0112,  0.0199,  0.0036],
          [ 0.4119,  1.2293,  0.0646,  1.7051, -1.3909,  0.6109]]]],
       device='cuda:0', dtype=torch.float64)

In [11]:
ks_pt.grad

tensor([[[[-0.0760,  0.1124,  0.0060,  0.4345, -0.2662, -0.0788],
          [ 0.2154,  0.0321,  0.1297,  0.0298, -0.1152, -0.1144],
          [-0.6233,  0.5338, -0.1448,  1.8454, -1.1918,  0.0946],
          [-0.0438,  0.0629, -0.0534, -0.1137,  0.1179,  0.2107],
          [-0.3017,  0.4579, -0.3569, -0.5725,  0.8650,  1.2078]]]],
       device='cuda:0', dtype=torch.float64)

In [12]:
vs_pt.grad

tensor([[[[ 2.1425,  2.1425,  2.1425,  2.1425,  2.1425,  2.1425],
          [ 0.0120,  0.0120,  0.0120,  0.0120,  0.0120,  0.0120],
          [-0.0269, -0.0269, -0.0269, -0.0269, -0.0269, -0.0269],
          [-0.4831, -0.4831, -0.4831, -0.4831, -0.4831, -0.4831],
          [-0.8412, -0.8412, -0.8412, -0.8412, -0.8412, -0.8412]]]],
       device='cuda:0', dtype=torch.float64)

In [13]:
fgate_preacts_pt.grad

tensor([[[[0.0000],
          [0.0769],
          [0.1946],
          [0.1068],
          [0.2175]]]], device='cuda:0', dtype=torch.float64)

In [14]:
igate_preacts_pt.grad

tensor([[[[ 0.1980],
          [ 0.0667],
          [ 0.2788],
          [ 0.1147],
          [-1.3870]]]], device='cuda:0', dtype=torch.float64)

### Own backward

In [15]:
fgate_preacts_obw = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_obw = igate_preacts.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [16]:
retr_val_obw = vlstm_fwbw_torch(
    qs_obw, ks_obw, vs_obw, igate_preacts_obw, fgate_preacts_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [17]:
retr_val_obw.sum().backward()

In [18]:
qs_obw.grad-qs_pt.grad

tensor([[[[-3.0528e-16,  1.0023e-16, -9.9024e-17,  2.2021e-17, -4.6745e-18,
            1.1148e-16],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00]]]], device='cuda:0', dtype=torch.float64)

In [19]:
ks_obw.grad - ks_pt.grad

tensor([[[[-2.2204e-16,  5.5511e-17,  1.8475e-16,  0.0000e+00, -1.1102e-16,
            1.5266e-16],
          [ 0.0000e+00,  6.9389e-18, -2.7756e-17,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  1.1102e-16,  0.0000e+00,  4.4409e-16,  0.0000e+00,
           -4.1633e-17],
          [ 0.0000e+00,  1.3878e-17,  0.0000e+00,  1.3878e-17,  0.0000e+00,
            0.0000e+00],
          [ 5.5511e-17,  5.5511e-17,  5.5511e-17,  1.1102e-16,  0.0000e+00,
            0.0000e+00]]]], device='cuda:0', dtype=torch.float64)

In [20]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64)

In [21]:
fgate_preacts_obw.grad-fgate_preacts_pt.grad

tensor([[[[ 0.0000e+00],
          [-2.7756e-17],
          [-2.7756e-17],
          [ 5.5511e-17],
          [-8.3267e-17]]]], device='cuda:0', dtype=torch.float64)

In [22]:
igate_preacts_obw.grad-igate_preacts_pt.grad

tensor([[[[3.8858e-16],
          [0.0000e+00],
          [2.2204e-16],
          [0.0000e+00],
          [2.2204e-16]]]], device='cuda:0', dtype=torch.float64)

### Do gradients match? 

In [23]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 3.0528e-16, -1.0023e-16,  9.9024e-17, -2.2021e-17,  4.6745e-18,
           -1.1148e-16],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00]]]], device='cuda:0', dtype=torch.float64)

In [25]:
atol = 1e-8
rtol = 1e-8
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgate_preacts_pt.grad, fgate_preacts_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igate_preacts_pt.grad, igate_preacts_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True
