In [118]:
import sys
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(".")
# os.environ["MAX_JOBS"] = "100"

import torch
torch.set_printoptions(linewidth=200, threshold=100000)
torch.cuda.is_available()

True

In [119]:
from src.vlstm_fwbw_v1.interface import vlstm_fwbw_torch_obw, vlstm_fwbw_cuda
from src.vlstm_fwbw_v1.interface import vlstm_fw_torch, vlstm_fw_cuda
from src.vlstm_fwbw_v1.interface import vlstm_bw_torch_obw, vlstm_bw_cuda

## CUDA vLSTM forward backward

In [120]:
S = 64 #32 #16 #8 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 8 # dim per head
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [121]:
# create qkv, inputgates, forgetgates 
torch.manual_seed(0)
# fixed:
# qs = torch.arange((B*NH*S*DH), device=DEVICE, dtype=DTYPE).reshape((B, NH, S, DH)) / 10.
# ks = torch.ones((B, NH, S, DH), device=DEVICE, dtype=DTYPE) / 100.
# vs = torch.ones((B, NH, S, DH), device=DEVICE, dtype=DTYPE) / 100.

# random: 
qs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
ks = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
vs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)
# igs = (1. + torch.arange((B * NH * S), device=DEVICE, dtype=DTYPE)).reshape(B, NH, S, 1) / 10.
# igs = torch.zeros((B, NH, S, 1), device=DEVICE, dtype=DTYPE) #/ 10.
igs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE) #/ 10.
# fgs = torch.ones((B, NH, S, 1), device=DEVICE, dtype=DTYPE)
fgs = torch.randn((B, NH, S, 1), device=DEVICE, dtype=DTYPE)

dHs = torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE)

### match directly

In [122]:
hs_pt, n_pt, m_pt, _, matD_pt = vlstm_fw_torch(queries=qs, keys=ks, values=vs, igate_preact=igs, fgate_preact=fgs)

In [123]:
dQs_pt, dKs_pt, dVs_pt, dIgs_pt, dFgs_pt, delta_D_pt, delta_Dtilde_pt, delta_fbar_pt, mat_P_pt, mat_R_pt = vlstm_bw_torch_obw(
    delta_Htilde=dHs,
    queries=qs,
    keys=ks,
    values=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    var_n=n_pt,
    var_m=m_pt,
)

In [124]:
# cuda kernel
hs_cu, n_cu, m_cu, matD_cu = vlstm_fw_cuda(mat_Q=qs, mat_K=ks, mat_V=vs, igate_preact=igs.squeeze(-1), fgate_preact=fgs.squeeze(-1))

before kernel dispatch - float32!
B: 1, NH: 1, S: 64, DH: 8
blocksxy: 1-2, threadsxy: 4-4, shared_mem in bytes: 5664
In FW-Kernel: gdim.x: 1, gdim.y: 2, gdim.z: 1, bdim.x: 4, bdim.y: 4
In FW-Kernel: QtileDim: 8, KVtileDim: 8, TblockDim:4


In [125]:
dQs_cu, dKs_cu, dVs_cu, dIgs_cu, dFgs_cu, matC_cu, deltaDcsChunkArr_cu, deltaDcsVec_cu = vlstm_bw_cuda(
    delta_Htilde=dHs,
    mat_Q=qs,
    mat_K=ks,
    mat_V=vs,
    igate_preact=igs,
    fgate_preact=fgs,
    n=n_cu,
    m=m_cu,
)

before kernel dispatch - float32!
B: 1, NH: 1, S: 64, DH: 8
blocksxy: 1-1, threadsxy: 4-4, shared_mem in bytes: 7648
In BW-Kernel: gdim.x: 1, gdim.y: 1, gdim.z: 1, bdim.x: 4, bdim.y: 4
In BW-Kernel: QtileDim: 8, KVtileDim: 8, TblockDim:4


In [126]:
FW_RTOL = 1e-10
FW_ATOL = 1e-4
BW_RTOL = FW_ATOL
BW_ATOL = FW_ATOL
print(f"fw hs match: {torch.allclose(hs_cu, hs_pt, rtol=FW_RTOL, atol=FW_ATOL)}")
print(f"fw n match: {torch.allclose(n_cu, n_pt, rtol=FW_RTOL, atol=FW_ATOL)}")
print(f"fw m match: {torch.allclose(m_cu, m_pt, rtol=FW_RTOL, atol=FW_ATOL)}")
print(f"fw D match: {torch.allclose((matD_cu - matD_pt).tril(), torch.zeros_like((matD_cu)), rtol=FW_RTOL, atol=FW_ATOL)}")

print(f"delta Q match: {torch.allclose(dQs_cu, dQs_pt, rtol=BW_RTOL, atol=BW_ATOL)}")
print(f"delta K match: {torch.allclose(dKs_cu, dKs_pt, rtol=BW_RTOL, atol=BW_ATOL)}")
print(f"delta V match: {torch.allclose(dVs_cu, dVs_pt, rtol=BW_RTOL, atol=BW_ATOL)}")
print(f"delta Igate match: {torch.allclose(dIgs_cu, dIgs_pt, rtol=BW_RTOL, atol=BW_ATOL)}")
print(f"delta Fgate match: {torch.allclose(dFgs_cu, dFgs_pt, rtol=BW_RTOL, atol=BW_ATOL)}")

print(f"mat R match: {torch.allclose(mat_R_pt, matC_cu, rtol=BW_RTOL, atol=BW_ATOL)}")


fw hs match: True
fw n match: True
fw m match: True
fw D match: True
delta Q match: True
delta K match: False
delta V match: False
delta Igate match: True
delta Fgate match: True
mat R match: True


In [127]:
torch.abs(hs_cu - hs_pt).max()

tensor(1.0967e-05, device='cuda:0')

In [128]:
(matD_cu - matD_pt).tril().abs().max()

tensor(1.1444e-05, device='cuda:0')

In [129]:
torch.abs(dQs_cu - dQs_pt).max(), torch.abs(dKs_cu - dKs_pt).max(), torch.abs(dVs_cu - dVs_pt).max(), torch.abs(dIgs_cu - dIgs_pt).max(), torch.abs(dFgs_cu - dFgs_pt).max()

(tensor(5.5790e-05, device='cuda:0'),
 tensor(1.8722, device='cuda:0'),
 tensor(inf, device='cuda:0'),
 tensor(6.5804e-05, device='cuda:0'),
 tensor(3.8147e-05, device='cuda:0'))

In [130]:
torch.abs(hs_cu - hs_pt).max(), torch.abs(n_cu - n_pt).max(), torch.abs(m_cu - m_pt).max()

(tensor(1.0967e-05, device='cuda:0'),
 tensor(7.1526e-06, device='cuda:0'),
 tensor(4.7684e-06, device='cuda:0'))