In [1]:
%load_ext autoreload
%autoreload 2

import torch 

from vlstm_parallel import vlstm_parallel_fw_torch
from vlstm_recurrent import vlstm_recurrent_sequence_stabilized



# vLSTM forward backward (FULL version stabilized) implementation

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")

In [3]:
B = 3
S = 12
NH = 2
DH = 4
EPS = 0.0

In [4]:
igate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)
fgate_preacts = torch.randn((B, NH, S, 1), dtype=DTYPE, device=DEVICE)

In [5]:
qs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
ks = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vs.shape

torch.Size([3, 2, 12, 4])

In [6]:
y_p = vlstm_parallel_fw_torch(qs, ks, vs, igate_preacts, fgate_preacts, eps=EPS)

In [7]:
y_p

tensor([[[[ 6.3522e-01, -1.6000e-01,  6.0370e-01,  5.1971e-02],
          [ 3.2410e-01,  5.0011e-01,  1.1769e+00,  2.7453e+00],
          [-1.8547e-01, -1.1592e+00, -1.1110e+00, -4.1392e+00],
          [-2.0519e-01, -5.2323e-01, -1.7289e+00, -3.5220e+00],
          [ 9.0273e-01, -1.4253e+00,  4.2384e-01, -1.2577e+00],
          [-1.5017e+00,  3.3967e-01, -8.7799e-02,  1.0062e+00],
          [-1.8978e-01,  1.7306e-01,  4.7722e-01, -8.0357e-01],
          [ 3.5925e-01, -1.3212e-01, -2.4939e-01, -4.2364e-02],
          [-2.4883e-01, -2.9782e-02, -4.6080e-01, -1.1615e+00],
          [-7.6605e-02, -1.9863e-01,  6.9560e-03,  4.8432e-03],
          [-1.2909e-01, -9.6699e-01, -1.5085e+00, -3.6148e-01],
          [ 8.2371e-01, -1.0690e+00, -2.1785e+00, -3.3104e-01]],

         [[ 8.8378e-01, -6.4858e-01, -7.8493e-02,  7.0480e-01],
          [-8.0988e-01,  7.7561e-01,  6.6320e-02, -1.1622e+00],
          [ 1.4728e-02, -1.0073e-01,  1.8832e-02,  2.5906e-01],
          [-4.0753e+00,  1.4054e+00, -

In [8]:
y_r = vlstm_recurrent_sequence_stabilized(qs, ks, vs, igate_preacts, fgate_preacts, normalization_mode="max_abs_sum_C_1", eps=EPS)

In [9]:
y_r

tensor([[[[ 6.3522e-01, -1.6000e-01,  6.0370e-01,  5.1971e-02],
          [ 3.2410e-01,  5.0011e-01,  1.1769e+00,  2.7453e+00],
          [-1.8547e-01, -1.1592e+00, -1.1110e+00, -4.1392e+00],
          [-2.0519e-01, -5.2323e-01, -1.7289e+00, -3.5220e+00],
          [ 9.0273e-01, -1.4253e+00,  4.2384e-01, -1.2577e+00],
          [-1.5017e+00,  3.3967e-01, -8.7799e-02,  1.0062e+00],
          [-1.8978e-01,  1.7306e-01,  4.7722e-01, -8.0357e-01],
          [ 3.5925e-01, -1.3212e-01, -2.4939e-01, -4.2364e-02],
          [-2.4883e-01, -2.9782e-02, -4.6080e-01, -1.1615e+00],
          [-7.6605e-02, -1.9863e-01,  6.9560e-03,  4.8432e-03],
          [-1.2909e-01, -9.6699e-01, -1.5085e+00, -3.6148e-01],
          [ 8.2371e-01, -1.0690e+00, -2.1785e+00, -3.3104e-01]],

         [[ 8.8378e-01, -6.4858e-01, -7.8493e-02,  7.0480e-01],
          [-8.0988e-01,  7.7561e-01,  6.6320e-02, -1.1622e+00],
          [ 1.4728e-02, -1.0073e-01,  1.8832e-02,  2.5906e-01],
          [-4.0753e+00,  1.4054e+00, -

In [10]:
torch.allclose(y_p, y_r, atol=1e-5)

True