In [1]:
%load_ext autoreload
%autoreload 2
import torch
from wkv_kernel import WKV, WKVConfig
from einops import rearrange, reduce, repeat


  from .autonotebook import tqdm as notebook_tqdm


## Translate the recurrence formula from CUDA to torch

In [2]:
# Setup cuda kernel
wkv_cuda_kernel = WKV()

Using /system/user/beck/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /system/user/beck/.cache/torch_extensions/py310_cu117/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module wkv...


### Inputs

In [3]:
dtype = torch.float32
device = torch.device('cuda')

In [4]:
# original input shapes from training
batch_size = 12
seq_len = 512
embedding_dim = 512
time_decay = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
time_first = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
k = torch.randn((batch_size, seq_len, embedding_dim),
                dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
v = torch.randn((batch_size, seq_len, embedding_dim),
                device=device,
                dtype=dtype).to(memory_format=torch.contiguous_format)
y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)

In [5]:
# mock up for experiment
# batch_size = 2
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# time_first = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)


### CUDA version

In [6]:
wkv_cuda_kernel.wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay,
                             time_first, k, v, y_gt)

In [7]:
y_gt

tensor([[[ 2.0534, -0.3219,  0.6274,  ...,  0.3324, -0.6062, -0.2180],
         [ 1.9780, -0.2207,  0.4891,  ...,  0.6034, -0.3334, -0.5567],
         [ 1.5262, -0.1409,  0.4524,  ...,  0.8996,  0.3577, -0.2522],
         ...,
         [ 0.6318, -0.2448, -0.3052,  ...,  0.4673, -1.2484, -0.0291],
         [ 0.4366, -0.2448, -0.8811,  ...,  0.4521, -1.2866, -0.0023],
         [ 0.3315, -0.2448,  0.1401,  ...,  0.4907, -0.7423,  0.0164]],

        [[-0.7355, -1.9907, -0.1692,  ...,  0.7174,  0.3276,  1.3929],
         [-0.6438, -0.0233,  0.2341,  ..., -0.0680, -0.5767,  0.1555],
         [-0.1367, -2.1636, -1.1826,  ...,  0.3415,  1.3689,  0.1871],
         ...,
         [ 0.3206, -1.5454,  1.0111,  ...,  0.1211,  0.8732, -0.7091],
         [ 0.8627, -1.5454,  0.4594,  ..., -0.0901,  1.2090, -0.6461],
         [ 0.6551, -1.5454,  0.1613,  ..., -0.3871,  0.4103, -0.6684]],

        [[-1.2690, -0.3401, -1.9859,  ..., -1.4925,  1.9849, -0.9622],
         [-0.6077, -0.3963, -1.0200,  ..., -1

### Reproduced torch version 1

In [8]:
# ChatGPT output:

# ww = u_timefirst[c] + k[b, i, c]
# p = max(pp, ww)
# e1 = exp(pp - p)
# e2 = exp(ww - p)
# y[b, i, c] = (e1 * aa + e2 * v[b, i, c]) / (e1 * bb + e2)
# ww = w_timedecay[c] + pp
# p = max(ww, k[b, i, c])
# e1 = exp(ww - p)
# e2 = exp(k[b, i, c] - p)
# aa = e1 * aa + e2 * v[b, i, c]
# bb = e1 * bb + e2
# pp = p


In [9]:
def cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    for b in range(batch_size):
        for c in range(embedding_dim):
            pp = MIN_VAL
            aa = 0
            bb = 0
            for i in range(seq_len):
                # ii = i * embedding_dim + c
                kk = k[b, i, c]
                vv = v[b, i, c]
                ww = time_first[c] + kk
                p = torch.tensor(max(pp, ww))
                e1 = torch.exp(pp - p)
                e2 = torch.exp(ww - p)
                new_y = (e1 * aa + e2 * vv) / (e1 * bb + e2)
                y[b, i, c] = new_y
                ww = time_decay[c] + pp
                p = torch.tensor(max(ww, kk))
                e1 = torch.exp(ww - p)
                e2 = torch.exp(kk - p) # uses current key
                aa = e1 * aa + e2 * vv
                bb = e1 * bb + e2
                pp = p

    return y

In [10]:
# y_m = cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
#                         time_first, k, v, y)
# y_m

### torch version v3 (own impl)

In [11]:
def cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [12]:
y_m3 = cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [13]:
y_m3

tensor([[[ 2.0534, -0.3219,  0.6274,  ...,  0.3324, -0.6062, -0.2180],
         [ 1.9780, -0.2207,  0.4891,  ...,  0.6034, -0.3334, -0.5567],
         [ 1.5262, -0.1409,  0.4524,  ...,  0.8996,  0.3577, -0.2522],
         ...,
         [ 0.6318, -0.2448, -0.3052,  ...,  0.4673, -1.2484, -0.0291],
         [ 0.4366, -0.2448, -0.8811,  ...,  0.4521, -1.2866, -0.0023],
         [ 0.3315, -0.2448,  0.1401,  ...,  0.4907, -0.7423,  0.0164]],

        [[-0.7355, -1.9907, -0.1692,  ...,  0.7174,  0.3276,  1.3929],
         [-0.6438, -0.0233,  0.2341,  ..., -0.0680, -0.5767,  0.1555],
         [-0.1367, -2.1636, -1.1826,  ...,  0.3415,  1.3689,  0.1871],
         ...,
         [ 0.3206, -1.5454,  1.0111,  ...,  0.1211,  0.8732, -0.7091],
         [ 0.8627, -1.5454,  0.4594,  ..., -0.0901,  1.2090, -0.6461],
         [ 0.6551, -1.5454,  0.1613,  ..., -0.3871,  0.4103, -0.6684]],

        [[-1.2690, -0.3401, -1.9859,  ..., -1.4925,  1.9849, -0.9622],
         [-0.6077, -0.3963, -1.0200,  ..., -1

In [14]:
print(torch.allclose(y_gt, y_m3, atol=1e-6)), print(torch.allclose(y_gt, y_m3, atol=1e-7))

True
False


(None, None)

In [15]:
y_gt[:2, :2, :2], y_m3[:2, :2, :2]

(tensor([[[ 2.0534, -0.3219],
          [ 1.9780, -0.2207]],
 
         [[-0.7355, -1.9907],
          [-0.6438, -0.0233]]], device='cuda:0'),
 tensor([[[ 2.0534, -0.3219],
          [ 1.9780, -0.2207]],
 
         [[-0.7355, -1.9907],
          [-0.6438, -0.0233]]], device='cuda:0'))

#### Single computation steps

In [16]:
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
            device=device).to(memory_format=torch.contiguous_format)

In [17]:
batch_size, seq_len, embedding_dim

(12, 512, 512)

In [18]:
k.shape, v.shape, y.shape, time_decay.shape, time_first.shape

(torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([512]),
 torch.Size([512]))

In [19]:
k_ = rearrange(k, 'b s e -> s b e')
v_ = rearrange(v, 'b s e -> s b e')
y_ = rearrange(y, 'b s e -> s b e')
tf = repeat(time_first, 'e -> b e', b=batch_size)
td = repeat(time_decay, 'e -> b e', b=batch_size)

In [20]:
k_.shape, v_.shape, y.shape, td.shape, tf.shape

(torch.Size([512, 12, 512]),
 torch.Size([512, 12, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512]),
 torch.Size([12, 512]))

In [21]:
MIN_VAL = -1e-38
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)

In [22]:
t = 0

In [23]:
ww = tf + k_[t]


In [24]:
p = torch.max(pp, ww)

In [25]:
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)

In [26]:
# y[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)

In [27]:
ww = td + pp

In [28]:
ww

tensor([[-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350],
        [-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350],
        [-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350],
        ...,
        [-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350],
        [-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350],
        [-0.6083,  1.0452, -1.0215,  ..., -0.2204, -0.9684, -0.0350]],
       device='cuda:0')

In [29]:
p = torch.max(ww, k_[t])

In [30]:
e1 = torch.exp(ww - p)
e2 = torch.exp(k_[t] - p)

In [31]:
aa = e1 * aa + e2 * v_[t]
bb = e1 * bb + e2

In [32]:
pp = p

### torch version v4 (own impl, simplify)

In [33]:
def cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    # debug metrics
    max_pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_pp = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_aa = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_aa = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_bb = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_bb = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    eps = pp
    for t in range(seq_len):
        # #! original version
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p

        # #? debug metrics
        max_pp = torch.max(max_pp, pp)
        min_pp = torch.min(min_pp, pp)
        max_aa = torch.max(max_aa, aa)
        min_aa = torch.min(min_aa, aa)
        max_bb = torch.max(max_bb, bb)
        min_bb = torch.min(min_bb, bb)

    y = rearrange(y_, 's b e -> b s e')
    return y, {'max_pp': max_pp, 'min_pp': min_pp, 'max_aa': max_aa, 'min_aa': min_aa, 'max_bb': max_bb, 'min_bb': min_bb}

In [34]:
y_m4, ret_dict = cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [35]:
y_m4

tensor([[[ 2.0534, -0.3219,  0.6274,  ...,  0.3324, -0.6062, -0.2180],
         [ 1.9780, -0.2207,  0.4891,  ...,  0.6034, -0.3334, -0.5567],
         [ 1.5262, -0.1409,  0.4524,  ...,  0.8996,  0.3577, -0.2522],
         ...,
         [ 0.6318, -0.2448, -0.3052,  ...,  0.4673, -1.2484, -0.0291],
         [ 0.4366, -0.2448, -0.8811,  ...,  0.4521, -1.2866, -0.0023],
         [ 0.3315, -0.2448,  0.1401,  ...,  0.4907, -0.7423,  0.0164]],

        [[-0.7355, -1.9907, -0.1692,  ...,  0.7174,  0.3276,  1.3929],
         [-0.6438, -0.0233,  0.2341,  ..., -0.0680, -0.5767,  0.1555],
         [-0.1367, -2.1636, -1.1826,  ...,  0.3415,  1.3689,  0.1871],
         ...,
         [ 0.3206, -1.5454,  1.0111,  ...,  0.1211,  0.8732, -0.7091],
         [ 0.8627, -1.5454,  0.4594,  ..., -0.0901,  1.2090, -0.6461],
         [ 0.6551, -1.5454,  0.1613,  ..., -0.3871,  0.4103, -0.6684]],

        [[-1.2690, -0.3401, -1.9859,  ..., -1.4925,  1.9849, -0.9622],
         [-0.6077, -0.3963, -1.0200,  ..., -1

In [36]:
print(torch.allclose(y_gt, y_m4, atol=1e-6), y_m4.isnan().any(), y_gt.isnan().any())

True tensor(False, device='cuda:0') tensor(False, device='cuda:0')


#### Look at max and min values of running sums
In each step we use elementwise max/min operations. 

In [37]:
ret_dict['max_pp']

tensor([[  2.6918, 535.4303,   3.0462,  ...,   3.2541,   2.7332,   2.6890],
        [  2.7069, 532.7648,   2.7629,  ...,   2.8708,   2.9609,   2.6649],
        [  3.1108, 534.8326,   2.8436,  ...,   2.6370,   2.5693,   2.6927],
        ...,
        [  3.1068, 534.5170,   2.8399,  ...,   2.7139,   3.0000,   3.8625],
        [  3.9813, 533.0916,   2.6091,  ...,   3.0547,   3.5176,   3.2510],
        [  4.3635, 533.9435,   3.2308,  ...,   2.5703,   3.4670,   2.8087]],
       device='cuda:0')

In [38]:
ret_dict['min_pp']

tensor([[-1.4057,  1.3498, -1.9416,  ..., -1.7547, -2.2083, -0.8776],
        [-1.5637, -1.6223, -2.3486,  ..., -1.0157, -2.5531, -0.4632],
        [-1.5771,  0.7520, -2.1367,  ..., -0.7678, -2.0311,  0.5119],
        ...,
        [-1.6594,  0.4364, -1.8841,  ..., -0.8170, -2.0275, -0.0373],
        [-1.3813, -0.9890, -1.9632,  ..., -0.4885, -1.8539, -0.1139],
        [-1.5046, -0.1371, -1.6838,  ..., -0.6645, -1.4663,  0.4953]],
       device='cuda:0')

In [39]:
ret_dict['max_aa'], torch.max(ret_dict['max_aa'])

(tensor([[ 4.5843, -0.2627,  3.3414,  ...,  4.5563,  3.1521,  2.5759],
         [ 3.7239, -1.3365,  2.7194,  ...,  4.4296,  4.1109,  4.9951],
         [ 3.6726, -0.3401,  2.9561,  ...,  3.6110,  2.8570,  4.5387],
         ...,
         [ 3.2468, -1.0790,  3.5646,  ...,  6.2724,  3.1855,  4.2919],
         [ 4.0605,  0.6280,  4.0908,  ...,  5.7543,  2.9985,  4.3112],
         [ 3.9596, -0.5081,  4.3807,  ...,  5.1361,  3.1081,  6.5758]],
        device='cuda:0'),
 tensor(12.8109, device='cuda:0'))

In [40]:
ret_dict['min_aa'], torch.min(ret_dict['min_aa'])

(tensor([[-3.6509, -0.3219, -3.6857,  ..., -4.5419, -3.8238, -3.7708],
         [-4.3884, -5.1513, -3.2587,  ..., -4.9464, -4.6272, -4.9858],
         [-4.3462, -0.4373, -4.0781,  ..., -4.2818, -3.2861, -4.7124],
         ...,
         [-4.0636, -1.5609, -5.0044,  ..., -5.3047, -3.7902, -2.9418],
         [-4.0710, -0.1798, -3.2084,  ..., -3.8164, -3.3576, -5.4531],
         [-3.7850, -0.8200, -3.6491,  ..., -4.0723, -3.9764, -2.2733]],
        device='cuda:0'),
 tensor(-10.0609, device='cuda:0'))

In [41]:
ret_dict['max_bb'], torch.max(ret_dict['max_bb'])

(tensor([[ 4.2488,  1.0751,  3.8374,  ...,  7.8931,  4.2367, 17.8179],
         [ 4.4654,  3.3105,  3.2997,  ...,  8.9564,  3.5567, 16.5492],
         [ 4.1412,  1.1341,  4.1529,  ...,  7.5512,  4.0450, 21.9785],
         ...,
         [ 4.8427,  1.4033,  3.2717,  ...,  8.6600,  3.3806, 17.6681],
         [ 4.2635,  2.0862,  3.2561,  ...,  6.5544,  3.6153, 19.9038],
         [ 4.2699,  1.3731,  3.1090,  ...,  7.6452,  3.7716, 19.3252]],
        device='cuda:0'),
 tensor(58.3277, device='cuda:0'))

In [42]:
ret_dict['min_bb'], torch.min(ret_dict['min_bb'])

(tensor([[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0'),
 tensor(1., device='cuda:0'))