In [1]:
import torch
torch.set_printoptions(linewidth=600, threshold=100000)

## PyTorch vLSTM backward with group norm tiled (headwise layernorm)

Shows what happens if we fuse the multihead layernorm with the vlstm kernel.
Here we compare the tiled impl (template for triton&cuda kernels) with the parallel impl and check for numerical correctness.

In [2]:
%load_ext autoreload
%autoreload 2
# from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel_w_groupnorm import vlstm_parallel_bw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm_full
from vlstm_parallel import vlstm_parallel_fw_torch

from vlstm_parallel_w_groupnorm_tiled_bw import mlstm_parallel_w_groupnorm_torch_tiled_bw, vlstm_parallel_w_groupnorm_torch_bw, construct_log_gate_matrix_tiled, construct_log_gate_matrix_paper

from ln import MultiHeadLayerNorm

## BW parallel with groupnorm

In [3]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

BLOCK_Q = 16
BLOCK_KV = 16

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(1)

# forward inputs
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
# igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

# backward inputs
dH = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vecN = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
vecM = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8, 1]))

In [5]:
offset = 3.* torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

In [6]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64, requires_grad=True),
 None)

### pytorch

In [7]:
fgs_pt = fgs.clone().detach().requires_grad_(True)
igs_pt = igs.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [8]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [9]:
qs_pt.grad

In [10]:
rs = vlstm_parallel_fw_torch(
    queries=qs_pt,
    keys=ks_pt,
    values=vs_pt,
    igate_preact=igs_pt,
    fgate_preact=fgs_pt,
    eps=EPS,
)
rs_scaled = mh_layernorm(rs)
rs_scaled, rs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0', dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [11]:
((rs_scaled+offset)**2).sum().backward()

In [12]:
rs.shape # (B, NH, S, DH)

torch.Size([1, 1, 8, 4])

In [13]:
qs_pt.grad

tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
          [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
          [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
          [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
          [ 9.8784e-01, -4.8033e-01,  1.2925e+00, -9.1624e-01],
          [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
          [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
          [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]], device='cuda:0', dtype=torch.float64)

In [14]:
rs2 = rs#.transpose(1, 2)
rs2.shape

torch.Size([1, 1, 8, 4])

In [15]:
rs3 = (rs2 - rs2.mean(-1, keepdim=True)) / rs2.std(-1, keepdim=True, unbiased=False)
# rs4 = rs3.transpose(1, 2)
rs3, rs3.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5336],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
 torch.Size([1, 1, 8, 4]))

In [16]:
rs3 - rs_scaled

tensor([[[[-7.4993e-07,  3.4896e-06,  1.0154e-06, -3.7551e-06],
          [ 3.4323e-06,  5.8879e-06, -1.0780e-06, -8.2422e-06],
          [ 2.2965e-06,  3.4939e-06, -4.9360e-06, -8.5439e-07],
          [ 2.5082e-06, -1.6520e-06, -8.3956e-07, -1.6605e-08],
          [ 1.5938e-06, -9.8798e-07, -5.6463e-07, -4.1209e-08],
          [ 1.0009e-06, -1.6121e-06,  6.8374e-07, -7.2535e-08],
          [ 5.5659e-07, -1.1214e-06,  7.5538e-07, -1.9061e-07],
          [ 3.9179e-06, -4.7970e-06,  2.4964e-05, -2.4085e-05]]]], device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

### own backward

In [17]:
fgs_obw = fgs.clone().detach().requires_grad_(True)
igs_obw = igs.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
hs, var_b, var_m = vlstm_parallel_fwbw_torch_w_groupnorm_full(
    queries=qs_obw,
    keys=ks_obw,
    values=vs_obw,
    igate_preact=igs_obw,
    fgate_preact=fgs_obw,
    eps=EPS,
)
hs, hs.shape
hs_scaled = mh_layernorm(hs)
hs_scaled, hs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0', dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [19]:
hs_scaled - rs_scaled

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

In [20]:
((hs_scaled+offset)**2).sum().backward()

tensor([[[[ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [-4.0645e-06],
          [ 6.1824e-06],
          [ 1.0703e-05],
          [ 0.0000e+00]]]], device='cuda:0', dtype=torch.float64)


In [21]:
qs_obw.grad, qs_pt.grad

(tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
           [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8033e-01,  1.2925e+00, -9.1624e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+0

In [22]:
var_b.abs(), torch.exp(-var_m), var_b.abs() > torch.exp(-var_m)

(tensor([[[[0.2209],
           [0.3124],
           [0.8088],
           [0.3172],
           [2.6139],
           [1.0345],
           [1.5021],
           [0.4596]]]], device='cuda:0', dtype=torch.float64, grad_fn=<AbsBackward0>),
 tensor([[[[0.4257],
           [1.0288],
           [1.6999],
           [0.3699],
           [0.5299],
           [0.7055],
           [1.1781],
           [0.5167]]]], device='cuda:0', dtype=torch.float64, grad_fn=<ExpBackward0>),
 tensor([[[[False],
           [False],
           [False],
           [False],
           [ True],
           [ True],
           [ True],
           [False]]]], device='cuda:0'))

In [23]:
qs_pt.grad - qs_obw.grad, ks_pt.grad - ks_obw.grad, vs_pt.grad - vs_obw.grad

(tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [-2.5606e-06,  1.4259e-05, -3.7683e-07,  6.6377e-06],
           [ 7.3301e-06, -4.2211e-06,  1.7718e-06,  1.7949e-06],
           [ 6.4459e-06, -5.9048e-06,  1.1989e-05,  1.4505e-06],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-1.3196e-07, -1.1467e-07, -7.7522e-08, -1.3573e-07],
           [-2.0452e-07, -1.7772e-07, -1.2015e-07, -2.1036e-07],
           [-8.9864e-08, -7.8086e-08, -5.2791e-08, -9.2430e-08],
           [-7.8008e-06, -6.7784e-06, -4.5826e-06, -8.0236e-06],
           [-1.4118e-06, -1.2267e-06, -8.2935e-07, -1.4521e-06],
           [-7.6242e-06,  6.1568e-06, -4.1334e-06, -9.8372e-06],
           [-6.9098e-06,  5.6214e-06, -8.3021e-0

In [24]:
fgs_pt.grad - fgs_obw.grad, igs_pt.grad - igs_obw.grad

(tensor([[[[ 6.0959e-16],
           [ 4.0919e-08],
           [ 5.5045e-09],
           [ 2.2432e-07],
           [ 7.5696e-06],
           [-2.4253e-06],
           [-1.4291e-05],
           [-1.1102e-16]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[ 5.5821e-08],
           [-4.1876e-08],
           [ 2.4110e-07],
           [ 2.4811e-05],
           [-9.3009e-07],
           [-2.9778e-05],
           [ 2.6454e-06],
           [-3.5527e-15]]]], device='cuda:0', dtype=torch.float64))

In [25]:
torch.allclose(qs_pt.grad, qs_obw.grad, atol=1e-5, rtol=1e-5)

True

In [26]:
atol = 1e-5
rtol = 1e-5
print(f"Forward match: {torch.allclose(hs_scaled, rs_scaled)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgs_pt.grad, fgs_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igs_pt.grad, igs_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True


In [27]:
## Conclusion: 
# dividing we get the same gradients, the error -1e-5 is due to numerical precision

In [28]:
# tensor([[[[-3.7828e-17, -2.8809e-17, -1.0405e-16,  4.7115e-17],
#           [-1.1102e-16, -4.4409e-16, -8.8818e-16,  4.4409e-16],
#           [ 8.8818e-16, -8.8818e-16,  0.0000e+00, -8.3267e-17],
#           [ 1.2490e-16, -5.5511e-17,  1.3878e-17, -1.7347e-17],
#           [ 1.8937e-06, -1.0558e-05, -2.7356e-06,  1.5993e-05],
#           [ 4.4409e-16,  0.0000e+00,  1.3323e-15,  0.0000e+00],
#           [-6.0601e-07, -9.2155e-07,  4.0914e-06, -5.0413e-07],
#           [-1.2725e-06,  1.7485e-06,  5.4959e-06,  8.2998e-07]]]],
#        device='cuda:0', dtype=torch.float64)

### own backward2

Reimplementation by using separate function for fw and bw which serve as ground truth for kernel impl.
They should match exactly own backward(1).

In [29]:
fgs_obw2 = fgs.clone().detach().requires_grad_(True)
igs_obw2 = igs.clone().detach().requires_grad_(True)
qs_obw2 = qs.clone().detach().requires_grad_(True)
ks_obw2 = ks.clone().detach().requires_grad_(True)
vs_obw2 = vs.clone().detach().requires_grad_(True)

In [30]:
hs2, var_b2, var_m2 = vlstm_parallel_fwbw_torch_w_groupnorm(
    queries=qs_obw2,
    keys=ks_obw2,
    values=vs_obw2,
    igate_preact=igs_obw2,
    fgate_preact=fgs_obw2,
    eps=EPS,
)
hs2, hs2.shape
hs_scaled2 = mh_layernorm(hs2)
hs_scaled2, hs_scaled2.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0', dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [31]:
hs_scaled - hs_scaled2

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

In [32]:
((hs_scaled2+offset)**2).sum().backward()

In [33]:
qs_obw.grad - qs_obw2.grad, ks_obw.grad - ks_obw2.grad, vs_obw.grad - vs_obw2.grad

(tensor([[[[-1.3020e-10,  8.6005e-11,  5.7015e-11, -1.0286e-10],
           [ 5.3706e-07, -6.5505e-07,  4.6599e-07,  7.6344e-07],
           [-2.5350e-06,  3.9298e-07,  4.4202e-06,  2.0236e-07],
           [-7.5124e-07,  3.8170e-06, -5.5758e-07, -1.4132e-07],
           [ 3.7791e-07, -1.8376e-07,  4.9448e-07, -3.5052e-07],
           [-4.0115e-06, -4.4232e-06, -1.8519e-06, -4.1562e-06],
           [ 9.0684e-07,  2.0908e-07, -1.0949e-06,  3.4366e-07],
           [-4.9801e-05,  2.0900e-05, -4.6278e-05, -1.7950e-05]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[ 6.1895e-06, -2.2341e-06,  1.7210e-06,  2.6788e-06],
           [ 3.8116e-06, -2.7216e-06,  2.7156e-06,  6.0842e-06],
           [ 8.4303e-08,  4.1735e-07, -5.1106e-07, -2.3957e-07],
           [-8.9174e-06, -1.0295e-05, -3.8709e-06,  2.0779e-05],
           [ 1.4558e-06,  2.0089e-06,  4.3810e-07, -1.9617e-06],
           [-1.6445e-05, -2.9992e-05, -1.6734e-05,  5.6504e-05],
           [-1.6525e-05, -2.8659e-05, -1.3668e-0

In [34]:
fgs_obw.grad - fgs_obw2.grad, igs_obw.grad - igs_obw2.grad

(tensor([[[[ 0.0000e+00],
           [-2.8527e-06],
           [-7.4537e-07],
           [-5.4030e-07],
           [ 1.7050e-07],
           [ 1.9369e-06],
           [ 1.4342e-05],
           [ 1.2706e-06]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-3.8915e-06],
           [ 2.0035e-06],
           [ 1.2739e-06],
           [ 1.1791e-06],
           [ 2.3387e-06],
           [ 2.3987e-05],
           [-2.0978e-05],
           [-5.9123e-06]]]], device='cuda:0', dtype=torch.float64))

In [35]:
qs_obw.grad, qs_obw2.grad, qs_obw.grad - qs_obw2.grad

(tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
           [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7786e-01,  1.4118e+00, -2.0623e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+0

In [36]:
# Conclusion: the impls match, the error is max 1e-5 is due to numerical precision

## BW parallel TILED with groupnorm

In [37]:
# params
S = 128 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 256 # dim per head
DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

BLOCK_Q = 64
BLOCK_KV = 32

In [38]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(1)

# forward inputs
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
# igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

# fgs = torch.zeros((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
# igs = torch.zeros_like(fgs) #torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

# backward inputs
dH = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vecN = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
vecM = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

qs.shape, fgs.shape

(torch.Size([1, 1, 128, 256]), torch.Size([1, 1, 128, 1]))

In [39]:
dQ_pt_p, dK_pt_p, dV_pt_p, dI_pt_p, dF_pt_p, matDtilde_p = vlstm_parallel_w_groupnorm_torch_bw(matDeltaHtilde=dH, matQ=qs, matK=ks, matV=vs, vecN=vecN, vecM=vecM, vecI=igs, vecF=fgs)

In [40]:
# dQ_pt_p, dK_pt_p, dV_pt_p, dI_pt_p, dF_pt_p

In [41]:
logD_tile = construct_log_gate_matrix_tiled(vecI=igs.squeeze(-1), vecF=fgs.squeeze(-1), BQ=BLOCK_Q, BKV=BLOCK_KV, idx_BQ=0, idx_BKV=0)
logD_tile

tensor([[[[-7.8305e-01,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf],
          [-8.1149e-01, -1.0101e+00,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf],
          [-2.1644e+00, -2.3630e+00,  4.2740e-01,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,        -inf,    

In [42]:
logD_tile_paper = construct_log_gate_matrix_paper(vecI=igs, vecF=fgs)

In [43]:
torch.exp(logD_tile_paper-vecM)

tensor([[[[5.3393e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000

In [44]:
# logD_tile_paper - logD_tile

In [45]:
# tiled version
dQ_pt_t, dK_pt_t, dV_pt_t, dI_pt_t, dF_pt_t, matDtilde_t = mlstm_parallel_w_groupnorm_torch_tiled_bw(
    matDeltaHtilde=dH,
    matQ=qs,
    matK=ks,
    matV=vs,
    vecN=vecN,
    vecM=vecM,
    vecI=igs.squeeze(-1),
    vecF=fgs.squeeze(-1),
    BLOCK_Q=BLOCK_Q,
    BLOCK_KV=BLOCK_KV,
)

matQ_tiles: 2, torch.Size([1, 1, 64, 256]) | matK_tiles: 4, torch.Size([1, 1, 32, 256])


In [46]:
dI_pt_t

tensor([[[-3.6833e+01, -8.0605e+00,  4.4622e+01, -6.6804e+01,  4.0738e+02, -8.1777e+00, -4.0472e+01, -9.9472e+01,  5.0582e+01,  1.3343e+01,  1.4263e+01, -5.5691e+01, -3.9848e+02,  7.5334e+01, -5.3141e+00,  1.4131e+02,  9.4419e+01,  6.2539e-01,  3.0202e+01, -4.1684e+01,  5.6370e+01, -5.0319e+02, -2.3440e+02, -9.4508e+02, -2.1177e+02,  4.3332e-02, -8.6712e+00,  1.6152e+01, -2.3372e+01,  3.4062e+01, -1.6858e+01, -2.3527e+02,  2.6515e+02,  1.7115e+01, -5.7261e+00, -1.4449e+01, -1.1371e+02, -5.9737e+01, -8.2768e+01,  1.5959e+03,  3.1347e+00, -5.5506e+01, -1.1206e+03, -5.1060e+02,  2.1892e+02,
          -9.5991e+00,  2.8083e+02,  1.3966e+02,  8.1759e+02,  5.2680e+00, -1.6316e+02,  1.5248e+01, -6.3917e+00,  1.6751e+02,  6.3357e+00,  1.1123e+01, -1.3286e+01,  6.2107e+00, -6.1921e+00,  2.3736e+01, -1.6014e+01,  6.9725e+01, -4.5092e+01, -1.5942e+01, -1.2162e+02,  3.1059e+01, -2.0560e+02, -3.1638e+01, -8.0559e+01, -3.1097e+00,  2.0463e+03, -3.3027e+02,  2.3525e+00, -2.4191e+00, -1.3327e+01, -9.18

In [47]:
dQ_pt_t, dQ_pt_p, 
dQ_pt_t - dQ_pt_p, dQ_pt_t.shape

(tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  

In [48]:
dK_pt_t - dK_pt_p

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.

In [49]:
dV_pt_t - dV_pt_p

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.

In [50]:
dI_pt_t - dI_pt_p.squeeze(-1)

tensor([[[ 7.1054e-15, -1.7764e-15,  0.0000e+00, -1.4211e-14,  5.6843e-14,  1.7764e-15, -1.4211e-14, -2.8422e-14,  2.1316e-14,  0.0000e+00,  3.5527e-15,  7.1054e-15,  1.7053e-13, -1.4211e-14,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.2204e-15, -3.5527e-15, -1.4211e-14,  0.0000e+00,  0.0000e+00,  2.8422e-14, -1.1369e-13,  0.0000e+00, -2.7756e-17,  0.0000e+00,  3.5527e-15,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -5.6843e-14,  3.5527e-15,  2.6645e-15,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.4211e-14,  0.0000e+00, -3.9968e-15,  1.4211e-14, -2.2737e-13,  0.0000e+00,  2.8422e-14,
           1.7764e-15, -5.6843e-14,  0.0000e+00, -1.1369e-13, -2.6645e-15,  2.8422e-14,  0.0000e+00,  8.8818e-16,  0.0000e+00,  0.0000e+00, -1.7764e-15, -1.7764e-15, -8.8818e-16, -8.8818e-16,  0.0000e+00,  0.0000e+00,  0.0000e+00, -7.1054e-15,  3.5527e-15,  0.0000e+00,  3.5527e-15, -2.8422e-14, -7.1054e-15,  1.4211e-14, -2.1316e-14, -4.5475e-13, -1.1369e-13,  4.4409e-16, -8.8818e-16,  1.0658e-14,  0.00

In [51]:
dF_pt_t - dF_pt_p.squeeze(-1)

tensor([[[ 1.2941e-13,  2.2560e-13,  2.8422e-13,  2.7645e-13,  1.4233e-13,  7.1054e-14,  8.5265e-14,  9.4147e-14,  8.5265e-14,  1.8385e-13,  2.2027e-13,  1.2723e-13,  1.6853e-13,  1.1369e-13,  3.9524e-14,  7.8160e-14,  1.1369e-13,  9.5923e-14,  5.8620e-14,  1.2434e-13,  4.0856e-14,  1.1902e-13,  8.5265e-14,  0.0000e+00,  5.6843e-14, -1.9540e-14, -1.7764e-14, -1.0658e-14, -5.7732e-15, -2.4869e-14, -3.9080e-14, -6.7502e-14, -3.6771e-13, -1.9185e-13, -1.1369e-13, -2.7001e-13, -5.5067e-14, -3.9346e-13, -4.5075e-13, -2.7001e-13, -3.5172e-13, -2.6823e-13, -1.4921e-13,  2.2737e-13, -1.4211e-13,
          -4.2633e-14, -7.1054e-15, -1.7053e-13, -6.0396e-14,  1.1369e-13,  5.6843e-14, -7.1054e-15, -5.3291e-15, -5.3291e-15,  1.3323e-14,  2.4980e-15,  8.2157e-15,  3.7748e-15, -7.1054e-15, -1.9540e-14, -2.1316e-14, -2.1316e-14, -5.3291e-14, -2.4869e-14, -1.9540e-14,  1.9540e-14,  3.5527e-14,  8.8818e-15,  1.4655e-14,  4.2633e-14,  3.6948e-13,  3.6193e-13,  1.5632e-13,  1.4211e-13,  1.9185e-13,  1.13

In [52]:
dF_pt_t, dF_pt_p.squeeze(-1) #, dF_pt_t - dF_pt_p.squeeze(-1)

(tensor([[[ 1.2941e-13, -1.4365e+01, -3.1924e+01, -1.0278e+00, -1.8937e+00,  1.1756e+02,  1.4445e+02,  5.9197e+00, -1.6288e+01,  6.0553e+00, -1.0758e+01, -9.0284e-01, -1.6878e+00, -2.0142e+01,  3.9759e+00,  6.3928e+00, -1.9634e+01,  2.0469e+01,  1.4520e+01, -7.1001e+00, -9.0633e+00,  1.4065e+01, -1.8519e+02, -3.5949e+02, -4.3365e+02, -1.0164e+01, -8.6021e+00, -8.2432e+00, -2.6296e+00, -2.1919e+01,  6.1352e+00,  1.0840e+01,  3.0831e+00,  5.2133e+01,  4.3502e+01,  7.8339e+01,  9.7628e+00,  6.7691e+00, -5.7109e-01, -3.3183e+01, -9.6016e+00, -1.0867e+01, -2.4283e+01, -6.8659e+02, -1.5020e+02,
            9.5957e+01,  5.6723e+01,  1.5374e+02,  3.0433e+01,  5.2916e+02,  4.2974e+02,  2.8088e+00,  9.2523e+00, -4.0836e+00,  5.5201e+00,  3.1760e-01, -1.9104e+00,  1.9914e+00,  2.6752e+00, -5.1529e+00,  9.1402e+00,  9.5180e+00,  2.2419e+01,  3.0390e+01,  1.2740e+01,  1.3459e+01,  1.7243e+01,  3.6843e+00, -1.9825e+00, -5.3785e+01, -1.4062e+02, -3.6739e+00, -9.0785e+01, -8.1116e+00, -1.3535e+01, -1.

In [53]:
tmp_df = torch.zeros_like(dF_pt_t)
tmp_df[:, :, 1:] = dF_pt_t[:, :, :-1]
dF_pt_p.squeeze(-1), tmp_df

(tensor([[[ 0.0000e+00, -1.4365e+01, -3.1924e+01, -1.0278e+00, -1.8937e+00,  1.1756e+02,  1.4445e+02,  5.9197e+00, -1.6288e+01,  6.0553e+00, -1.0758e+01, -9.0284e-01, -1.6878e+00, -2.0142e+01,  3.9759e+00,  6.3928e+00, -1.9634e+01,  2.0469e+01,  1.4520e+01, -7.1001e+00, -9.0633e+00,  1.4065e+01, -1.8519e+02, -3.5949e+02, -4.3365e+02, -1.0164e+01, -8.6021e+00, -8.2432e+00, -2.6296e+00, -2.1919e+01,  6.1352e+00,  1.0840e+01,  3.0831e+00,  5.2133e+01,  4.3502e+01,  7.8339e+01,  9.7628e+00,  6.7691e+00, -5.7109e-01, -3.3183e+01, -9.6016e+00, -1.0867e+01, -2.4283e+01, -6.8659e+02, -1.5020e+02,
            9.5957e+01,  5.6723e+01,  1.5374e+02,  3.0433e+01,  5.2916e+02,  4.2974e+02,  2.8088e+00,  9.2523e+00, -4.0836e+00,  5.5201e+00,  3.1760e-01, -1.9104e+00,  1.9914e+00,  2.6752e+00, -5.1529e+00,  9.1402e+00,  9.5180e+00,  2.2419e+01,  3.0390e+01,  1.2740e+01,  1.3459e+01,  1.7243e+01,  3.6843e+00, -1.9825e+00, -5.3785e+01, -1.4062e+02, -3.6739e+00, -9.0785e+01, -8.1116e+00, -1.3535e+01, -1.

In [54]:
tmp_df - dF_pt_p.squeeze(-1)

tensor([[[ 0.0000e+00,  1.4365e+01,  1.7559e+01, -3.0896e+01,  8.6584e-01, -1.1946e+02, -2.6886e+01,  1.3853e+02,  2.2208e+01, -2.2344e+01,  1.6813e+01, -9.8553e+00,  7.8494e-01,  1.8454e+01, -2.4118e+01, -2.4169e+00,  2.6027e+01, -4.0103e+01,  5.9486e+00,  2.1620e+01,  1.9632e+00, -2.3128e+01,  1.9925e+02,  1.7430e+02,  7.4166e+01, -4.2349e+02, -1.5622e+00, -3.5893e-01, -5.6136e+00,  1.9289e+01, -2.8054e+01, -4.7048e+00,  7.7570e+00, -4.9050e+01,  8.6314e+00, -3.4837e+01,  6.8576e+01,  2.9937e+00,  7.3402e+00,  3.2612e+01, -2.3581e+01,  1.2652e+00,  1.3416e+01,  6.6231e+02, -5.3639e+02,
          -2.4616e+02,  3.9234e+01, -9.7015e+01,  1.2330e+02, -4.9873e+02,  9.9423e+01,  4.2693e+02, -6.4435e+00,  1.3336e+01, -9.6037e+00,  5.2025e+00,  2.2280e+00, -3.9017e+00, -6.8387e-01,  7.8281e+00, -1.4293e+01, -3.7778e-01, -1.2901e+01, -7.9710e+00,  1.7650e+01, -7.1856e-01, -3.7846e+00,  1.3559e+01,  5.6667e+00,  5.1802e+01,  8.6837e+01, -1.3695e+02,  8.7111e+01, -8.2673e+01,  5.4236e+00, -1.85

In [55]:
dF_pt_p.squeeze(-1), dF_pt_t

(tensor([[[ 0.0000e+00, -1.4365e+01, -3.1924e+01, -1.0278e+00, -1.8937e+00,  1.1756e+02,  1.4445e+02,  5.9197e+00, -1.6288e+01,  6.0553e+00, -1.0758e+01, -9.0284e-01, -1.6878e+00, -2.0142e+01,  3.9759e+00,  6.3928e+00, -1.9634e+01,  2.0469e+01,  1.4520e+01, -7.1001e+00, -9.0633e+00,  1.4065e+01, -1.8519e+02, -3.5949e+02, -4.3365e+02, -1.0164e+01, -8.6021e+00, -8.2432e+00, -2.6296e+00, -2.1919e+01,  6.1352e+00,  1.0840e+01,  3.0831e+00,  5.2133e+01,  4.3502e+01,  7.8339e+01,  9.7628e+00,  6.7691e+00, -5.7109e-01, -3.3183e+01, -9.6016e+00, -1.0867e+01, -2.4283e+01, -6.8659e+02, -1.5020e+02,
            9.5957e+01,  5.6723e+01,  1.5374e+02,  3.0433e+01,  5.2916e+02,  4.2974e+02,  2.8088e+00,  9.2523e+00, -4.0836e+00,  5.5201e+00,  3.1760e-01, -1.9104e+00,  1.9914e+00,  2.6752e+00, -5.1529e+00,  9.1402e+00,  9.5180e+00,  2.2419e+01,  3.0390e+01,  1.2740e+01,  1.3459e+01,  1.7243e+01,  3.6843e+00, -1.9825e+00, -5.3785e+01, -1.4062e+02, -3.6739e+00, -9.0785e+01, -8.1116e+00, -1.3535e+01, -1.

In [56]:
#other way to compute dF_pt_t
tmp_df2 = (-ks * dK_pt_t + qs * dQ_pt_t)
tmp_df2.sum(-1)

tensor([[[ 2.8328e+01,  1.1848e+01, -3.6369e+01, -1.2663e+00, -2.8673e+02,  1.7127e+01,  2.4968e+02,  6.3810e+01, -5.4520e+01,  2.3454e+01, -1.3099e+01,  8.2680e-01,  3.8521e+01, -6.3575e+01,  3.3265e+00,  5.0429e+01, -7.4227e+01, -6.5370e+00,  5.7932e+01,  2.9171e+01, -6.1506e+01,  5.0140e+02,  2.0122e+02,  7.4573e+02, -1.4083e+03, -1.4734e+00,  1.1874e+01, -1.4968e+01,  2.1155e+01, -4.2539e+01, -6.1496e+00,  1.2046e+01, -8.7825e+01, -3.3325e+01,  2.3649e+01,  4.3266e+01,  4.9306e+01,  8.9563e+00,  9.5718e+01, -8.3539e+01,  6.3104e+00,  5.4798e+01,  1.0818e+03, -9.3826e+02, -4.7983e+02,
           6.6722e+00,  7.2832e+01,  1.3350e+02, -8.2461e+02,  2.8403e+01,  8.4079e+02, -1.6311e+01,  2.6489e+01, -1.1355e+01,  3.9370e+00,  4.6647e+00, -5.4528e+00, -1.0555e+01,  2.1244e+01, -2.6912e+01, -3.3198e+00, -2.9990e+01,  2.6362e+00,  1.1228e+01,  1.8584e+01, -1.3318e+01,  1.5357e+01,  2.6280e+01,  9.5755e+01,  7.0987e+01, -1.7006e+02,  2.4510e+02, -2.2579e+02,  5.3121e+00,  1.3614e+01, -1.34

In [57]:
qIdx = 0
kvIdx = 0

bq_idxes = torch.arange(
    qIdx * BLOCK_Q, (qIdx + 1) * BLOCK_Q, device=DEVICE
)
kv_idxes = torch.arange(
    kvIdx * BLOCK_KV, (kvIdx + 1) * BLOCK_KV, device=DEVICE
)
idx_mask = bq_idxes[:, None] - kv_idxes[None, :]

In [58]:
idx_mask

tensor([[  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31],
        [  1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30],
        [  2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29],
        [  3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28],
        [  4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27],
        [  5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -

In [59]:
qIdx = 1
kvIdx = 1

bq_idxes = torch.arange(
    qIdx * BLOCK_Q, (qIdx + 1) * BLOCK_Q, device=DEVICE
)
kv_idxes = torch.arange(
    kvIdx * BLOCK_KV, (kvIdx + 1) * BLOCK_KV, device=DEVICE
) + 1
idx_mask = bq_idxes[:, None] - kv_idxes[None, :]
idx_mask

tensor([[31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0],
        [32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1],
        [33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2],
        [34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3],
        [35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4],
        [36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6,  5],
        [37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7,  6],
        [38, 37, 36, 35, 34, 33, 3