In [None]:
import sys
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(".")
# os.environ["MAX_JOBS"] = "100"

import torch
torch.cuda.is_available()

In [None]:
from src.vlstm_v3.interface import qkvkernel

### qkvtest

In [None]:
S = 8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 4 # dim per head
DTYPE = torch.float16
DEVICE = torch.device("cuda:0")

In [None]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.arange((B*NH*S*DH), device=DEVICE, dtype=DTYPE).reshape((B, NH, S, DH)) / 10.
ks = torch.ones((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.ones((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
qs, qs.shape

In [None]:
rs = qs @ ks.transpose(-1, -2)
rs, rs.shape

In [None]:
# pytorch version
rs = qs @ ks.transpose(-1, -2) @ vs
rs, rs.shape

In [None]:
# cuda kernel
rs = qkvkernel(mat_Q=qs, mat_K=ks, mat_V=vs)
rs, rs.shape

In [None]:
qs

### Matmul kernel Test from vlstm_v2

In [None]:
# from src.vlstm_v3.interface import testkernel, copykernel, mmkernelv1

In [None]:
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda:0")

In [None]:
S = 8 # sequence length
DH = 8 # hidden size

In [None]:
matA = torch.arange((2*S * DH), device=DEVICE, dtype=DTYPE).reshape((2*S, DH))
matB = torch.ones((DH, S), device=DEVICE, dtype=DTYPE)
matA.shape, matB.shape, matA, matA.sum(-1)

In [None]:
# pytorch
pt_out = matA @ matB
pt_out, pt_out.shape

In [None]:
matA.is_contiguous(), matB.is_contiguous()

In [None]:
cu_out = mmkernelv1(mat_A=matA, mat_B=matB)
cu_out, cu_out.shape

In [None]:
matA[9]

In [None]:
matA[9].cumsum(-1)

In [None]:
torch.arange(72,80, dtype=torch.bfloat16, device=torch.device('cuda:0')).cumsum(-1)

In [None]:
torch.arange(72,80, dtype=torch.bfloat16, device=torch.device('cpu')).cumsum(-1)

In [None]:
torch.arange(72,80, dtype=torch.float32, device=torch.device('cpu')).cumsum(-1)

In [None]:
torch.arange(72,80, dtype=torch.float16, device=torch.device('cuda:0')).cumsum(-1)

In [None]:
# cu_out = mmkernelv2(mat_A=matA, mat_B=matB)
# cu_out, cu_out.shape

In [None]:
# mat @ mat.T @ mat

### pytorch version

In [None]:
H = 6 # hidden size
S = 5 # seq len
B = 1 # batch size
NH = 2 # num heads
DH = H // NH # dim per head
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda:0")
assert H % NH == 0

In [None]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ds = torch.rand((B, NH, S, S), device=DEVICE, dtype=DTYPE)

max_log_D, _ = torch.max(ds.view(B, NH, -1), dim=-1, keepdim=True)  # (B, NH, 1)
log_D_matrix_stabilized = ds - max_log_D.unsqueeze(-1)  # (B, NH, S, S) = (B, NH, S, S) - (B, NH, 1, 1)
D_matrix = torch.exp(log_D_matrix_stabilized)  # (B, NH, S, S)
mval = torch.exp(-max_log_D.unsqueeze(-1))