In [10]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_Dmat import vlstm_fw_Dtildemat, vlstm_fwbw_Dtildemat

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# vLSTM Dmat forward backward implementation

In this notebook we implement the forward and backward pass of the D (decay matrix) construction.

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 5
NH = 1
DH = 6
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

## Forward / Backward impl

### Torch autograd

In [5]:
igate_preacts_pt = igate_preacts.clone().detach().requires_grad_(True)
fgate_preacts_pt = fgate_preacts.clone().detach().requires_grad_(True)

In [6]:
dmat_pt = vlstm_fw_Dtildemat(igate_preacts_pt, fgate_preacts_pt)
dmat_pt, dmat_pt.shape

(tensor([[[[-0.5286,    -inf,    -inf,    -inf,    -inf],
           [-0.5716,  0.2702,    -inf,    -inf,    -inf],
           [-0.8562, -0.0143,  0.2460,    -inf,    -inf],
           [-1.7083, -0.8665, -0.6061, -0.8402,    -inf],
           [-2.0340, -1.1922, -0.9319, -1.1659,  1.7142]]]], device='cuda:0',
        grad_fn=<AddBackward0>),
 torch.Size([1, 1, 5, 5]))

In [7]:
dmat_pt.sum().backward()

In [8]:
igate_preacts_pt.grad, fgate_preacts_pt.grad

(tensor([[[[5.],
           [5.],
           [5.],
           [5.],
           [5.]]]], device='cuda:0'),
 tensor([[[[0.0000],
           [0.1684],
           [1.4858],
           [3.4411],
           [1.1120]]]], device='cuda:0'))

### Own backward

In [12]:
igate_preacts_obw = igate_preacts.clone().detach().requires_grad_(True)
fgate_preacts_obw = fgate_preacts.clone().detach().requires_grad_(True)

In [13]:
dmat_obw = vlstm_fwbw_Dtildemat(igate_preacts_obw, fgate_preacts_obw)
dmat_obw, dmat_obw.shape

(tensor([[[[-0.5286,    -inf,    -inf,    -inf,    -inf],
           [-0.5716,  0.2702,    -inf,    -inf,    -inf],
           [-0.8562, -0.0143,  0.2460,    -inf,    -inf],
           [-1.7083, -0.8665, -0.6061, -0.8402,    -inf],
           [-2.0340, -1.1922, -0.9319, -1.1659,  1.7142]]]], device='cuda:0',
        grad_fn=<vLSTMFwBwDtildematBackward>),
 torch.Size([1, 1, 5, 5]))

In [14]:
dmat_obw - dmat_pt

tensor([[[[0., nan, nan, nan, nan],
          [0., 0., nan, nan, nan],
          [0., 0., 0., nan, nan],
          [0., 0., 0., 0., nan],
          [0., 0., 0., 0., 0.]]]], device='cuda:0', grad_fn=<SubBackward0>)