In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_Dmat import vlstm_fw_Dtildemat, vlstm_fwbw_Dtildemat

# vLSTM Dmat forward backward implementation

In this notebook we implement the forward and backward pass of the D (decay matrix) construction.

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 2
S = 5
NH = 4
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
# fgate_preacts = 0.1*torch.ones((B, NH, S, 1), dtype=DTYPE, device=DEVICE)


## Forward / Backward impl

### Torch autograd

In [5]:
igate_preacts_pt = igate_preacts.clone().detach().requires_grad_(True)
fgate_preacts_pt = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_pt, fgate_preacts_pt

(tensor([[[[ 1.6217],
           [-1.1963],
           [ 1.2651],
           [-1.0511],
           [ 0.8248]],
 
          [[ 1.7687],
           [-0.2273],
           [-0.3192],
           [-0.1224],
           [ 0.5716]],
 
          [[ 1.7480],
           [ 1.5525],
           [-0.0406],
           [ 1.0113],
           [ 1.5455]],
 
          [[ 2.1709],
           [-0.0638],
           [ 0.3106],
           [-0.2125],
           [ 0.0994]]],
 
 
         [[[ 0.7163],
           [-0.5652],
           [ 0.1756],
           [-0.3835],
           [ 0.2200]],
 
          [[-0.3686],
           [-1.3597],
           [ 0.2573],
           [ 0.0560],
           [ 0.3304]],
 
          [[-0.1345],
           [ 0.9675],
           [ 0.8923],
           [ 0.5020],
           [ 0.1634]],
 
          [[-0.3290],
           [ 0.5576],
           [ 0.7773],
           [-1.5742],
           [ 0.2267]]]], device='cuda:0', requires_grad=True),
 tensor([[[[ 0.0124],
           [-1.0455],
           

In [6]:
dmat_pt = vlstm_fw_Dtildemat(igate_preacts_pt, fgate_preacts_pt)
dmat_pt, dmat_pt.shape

(tensor([[[[ 1.6217,    -inf,    -inf,    -inf,    -inf],
           [ 0.2749, -1.1963,    -inf,    -inf,    -inf],
           [-0.4557, -1.9268,  1.2651,    -inf,    -inf],
           [-1.1081, -2.5793,  0.6127, -1.0511,    -inf],
           [-2.1250, -3.5962, -0.4043, -2.0680,  0.8248]],
 
          [[ 1.7687,    -inf,    -inf,    -inf,    -inf],
           [ 0.8717, -0.2273,    -inf,    -inf,    -inf],
           [ 0.4286, -0.6704, -0.3192,    -inf,    -inf],
           [-0.4977, -1.5967, -1.2455, -0.1224,    -inf],
           [-0.6990, -1.7980, -1.4468, -0.3237,  0.5716]],
 
          [[ 1.7480,    -inf,    -inf,    -inf,    -inf],
           [ 1.4963,  1.5525,    -inf,    -inf,    -inf],
           [ 1.0670,  1.1231, -0.0406,    -inf,    -inf],
           [ 0.6663,  0.7224, -0.4413,  1.0113,    -inf],
           [ 0.4371,  0.4932, -0.6705,  0.7821,  1.5455]],
 
          [[ 2.1709,    -inf,    -inf,    -inf,    -inf],
           [ 2.0383, -0.0638,    -inf,    -inf,    -inf],
     

In [7]:
dmat_pt.sum()

tensor(-inf, device='cuda:0', grad_fn=<SumBackward0>)

In [8]:
dmat_pt.sum().backward()

In [9]:
igate_preacts_pt.grad, fgate_preacts_pt.grad

(tensor([[[[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]]],
 
 
         [[[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]],
 
          [[5.],
           [5.],
           [5.],
           [5.],
           [5.]]]], device='cuda:0'),
 tensor([[[[0.0000],
           [2.9597],
           [3.1102],
           [2.8753],
           [2.5532]],
 
          [[0.0000],
           [2.3688],
           [2.1476],
           [3.6239],
           [0.7294]],
 
          [[0.0000],
           [0.8900],
           [2.094

### Own backward

In [10]:
igate_preacts_obw = igate_preacts.clone().detach().requires_grad_(True)
fgate_preacts_obw = fgate_preacts.clone().detach().requires_grad_(True)
igate_preacts_obw, igate_preacts_obw.shape, fgate_preacts_obw, fgate_preacts_obw.shape

(tensor([[[[ 1.6217],
           [-1.1963],
           [ 1.2651],
           [-1.0511],
           [ 0.8248]],
 
          [[ 1.7687],
           [-0.2273],
           [-0.3192],
           [-0.1224],
           [ 0.5716]],
 
          [[ 1.7480],
           [ 1.5525],
           [-0.0406],
           [ 1.0113],
           [ 1.5455]],
 
          [[ 2.1709],
           [-0.0638],
           [ 0.3106],
           [-0.2125],
           [ 0.0994]]],
 
 
         [[[ 0.7163],
           [-0.5652],
           [ 0.1756],
           [-0.3835],
           [ 0.2200]],
 
          [[-0.3686],
           [-1.3597],
           [ 0.2573],
           [ 0.0560],
           [ 0.3304]],
 
          [[-0.1345],
           [ 0.9675],
           [ 0.8923],
           [ 0.5020],
           [ 0.1634]],
 
          [[-0.3290],
           [ 0.5576],
           [ 0.7773],
           [-1.5742],
           [ 0.2267]]]], device='cuda:0', requires_grad=True),
 torch.Size([2, 4, 5, 1]),
 tensor([[[[ 0.0124],
      

In [11]:
dmat_obw = vlstm_fwbw_Dtildemat(igate_preacts_obw, fgate_preacts_obw)
dmat_obw, dmat_obw.shape

(tensor([[[[ 1.6217,    -inf,    -inf,    -inf,    -inf],
           [ 0.2749, -1.1963,    -inf,    -inf,    -inf],
           [-0.4557, -1.9268,  1.2651,    -inf,    -inf],
           [-1.1081, -2.5793,  0.6127, -1.0511,    -inf],
           [-2.1250, -3.5962, -0.4043, -2.0680,  0.8248]],
 
          [[ 1.7687,    -inf,    -inf,    -inf,    -inf],
           [ 0.8717, -0.2273,    -inf,    -inf,    -inf],
           [ 0.4286, -0.6704, -0.3192,    -inf,    -inf],
           [-0.4977, -1.5967, -1.2455, -0.1224,    -inf],
           [-0.6990, -1.7980, -1.4468, -0.3237,  0.5716]],
 
          [[ 1.7480,    -inf,    -inf,    -inf,    -inf],
           [ 1.4963,  1.5525,    -inf,    -inf,    -inf],
           [ 1.0670,  1.1231, -0.0406,    -inf,    -inf],
           [ 0.6663,  0.7224, -0.4413,  1.0113,    -inf],
           [ 0.4371,  0.4932, -0.6705,  0.7821,  1.5455]],
 
          [[ 2.1709,    -inf,    -inf,    -inf,    -inf],
           [ 2.0383, -0.0638,    -inf,    -inf,    -inf],
     

In [12]:
dmat_obw - dmat_pt

tensor([[[[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]],

         [[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]],

         [[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]],

         [[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]]],


        [[[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]],

         [[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],


In [13]:
dmat_obw.sum().backward()

In [14]:
# igate_preacts_obw.grad, fgate_preacts_obw.grad

In [15]:
atol = 1e-6
rtol = 1e-6
print(f"Forward match: {torch.allclose(dmat_obw, dmat_pt, atol=atol, rtol=rtol)}")
print(f"igate preact match: {torch.allclose(igate_preacts_obw.grad, igate_preacts_pt.grad, atol=atol, rtol=rtol)}")
print(f"fgate pract match: {torch.allclose(fgate_preacts_obw.grad, fgate_preacts_pt.grad, atol=atol, rtol=rtol)}")

Forward match: True
igate preact match: True
fgate pract match: True


In [16]:
igate_preacts_obw.grad - igate_preacts_pt.grad, fgate_preacts_obw.grad - fgate_preacts_pt.grad

(tensor([[[[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]]],
 
 
         [[[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           [0.],
           [0.]]]], device='cuda:0'),
 tensor([[[[ 0.0000e+00],
           [ 0.0000e+00],
           [ 0.0000e+00],
           [-2.3842e-07],
           [ 2.3842e-07]],
 
          [[ 0.0000e+00],
           [-2.3842e-07],
           [ 0.0000e+00],
           [ 0.0000e+00],
           [ 0.0000e+00]],
 
         

### Playground

In [17]:
masked_grad_dtilde_mat = torch.tril(torch.ones((S, S), dtype=DTYPE, device=DEVICE), diagonal=-1).unsqueeze(0).unsqueeze(0)

In [18]:
delta_fbar = torch.zeros((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
print(masked_grad_dtilde_mat, masked_grad_dtilde_mat.shape)

# first forget gate index (k=0) does not get a gradient (since it is not used in the forward pass)
for k in range(1, S):
    for j in range(k):
        delta_fbar[:, :, k, 0] += masked_grad_dtilde_mat[:, :, :, j].view(B, NH, -1).sum()
delta_fbar

tensor([[[[0., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.]]]], device='cuda:0') torch.Size([1, 1, 5, 5])


RuntimeError: shape '[2, 4, -1]' is invalid for input of size 5

In [None]:
masked_grad_dtilde_mat[:, :, :, j].view(B, NH, -1).sum()

tensor(1., device='cuda:0')

In [None]:
dmat_cs = masked_grad_dtilde_mat.cumsum(dim=-1)
dmat_cs

tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  1.,   1.,   1.,  ...,   1.,   1.,   1.],
          [  1.,   2.,   2.,  ...,   2.,   2.,   2.],
          ...,
          [  1.,   2.,   3.,  ..., 125., 125., 125.],
          [  1.,   2.,   3.,  ..., 126., 126., 126.],
          [  1.,   2.,   3.,  ..., 126., 127., 127.]]]], device='cuda:0')

In [None]:
res = dmat_cs.cumsum(dim=-2)
res, res.shape

(tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.0000e+00,
            1.0000e+00, 1.0000e+00],
           [2.0000e+00, 3.0000e+00, 3.0000e+00,  ..., 3.0000e+00,
            3.0000e+00, 3.0000e+00],
           ...,
           [1.2500e+02, 2.4900e+02, 3.7200e+02,  ..., 7.8750e+03,
            7.8750e+03, 7.8750e+03],
           [1.2600e+02, 2.5100e+02, 3.7500e+02,  ..., 8.0010e+03,
            8.0010e+03, 8.0010e+03],
           [1.2700e+02, 2.5300e+02, 3.7800e+02,  ..., 8.1270e+03,
            8.1280e+03, 8.1280e+03]]]], device='cuda:0'),
 torch.Size([1, 1, 128, 128]))

In [None]:
res[:, :, -1, :-1].shape

torch.Size([1, 1, 127])