In [1]:
import torch

## PyTorch vLSTM forward with group norm (headwise layernorm)

Shows what happens if we fuse the multihead layernorm with the vlstm kernel.

In [2]:
%load_ext autoreload
%autoreload 2
# from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel import vlstm_parallel_fw_torch
from vlstm_parallel_w_groupnorm import vlstm_parallel_fwbw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm_full
from ln import MultiHeadLayerNorm

In [3]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(1)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8, 1]))

In [5]:
offset = 3.* torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

In [6]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64,
        requires_grad=True),
 None)

### pytorch

In [7]:
fgs_pt = fgs.clone().detach().requires_grad_(True)
igs_pt = igs.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [8]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [9]:
qs_pt.grad

In [10]:
rs = vlstm_parallel_fw_torch(
    queries=qs_pt,
    keys=ks_pt,
    values=vs_pt,
    igate_preact=igs_pt,
    fgate_preact=fgs_pt,
    eps=EPS,
)
rs_scaled = mh_layernorm(rs)
rs_scaled, rs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4225],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [11]:
((rs_scaled+offset)**2).sum().backward()

In [12]:
rs.shape # (B, NH, S, DH)

torch.Size([1, 1, 8, 4])

In [13]:
qs_pt.grad

tensor([[[[-7.6576e-04,  5.0582e-04,  3.3532e-04, -6.0494e-04],
          [-7.7297e-01,  9.4263e-01, -6.7030e-01, -1.0986e+00],
          [ 4.6686e-01, -1.0144e-01, -7.7629e-01, -7.6172e-02],
          [ 3.4602e+00, -8.3292e+00, -8.0756e+00, -8.4937e+00],
          [ 1.3444e+00, -4.4481e-01,  1.1625e+00, -1.3412e+00],
          [-1.9742e+01, -6.0909e+00, -1.7204e+01, -8.8154e+00],
          [-7.0999e+00,  1.6160e+00,  7.0835e+00, -5.7153e-02],
          [ 5.5990e+00, -2.8259e-01, -1.5718e+01, -2.6351e+00]]]],
       device='cuda:0', dtype=torch.float64)

In [14]:
rs2 = rs#.transpose(1, 2)
rs2.shape

torch.Size([1, 1, 8, 4])

In [15]:
rs3 = (rs2 - rs2.mean(-1, keepdim=True)) / rs2.std(-1, keepdim=True, unbiased=False)
# rs4 = rs3.transpose(1, 2)
rs3, rs3.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4226],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<DivBackward0>),
 torch.Size([1, 1, 8, 4]))

In [16]:
rs3 - rs_scaled

tensor([[[[-3.8469e-06,  1.7901e-05,  5.2088e-06, -1.9262e-05],
          [ 1.0619e-06,  1.4970e-06, -3.8655e-07, -2.1723e-06],
          [ 2.6418e-07,  7.7482e-07, -8.0577e-07, -2.3323e-07],
          [ 5.2545e-06, -2.8487e-06, -4.4742e-06,  2.0684e-06],
          [ 9.5230e-07, -3.9714e-07, -6.5392e-07,  9.8760e-08],
          [ 8.9679e-07, -1.2658e-06,  3.8687e-07, -1.7821e-08],
          [ 2.1787e-07, -5.8553e-07,  4.8230e-07, -1.1463e-07],
          [ 6.1161e-06, -3.3540e-07, -4.2443e-06, -1.5364e-06]]]],
       device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

### own backward

In [17]:
fgs_obw = fgs.clone().detach().requires_grad_(True)
igs_obw = igs.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
hs, var_b, var_m = vlstm_parallel_fwbw_torch_w_groupnorm_full(
    queries=qs_obw,
    keys=ks_obw,
    values=vs_obw,
    igate_preact=igs_obw,
    fgate_preact=fgs_obw,
    eps=EPS,
)
hs, hs.shape
hs_scaled = mh_layernorm(hs)
hs_scaled, hs_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4225],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [19]:
hs_scaled - rs_scaled

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [20]:
((hs_scaled+offset)**2).sum().backward()

