In [1]:
import torch

## PyTorch vLSTM forward - Tiled Computation

Shows that we can compute the forward pass of the vLSTM in tiled fashion (similar to FlashAttention), 
which is necessary for the fused kernels.

In [2]:
%load_ext autoreload
%autoreload 2
from vlstm_full import vlstm_fw_torch
from vlstm_fw_tiled import vlstm_fw_tiled_torch

In [3]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8, 1]))

In [5]:
rs = vlstm_fw_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    stabilize_rowwise=True,
)
rs, rs.shape

(tensor([[[[-2.5441,  0.7163,  0.4934, -0.1267],
           [-0.0950,  0.1215, -0.2144, -0.2080],
           [-1.1865,  0.5090, -0.2582, -0.4349],
           [ 0.6274, -0.0040, -1.7983,  1.1217],
           [-0.2278,  0.4122, -2.1547,  1.6022],
           [ 0.8848, -0.3820,  0.6012, -0.5168],
           [ 1.9075, -1.2873, -0.1571,  0.0968],
           [ 1.5725, -1.0859, -0.3932,  1.7004]]]], device='cuda:0'),
 torch.Size([1, 1, 8, 4]))

In [6]:
hs = vlstm_fw_tiled_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=4,
    bkv_tile_size=4,
)
hs, hs.shape

(tensor([[[[    inf,     nan,     nan,     nan],
           [    nan,     nan,     nan,     nan],
           [    inf,     nan,     nan,     nan],
           [    nan,     nan,     nan,     nan],
           [    nan,     nan,     nan,     nan],
           [    nan,     nan,    -inf,     inf],
           [   -inf,     inf,     inf,    -inf],
           [-2.8085,  0.6834,  0.1378,  2.3397]]]], device='cuda:0'),
 torch.Size([1, 1, 8, 4]))

In [11]:
qs.shape

torch.Size([1, 1, 8, 4])

In [14]:
qs.split(3, dim=2)[0].shape, qs.split(3, dim=2)[2].shape

(torch.Size([1, 1, 3, 4]), torch.Size([1, 1, 2, 4]))