In [7]:
import torch

## PyTorch vLSTM forward - Tiled Computation

Shows that we can compute the forward pass of the vLSTM in tiled fashion (similar to FlashAttention), 
which is necessary for the fused kernels.

In [8]:
%load_ext autoreload
%autoreload 2
from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel import vlstm_parallel_fw_torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 6 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [10]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 8, 6]), torch.Size([1, 1, 8, 1]))

In [11]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [12]:
rs = vlstm_parallel_fw_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
)
rs, rs.shape

(tensor([[[[-2.5441,  0.7163,  0.4934, -0.1267, -0.1014,  0.4035],
           [-1.6203,  0.4098,  0.3366, -0.0847, -0.1083,  0.2467],
           [ 0.8271,  0.0098,  0.5125, -0.6795,  0.2315,  0.0330],
           [-1.9904,  0.9070,  0.9530,  0.1805, -1.7032,  1.2508],
           [ 1.3889, -0.6382, -0.3210,  0.5379,  0.7543, -0.3709],
           [-0.4355,  0.4663, -0.8395,  0.3217,  0.2152, -0.0767],
           [ 0.8291, -3.9762, -2.8966,  0.0866, -0.2370,  0.0088],
           [-0.8208, -0.8103, -2.0256,  0.4166, -0.6027, -0.1329]]]],
        device='cuda:0'),
 torch.Size([1, 1, 8, 6]))

In [13]:
hs, m, l = vlstm_parallel_tiled(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=4,
    bkv_tile_size=4,
)
hs, hs.shape#, m, l

q_tiles: 2, torch.Size([1, 1, 4, 6])
kv_tiles: 2, torch.Size([1, 1, 4, 6])


(tensor([[[[-2.5441,  0.7163,  0.4934, -0.1267, -0.1014,  0.4035],
           [-1.6203,  0.4098,  0.3366, -0.0847, -0.1083,  0.2467],
           [ 0.8271,  0.0098,  0.5125, -0.6795,  0.2315,  0.0330],
           [-1.9904,  0.9070,  0.9530,  0.1805, -1.7032,  1.2508],
           [ 1.3889, -0.6382, -0.3210,  0.5379,  0.7543, -0.3709],
           [-0.4355,  0.4663, -0.8395,  0.3217,  0.2152, -0.0767],
           [ 0.8291, -3.9762, -2.8966,  0.0866, -0.2370,  0.0088],
           [-0.8208, -0.8103, -2.0256,  0.4166, -0.6027, -0.1329]]]],
        device='cuda:0'),
 torch.Size([1, 1, 8, 6]))

In [14]:
hs - rs

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  3.7253e-09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-1.1921e-07,  5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 1.1921e-07,  1.1921e-07, -3.5763e-07,  2.3842e-07,  0.0000e+00,
            2.9802e-08],
          [ 2.3842e-07, -2.3842e-07, -2.3842e-07,  0.0000e+00,  4.7684e-07,
           -3.8743e-07],
          [ 5.9605e-08,  2.3842e-07,  0.0000e+00,  0.0000e+00,  5.9605e-08,
           -2.2352e-08],
          [-5.9605e-08,  5.9605e-08,  0.0000e+00,  0.0000e+00,  2.9802e-07,
           -1.1921e-07]]]], device='cuda:0')