In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_parallel import vlstm_parallel_fw_torch
from vlstm_recurrent import vlstm_recurrent_sequence_stabilized



# vLSTM forward backward (FULL version stabilized) implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 1
S = 12
NH = 1
DH = 4
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([1, 1, 12, 4])

In [6]:
y_p = vlstm_parallel_fw_torch(qs, ks, vs, igate_preacts, fgate_preacts, eps=EPS)

In [7]:
y_p

tensor([[[[ 0.0401, -0.6524,  0.0756, -0.1705],
          [-0.0835, -0.5428, -0.1295, -0.1220],
          [ 0.0119,  0.2431,  0.1620,  0.0266],
          [ 4.1489, -0.8343, -1.1976, -0.1249],
          [ 1.0242,  0.5282, -0.4067, -0.4703],
          [-0.1481,  0.1326, -0.2982, -0.3432],
          [-0.9226,  0.4848,  0.3874, -0.2165],
          [ 0.1472,  0.0197, -0.5274, -0.2781],
          [ 0.6186,  0.0995, -0.5075, -0.1090],
          [ 0.0781,  0.1545, -0.0641,  0.4224],
          [ 0.0790,  0.1072,  0.2291, -0.3317],
          [-0.4620, -0.2302, -1.1630, -1.4933]]]], device='cuda:0')

In [8]:
y_r = vlstm_recurrent_sequence_stabilized(qs, ks, vs, igate_preacts, fgate_preacts, normalization_mode="max_abs_sum_C_1", eps=EPS)

In [9]:
y_r

tensor([[[[ 0.0401, -0.6524,  0.0756, -0.1705],
          [-0.0835, -0.5428, -0.1295, -0.1220],
          [ 0.0119,  0.2431,  0.1620,  0.0266],
          [ 4.1489, -0.8343, -1.1976, -0.1249],
          [ 1.0242,  0.5282, -0.4067, -0.4703],
          [-0.1481,  0.1326, -0.2982, -0.3432],
          [-0.9226,  0.4848,  0.3874, -0.2165],
          [ 0.1472,  0.0197, -0.5274, -0.2781],
          [ 0.6186,  0.0995, -0.5075, -0.1090],
          [ 0.0781,  0.1545, -0.0641,  0.4224],
          [ 0.0790,  0.1072,  0.2291, -0.3317],
          [-0.4620, -0.2302, -1.1630, -1.4933]]]], device='cuda:0')

In [10]:
torch.allclose(y_p, y_r, atol=1e-5)

True