tensor([[[[0.0000e+00],
          [0.0000e+00],
          [7.1797e-06],
          [0.0000e+00],
          [1.2116e-05],
          [3.1064e-05],
          [1.6215e-05],
          [0.0000e+00]]]], device='cuda:0', dtype=torch.float64)


In [21]:
qs_obw.grad, qs_pt.grad

(tensor([[[[-7.6576e-04,  5.0582e-04,  3.3532e-04, -6.0494e-04],
           [-7.7297e-01,  9.4263e-01, -6.7030e-01, -1.0986e+00],
           [ 4.6686e-01, -1.0144e-01, -7.7628e-01, -7.6160e-02],
           [ 3.4602e+00, -8.3292e+00, -8.0756e+00, -8.4937e+00],
           [ 1.3444e+00, -4.4479e-01,  1.1625e+00, -1.3412e+00],
           [-1.9742e+01, -6.0909e+00, -1.7204e+01, -8.8154e+00],
           [-7.0999e+00,  1.6160e+00,  7.0835e+00, -5.7152e-02],
           [ 5.5990e+00, -2.8259e-01, -1.5718e+01, -2.6351e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-7.6576e-04,  5.0582e-04,  3.3532e-04, -6.0494e-04],
           [-7.7297e-01,  9.4263e-01, -6.7030e-01, -1.0986e+00],
           [ 4.6686e-01, -1.0144e-01, -7.7629e-01, -7.6172e-02],
           [ 3.4602e+00, -8.3292e+00, -8.0756e+00, -8.4937e+00],
           [ 1.3444e+00, -4.4481e-01,  1.1625e+00, -1.3412e+00],
           [-1.9742e+01, -6.0909e+00, -1.7204e+01, -8.8154e+00],
           [-7.0999e+00,  1.6160e+00,  7

In [22]:
var_b.abs(), torch.exp(-var_m), var_b.abs() > torch.exp(-var_m)

(tensor([[[[0.2209],
           [0.2700],
           [1.5735],
           [0.5737],
           [1.5030],
           [0.8177],
           [0.9032],
           [0.1300]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<AbsBackward0>),
 tensor([[[[0.9641],
           [0.4819],
           [0.6707],
           [0.6399],
           [0.5793],
           [0.4735],
           [0.3951],
           [0.6694]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<ExpBackward0>),
 tensor([[[[False],
           [False],
           [ True],
           [False],
           [ True],
           [ True],
           [ True],
           [False]]]], device='cuda:0'))

In [23]:
qs_pt.grad - qs_obw.grad, ks_pt.grad - ks_obw.grad, vs_pt.grad - vs_obw.grad

(tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [-1.5314e-06,  1.1898e-06, -4.0425e-06, -1.1652e-05],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 6.4882e-06, -1.3727e-05,  6.1535e-06, -1.6513e-05],
           [ 3.1307e-05, -1.2041e-05,  1.1293e-05,  1.6676e-06],
           [ 2.7021e-06, -1.6586e-06,  1.1334e-05, -1.3535e-06],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[ 1.6914e-06,  8.2518e-07,  1.4939e-06,  1.6078e-06],
           [ 5.2692e-06,  2.5707e-06,  4.6539e-06,  5.0088e-06],
           [ 6.8855e-06,  3.3593e-06,  6.0815e-06,  6.5453e-06],
           [ 5.3227e-07,  1.3101e-05,  4.5358e-06,  8.8716e-07],
           [ 8.7948e-07,  2.1647e-05,  7.4946e-06,  1.4659e-06],
           [-1.2351e-05,  9.8664e-06,  5.1246e-06, -8.8750e-06],
           [-6.2944e-06,  5.1208e-06, -7

In [24]:
fgs_pt.grad - fgs_obw.grad, igs_pt.grad - igs_obw.grad

(tensor([[[[-6.0397e-16],
           [-2.7182e-07],
           [ 7.4923e-07],
           [-5.8534e-07],
           [-1.2563e-05],
           [-7.5149e-07],
           [-5.3494e-06],
           [ 2.0817e-16]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[-7.5963e-07],
           [ 2.4240e-06],
           [-2.0677e-05],
           [-3.6671e-05],
           [ 8.3042e-06],
           [-3.4175e-05],
           [ 2.4098e-06],
           [ 0.0000e+00]]]], device='cuda:0', dtype=torch.float64))

In [25]:
torch.allclose(qs_pt.grad, qs_obw.grad, atol=1e-5, rtol=1e-5)

False

In [26]:
atol = 1e-5
rtol = 1e-5
print(f"Forward match: {torch.allclose(hs_scaled, rs_scaled)}")
print(f"qs match: {torch.allclose(qs_pt.grad, qs_obw.grad, atol=atol, rtol=rtol)}")
print(f"ks match: {torch.allclose(ks_pt.grad, ks_obw.grad, atol=atol, rtol=rtol)}")
print(f"vs match: {torch.allclose(vs_pt.grad, vs_obw.grad, atol=atol, rtol=rtol)}")
print(f"fgate_preacts match: {torch.allclose(fgs_pt.grad, fgs_obw.grad, atol=atol, rtol=rtol)}")
print(f"igate_preacts match: {torch.allclose(igs_pt.grad, igs_obw.grad, atol=atol, rtol=rtol)}")

Forward match: True
qs match: False
ks match: True
vs match: True
fgate_preacts match: True
igate_preacts match: True


In [27]:
## Conclusion: 
# dividing we get the same gradients, the error -1e-5 is due to numerical precision

In [28]:
# tensor([[[[-3.7828e-17, -2.8809e-17, -1.0405e-16,  4.7115e-17],
#           [-1.1102e-16, -4.4409e-16, -8.8818e-16,  4.4409e-16],
#           [ 8.8818e-16, -8.8818e-16,  0.0000e+00, -8.3267e-17],
#           [ 1.2490e-16, -5.5511e-17,  1.3878e-17, -1.7347e-17],
#           [ 1.8937e-06, -1.0558e-05, -2.7356e-06,  1.5993e-05],
#           [ 4.4409e-16,  0.0000e+00,  1.3323e-15,  0.0000e+00],
#           [-6.0601e-07, -9.2155e-07,  4.0914e-06, -5.0413e-07],
#           [-1.2725e-06,  1.7485e-06,  5.4959e-06,  8.2998e-07]]]],
#        device='cuda:0', dtype=torch.float64)

### own backward2

Reimplementation by using separate function for fw and bw which serve as ground truth for kernel impl.
They should match exactly own backward(1).

In [29]:
fgs_obw2 = fgs.clone().detach().requires_grad_(True)
igs_obw2 = igs.clone().detach().requires_grad_(True)
qs_obw2 = qs.clone().detach().requires_grad_(True)
ks_obw2 = ks.clone().detach().requires_grad_(True)
vs_obw2 = vs.clone().detach().requires_grad_(True)

In [30]:
hs2, var_b2, var_m2 = vlstm_parallel_fwbw_torch_w_groupnorm(
    queries=qs_obw2,
    keys=ks_obw2,
    values=vs_obw2,
    igate_preact=igs_obw2,
    fgate_preact=fgs_obw2,
    eps=EPS,
)
hs2, hs2.shape
hs_scaled2 = mh_layernorm(hs2)
hs_scaled2, hs_scaled2.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4225],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [31]:
hs_scaled - hs_scaled2

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [32]:
((hs_scaled2+offset)**2).sum().backward()

In [33]:
qs_obw.grad - qs_obw2.grad, ks_obw.grad - ks_obw2.grad, vs_obw.grad - vs_obw2.grad

(tensor([[[[-7.9426e-10,  5.2465e-10,  3.4780e-10, -6.2745e-10],
           [-1.6039e-06,  1.9560e-06, -1.3909e-06, -2.2797e-06],
           [ 2.9670e-07, -6.4466e-08, -4.9334e-07, -4.8401e-08],
           [ 5.4072e-06, -1.3016e-05, -1.2619e-05, -1.3273e-05],
           [ 8.9451e-07, -2.9594e-07,  7.7347e-07, -8.9234e-07],
           [-2.4144e-05, -7.4488e-06, -2.1040e-05, -1.0781e-05],
           [-7.8605e-06,  1.7892e-06,  7.8424e-06, -6.3274e-08],
           [ 8.3637e-06, -4.2214e-07, -2.3480e-05, -3.9363e-06]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-9.0393e-06,  2.8399e-07,  8.6789e-07,  3.3980e-06],
           [-5.2885e-06, -2.5149e-06,  1.5947e-06, -7.6160e-06],
           [ 1.7443e-06,  2.3299e-06,  1.0836e-06,  1.3475e-07],
           [ 1.0506e-06,  1.5316e-06, -4.8691e-07,  4.3572e-06],
           [ 1.3810e-05, -4.2640e-06, -6.2281e-06,  2.0252e-06],
           [ 2.2273e-06, -9.4424e-06, -9.5155e-06,  1.1184e-05],
           [-1.6967e-05, -1.0631e-05, -1

In [35]:
fgs_obw.grad - fgs_obw2.grad, igs_obw.grad - igs_obw2.grad

(tensor([[[[ 0.0000e+00],
           [ 1.0675e-06],
           [ 2.2536e-06],
           [ 1.8212e-06],
           [-1.5752e-06],
           [-1.4597e-06],
           [ 3.0853e-06],
           [-1.3558e-07]]]], device='cuda:0', dtype=torch.float64),
 tensor([[[[ 2.9840e-06],
           [ 2.0232e-06],
           [-1.1594e-06],
           [-8.5998e-06],
           [ 4.1478e-07],
           [ 1.3358e-05],
           [-9.3505e-06],
           [ 3.3107e-07]]]], device='cuda:0', dtype=torch.float64))

In [34]:
qs_obw.grad, qs_obw2.grad

(tensor([[[[-7.6576e-04,  5.0582e-04,  3.3532e-04, -6.0494e-04],
           [-7.7297e-01,  9.4263e-01, -6.7030e-01, -1.0986e+00],
           [ 4.6686e-01, -1.0144e-01, -7.7628e-01, -7.6160e-02],
           [ 3.4602e+00, -8.3292e+00, -8.0756e+00, -8.4937e+00],
           [ 1.3444e+00, -4.4479e-01,  1.1625e+00, -1.3412e+00],
           [-1.9742e+01, -6.0909e+00, -1.7204e+01, -8.8154e+00],
           [-7.0999e+00,  1.6160e+00,  7.0835e+00, -5.7152e-02],
           [ 5.5990e+00, -2.8259e-01, -1.5718e+01, -2.6351e+00]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[-7.6576e-04,  5.0582e-04,  3.3532e-04, -6.0494e-04],
           [-7.7297e-01,  9.4263e-01, -6.7030e-01, -1.0986e+00],
           [ 4.6686e-01, -1.0144e-01, -7.7628e-01, -7.6160e-02],
           [ 3.4602e+00, -8.3292e+00, -8.0756e+00, -8.4937e+00],
           [ 1.3444e+00, -4.4479e-01,  1.1625e+00, -1.3412e+00],
           [-1.9742e+01, -6.0909e+00, -1.7204e+01, -8.8154e+00],
           [-7.0999e+00,  1.6160e+00,  7

In [None]:
# Conclusion: the impls match, the error is max 1e-5 is due to numerical precision