In [1]:
import torch
torch.set_printoptions(linewidth=300, threshold=100000)

## PyTorch vLSTM forward - Tiled Computation

Shows that we can compute the forward pass of the vLSTM in tiled fashion (similar to FlashAttention), 
which is necessary for the fused kernels.

In [2]:
%load_ext autoreload
%autoreload 2
from vlstm_parallel_tiled import vlstm_parallel_tiled, vlstm_parallel_tiled_stable
from vlstm_parallel import vlstm_parallel_fw_torch

In [3]:
# params
S = 1024 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 32 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)

# fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
r1 = 3.0
r2 = 6.0
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE) * (r2 - r1) + r1

#! This shows the numerical instability of the tiled forward pass
# fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

qs.shape, fgs.shape

(torch.Size([1, 1, 1024, 32]), torch.Size([1, 1, 1024, 1]))

In [5]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [6]:
rs, log_matD, matC_normalized = vlstm_parallel_fw_torch(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
)
rs, rs.shape

(tensor([[[[ 2.5430e+00, -7.1582e-01, -4.9292e-01,  1.2659e-01,  1.0126e-01, -4.0308e-01,  9.0137e-01,  8.0908e-01, -6.8799e-01,  1.3708e-01,  1.0371e+00,  9.2468e-02, -3.7476e-01, -9.0759e-02,  2.0625e+00, -1.8145e+00, -2.7173e-01,  2.8076e-01, -1.0391e+00,  7.7539e-01,  8.8037e-01,  4.4342e-02,
            -1.4863e+00,  1.1328e+00,  1.3262e+00, -1.2607e+00,  9.4922e-01, -6.5527e-01,  9.0869e-01, -6.2842e-01, -6.5820e-01,  2.0801e+00],
           [-1.8379e+00,  4.6167e-01,  3.1323e-01, -1.3340e+00, -7.2937e-03,  5.4932e-01, -2.5940e-03, -2.2144e-01,  4.4189e-01, -8.1934e-01, -1.0488e+00, -2.5830e-01, -8.8379e-01,  1.7859e-01, -1.2656e+00,  7.3096e-01,  3.4717e-01, -8.9795e-01, -2.6221e-01,  1.8478e-02,  3.9795e-01, -1.0010e+00,
             1.1200e-01, -8.0615e-01, -1.3008e+00,  8.8135e-01, -1.0547e+00,  8.7549e-01, -1.1484e+00, -7.6660e-01,  1.6333e-01, -1.4492e+00],
           [ 1.7070e+00, -6.5234e-01, -9.5703e-01,  6.4062e-01,  7.4158e-02, -5.1221e-01,  4.8096e-01,  3.3667e-01, -5

In [7]:
log_matD[:, :, 50:, :10]

tensor([[[[-2.4902e-02, -6.7383e-01, -5.4297e-01,  8.5449e-02, -3.4668e-02,  4.1992e-02, -4.4385e-01,  2.3096e-01, -2.7026e-01, -2.4414e-02],
          [-3.5645e-02, -6.8457e-01, -5.5371e-01,  7.4707e-02, -4.5410e-02,  3.1250e-02, -4.5459e-01,  2.2021e-01, -2.8101e-01, -3.5156e-02],
          [-6.9336e-02, -7.1826e-01, -5.8691e-01,  4.1016e-02, -7.8613e-02, -2.4414e-03, -4.8828e-01,  1.8604e-01, -3.1421e-01, -6.8848e-02],
          [-1.0010e-01, -7.4902e-01, -6.1816e-01,  1.0254e-02, -1.0986e-01, -3.3203e-02, -5.1904e-01,  1.5576e-01, -3.4546e-01, -9.9609e-02],
          [-1.2402e-01, -7.7295e-01, -6.4160e-01, -1.3672e-02, -1.3330e-01, -5.7129e-02, -5.4297e-01,  1.3135e-01, -3.6890e-01, -1.2354e-01],
          [-1.2891e-01, -7.7783e-01, -6.4648e-01, -1.8555e-02, -1.3818e-01, -6.2012e-02, -5.4785e-01,  1.2646e-01, -3.7378e-01, -1.2842e-01],
          [-1.3867e-01, -7.8760e-01, -6.5625e-01, -2.8320e-02, -1.4795e-01, -7.1777e-02, -5.5762e-01,  1.1670e-01, -3.8354e-01, -1.3818e-01],
      

In [8]:
max_logD = torch.max(log_matD, dim=-1, keepdim=True)[0]

In [9]:
log_matD_stab = log_matD - max_logD

In [10]:
log_matD_stab[:, :, 50:, :10]

tensor([[[[ -0.7295,  -1.3789,  -1.2480,  -0.6191,  -0.7393,  -0.6626,  -1.1484,  -0.4736,  -0.9746,  -0.7290],
          [ -0.7295,  -1.3789,  -1.2480,  -0.6191,  -0.7393,  -0.6626,  -1.1484,  -0.4736,  -0.9746,  -0.7290],
          [ -0.7295,  -1.3789,  -1.2471,  -0.6191,  -0.7388,  -0.6626,  -1.1484,  -0.4741,  -0.9746,  -0.7290],
          [ -0.7295,  -1.3789,  -1.2480,  -0.6191,  -0.7393,  -0.6626,  -1.1484,  -0.4736,  -0.9746,  -0.7290],
          [ -0.7295,  -1.3789,  -1.2471,  -0.6191,  -0.7388,  -0.6626,  -1.1484,  -0.4741,  -0.9746,  -0.7290],
          [ -0.8296,  -1.4785,  -1.3477,  -0.7192,  -0.8389,  -0.7627,  -1.2480,  -0.5742,  -1.0742,  -0.8291],
          [ -0.9902,  -1.6387,  -1.5078,  -0.8799,  -0.9995,  -0.9233,  -1.4092,  -0.7349,  -1.2354,  -0.9897],
          [ -0.9902,  -1.6387,  -1.5078,  -0.8799,  -0.9995,  -0.9233,  -1.4092,  -0.7349,  -1.2354,  -0.9897],
          [ -0.9902,  -1.6387,  -1.5078,  -0.8799,  -1.0000,  -0.9233,  -1.4092,  -0.7344,  -1.2354,  -0

In [11]:
torch.exp(log_matD_stab)[:,:,100:300,:20].to(dtype=torch.bfloat16)

tensor([[[[0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1147, 0.2285, 0.1543, 0.1973, 0.1953, 0.1396, 0.1592, 0.2236, 0.1309],
          [0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1147, 0.2285, 0.1543, 0.1973, 0.1953, 0.1396, 0.1592, 0.2236, 0.1309],
          [0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1152, 0.2285, 0.1543, 0.1973, 0.1953, 0.1396, 0.1592, 0.2236, 0.1309],
          [0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1152, 0.2285, 0.1543, 0.1973, 0.1953, 0.1396, 0.1592, 0.2236, 0.1309],
          [0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1147, 0.2285, 0.1543, 0.1973, 0.1953, 0.1396, 0.1592, 0.2236, 0.1309],
          [0.1621, 0.0850, 0.0967, 0.1816, 0.1611, 0.1738, 0.1064, 0.2100, 0.1270, 0.1621, 0.1494, 0.1152, 0.2285, 0.1543, 0.1973, 0.1953, 0

Outcome of this investigation: 
For fgs init torch.rand() (uniform[0,1])
- The mLSTM looks back only about 50 to 70 timesteps, depending on numerical range. 
    - bfloat16 & float32 look back about 50 to 70.
    - float16 looks back about 35.

In [12]:
matC_normalized[:, :, 800:, 0:300].to(dtype=torch.float32)

tensor([[[[-9.5367e-07,  4.1723e-07, -1.0729e-06,  2.3842e-07, -5.9605e-07, -2.0862e-06, -2.3842e-06,  1.7285e-06,  1.3113e-06, -1.3709e-06, -5.3644e-07,  2.5034e-06, -3.5763e-06, -1.7881e-07,  1.4305e-06, -4.1723e-06,  1.8477e-06,  1.1325e-06,  4.5896e-06,  2.3842e-07,  2.2054e-06, -5.3048e-06,
           -3.0398e-06, -3.7551e-06,  5.9605e-08,  1.6689e-06, -1.1921e-06, -5.9605e-08,  0.0000e+00,  4.5896e-06,  2.2650e-06,  4.1723e-07, -1.0729e-06, -1.1921e-06, -1.7881e-07, -2.1458e-06, -4.1723e-07, -7.7486e-07,  3.0398e-06, -4.1723e-07,  4.4107e-06,  1.0729e-06,  4.7684e-06, -3.0994e-06,
           -4.7684e-07, -3.0398e-06,  6.5565e-07, -5.3644e-06,  4.5896e-06, -1.9073e-06, -2.9802e-07,  5.1856e-06, -1.1921e-07,  2.5630e-06,  2.1458e-06, -5.9605e-08, -2.6226e-06,  1.8477e-06,  1.8477e-06,  5.4836e-06, -5.9605e-07,  2.5630e-06,  5.9605e-07, -4.8876e-06,  1.4901e-06,  7.2122e-06,
            9.5367e-06,  1.9670e-06,  2.1040e-05, -3.8743e-06,  5.0664e-06,  1.3113e-06, -5.1260e-06,  9.5367

In [13]:
hs, m, l, n = vlstm_parallel_tiled(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=8,
    bkv_tile_size=8,
)
# hs, hs.shape#, m, l

q_tiles: 128, torch.Size([1, 1, 8, 32])
kv_tiles: 128, torch.Size([1, 1, 8, 32])
q_idx: 10, kv_idx: 0, m_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), m: tensor([[[[-0.4048],
          [-0.4077],
          [-0.4175],
          [-0.4224],
          [-0.4292],
          [-0.4712],
          [-0.4761],
          [-0.5024]]]], device='cuda:0', dtype=torch.float16), l_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l: tensor([[[[ 0.6074],
          [ 0.4756],
          [-0.6094],
          [ 0.9624],
          [-1.6377],
          [ 3.1211],
          [ 0.7490],
          [-1.4629]]]], device='cuda:0', dtype=torch.float16), n_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [

In [14]:
torch.abs(hs - rs).max()
#m_prev0 = -10
#tensor(3.8147e-06, device='cuda:0')
#m_prev0 = 0
#tensor(2.8610e-06, device='cuda:0')

tensor(nan, device='cuda:0', dtype=torch.float16)

In [15]:
hss, m, l, n = vlstm_parallel_tiled_stable(
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    bq_tile_size=8,
    bkv_tile_size=8,
)
# hss, hss.shape#, m, l

q_tiles: 128, torch.Size([1, 1, 8, 32])
kv_tiles: 128, torch.Size([1, 1, 8, 32])
q_idx: 10, kv_idx: 0, m_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), m: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0', dtype=torch.float16), l: tensor([[[[ 0.4058],
          [ 0.3162],
          [-0.4014],
          [ 0.6313],
          [-1.0664],
          [ 1.9473],
          [ 0.4656],
          [-0.8848]]]], device='cuda:0', dtype=torch.float16), n_prev: tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]]], device='cuda:0',

In [16]:
torch.abs(hss -rs).max()

tensor(0.5000, device='cuda:0', dtype=torch.float16)