In [1]:
import torch

## PyTorch vLSTM forward with group norm (headwise layernorm)

Shows what happens if we fuse the multihead layernorm with the vlstm kernel.

In [2]:
%load_ext autoreload
%autoreload 2
# from vlstm_parallel_tiled import vlstm_parallel_tiled
from vlstm_parallel import vlstm_parallel_fw_torch
from vlstm_parallel_w_groupnorm import vlstm_parallel_fw_torch_w_groupnorm, vlstm_parallel_fwbw_torch_w_groupnorm
from ln import MultiHeadLayerNorm

In [3]:
# params
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1)
fgs = torch.rand((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8, 1]))

In [5]:
offset = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

In [6]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64,
        requires_grad=True),
 None)

### pytorch

In [7]:
fgs_pt = fgs.clone().detach().requires_grad_(True)
igs_pt = igs.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [8]:
# rs = vlstm_fw_torch(
#     queries=qs,
#     keys=ks,
#     values=vs,
#     igate_preact=igs,
#     fgate_preact=fgs,
#     stabilize_rowwise=True,
# )
# rs, rs.shape

In [9]:
qs_pt.grad

In [10]:
rs = vlstm_parallel_fw_torch(
    queries=qs_pt,
    keys=ks_pt,
    values=vs_pt,
    igate_preact=igs_pt,
    fgate_preact=fgs_pt,
    eps=EPS,
)
rs_scaled = mh_layernorm(rs)
rs_scaled, rs_scaled.shape

