In [1]:
import torch
torch.set_printoptions(linewidth=200, threshold=100000)


In [2]:
from ln import MultiHeadLayerNorm

from mlstm_parallel import mlstm_torch_autograd, mlstm_torch_ownbw, mlstm_triton

In [3]:
# params
S = 32 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 64 # dim per head
DTYPE = torch.float16
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [4]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(1)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
igs = torch.rand((B, NH, S), device=DEVICE, dtype=DTYPE)
igs2 = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S)
fgs = torch.rand((B, NH, S), device=DEVICE, dtype=DTYPE)
qs.shape, fgs.shape

(torch.Size([1, 1, 8, 4]), torch.Size([1, 1, 8]))

In [5]:
offset = 3.* torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

In [6]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64, requires_grad=True),
 None)

## pytorch

In [7]:
fgs_pt = fgs.clone().detach().requires_grad_(True)
igs_pt = igs.clone().detach().requires_grad_(True)
qs_pt = qs.clone().detach().requires_grad_(True)
ks_pt = ks.clone().detach().requires_grad_(True)
vs_pt = vs.clone().detach().requires_grad_(True)

In [8]:
qs_pt.grad

In [9]:
hs_pt, _, _ = mlstm_torch_autograd(qs_pt, ks_pt, vs_pt, igs_pt, fgs_pt)
hs_pt_scaled = mh_layernorm(hs_pt)
hs_pt_scaled, hs_pt_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4225],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0', dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [10]:
((hs_pt_scaled+offset)**2).sum().backward()

## own backward torch

In [11]:
fgs_obw = fgs.clone().detach().requires_grad_(True)
igs_obw = igs.clone().detach().requires_grad_(True)
qs_obw = qs.clone().detach().requires_grad_(True)
ks_obw = ks.clone().detach().requires_grad_(True)
vs_obw = vs.clone().detach().requires_grad_(True)

In [12]:
hs_obw = mlstm_torch_ownbw(qs_obw, ks_obw, vs_obw, igs_obw, fgs_obw)
hs_obw_scaled = mh_layernorm(hs_obw)
hs_obw_scaled, hs_obw_scaled.shape

(tensor([[[[-0.2841,  1.3220,  0.3847, -1.4225],
           [ 0.7400,  1.0432, -0.2694, -1.5138],
           [ 0.4508,  1.3221, -1.3749, -0.3980],
           [ 1.3565, -0.7354, -1.1550,  0.5340],
           [ 1.5541, -0.6481, -1.0672,  0.1612],
           [ 1.1217, -1.5834,  0.4839, -0.0223],
           [ 0.5463, -1.4683,  1.2095, -0.2875],
           [ 1.6076, -0.0882, -1.1156, -0.4038]]]], device='cuda:0', dtype=torch.float64, grad_fn=<TransposeBackward0>),
 torch.Size([1, 1, 8, 4]))

In [13]:
((hs_obw_scaled+offset)**2).sum().backward()

In [14]:
hs_pt_scaled - hs_obw_scaled

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0', dtype=torch.float64, grad_fn=<SubBackward0>)

In [15]:
qs_pt.grad - qs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-1.5314e-06,  1.1898e-06, -4.0425e-06, -1.1652e-05],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 6.4882e-06, -1.3727e-05,  6.1535e-06, -1.6513e-05],
          [ 3.1307e-05, -1.2041e-05,  1.1293e-05,  1.6676e-06],
          [ 2.7022e-06, -1.6586e-06,  1.1334e-05, -1.3535e-06],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]], device='cuda:0', dtype=torch.float64)

In [16]:
ks_pt.grad - ks_obw.grad

tensor([[[[ 1.6914e-06,  8.2518e-07,  1.4939e-06,  1.6078e-06],
          [ 5.2692e-06,  2.5707e-06,  4.6539e-06,  5.0088e-06],
          [ 6.8855e-06,  3.3593e-06,  6.0815e-06,  6.5453e-06],
          [ 5.3227e-07,  1.3101e-05,  4.5358e-06,  8.8716e-07],
          [ 8.7948e-07,  2.1647e-05,  7.4946e-06,  1.4659e-06],
          [-1.2351e-05,  9.8664e-06,  5.1246e-06, -8.8751e-06],
          [-6.2944e-06,  5.1208e-06, -7.5627e-06, -1.0601e-05],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]], device='cuda:0', dtype=torch.float64)

In [17]:
vs_pt.grad - vs_obw.grad

tensor([[[[ 0.0000e+00,  0.0000e+00,  1.3878e-16,  0.0000e+00],
          [-4.4409e-16,  3.3307e-16,  0.0000e+00,  0.0000e+00],
          [ 4.4409e-16,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-8.8818e-16,  2.2204e-16,  0.0000e+00, -8.8818e-16],
          [ 0.0000e+00,  0.0000e+00,  4.4409e-16, -4.4409e-16],
          [ 0.0000e+00,  0.0000e+00, -1.7764e-15,  1.7764e-15],
          [ 0.0000e+00,  0.0000e+00, -5.5511e-17,  0.0000e+00],
          [-1.1102e-16,  0.0000e+00,  0.0000e+00, -4.4409e-16]]]], device='cuda:0', dtype=torch.float64)

In [18]:
igs_pt.grad - igs_obw.grad

tensor([[[-7.6044e-07,  2.4240e-06, -2.0677e-05, -3.6671e-05,  8.3042e-06, -3.4175e-05,  2.4096e-06,  0.0000e+00]]], device='cuda:0', dtype=torch.float64)

In [19]:
fgs_pt.grad - fgs_obw.grad

tensor([[[-1.5099e-16, -2.7182e-07,  7.4923e-07, -5.8534e-07, -1.2563e-05, -7.5150e-07, -5.3494e-06, -8.0863e-11]]], device='cuda:0', dtype=torch.float64)

## triton

In [None]:
fgs_tr = fgs.clone().detach().requires_grad_(True)
igs_tr = igs.clone().detach().requires_grad_(True)
qs_tr = qs.clone().detach().requires_grad_(True)
ks_tr = ks.clone().detach().requires_grad_(True)
vs_tr = vs.clone().detach().requires_grad_(True)