In [1]:
%load_ext autoreload
%autoreload 2

import torch 
torch.set_printoptions(linewidth=200, threshold=100000)


from mlstm_parallel import mlstm_torch_autograd
from mlstm_chunkwise._torch_fw_legacy import mlstm_chunkwise_parallel_legacy
from mlstm_chunkwise.torch_fw import mlstm_chunkwise_parallel_fw_looped


# Match vLSTM chunkwise parallel to parallel


In [2]:
# params
S = 12 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 6 # dim per head

DTYPE = torch.float32
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [3]:
torch.manual_seed(0)
matQ = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
matK = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
matV = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
vecF = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)

In [4]:
matH_p = mlstm_torch_autograd(matQ, matK, matV, vecI, vecF, EPS)

### chunkwise legacy version.

In [5]:
matH_cpl = mlstm_chunkwise_parallel_legacy(matQ, matK, matV, vecI.unsqueeze(-1), vecF.unsqueeze(-1), chunk_size=4)
matH_cpl

tensor([[[[-2.5441,  0.7163,  0.4934, -0.1267, -0.1014,  0.4035],
          [-0.9682,  0.2628,  0.1925, -0.0491, -0.0478,  0.1514],
          [ 0.8653, -0.1537,  0.2109, -0.3454,  0.1031, -0.0583],
          [-1.1542,  0.5936,  0.6027,  0.0523, -0.9670,  0.7676],
          [ 0.2532, -0.1195,  0.0188,  0.0912,  0.0605, -0.0186],
          [-0.0608,  0.1141, -0.1181,  0.0315,  0.1147,  0.0054],
          [ 0.7045, -3.0562, -2.1854,  0.2068, -0.0699, -1.0153],
          [-0.2197, -1.3752, -1.6874,  0.2571, -0.3313, -0.6734],
          [ 0.3718, -0.9543, -0.7833,  0.5497,  1.5527, -1.3676],
          [-0.5933,  0.0954, -0.4657,  0.1006, -1.0383,  0.6899],
          [ 0.3316, -0.1104, -0.1473,  0.2832,  0.3755, -0.6177],
          [ 3.8387, -0.8141, -0.0150,  0.6760,  0.5455,  0.1670]]]], device='cuda:0')

In [6]:
matH_p - matH_cpl

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.9802e-08],
          [ 4.1723e-07, -1.1921e-07, -7.4506e-08,  2.2352e-08,  7.4506e-09, -4.4703e-08],
          [ 0.0000e+00, -2.9802e-08, -7.4506e-08,  1.1921e-07, -7.4506e-09, -1.1176e-08],
          [ 1.1921e-07, -5.9605e-08,  0.0000e+00, -3.3528e-08,  5.9605e-08,  0.0000e+00],
          [-1.1921e-07,  6.7055e-08, -2.2352e-08,  2.2352e-08, -4.0978e-08,  2.6077e-08],
          [ 1.6764e-07, -6.0350e-07, -4.0978e-07,  1.0431e-07, -2.2352e-08, -5.9558e-07],
          [-5.9605e-08,  2.3842e-07,  2.3842e-07, -5.9605e-08,  0.0000e+00,  3.5763e-07],
          [-4.4703e-08,  5.9605e-07,  4.7684e-07, -8.9407e-08,  0.0000e+00,  4.1723e-07],
          [ 5.9605e-08, -5.9605e-08, -5.9605e-08,  0.0000e+00,  0.0000e+00, -1.1921e-07],
          [ 5.9605e-08, -1.4901e-08,  2.9802e-08, -1.4901e-08,  2.3842e-07,  0.0000e+00],
          [ 0.0000e+00, -1.2666e-07, -4.4703e-08,  2.9802e-08,  5.9605e-08, -1.1921e-07],
          

In [7]:
(matH_p - matH_cpl).abs().max()

tensor(6.0350e-07, device='cuda:0')

### chunkwise looped version.

In [8]:
matH_cplo = mlstm_chunkwise_parallel_fw_looped(matQ, matK, matV, vecI, vecF, seq_chunk_size=4)

In [9]:
matH_p - matH_cplo

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.9802e-08],
          [ 4.1723e-07, -1.1921e-07, -7.4506e-08,  2.2352e-08,  7.4506e-09, -4.4703e-08],
          [ 0.0000e+00, -2.9802e-08, -7.4506e-08,  1.1921e-07, -7.4506e-09, -1.1176e-08],
          [ 1.1921e-07, -5.9605e-08,  0.0000e+00, -3.3528e-08,  5.9605e-08,  0.0000e+00],
          [-1.1921e-07,  6.7055e-08, -2.2352e-08,  2.2352e-08, -4.0978e-08,  2.6077e-08],
          [ 1.6764e-07, -6.0350e-07, -4.0978e-07,  1.0431e-07, -2.2352e-08, -5.9558e-07],
          [-5.9605e-08,  2.3842e-07,  2.3842e-07, -5.9605e-08,  0.0000e+00,  3.5763e-07],
          [-4.4703e-08,  5.9605e-07,  4.7684e-07, -8.9407e-08,  0.0000e+00,  4.1723e-07],
          [ 5.9605e-08, -5.9605e-08, -5.9605e-08,  0.0000e+00,  0.0000e+00, -1.1921e-07],
          [ 5.9605e-08, -1.4901e-08,  2.9802e-08, -1.4901e-08,  2.3842e-07,  0.0000e+00],
          [ 0.0000e+00, -1.2666e-07, -4.4703e-08,  2.9802e-08,  5.9605e-08, -1.1921e-07],
          

In [10]:
(matH_p - matH_cplo).abs().max()

tensor(6.0350e-07, device='cuda:0')