(tensor([[[[ 1.2508, -1.1810, -0.7554,  0.6856],
           [-1.4225, -0.2860,  1.3208,  0.3877],
           [ 1.1236,  0.4122, -1.6011,  0.0653],
           [ 0.9885, -0.1728, -1.5612,  0.7455],
           [ 0.8213, -1.0949, -0.8853,  1.1589],
           [-0.8941,  0.1563,  1.5739, -0.8361],
           [-1.4935,  0.9028,  0.9198, -0.3292],
           [-1.6608,  0.7769,  0.0900,  0.7939]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [11]:
((rs_scaled+offset)**2).sum().backward()

In [12]:
rs.shape # (B, NH, S, DH)

torch.Size([1, 1, 8, 4])

In [13]:
qs_pt.grad

tensor([[[[ 1.3274e-05,  1.0109e-05,  3.6511e-05, -1.6533e-05],
          [ 5.3569e-01,  2.4857e+00,  3.0122e+00, -1.4810e+00],
          [-6.8308e+00,  7.2907e+00,  2.5978e+00,  2.0262e-01],
          [ 8.0473e-02, -3.9897e-01,  8.5354e-02,  2.3112e-02],
          [ 2.6339e+00,  1.2841e+00,  7.9257e-01, -3.0563e+00],
          [ 3.3218e+00, -5.3795e+00,  3.3445e+00,  3.0653e+00],
          [ 1.2528e+00, -9.5476e-01,  2.2962e+00, -2.4906e-01],
          [ 1.1260e+00, -8.5640e-01,  1.0127e-01, -4.8459e-01]]]],
       device='cuda:0', dtype=torch.float64)

In [14]:
rs2 = rs#.transpose(1, 2)
rs2.shape

torch.Size([1, 1, 8, 4])

In [15]:
rs3 = (rs2 - rs2.mean(-1, keepdim=True)) / rs2.std(-1, keepdim=True, unbiased=False)
# rs4 = rs3.transpose(1, 2)
rs3, rs3.shape

(tensor([[[[ 1.2508, -1.1810, -0.7554,  0.6856],
           [-1.4225, -0.2860,  1.3208,  0.3877],
           [ 1.1236,  0.4122, -1.6011,  0.0653],
           [ 0.9885, -0.1728, -1.5612,  0.7455],
           [ 0.8213, -1.0949, -0.8853,  1.1589],
           [-0.8941,  0.1563,  1.5739, -0.8361],
           [-1.4935,  0.9028,  0.9198, -0.3292],
           [-1.6608,  0.7769,  0.0900,  0.7939]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<DivBackward0>),
 torch.Size([1, 1, 8, 4]))

In [16]:
rs3 - rs_scaled

tensor([[[[ 1.4372e-06, -1.3570e-06, -8.6796e-07,  7.8776e-07],
          [-1.8984e-06, -3.8173e-07,  1.7627e-06,  5.1739e-07],
          [ 1.8660e-06,  6.8450e-07, -2.6590e-06,  1.0852e-07],
          [ 2.5576e-08, -4.4698e-09, -4.0394e-08,  1.9289e-08],
          [ 1.5123e-06, -2.0160e-06, -1.6302e-06,  2.1339e-06],
          [-7.6744e-07,  1.3418e-07,  1.3509e-06, -7.1761e-07],
          [-1.6035e-06,  9.6934e-07,  9.8759e-07, -3.5345e-07],
          [-1.0968e-06,  5.1307e-07,  5.9439e-08,  5.2432e-07]]]],
       device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

### own backward

In [17]:
fgs_obw = fgs.clone().detach().requires_grad_(True)
igs_obw = igs.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [18]:
hs, var_b, var_m = vlstm_parallel_fwbw_torch_w_groupnorm(
    queries=qs_obw,
    keys=ks_obw,
    values=vs_obw,
    igate_preact=igs_obw,
    fgate_preact=fgs_obw,
    eps=EPS,
)
hs, hs.shape
hs_scaled = mh_layernorm(hs)
hs_scaled, hs_scaled.shape

(tensor([[[[ 1.2508, -1.1810, -0.7554,  0.6856],
           [-1.4225, -0.2860,  1.3208,  0.3877],
           [ 1.1236,  0.4122, -1.6011,  0.0653],
           [ 0.9885, -0.1728, -1.5612,  0.7455],
           [ 0.8213, -1.0949, -0.8853,  1.1589],
           [-0.8941,  0.1563,  1.5739, -0.8361],
           [-1.4935,  0.9028,  0.9198, -0.3292],
           [-1.6608,  0.7769,  0.0900,  0.7939]]]], device='cuda:0',
        dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [19]:
hs_scaled - rs_scaled

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)

In [20]:
((hs_scaled+offset)**2).sum().backward()

tensor([[[[ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 1.9281e-05],
          [ 0.0000e+00],
          [-5.1218e-06],
          [-2.0799e-06]]]], device='cuda:0', dtype=torch.float64)


In [21]:
qs_obw.grad, qs_pt.grad

(tensor([[[[ 1.3274e-05,  1.0109e-05,  3.6511e-05, -1.6533e-05],
           [ 5.3569e-01,  2.4857e+00,  3.0122e+00, -1.4810e+00],
           [-6.8308e+00,  7.2907e+00,  2.5978e+00,  2.0262e-01],
           [ 8.0473e-02, -3.9897e-01,  8.5354e-02,  2.3112e-02],
           [ 2.6339e+00,  1.2841e+00,  7.9257e-01, -3.0563e+00],
           [ 3.3218e+00, -5.3795e+00,  3.3445e+00,  3.0653e+00],
           [ 1.2528e+00, -9.5476e-01,  2.2962e+00, -2.4906e-01],
           [ 1.1260e+00, -8.5641e-01,  1.0126e-01, -4.8459e-01]]]],
        device='cuda:0', dtype=torch.float64),
 tensor([[[[ 1.3274e-05,  1.0109e-05,  3.6511e-05, -1.6533e-05],
           [ 5.3569e-01,  2.4857e+00,  3.0122e+00, -1.4810e+00],
           [-6.8308e+00,  7.2907e+00,  2.5978e+00,  2.0262e-01],
           [ 8.0473e-02, -3.9897e-01,  8.5354e-02,  2.3112e-02],
           [ 2.6339e+00,  1.2841e+00,  7.9257e-01, -3.0563e+00],
           [ 3.3218e+00, -5.3795e+00,  3.3445e+00,  3.0653e+00],
           [ 1.2528e+00, -9.5476e-01,  2

In [22]:
var_b.abs(), torch.exp(-var_m), var_b.abs() < torch.exp(-var_m)

(tensor([[[[0.4916],
           [0.3684],
           [0.1097],
           [0.2194],
           [0.9885],
           [0.0628],
           [1.3290],
           [2.0126]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<AbsBackward0>),
 tensor([[[[0.6603],
           [0.5352],
           [0.3983],
           [0.6340],
           [0.4514],
           [0.6290],
           [0.3836],
           [0.7017]]]], device='cuda:0', dtype=torch.float64,
        grad_fn=<ExpBackward0>),
 tensor([[[[ True],
           [ True],
           [ True],
           [ True],
           [False],
           [ True],
           [False],
           [False]]]], device='cuda:0'))

In [23]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 1.8937e-06, -1.0558e-05, -2.7356e-06,  1.5993e-05],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-6.0601e-07, -9.2155e-07,  4.0914e-06, -5.0413e-07],
          [-1.2725e-06,  1.7485e-06,  5.4959e-06,  8.2998e-07]]]],
       device='cuda:0', dtype=torch.float64)

In [24]:
fgs_pt.grad - fgs_obw.grad

tensor([[[[ 3.6995e-17],
          [-2.3993e-07],
          [ 9.2670e-07],
          [ 1.9190e-06],
          [-3.1642e-06],
          [-1.1329e-06],
          [-1.9609e-06],
          [-1.3309e-06]]]], device='cuda:0', dtype=torch.float64)

In [25]:
torch.allclose(qs_pt.grad, qs_obw.grad, atol=1e-5, rtol=1e-5)

True

In [27]:
## Conclusion: 
# dividing we get the same gradients, the error -1e-5 is due to numerical precision

In [28]:
# tensor([[[[-3.7828e-17, -2.8809e-17, -1.0405e-16,  4.7115e-17],
#           [-1.1102e-16, -4.4409e-16, -8.8818e-16,  4.4409e-16],
#           [ 8.8818e-16, -8.8818e-16,  0.0000e+00, -8.3267e-17],
#           [ 1.2490e-16, -5.5511e-17,  1.3878e-17, -1.7347e-17],
#           [ 1.8937e-06, -1.0558e-05, -2.7356e-06,  1.5993e-05],
#           [ 4.4409e-16,  0.0000e+00,  1.3323e-15,  0.0000e+00],
#           [-6.0601e-07, -9.2155e-07,  4.0914e-06, -5.0413e-07],
#           [-1.2725e-06,  1.7485e-06,  5.4959e-06,  8.2998e-07]]]],
#        device='cuda:0', dtype=torch.float64)