In [1]:
import torch

## PyTorch vLSTM backward with group norm tiled (headwise layernorm)

Shows what happens if we fuse the multihead layernorm with the vlstm kernel.
Here we compare the tiled impl (template for triton&cuda kernels) with the parallel impl and check for numerical correctness.

In [2]:
%load_ext autoreload
%autoreload 2
# from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel_w_groupnorm import vlstm_parallel_bw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm_full
from vlstm_parallel import vlstm_parallel_fw_torch

from vlstm_parallel_w_groupnorm_tiled_bw import mlstm_parallel_w_groupnorm_torch_tiled_bw, vlstm_parallel_w_groupnorm_torch_bw

from ln import MultiHeadLayerNorm

In [3]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(1)

# forward inputs
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
# igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

# backward inputs
dH = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vecN = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
vecM = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8, 1]))

In [5]:
offset = 3.* torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

In [6]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64,
        requires_grad=True),
 None)

## BW parallel with groupnorm

### pytorch

In [7]:
fgs_pt = fgs.clone().detach().requires_grad_(True)
igs_pt = igs.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [8]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [9]:
qs_pt.grad

In [10]:
rs = vlstm_parallel_fw_torch(
    queries=qs_pt,
    keys=ks_pt,
    values=vs_pt,
    igate_preact=igs_pt,
    fgate_preact=fgs_pt,
    eps=EPS,
)
rs_scaled = mh_layernorm(rs)
rs_scaled, rs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [11]:
((rs_scaled+offset)**2).sum().backward()

In [12]:
rs.shape # (B, NH, S, DH)

torch.Size([1, 1, 8, 4])

In [13]:
qs_pt.grad

tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
          [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
          [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
          [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
          [ 9.8784e-01, -4.8033e-01,  1.2925e+00, -9.1624e-01],
          [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
          [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
          [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]],
       device='cuda:0', dtype=torch.float64)

In [14]:
rs2 = rs#.transpose(1, 2)
rs2.shape

torch.Size([1, 1, 8, 4])

In [15]:
rs3 = (rs2 - rs2.mean(-1, keepdim=True)) / rs2.std(-1, keepdim=True, unbiased=False)
# rs4 = rs3.transpose(1, 2)
rs3, rs3.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5336],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<DivBackward0>),
 torch.Size([1, 1, 8, 4]))

In [16]:
rs3 - rs_scaled

