In [1]:
%load_ext autoreload
%autoreload 2
import torch
from wkv_kernel import WKV, WKVConfig
from einops import rearrange, reduce, repeat


  from .autonotebook import tqdm as notebook_tqdm


## Translate the recurrence formula from CUDA to torch

In [2]:
# Setup cuda kernel
wkv_cuda_kernel = WKV()

Using /system/user/beck/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /system/user/beck/.cache/torch_extensions/py310_cu117/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module wkv...


### Inputs

In [3]:
dtype = torch.float32
device = torch.device('cuda')

In [4]:
# original input shapes from training
batch_size = 12
seq_len = 512
embedding_dim = 512
time_decay = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
time_first = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
k = torch.randn((batch_size, seq_len, embedding_dim),
                dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
v = torch.randn((batch_size, seq_len, embedding_dim),
                device=device,
                dtype=dtype).to(memory_format=torch.contiguous_format)
y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)

In [5]:
# mock up for experiment
# batch_size = 2
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# time_first = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)


### CUDA version

In [6]:
wkv_cuda_kernel.wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay,
                             time_first, k, v, y_gt)

In [7]:
y_gt

tensor([[[-1.0037e+00, -8.6573e-02,  8.0910e-01,  ...,  7.8404e-01,
           8.3063e-01, -4.3740e-02],
         [ 1.1826e+00,  5.2133e-01,  8.1488e-01,  ...,  3.4065e-01,
           7.3034e-01, -2.9765e-01],
         [-2.8041e-01,  2.7679e-01,  6.9914e-01,  ...,  1.9742e-01,
           1.0599e+00, -6.8128e-01],
         ...,
         [-3.0968e-01,  2.4749e-01,  4.4073e-01,  ..., -4.4695e-01,
           8.7492e-02, -5.6729e-03],
         [ 6.1573e-01,  2.4749e-01,  5.6366e-01,  ..., -3.8346e-01,
           8.7492e-02,  7.1407e-04],
         [ 8.1210e-01,  2.4749e-01,  6.9856e-01,  ..., -4.3334e-01,
           8.7492e-02, -6.5013e-02]],

        [[-1.2718e+00,  1.0147e+00, -2.2606e-01,  ..., -6.5893e-01,
          -2.4915e-01, -4.5347e-01],
         [-8.1836e-01,  9.2059e-01,  7.4954e-01,  ..., -3.6657e-01,
          -1.0068e-01, -5.9264e-01],
         [ 3.4145e-01,  6.0618e-01,  9.1639e-01,  ...,  3.7305e-01,
           1.0906e-01,  4.2963e-01],
         ...,
         [ 1.8321e-01,  4

### Reproduced torch version 1

In [8]:
# ChatGPT output:

# ww = u_timefirst[c] + k[b, i, c]
# p = max(pp, ww)
# e1 = exp(pp - p)
# e2 = exp(ww - p)
# y[b, i, c] = (e1 * aa + e2 * v[b, i, c]) / (e1 * bb + e2)
# ww = w_timedecay[c] + pp
# p = max(ww, k[b, i, c])
# e1 = exp(ww - p)
# e2 = exp(k[b, i, c] - p)
# aa = e1 * aa + e2 * v[b, i, c]
# bb = e1 * bb + e2
# pp = p


In [9]:
def cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    for b in range(batch_size):
        for c in range(embedding_dim):
            pp = MIN_VAL
            aa = 0
            bb = 0
            for i in range(seq_len):
                # ii = i * embedding_dim + c
                kk = k[b, i, c]
                vv = v[b, i, c]
                ww = time_first[c] + kk
                p = torch.tensor(max(pp, ww))
                e1 = torch.exp(pp - p)
                e2 = torch.exp(ww - p)
                new_y = (e1 * aa + e2 * vv) / (e1 * bb + e2)
                y[b, i, c] = new_y
                ww = time_decay[c] + pp
                p = torch.tensor(max(ww, kk))
                e1 = torch.exp(ww - p)
                e2 = torch.exp(kk - p) # uses current key
                aa = e1 * aa + e2 * vv
                bb = e1 * bb + e2
                pp = p

    return y

In [10]:
# y_m = cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
#                         time_first, k, v, y)
# y_m

### torch version v3 (own impl)

In [11]:
def cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [12]:
y_m3 = cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [13]:
y_m3

tensor([[[-1.0037e+00, -8.6573e-02,  8.0910e-01,  ...,  7.8404e-01,
           8.3063e-01, -4.3740e-02],
         [ 1.1826e+00,  5.2133e-01,  8.1488e-01,  ...,  3.4065e-01,
           7.3034e-01, -2.9765e-01],
         [-2.8041e-01,  2.7679e-01,  6.9914e-01,  ...,  1.9742e-01,
           1.0599e+00, -6.8128e-01],
         ...,
         [-3.0968e-01,  2.4749e-01,  4.4073e-01,  ..., -4.4695e-01,
           8.7492e-02, -5.6729e-03],
         [ 6.1573e-01,  2.4749e-01,  5.6366e-01,  ..., -3.8346e-01,
           8.7492e-02,  7.1407e-04],
         [ 8.1210e-01,  2.4749e-01,  6.9856e-01,  ..., -4.3334e-01,
           8.7492e-02, -6.5013e-02]],

        [[-1.2718e+00,  1.0147e+00, -2.2606e-01,  ..., -6.5893e-01,
          -2.4915e-01, -4.5347e-01],
         [-8.1836e-01,  9.2059e-01,  7.4954e-01,  ..., -3.6657e-01,
          -1.0068e-01, -5.9264e-01],
         [ 3.4145e-01,  6.0618e-01,  9.1639e-01,  ...,  3.7305e-01,
           1.0906e-01,  4.2963e-01],
         ...,
         [ 1.8321e-01,  4

In [40]:
print(torch.allclose(y_gt, y_m3, atol=1e-6)), print(torch.allclose(y_gt, y_m3, atol=1e-7))

True
False


(None, None)

In [15]:
y_gt[:2, :2, :2], y_m3[:2, :2, :2]

(tensor([[[-1.0037, -0.0866],
          [ 1.1826,  0.5213]],
 
         [[-1.2718,  1.0147],
          [-0.8184,  0.9206]]], device='cuda:0'),
 tensor([[[-1.0037, -0.0866],
          [ 1.1826,  0.5213]],
 
         [[-1.2718,  1.0147],
          [-0.8184,  0.9206]]], device='cuda:0'))

#### Single computation steps

In [16]:
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
            device=device).to(memory_format=torch.contiguous_format)

In [17]:
batch_size, seq_len, embedding_dim

(12, 512, 512)

In [18]:
k.shape, v.shape, y.shape, time_decay.shape, time_first.shape

(torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([512]),
 torch.Size([512]))

In [19]:
k_ = rearrange(k, 'b s e -> s b e')
v_ = rearrange(v, 'b s e -> s b e')
y_ = rearrange(y, 'b s e -> s b e')
tf = repeat(time_first, 'e -> b e', b=batch_size)
td = repeat(time_decay, 'e -> b e', b=batch_size)

In [20]:
k_.shape, v_.shape, y.shape, td.shape, tf.shape

(torch.Size([512, 12, 512]),
 torch.Size([512, 12, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512]),
 torch.Size([12, 512]))

In [21]:
MIN_VAL = -1e-38
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)

In [22]:
t = 0

In [23]:
ww = tf + k_[t]


In [24]:
p = torch.max(pp, ww)

In [25]:
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)

In [26]:
# y[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)

In [27]:
ww = td + pp

In [28]:
ww

tensor([[-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464],
        [-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464],
        [-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464],
        ...,
        [-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464],
        [-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464],
        [-0.5128,  0.2325, -0.4114,  ..., -0.6003,  0.1299, -0.2464]],
       device='cuda:0')

In [29]:
p = torch.max(ww, k_[t])

In [30]:
e1 = torch.exp(ww - p)
e2 = torch.exp(k_[t] - p)

In [31]:
aa = e1 * aa + e2 * v_[t]
bb = e1 * bb + e2

In [32]:
pp = p

### torch version v4 (own impl, simplify)

In [33]:
def cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        # ww = tf + k_[t]
        # p = torch.max(pp, ww)
        # e1 = torch.exp(pp - p)
        # e2 = torch.exp(ww - p)
        # y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        # ww = td + pp
        # p = torch.max(ww, k_[t])
        # e1 = torch.exp(ww - p)
        # e2 = torch.exp(k_[t] - p)
        # aa = e1 * aa + e2 * v_[t]
        # bb = e1 * bb + e2
        # pp = p
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [34]:
y_m4 = cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [35]:
y_m4

tensor([[[-1.0037e+00, -8.6573e-02,  8.0910e-01,  ...,  7.8404e-01,
           8.3063e-01, -4.3740e-02],
         [ 1.1826e+00,  5.2133e-01,  8.1488e-01,  ...,  3.4065e-01,
           7.3034e-01, -2.9765e-01],
         [-2.8041e-01,  2.7679e-01,  6.9914e-01,  ...,  1.9742e-01,
           1.0599e+00, -6.8128e-01],
         ...,
         [-3.0968e-01,  2.4749e-01,  4.4073e-01,  ..., -4.4695e-01,
           8.7492e-02, -5.6729e-03],
         [ 6.1573e-01,  2.4749e-01,  5.6366e-01,  ..., -3.8346e-01,
           8.7492e-02,  7.1407e-04],
         [ 8.1210e-01,  2.4749e-01,  6.9856e-01,  ..., -4.3334e-01,
           8.7492e-02, -6.5013e-02]],

        [[-1.2718e+00,  1.0147e+00, -2.2606e-01,  ..., -6.5893e-01,
          -2.4915e-01, -4.5347e-01],
         [-8.1836e-01,  9.2059e-01,  7.4954e-01,  ..., -3.6657e-01,
          -1.0068e-01, -5.9264e-01],
         [ 3.4145e-01,  6.0618e-01,  9.1639e-01,  ...,  3.7305e-01,
           1.0906e-01,  4.2963e-01],
         ...,
         [ 1.8321e-01,  4

In [36]:
print(torch.allclose(y_gt, y_m4, atol=1e-5))

True


In [37]:
print(-1e4)

-10000.0
