In [1]:
%load_ext autoreload
%autoreload 2
import torch
from wkv_kernel import WKV, WKVConfig
from einops import rearrange, reduce, repeat


  from .autonotebook import tqdm as notebook_tqdm


## Translate the recurrence formula from CUDA to torch

In [2]:
# Setup cuda kernel
wkv_cuda_kernel = WKV()

Using /system/user/beck/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /system/user/beck/.cache/torch_extensions/py310_cu117/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module wkv...


### Inputs

In [3]:
dtype = torch.float32
device = torch.device('cuda')

In [4]:
# original input shapes from training
batch_size = 12
seq_len = 512
embedding_dim = 512
time_decay = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
time_first = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
k = torch.randn((batch_size, seq_len, embedding_dim),
                dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
v = torch.randn((batch_size, seq_len, embedding_dim),
                device=device,
                dtype=dtype).to(memory_format=torch.contiguous_format)
y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)

In [5]:
# mock up for experiment
# batch_size = 2
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# time_first = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)


### CUDA version

In [6]:
wkv_cuda_kernel.wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay,
                             time_first, k, v, y_gt)

In [7]:
y_gt

tensor([[[ 6.0614e-01, -2.7206e-02, -1.0326e+00,  ..., -8.8088e-01,
           3.0217e-01, -1.6988e+00],
         [-1.8448e-02,  1.9646e-01, -1.2128e+00,  ..., -8.4950e-01,
           6.7267e-01, -1.3180e+00],
         [ 2.3237e-01,  2.8511e-01, -1.0060e+00,  ..., -8.2239e-01,
           1.8111e-01, -6.8774e-01],
         ...,
         [ 3.2317e-01, -4.1025e-02, -4.8568e-01,  ..., -1.6375e-01,
          -1.9714e+00, -5.2361e-04],
         [ 3.2317e-01, -2.5891e-01, -1.1748e-01,  ..., -8.4061e-01,
           1.5993e-01, -5.0706e-01],
         [ 3.2317e-01, -5.5133e-01, -9.9434e-02,  ..., -9.5514e-01,
          -9.1517e-01, -7.5674e-01]],

        [[-1.8561e+00,  3.2080e-01, -2.9945e+00,  ...,  8.6864e-01,
          -7.0153e-01, -2.9329e-01],
         [-4.4004e-01,  9.4639e-01, -1.0114e+00,  ...,  1.3411e+00,
           5.8858e-01, -3.4078e-01],
         [-4.7000e-01,  1.4115e-01, -1.0208e+00,  ..., -3.7792e-01,
           7.2464e-01,  2.1679e-01],
         ...,
         [-1.0560e+00,  2

### Reproduced torch version 1

In [8]:
# ChatGPT output:

# ww = u_timefirst[c] + k[b, i, c]
# p = max(pp, ww)
# e1 = exp(pp - p)
# e2 = exp(ww - p)
# y[b, i, c] = (e1 * aa + e2 * v[b, i, c]) / (e1 * bb + e2)
# ww = w_timedecay[c] + pp
# p = max(ww, k[b, i, c])
# e1 = exp(ww - p)
# e2 = exp(k[b, i, c] - p)
# aa = e1 * aa + e2 * v[b, i, c]
# bb = e1 * bb + e2
# pp = p


In [9]:
def cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    for b in range(batch_size):
        for c in range(embedding_dim):
            pp = MIN_VAL
            aa = 0
            bb = 0
            for i in range(seq_len):
                # ii = i * embedding_dim + c
                kk = k[b, i, c]
                vv = v[b, i, c]
                ww = time_first[c] + kk
                p = torch.tensor(max(pp, ww))
                e1 = torch.exp(pp - p)
                e2 = torch.exp(ww - p)
                new_y = (e1 * aa + e2 * vv) / (e1 * bb + e2)
                y[b, i, c] = new_y
                ww = time_decay[c] + pp
                p = torch.tensor(max(ww, kk))
                e1 = torch.exp(ww - p)
                e2 = torch.exp(kk - p) # uses current key
                aa = e1 * aa + e2 * vv
                bb = e1 * bb + e2
                pp = p

    return y

In [10]:
# y_m = cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
#                         time_first, k, v, y)
# y_m

### torch version v3 (own impl)

In [11]:
def cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [12]:
y_m3 = cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [13]:
y_m3

tensor([[[ 6.0614e-01, -2.7206e-02, -1.0326e+00,  ..., -8.8088e-01,
           3.0217e-01, -1.6988e+00],
         [-1.8448e-02,  1.9646e-01, -1.2128e+00,  ..., -8.4950e-01,
           6.7267e-01, -1.3180e+00],
         [ 2.3237e-01,  2.8511e-01, -1.0060e+00,  ..., -8.2239e-01,
           1.8111e-01, -6.8774e-01],
         ...,
         [ 3.2317e-01, -4.1025e-02, -4.8568e-01,  ..., -1.6375e-01,
          -1.9714e+00, -5.2361e-04],
         [ 3.2317e-01, -2.5891e-01, -1.1748e-01,  ..., -8.4061e-01,
           1.5993e-01, -5.0706e-01],
         [ 3.2317e-01, -5.5133e-01, -9.9434e-02,  ..., -9.5514e-01,
          -9.1517e-01, -7.5674e-01]],

        [[-1.8561e+00,  3.2080e-01, -2.9945e+00,  ...,  8.6864e-01,
          -7.0153e-01, -2.9329e-01],
         [-4.4004e-01,  9.4639e-01, -1.0114e+00,  ...,  1.3411e+00,
           5.8858e-01, -3.4078e-01],
         [-4.7000e-01,  1.4115e-01, -1.0208e+00,  ..., -3.7792e-01,
           7.2464e-01,  2.1679e-01],
         ...,
         [-1.0560e+00,  2

In [14]:
print(torch.allclose(y_gt, y_m3, atol=1e-6)), print(torch.allclose(y_gt, y_m3, atol=1e-7))

True
False


(None, None)

In [15]:
y_gt[:2, :2, :2], y_m3[:2, :2, :2]

(tensor([[[ 0.6061, -0.0272],
          [-0.0184,  0.1965]],
 
         [[-1.8561,  0.3208],
          [-0.4400,  0.9464]]], device='cuda:0'),
 tensor([[[ 0.6061, -0.0272],
          [-0.0184,  0.1965]],
 
         [[-1.8561,  0.3208],
          [-0.4400,  0.9464]]], device='cuda:0'))

### torch version v4 (simplify)

In [21]:
def cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = 0.0 #-1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    eps = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    
    for t in range(seq_len):
        #! v1
        # y_[t] = (aa + v_[t] * torch.exp(tf + k_[t]-eps)) / (bb + torch.exp(tf + k_[t]-eps))
        # eps_next = torch.max(td+eps, k_[t])

        # aa = (aa * torch.exp(td+eps-eps_next) + v_[t] * torch.exp(k_[t]-eps_next))
        # bb = (bb * torch.exp(td+eps-eps_next) + torch.exp(k_[t]-eps_next))
        # eps = eps_next

        #! v2
        e_tf_k = torch.exp(tf + k_[t] - eps)
        y_[t] = (aa + v_[t] * e_tf_k) / (bb + e_tf_k)
        eps_next = torch.max(td+eps, k_[t])

        e_td_k = torch.exp(td+eps-eps_next)
        e_k = torch.exp(k_[t]-eps_next)
        aa = (aa * e_td_k + v_[t] * e_k)
        bb = (bb * e_td_k + e_k)
        eps = eps_next

    y = rearrange(y_, 's b e -> b s e')
    return y

In [22]:
y_m4 = cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [23]:
y_m4

tensor([[[ 6.0614e-01, -2.7206e-02, -1.0326e+00,  ..., -8.8088e-01,
           3.0217e-01, -1.6988e+00],
         [-1.8448e-02,  1.9646e-01, -1.2128e+00,  ..., -8.4950e-01,
           6.7267e-01, -1.3180e+00],
         [ 2.3237e-01,  2.8511e-01, -1.0060e+00,  ..., -8.2239e-01,
           1.8111e-01, -6.8774e-01],
         ...,
         [ 3.2317e-01, -4.1025e-02, -4.8568e-01,  ..., -1.6375e-01,
          -1.9714e+00, -5.2361e-04],
         [ 3.2317e-01, -2.5891e-01, -1.1748e-01,  ..., -8.4061e-01,
           1.5993e-01, -5.0706e-01],
         [ 3.2317e-01, -5.5133e-01, -9.9434e-02,  ..., -9.5514e-01,
          -9.1517e-01, -7.5674e-01]],

        [[-1.8561e+00,  3.2080e-01, -2.9945e+00,  ...,  8.6864e-01,
          -7.0153e-01, -2.9329e-01],
         [-4.4004e-01,  9.4639e-01, -1.0114e+00,  ...,  1.3411e+00,
           5.8858e-01, -3.4078e-01],
         [-4.7000e-01,  1.4115e-01, -1.0208e+00,  ..., -3.7792e-01,
           7.2464e-01,  2.1679e-01],
         ...,
         [-1.0560e+00,  2

In [24]:
print(torch.allclose(y_gt, y_m4, atol=1e-6), y_m4.isnan().any(), y_gt.isnan().any())

True tensor(False, device='cuda:0') tensor(False, device='cuda:0')