tensor([[[[-7.4993e-07,  3.4896e-06,  1.0154e-06, -3.7551e-06],
          [ 3.4323e-06,  5.8879e-06, -1.0780e-06, -8.2422e-06],
          [ 2.2965e-06,  3.4939e-06, -4.9360e-06, -8.5439e-07],
          [ 2.5082e-06, -1.6520e-06, -8.3956e-07, -1.6605e-08],
          [ 1.5938e-06, -9.8798e-07, -5.6463e-07, -4.1209e-08],
          [ 1.0009e-06, -1.6121e-06,  6.8374e-07, -7.2535e-08],
          [ 5.5659e-07, -1.1214e-06,  7.5538e-07, -1.9061e-07],
          [ 3.9179e-06, -4.7970e-06,  2.4964e-05, -2.4085e-05]]]],
       device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

### own backward

In [17]:
fgs_obw = fgs.clone().detach().requires_grad_(True)
igs_obw = igs.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
hs, var_b, var_m = vlstm_parallel_fwbw_torch_w_groupnorm_full(
    queries=qs_obw,
    keys=ks_obw,
    values=vs_obw,
    igate_preact=igs_obw,
    fgate_preact=fgs_obw,
    eps=EPS,
)
hs, hs.shape
hs_scaled = mh_layernorm(hs)
hs_scaled, hs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [19]:
hs_scaled - rs_scaled

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [20]:
((hs_scaled+offset)**2).sum().backward()

tensor([[[[ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [-4.0645e-06],
          [ 6.1824e-06],
          [ 1.0703e-05],
          [ 0.0000e+00]]]], device='cuda:0', dtype=torch.float64)


In [21]:
qs_obw.grad, qs_pt.grad

(tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
           [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8033e-01,  1.2925e+00, -9.1624e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1

In [22]:
var_b.abs(), torch.exp(-var_m), var_b.abs() > torch.exp(-var_m)

(tensor([[[[0.2209],
           [0.3124],
           [0.8088],
           [0.3172],
           [2.6139],
           [1.0345],
           [1.5021],
           [0.4596]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<AbsBackward0>),
 tensor([[[[0.4257],
           [1.0288],
           [1.6999],
           [0.3699],
           [0.5299],
           [0.7055],
           [1.1781],
           [0.5167]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<ExpBackward0>),
 tensor([[[[False],
           [False],
           [False],
           [False],
           [ True],
           [ True],
           [ True],
           [False]]]], device='cuda:0'))

In [23]:
qs_pt.grad - qs_obw.grad, ks_pt.grad - ks_obw.grad, vs_pt.grad - vs_obw.grad

(tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [-2.5606e-06,  1.4259e-05, -3.7683e-07,  6.6377e-06],
           [ 7.3301e-06, -4.2211e-06,  1.7718e-06,  1.7949e-06],
           [ 6.4459e-06, -5.9048e-06,  1.1989e-05,  1.4505e-06],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-1.3196e-07, -1.1467e-07, -7.7522e-08, -1.3573e-07],
           [-2.0452e-07, -1.7772e-07, -1.2015e-07, -2.1036e-07],
           [-8.9864e-08, -7.8086e-08, -5.2791e-08, -9.2430e-08],
           [-7.8008e-06, -6.7784e-06, -4.5826e-06, -8.0236e-06],
           [-1.4118e-06, -1.2267e-06, -8.2935e-07, -1.4521e-06],
           [-7.6242e-06,  6.1568e-06, -4.1334e-06, -9.8372e-06],
           [-6.9098e-06,  5.6214e-06, -8

In [24]:
fgs_pt.grad - fgs_obw.grad, igs_pt.grad - igs_obw.grad

(tensor([[[[ 6.0959e-16],
           [ 4.0919e-08],
           [ 5.5045e-09],
           [ 2.2432e-07],
           [ 7.5696e-06],
           [-2.4253e-06],
           [-1.4291e-05],
           [-1.1102e-16]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[ 5.5821e-08],
           [-4.1876e-08],
           [ 2.4110e-07],
           [ 2.4811e-05],
           [-9.3009e-07],
           [-2.9778e-05],
           [ 2.6454e-06],
           [-3.5527e-15]]]], device='cuda:0', dtype=torch.float64))

In [25]:
torch.allclose(qs_pt.grad, qs_obw.grad, atol=1e-5, rtol=1e-5)

True

In [26]:
atol = 1e-5
rtol = 1e-5
print(f"Forward match: {torch.allclose(hs_scaled, rs_scaled)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgs_pt.grad, fgs_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igs_pt.grad, igs_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: True
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True


In [27]:
## Conclusion: 
# dividing we get the same gradients, the error -1e-5 is due to numerical precision

In [28]:
# tensor([[[[-3.7828e-17, -2.8809e-17, -1.0405e-16,  4.7115e-17],
#           [-1.1102e-16, -4.4409e-16, -8.8818e-16,  4.4409e-16],
#           [ 8.8818e-16, -8.8818e-16,  0.0000e+00, -8.3267e-17],
#           [ 1.2490e-16, -5.5511e-17,  1.3878e-17, -1.7347e-17],
#           [ 1.8937e-06, -1.0558e-05, -2.7356e-06,  1.5993e-05],
#           [ 4.4409e-16,  0.0000e+00,  1.3323e-15,  0.0000e+00],
#           [-6.0601e-07, -9.2155e-07,  4.0914e-06, -5.0413e-07],
#           [-1.2725e-06,  1.7485e-06,  5.4959e-06,  8.2998e-07]]]],
#        device='cuda:0', dtype=torch.float64)

### own backward2

Reimplementation by using separate function for fw and bw which serve as ground truth for kernel impl.
They should match exactly own backward(1).

In [29]:
fgs_obw2 = fgs.clone().detach().requires_grad_(True)
igs_obw2 = igs.clone().detach().requires_grad_(True)
qs_obw2 = qs.clone().detach().requires_grad_(True)
ks_obw2 = ks.clone().detach().requires_grad_(True)
vs_obw2 = vs.clone().detach().requires_grad_(True)

In [30]:
hs2, var_b2, var_m2 = vlstm_parallel_fwbw_torch_w_groupnorm(
    queries=qs_obw2,
    keys=ks_obw2,
    values=vs_obw2,
    igate_preact=igs_obw2,
    fgate_preact=fgs_obw2,
    eps=EPS,
)
hs2, hs2.shape
hs_scaled2 = mh_layernorm(hs2)
hs_scaled2, hs_scaled2.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.6386,  1.0955, -0.2006, -1.5335],
           [ 0.7039,  1.0709, -1.5130, -0.2619],
           [ 1.6086, -1.0595, -0.5384, -0.0106],
           [ 1.6273, -1.0088, -0.5765, -0.0421],
           [ 0.9918, -1.5975,  0.6775, -0.0719],
           [ 0.7550, -1.5210,  1.0246, -0.2585],
           [ 0.2224, -0.2723,  1.4169, -1.3670]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [31]:
hs_scaled - hs_scaled2

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [32]:
((hs_scaled2+offset)**2).sum().backward()

In [33]:
qs_obw.grad - qs_obw2.grad, ks_obw.grad - ks_obw2.grad, vs_obw.grad - vs_obw2.grad

(tensor([[[[-1.3020e-10,  8.6005e-11,  5.7015e-11, -1.0286e-10],
           [ 5.3706e-07, -6.5505e-07,  4.6599e-07,  7.6344e-07],
           [-2.5350e-06,  3.9298e-07,  4.4202e-06,  2.0236e-07],
           [-7.5124e-07,  3.8170e-06, -5.5758e-07, -1.4132e-07],
           [ 3.7791e-07, -1.8376e-07,  4.9448e-07, -3.5052e-07],
           [-4.0115e-06, -4.4232e-06, -1.8519e-06, -4.1562e-06],
           [ 9.0684e-07,  2.0908e-07, -1.0949e-06,  3.4366e-07],
           [-4.9801e-05,  2.0900e-05, -4.6278e-05, -1.7950e-05]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[ 6.1895e-06, -2.2341e-06,  1.7210e-06,  2.6788e-06],
           [ 3.8116e-06, -2.7216e-06,  2.7156e-06,  6.0842e-06],
           [ 8.4303e-08,  4.1735e-07, -5.1106e-07, -2.3957e-07],
           [-8.9174e-06, -1.0295e-05, -3.8709e-06,  2.0779e-05],
           [ 1.4558e-06,  2.0089e-06,  4.3810e-07, -1.9617e-06],
           [-1.6445e-05, -2.9992e-05, -1.6734e-05,  5.6504e-05],
           [-1.6525e-05, -2.8659e-05, -1

In [34]:
fgs_obw.grad - fgs_obw2.grad, igs_obw.grad - igs_obw2.grad

(tensor([[[[ 0.0000e+00],
           [-2.8527e-06],
           [-7.4537e-07],
           [-5.4030e-07],
           [ 1.7050e-07],
           [ 1.9369e-06],
           [ 1.4342e-05],
           [ 1.2706e-06]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-3.8915e-06],
           [ 2.0035e-06],
           [ 1.2739e-06],
           [ 1.1791e-06],
           [ 2.3387e-06],
           [ 2.3987e-05],
           [-2.0978e-05],
           [-5.9123e-06]]]], device='cuda:0', dtype=torch.float64))

In [35]:
qs_obw.grad, qs_obw2.grad, qs_obw.grad - qs_obw2.grad

(tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7787e-01,  1.4118e+00, -2.0624e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1.6447e+00,  5.1623e-01],
           [-2.5732e+01,  1.0798e+01, -2.3911e+01, -9.2743e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-5.5424e-05,  3.6611e-05,  2.4270e-05, -4.3784e-05],
           [ 5.5255e-01, -6.7394e-01,  4.7943e-01,  7.8546e-01],
           [-4.3091e+00,  6.6800e-01,  7.5137e+00,  3.4398e-01],
           [-2.7786e-01,  1.4118e+00, -2.0623e-01, -5.2272e-02],
           [ 9.8784e-01, -4.8034e-01,  1.2925e+00, -9.1625e-01],
           [-4.1501e+00, -4.5760e+00, -1.9159e+00, -4.2997e+00],
           [ 1.3622e+00,  3.1406e-01, -1

In [36]:
# Conclusion: the impls match, the error is max 1e-5 is due to numerical precision

## BW parallel TILED with groupnorm

In [37]:
dQ_pt_p, dK_pt_p, dV_pt_p, dI_pt_p, dF_pt_p = vlstm_parallel_w_groupnorm_torch_bw(matDeltaHtilde=dH, matQ=qs, matK=ks, matV=vs, vecN=vecN, vecM=vecM, vecI=igs, vecF=fgs)

In [38]:
dQ_pt_p, dK_pt_p, dV_pt_p, dI_pt_p, dF_pt_p

(tensor([[[[ 7.5400e-01, -4.9806e-01, -3.3017e-01,  5.9565e-01],
           [-6.7514e-02,  4.3178e-02,  3.2875e-02, -5.1733e-02],
           [ 1.1512e-02, -3.2063e-02,  8.4872e-02,  1.1490e-01],
           [ 7.4273e+00, -5.7824e+01, -7.5553e-01, -2.4850e+01],
           [-2.6519e-01,  3.6867e+00,  5.8413e-01,  1.3191e+00],
           [-2.7861e+00, -1.4612e-01, -8.6080e-01, -1.5563e+00],
           [ 1.9155e+01, -6.8366e+00, -7.8849e+00,  5.2391e+00],
           [-1.4842e+00, -3.2384e+00,  1.4640e+00, -4.3700e-01]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[  1.4249,  -1.1815,  -0.8963,   0.7549],
           [  1.7919,   0.0621,  -0.6536,   1.5702],
           [  0.0869,   0.0401,  -0.0827,   0.0721],
           [ 13.3159,   0.8007, -10.6145,   9.5685],
           [ -0.1685,   1.3196,  -0.6682,  -0.8364],
           [ -2.3977,   2.3395,  -5.0113,  -6.0130],
           [  8.4134,  -7.0458,  10.1382,  14.5689],
           [  1.0261,   1.6375,   0.8696,  -3.1824]]]], dev