In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_fw_bw_torch.vlstm_nogatematrices import vlstm_fw_nogatematrices_torch, vlstm_fw_prepare_gate_preacts, vlstm_fwbw_nogatematrices_torch
from vlstm_fw_bw_torch.vlstm_full import vlstm_fw_torch



# vLSTM forward backward implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
fgate_mat, igate_mat = vlstm_fw_prepare_gate_preacts(igate_preacts, fgate_preacts)
igate_mat.shape, fgate_mat.shape

(torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 5, 5]))

In [6]:
fgate_mat

tensor([[[[ 0.0000,    -inf,    -inf,    -inf,    -inf],
          [-0.2479,  0.0000,    -inf,    -inf,    -inf],
          [-0.7124, -0.4645,  0.0000,    -inf,    -inf],
          [-2.1037, -1.8558, -1.3913,  0.0000,    -inf],
          [-4.6376, -4.3897, -3.9252, -2.5339,  0.0000]]]], device='cuda:0')

In [7]:
igate_mat

tensor([[[[1.2287,   -inf,   -inf,   -inf,   -inf],
          [1.2287, 1.9802,   -inf,   -inf,   -inf],
          [1.2287, 1.9802, 2.1222,   -inf,   -inf],
          [1.2287, 1.9802, 2.1222, 0.0921,   -inf],
          [1.2287, 1.9802, 2.1222, 0.0921, 0.9398]]]], device='cuda:0')

In [8]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 5, 6])

## Backward without input & forget gate

### Torch Autograd

In [9]:
fgate_mat_pt = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_pt = igate_mat.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [27]:
igate_mat_pt, fgate_mat_pt

(tensor([[[[-0.9360,    -inf,    -inf],
           [-0.9360, -0.2426,    -inf],
           [-0.9360, -0.2426, -1.8288]]]], device='cuda:0', requires_grad=True),
 tensor([[[[ 0.0000,    -inf,    -inf],
           [-1.5091,  0.0000,    -inf],
           [-2.3670, -0.8579,  0.0000]]]], device='cuda:0', requires_grad=True))

In [10]:
retr_val_pt = vlstm_fw_nogatematrices_torch(qs_pt, ks_pt, vs_pt, igate_mat_pt, fgate_mat_pt, eps=EPS)
retr_val_pt.shape

torch.Size([1, 1, 5, 6])

In [11]:
retr_val_pt.sum().backward()

In [12]:
qs_pt.grad

tensor([[[[-4.4769e-01,  2.0942e+00,  7.2811e-01,  2.6674e+00,  8.4051e-02,
           -2.7446e-01],
          [-1.3069e-01,  2.2627e-01,  2.6196e-01,  2.2561e-01, -1.0939e-01,
           -2.9525e-02],
          [-2.5098e-01,  5.7592e-01, -4.2947e-01, -1.4708e+00, -5.0830e-01,
           -4.9990e-01],
          [-3.0950e-02, -1.5508e-01,  5.2932e-01,  8.5641e-01, -6.5722e-01,
            1.5229e+00],
          [-1.3032e-02,  5.1077e-02, -1.8165e-05, -3.4160e-02,  1.0160e-03,
            2.0104e-02]]]], device='cuda:0')

In [13]:
ks_pt.grad

tensor([[[[ 3.4601, -4.2889,  7.8884, -0.3676,  3.9176,  2.0089],
          [ 0.2752, -1.1803, -0.5051, -0.6417,  0.1513, -0.2466],
          [ 0.3395, -1.8212, -1.4104, -0.6446, -0.3831,  0.1261],
          [ 0.5125, -1.5310, -1.5425, -0.5138, -1.8191, -0.0684],
          [-0.0185,  0.0261, -0.0306,  0.0285,  0.0112, -0.0305]]]],
       device='cuda:0')

In [14]:
vs_pt.grad

tensor([[[[-1.3292, -1.3292, -1.3292, -1.3292, -1.3292, -1.3292],
          [ 0.4179,  0.4179,  0.4179,  0.4179,  0.4179,  0.4179],
          [ 1.0969,  1.0969,  1.0969,  1.0969,  1.0969,  1.0969],
          [ 0.3466,  0.3466,  0.3466,  0.3466,  0.3466,  0.3466],
          [ 0.9754,  0.9754,  0.9754,  0.9754,  0.9754,  0.9754]]]],
       device='cuda:0')

In [15]:
fgate_mat_pt.grad

tensor([[[[ 1.1273,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1510, -0.1510,  0.0000,  0.0000,  0.0000],
          [ 0.5477,  0.6955, -1.2433,  0.0000,  0.0000],
          [ 0.0618,  0.1297, -1.2120,  1.0205,  0.0000],
          [-0.0123,  0.0151, -0.0017,  0.0481, -0.0492]]]], device='cuda:0')

