In [1]:
import torch
torch.set_printoptions(linewidth=300, threshold=100000)

## PyTorch vLSTM forward - Tiled Computation

Shows that we can compute the forward pass of the vLSTM in tiled fashion (similar to FlashAttention), 
which is necessary for the fused kernels.

In [2]:
%load_ext autoreload
%autoreload 2
from vlstm_parallel_tiled import vlstm_parallel_tiled, vlstm_parallel_tiled_stable
from vlstm_parallel import vlstm_parallel_fw_torch

In [3]:
# params
S = 1024 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 32 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 1024, 32]), torch.Size([1, 1, 1024, 1]))

In [5]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [6]:
rs, log_matD, matC_normalized = vlstm_parallel_fw_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
)
rs, rs.shape

(tensor([[[[ 2.5430e+00, -7.1582e-01, -4.9292e-01,  1.2659e-01,  1.0126e-01, -4.0308e-01,  9.0137e-01,  8.0908e-01, -6.8799e-01,  1.3708e-01,  1.0371e+00,  9.2468e-02, -3.7476e-01, -9.0759e-02,  2.0625e+00, -1.8145e+00, -2.7173e-01,  2.8076e-01, -1.0391e+00,  7.7539e-01,  8.8037e-01,  4.4342e-02,
            -1.4863e+00,  1.1328e+00,  1.3262e+00, -1.2607e+00,  9.4922e-01, -6.5527e-01,  9.0869e-01, -6.2842e-01, -6.5820e-01,  2.0801e+00],
           [-1.7480e+00,  4.2896e-01,  2.9028e-01, -1.4883e+00,  4.7531e-03,  5.6787e-01,  1.1261e-01, -1.4612e-01,  4.1040e-01, -9.0674e-01, -1.0498e+00, -2.7954e-01, -1.0449e+00,  1.8982e-01, -1.1641e+00,  5.9180e-01,  3.5693e-01, -9.7705e-01, -4.2896e-01,  1.2024e-01,  5.6201e-01, -1.1240e+00,
            -6.4026e-02, -7.6465e-01, -1.2979e+00,  8.3252e-01, -1.0684e+00,  9.0332e-01, -1.1787e+00, -9.4531e-01,  9.9976e-02, -1.3682e+00],
           [ 1.1074e+00, -6.6797e-01, -1.4648e+00,  7.7002e-01,  7.6660e-02, -5.7275e-01,  3.4766e-01,  5.6366e-02, -5

In [7]:
log_matD[:, :, 50:, :10]

tensor([[[[ -23.8438,  -24.1094,  -23.3594,  -22.1562,  -21.9375,  -21.2344,  -21.2656,  -20.2031,  -20.2344,  -19.4688],
          [ -24.3125,  -24.5781,  -23.8281,  -22.6250,  -22.4062,  -21.7031,  -21.7344,  -20.6719,  -20.7031,  -19.9375],
          [ -24.9531,  -25.2188,  -24.4688,  -23.2656,  -23.0469,  -22.3438,  -22.3594,  -21.3125,  -21.3438,  -20.5781],
          [ -25.5781,  -25.8438,  -25.0938,  -23.8906,  -23.6719,  -22.9688,  -22.9844,  -21.9375,  -21.9688,  -21.2031],
          [ -26.1562,  -26.4219,  -25.6719,  -24.4688,  -24.2500,  -23.5469,  -23.5781,  -22.5156,  -22.5469,  -21.7812],
          [ -26.5312,  -26.7969,  -26.0469,  -24.8438,  -24.6250,  -23.9219,  -23.9531,  -22.8906,  -22.9219,  -22.1562],
          [ -26.9844,  -27.2500,  -26.5000,  -25.2969,  -25.0781,  -24.3750,  -24.3906,  -23.3438,  -23.3750,  -22.6094],
          [ -27.4844,  -27.7500,  -27.0000,  -25.7969,  -25.5781,  -24.8750,  -24.8906,  -23.8438,  -23.8750,  -23.1094],
          [ -27.9531,  -

In [8]:
max_logD = torch.max(log_matD, dim=-1, keepdim=True)[0]

In [9]:
log_matD_stab = log_matD - max_logD

In [10]:
log_matD_stab[:, :, 50:, :10]

tensor([[[[ -23.9688,  -24.2344,  -23.4844,  -22.2812,  -22.0625,  -21.3594,  -21.3906,  -20.3281,  -20.3594,  -19.5938],
          [ -24.9062,  -25.1719,  -24.4219,  -23.2188,  -23.0000,  -22.2969,  -22.3281,  -21.2656,  -21.2969,  -20.5312],
          [ -25.0938,  -25.3594,  -24.6094,  -23.4062,  -23.1875,  -22.4844,  -22.5000,  -21.4531,  -21.4844,  -20.7188],
          [ -25.5781,  -25.8438,  -25.0938,  -23.8906,  -23.6719,  -22.9688,  -22.9844,  -21.9375,  -21.9688,  -21.2031],
          [ -26.4531,  -26.7188,  -25.9688,  -24.7656,  -24.5469,  -23.8438,  -23.8750,  -22.8125,  -22.8438,  -22.0781],
          [ -27.2344,  -27.5000,  -26.7500,  -25.5469,  -25.3281,  -24.6250,  -24.6562,  -23.5938,  -23.6250,  -22.8594],
          [ -27.8438,  -28.0938,  -27.3438,  -26.1562,  -25.9375,  -25.2188,  -25.2500,  -24.1875,  -24.2188,  -23.4688],
          [ -27.8438,  -28.0938,  -27.3438,  -26.1562,  -25.9375,  -25.2188,  -25.2500,  -24.1875,  -24.2188,  -23.4688],
          [ -28.2188,  -

In [11]:
torch.exp(log_matD_stab)[:,:,100:300,:20].to(dtype=torch.bfloat16)

tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

Outcome of this investigation: 

- The mLSTM looks back only about 50 to 70 timesteps, depending on numerical range. 
    - bfloat16 & float32 look back about 50 to 70.
    - float16 looks back about 35.

In [12]:
matC_normalized[:, :, 50:300, 260:300].to(dtype=torch.float16)

tensor([[[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
           -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
           -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.000

In [13]:
hs, m, l, n = vlstm_parallel_tiled(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=8,
    bkv_tile_size=8,
)
# hs, hs.shape#, m, l

q_tiles: 128, torch.Size([1, 1, 8, 32])
kv_tiles: 128, torch.Size([1, 1, 8, 32])
q_idx: 10, kv_idx: 0, m_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), m: tensor([[[[-36.0000],
          [-36.3438],
          [-36.8125],
          [-37.1875],
          [-37.5938],
          [-38.2500],
          [-38.6250],
          [-39.2188]]]], device='cuda:0', dtype=torch.float16), l_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l: tensor([[[[-0.8042],
          [-0.5635],
          [-1.1475],
          [ 1.1768],
          [-1.0508],
          [ 0.9272],
          [-0.2133],
          [-0.5244]]]], device='cuda:0', dtype=torch.float16), n_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
   

In [14]:
torch.abs(hs - rs).max()
#m_prev0 = -10
#tensor(3.8147e-06, device='cuda:0')
#m_prev0 = 0
#tensor(2.8610e-06, device='cuda:0')

tensor(nan, device='cuda:0', dtype=torch.float16)

In [15]:
hss, m, l, n = vlstm_parallel_tiled_stable(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=8,
    bkv_tile_size=8,
)
# hss, hss.shape#, m, l

q_tiles: 128, torch.Size([1, 1, 8, 32])
kv_tiles: 128, torch.Size([1, 1, 8, 32])
q_idx: 10, kv_idx: 0, m_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), m: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), n_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), n: tensor([[[[1.]

In [16]:
torch.abs(hss -rs).max()

tensor(0.0078, device='cuda:0', dtype=torch.float16)