In [1]:
import torch

## PyTorch vLSTM forward - Tiled Computation

Shows that we can compute the forward pass of the vLSTM in tiled fashion (similar to FlashAttention), 
which is necessary for the fused kernels.

In [2]:
%load_ext autoreload
%autoreload 2
from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel import vlstm_parallel_fw_torch

In [3]:
# params
S = 128 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 32 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 128, 32]), torch.Size([1, 1, 128, 1]))

In [5]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [6]:
rs = vlstm_parallel_fw_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
)
rs, rs.shape

(tensor([[[[ 2.5391, -0.7149, -0.4924,  ..., -0.6278, -0.6574,  2.0770],
           [-1.7476,  0.4291,  0.2903,  ..., -0.9455,  0.1000, -1.3680],
           [ 1.1070, -0.6684, -1.4656,  ..., -0.2738, -0.3802,  0.9032],
           ...,
           [ 1.0797, -0.1193, -0.0337,  ..., -0.2342, -0.5335,  0.0711],
           [-1.0232,  0.7267,  0.0330,  ...,  0.0983, -1.5277, -0.0601],
           [ 1.1762, -0.9224,  0.4168,  ..., -0.3047,  2.4751,  1.3114]]]],
        device='cuda:0'),
 torch.Size([1, 1, 128, 32]))

In [7]:
hs, m, l, n = vlstm_parallel_tiled(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=8,
    bkv_tile_size=8,
)
hs, hs.shape#, m, l

q_tiles: 16, torch.Size([1, 1, 8, 32])
kv_tiles: 16, torch.Size([1, 1, 8, 32])
q_idx: 10, kv_idx: 0, m_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0'), m: tensor([[[[-35.9809],
          [-36.3237],
          [-36.7787],
          [-37.1526],
          [-37.5637],
          [-38.2332],
          [-38.6061],
          [-39.1988]]]], device='cuda:0'), l_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0'), l: tensor([[[[-0.8050],
          [-0.5583],
          [-1.1367],
          [ 1.1849],
          [-1.0606],
          [ 0.9504],
          [-0.2236],
          [-0.5217]]]], device='cuda:0'), n_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0'), n: tensor([[[[4.2296e+15],
         

(tensor([[[[ 2.5391, -0.7149, -0.4924,  ..., -0.6278, -0.6574,  2.0770],
           [-1.7476,  0.4291,  0.2903,  ..., -0.9455,  0.1000, -1.3680],
           [ 1.1070, -0.6684, -1.4656,  ..., -0.2738, -0.3802,  0.9032],
           ...,
           [ 1.0797, -0.1193, -0.0337,  ..., -0.2342, -0.5335,  0.0711],
           [-1.0232,  0.7267,  0.0330,  ...,  0.0983, -1.5277, -0.0601],
           [ 1.1762, -0.9224,  0.4168,  ..., -0.3047,  2.4751,  1.3114]]]],
        device='cuda:0'),
 torch.Size([1, 1, 128, 32]))

In [8]:
hs - rs

tensor([[[[ 7.1526e-07, -1.7881e-07, -1.1921e-07,  ..., -1.1921e-07,
           -1.7881e-07,  4.7684e-07],
          [ 0.0000e+00,  0.0000e+00,  2.9802e-08,  ..., -5.9605e-08,
           -7.4506e-09, -1.1921e-07],
          [ 0.0000e+00,  5.9605e-08,  0.0000e+00,  ...,  2.9802e-08,
            2.9802e-08,  0.0000e+00],
          ...,
          [ 1.1921e-07,  4.4703e-08, -4.8429e-08,  ..., -4.4703e-08,
            5.9605e-08, -1.3411e-07],
          [ 0.0000e+00,  0.0000e+00,  3.3528e-07,  ..., -1.4156e-07,
            1.1921e-07, -1.7509e-07],
          [ 1.1921e-07,  0.0000e+00,  2.9802e-08,  ..., -2.9802e-08,
            4.7684e-07, -1.1921e-07]]]], device='cuda:0')