In [16]:
igate_mat_pt.grad

tensor([[[[ 1.1273,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1510, -0.1510,  0.0000,  0.0000,  0.0000],
          [ 0.5477,  0.6955, -1.2433,  0.0000,  0.0000],
          [ 0.0618,  0.1297, -1.2120,  1.0205,  0.0000],
          [-0.0123,  0.0151, -0.0017,  0.0481, -0.0492]]]], device='cuda:0')

### Own backward

In [17]:
fgate_mat_obw = fgate_mat.clone().detach().requires_grad_(True)
igate_mat_obw = igate_mat.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
retr_val_obw = vlstm_fwbw_nogatematrices_torch(
    qs_obw, ks_obw, vs_obw, igate_mat_obw, fgate_mat_obw, eps=EPS
)
retr_val_obw.shape

torch.Size([1, 1, 5, 6])

In [19]:
retr_val_obw.sum().backward()

In [20]:
qs_obw.grad-qs_pt.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 8.6222e+00, -8.1084e+00, -1.8159e+01, -5.0896e+00,  9.5891e+00,
            1.0518e+00],
          [ 3.1541e+00, -7.8381e+00,  8.4857e+00,  2.4560e+01,  7.3230e+00,
            7.7596e+00],
          [-3.3665e-01,  1.0360e+00, -1.4654e+00, -3.6361e+00, -3.3282e-01,
           -2.1435e+00],
          [ 2.6664e-03,  1.3476e-02,  1.6184e-02,  4.0560e-02,  4.7088e-02,
           -2.8583e-03]]]], device='cuda:0')

In [21]:
ks_obw.grad-ks_pt.grad

tensor([[[[-3.8407e+00,  6.8097e+00, -8.1687e+00,  1.0264e+01, -8.6000e+00,
            1.1786e+01],
          [-9.0823e+00,  1.6193e+01, -1.9000e+01,  2.4178e+01, -2.0007e+01,
            2.7693e+01],
          [-3.2797e+00,  2.1980e+01,  1.4203e+01,  8.3262e+00, -2.5502e+00,
           -2.9804e+00],
          [-4.1469e-01,  1.2331e+00,  1.2312e+00,  4.1770e-01,  1.4612e+00,
            4.8931e-02],
          [-1.3357e-02,  1.8911e-02, -2.2170e-02,  2.0651e-02,  8.1031e-03,
           -2.2044e-02]]]], device='cuda:0')

In [22]:
vs_obw.grad-vs_pt.grad

tensor([[[[0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0.]]]], device='cuda:0')

In [23]:
fgate_mat_obw.grad-fgate_mat_pt.grad

tensor([[[[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-4.2498e+00,  1.2641e+01,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-5.6704e+00, -9.0899e+00,  1.9719e+01, -0.0000e+00, -0.0000e+00],
          [-5.6507e-01, -6.1461e-01,  4.5689e+00, -8.1855e-01, -0.0000e+00],
          [ 1.0502e-02, -1.2583e-02,  1.3721e-03, -4.6752e-02, -3.5617e-02]]]],
       device='cuda:0')

### Do gradients match? 

In [24]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-8.6222e+00,  8.1084e+00,  1.8159e+01,  5.0896e+00, -9.5891e+00,
           -1.0518e+00],
          [-3.1541e+00,  7.8381e+00, -8.4857e+00, -2.4560e+01, -7.3230e+00,
           -7.7596e+00],
          [ 3.3665e-01, -1.0360e+00,  1.4654e+00,  3.6361e+00,  3.3282e-01,
            2.1435e+00],
          [-2.6664e-03, -1.3476e-02, -1.6184e-02, -4.0560e-02, -4.7088e-02,
            2.8583e-03]]]], device='cuda:0')

In [25]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(retr_val_pt, retr_val_obw)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_mat match: {torch.allclose(fgate_mat_pt.grad, fgate_mat_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_mat match: {torch.allclose(igate_mat_pt.grad, igate_mat_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: False
ks match: False
vs match: True
fgate_mat match: False
igate_mat match: False


### DEBUG

## Forward without input & forget gate

In [9]:
retr_vals = vlstm_fw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals.shape

torch.Size([1, 1, 3, 4])

In [10]:
retr_vals_fwbw = vlstm_fwbw_nogatematrices_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_mat, fgate_preact=fgate_mat)
retr_vals_fwbw.shape

torch.Size([1, 1, 3, 4])

### Check if it equals the full version:

In [11]:
# check if equals the full version
retr_vals_full = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igate_preacts, fgate_preact=fgate_preacts)
retr_vals_full.shape

torch.Size([1, 1, 3, 4])

In [12]:
# The implementations match!!!
retr_vals - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')

In [13]:
retr_vals_fwbw - retr_vals_full

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')