In [1]:
%load_ext autoreload
%autoreload 2
import torch
from wkv_kernel import WKV, WKVConfig
from einops import rearrange, reduce, repeat


  from .autonotebook import tqdm as notebook_tqdm


## Translate the recurrence formula from CUDA to torch

In [2]:
# Setup cuda kernel
wkv_cuda_kernel = WKV()

Using /system/user/beck/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /system/user/beck/.cache/torch_extensions/py310_cu117/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module wkv...


### Inputs

In [3]:
dtype = torch.float32
device = torch.device('cuda')

In [4]:
# original input shapes from training
batch_size = 12
seq_len = 512
embedding_dim = 512
time_decay = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
time_first = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
k = torch.randn((batch_size, seq_len, embedding_dim),
                dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
v = torch.randn((batch_size, seq_len, embedding_dim),
                device=device,
                dtype=dtype).to(memory_format=torch.contiguous_format)
y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)

In [5]:
# mock up for experiment
# batch_size = 2
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# time_first = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)


### CUDA version

In [6]:
wkv_cuda_kernel.wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay,
                             time_first, k, v, y_gt)

In [7]:
y_gt

tensor([[[-1.5429, -0.6767,  1.6157,  ...,  1.8309,  0.6268,  1.9657],
         [-1.0107,  2.2756,  1.2978,  ...,  0.5438, -0.5255,  1.6704],
         [-1.3397, -0.0462,  0.6959,  ...,  0.0256, -0.9058,  1.8682],
         ...,
         [-1.4292, -0.2980, -0.5059,  ..., -0.2350,  0.2973,  1.8906],
         [-1.4292, -0.2980, -0.7362,  ..., -0.2096,  0.0230,  1.8906],
         [-1.4292, -0.2980, -0.2837,  ..., -0.2654, -0.3144,  1.8906]],

        [[ 0.2562,  0.3100, -0.2654,  ...,  1.4200, -0.0231,  1.4399],
         [ 0.1513,  0.0361, -0.0853,  ...,  1.2115,  0.2009,  0.5032],
         [-0.0307,  0.0746,  0.2455,  ...,  0.6168,  0.8053,  0.8045],
         ...,
         [ 0.1693,  0.1902,  0.6899,  ...,  1.4375,  0.9854,  1.0094],
         [ 0.1693,  0.1902,  0.1949,  ...,  1.1635, -0.6814,  1.0094],
         [ 0.1693,  0.1902,  0.7529,  ...,  0.9096, -1.2734,  1.0094]],

        [[-0.9897,  0.0436, -1.5413,  ...,  1.3749, -0.7258, -0.0335],
         [ 0.6870, -0.2684, -1.5700,  ...,  1

### Reproduced torch version 1

In [8]:
# ChatGPT output:

# ww = u_timefirst[c] + k[b, i, c]
# p = max(pp, ww)
# e1 = exp(pp - p)
# e2 = exp(ww - p)
# y[b, i, c] = (e1 * aa + e2 * v[b, i, c]) / (e1 * bb + e2)
# ww = w_timedecay[c] + pp
# p = max(ww, k[b, i, c])
# e1 = exp(ww - p)
# e2 = exp(k[b, i, c] - p)
# aa = e1 * aa + e2 * v[b, i, c]
# bb = e1 * bb + e2
# pp = p


In [9]:
def cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    for b in range(batch_size):
        for c in range(embedding_dim):
            pp = MIN_VAL
            aa = 0
            bb = 0
            for i in range(seq_len):
                # ii = i * embedding_dim + c
                kk = k[b, i, c]
                vv = v[b, i, c]
                ww = time_first[c] + kk
                p = torch.tensor(max(pp, ww))
                e1 = torch.exp(pp - p)
                e2 = torch.exp(ww - p)
                new_y = (e1 * aa + e2 * vv) / (e1 * bb + e2)
                y[b, i, c] = new_y
                ww = time_decay[c] + pp
                p = torch.tensor(max(ww, kk))
                e1 = torch.exp(ww - p)
                e2 = torch.exp(kk - p) # uses current key
                aa = e1 * aa + e2 * vv
                bb = e1 * bb + e2
                pp = p

    return y

In [10]:
# y_m = cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
#                         time_first, k, v, y)
# y_m

### torch version v3 (own impl)

In [11]:
def cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [12]:
y_m3 = cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [13]:
y_m3

tensor([[[-1.5429, -0.6767,  1.6157,  ...,  1.8309,  0.6268,  1.9657],
         [-1.0107,  2.2756,  1.2978,  ...,  0.5438, -0.5255,  1.6704],
         [-1.3397, -0.0462,  0.6959,  ...,  0.0256, -0.9058,  1.8682],
         ...,
         [-1.4292, -0.2980, -0.5059,  ..., -0.2350,  0.2973,  1.8906],
         [-1.4292, -0.2980, -0.7362,  ..., -0.2096,  0.0230,  1.8906],
         [-1.4292, -0.2980, -0.2837,  ..., -0.2654, -0.3144,  1.8906]],

        [[ 0.2562,  0.3100, -0.2654,  ...,  1.4200, -0.0231,  1.4399],
         [ 0.1513,  0.0361, -0.0853,  ...,  1.2115,  0.2009,  0.5032],
         [-0.0307,  0.0746,  0.2455,  ...,  0.6168,  0.8053,  0.8045],
         ...,
         [ 0.1693,  0.1902,  0.6899,  ...,  1.4375,  0.9854,  1.0094],
         [ 0.1693,  0.1902,  0.1949,  ...,  1.1635, -0.6814,  1.0094],
         [ 0.1693,  0.1902,  0.7529,  ...,  0.9096, -1.2734,  1.0094]],

        [[-0.9897,  0.0436, -1.5413,  ...,  1.3749, -0.7258, -0.0335],
         [ 0.6870, -0.2684, -1.5700,  ...,  1

In [14]:
print(torch.allclose(y_gt, y_m3, atol=1e-6)), print(torch.allclose(y_gt, y_m3, atol=1e-7))

True
False


(None, None)

In [15]:
y_gt[:2, :2, :2], y_m3[:2, :2, :2]

(tensor([[[-1.5429, -0.6767],
          [-1.0107,  2.2756]],
 
         [[ 0.2562,  0.3100],
          [ 0.1513,  0.0361]]], device='cuda:0'),
 tensor([[[-1.5429, -0.6767],
          [-1.0107,  2.2756]],
 
         [[ 0.2562,  0.3100],
          [ 0.1513,  0.0361]]], device='cuda:0'))

#### Single computation steps

In [16]:
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
            device=device).to(memory_format=torch.contiguous_format)

In [17]:
batch_size, seq_len, embedding_dim

(12, 512, 512)

In [18]:
k.shape, v.shape, y.shape, time_decay.shape, time_first.shape

(torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([512]),
 torch.Size([512]))

In [19]:
k_ = rearrange(k, 'b s e -> s b e')
v_ = rearrange(v, 'b s e -> s b e')
y_ = rearrange(y, 'b s e -> s b e')
tf = repeat(time_first, 'e -> b e', b=batch_size)
td = repeat(time_decay, 'e -> b e', b=batch_size)

In [20]:
k_.shape, v_.shape, y.shape, td.shape, tf.shape

(torch.Size([512, 12, 512]),
 torch.Size([512, 12, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512]),
 torch.Size([12, 512]))

In [21]:
MIN_VAL = -1e-38
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)

In [22]:
t = 0

In [23]:
ww = tf + k_[t]


In [24]:
p = torch.max(pp, ww)

In [25]:
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)

In [26]:
# y[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)

In [27]:
ww = td + pp

In [28]:
ww

tensor([[ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185],
        [ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185],
        [ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185],
        ...,
        [ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185],
        [ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185],
        [ 1.3694,  1.8378, -1.3101,  ..., -0.2241, -1.0518,  1.3185]],
       device='cuda:0')

In [29]:
p = torch.max(ww, k_[t])

In [30]:
e1 = torch.exp(ww - p)
e2 = torch.exp(k_[t] - p)

In [31]:
aa = e1 * aa + e2 * v_[t]
bb = e1 * bb + e2

In [32]:
pp = p

### torch version v4 (own impl, simplify)

In [50]:
def cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay,
                       time_first, k, v, y):
    y = torch.zeros(batch_size,
                    seq_len,
                    embedding_dim,
                    dtype=dtype,
                    device=device).to(memory_format=torch.contiguous_format)
    # MIN_VAL = -1e38
    MIN_VAL = 0.0
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim),
                    MIN_VAL,
                    dtype=dtype,
                    device=device)
    # debug metrics
    max_pp = torch.full((batch_size, embedding_dim),
                        MIN_VAL,
                        dtype=dtype,
                        device=device)
    min_pp = torch.full((batch_size, embedding_dim),
                        -MIN_VAL,
                        dtype=dtype,
                        device=device)
    max_aa = torch.full((batch_size, embedding_dim),
                        MIN_VAL,
                        dtype=dtype,
                        device=device)
    min_aa = torch.full((batch_size, embedding_dim),
                        -MIN_VAL,
                        dtype=dtype,
                        device=device)
    max_bb = torch.full((batch_size, embedding_dim),
                        MIN_VAL,
                        dtype=dtype,
                        device=device)
    min_bb = torch.full((batch_size, embedding_dim),
                        -MIN_VAL,
                        dtype=dtype,
                        device=device)
    eps = pp
    for t in range(seq_len):
        # #! original version
        # ww = tf + k_[t]
        # p = torch.max(pp, ww)
        # e1 = torch.exp(pp - p)
        # e2 = torch.exp(ww - p)
        # y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        # ww = td + pp
        # p = torch.max(ww, k_[t])
        # e1 = torch.exp(ww - p)
        # e2 = torch.exp(k_[t] - p)
        # aa = e1 * aa + e2 * v_[t]
        # bb = e1 * bb + e2
        # pp = p

        # #? debug metrics
        # max_pp = torch.max(max_pp, pp)
        # min_pp = torch.min(min_pp, pp)
        # max_aa = torch.max(max_aa, aa)
        # min_aa = torch.min(min_aa, aa)
        # max_bb = torch.max(max_bb, bb)
        # min_bb = torch.min(min_bb, bb)

        #! version markus (v1) -> yields nan
        # eps_next = torch.max(td, eps)
        # y_[t] = (aa * torch.exp(eps-eps_next) + v_[t] * torch.exp(tf + k_[t])) / (bb * torch.exp(eps) + torch.exp(tf + k_[t]))

        # aa = torch.exp(-eps_next)*(aa * torch.exp(td+eps) + v_[t] * torch.exp(k_[t]))
        # bb = torch.exp(-eps_next)*(bb * torch.exp(td+eps) + torch.exp(k_[t]))
        # eps = eps_next

        #! version markus (v2) -> yields nan
        # y_[t] = (aa * torch.exp(eps) + v_[t] * torch.exp(tf + k_[t])) / (bb * torch.exp(eps) + torch.exp(tf + k_[t]))
        # eps_next = torch.max(td+eps, k_[t])

        # aa = torch.exp(-eps_next)*(aa * torch.exp(td+eps) + v_[t] * torch.exp(k_[t]))
        # bb = torch.exp(-eps_next)*(bb * torch.exp(td+eps) + torch.exp(k_[t]))
        # eps = eps_next
        y_[t] = (aa + v_[t] * torch.exp(tf + k_[t] - eps)) / (
            bb + torch.exp(tf + k_[t] - eps))
        eps_next = torch.max(td + eps, k_[t])

        aa = (aa * torch.exp(td + eps - eps_next) +
              v_[t] * torch.exp(k_[t] - eps_next))
        bb = (bb * torch.exp(td + eps - eps_next) +
              torch.exp(k_[t] - eps_next))
        eps = eps_next

    y = rearrange(y_, 's b e -> b s e')
    return y, {
        'max_pp': max_pp,
        'min_pp': min_pp,
        'max_aa': max_aa,
        'min_aa': min_aa,
        'max_bb': max_bb,
        'min_bb': min_bb
    }


In [51]:
y_m4, ret_dict = cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [52]:
y_m4

tensor([[[-1.5429, -0.6767,  1.6157,  ...,  1.8309,  0.6268,  1.9657],
         [-1.0107,  2.2756,  1.2978,  ...,  0.5438, -0.5255,  1.6704],
         [-1.3397, -0.0462,  0.6959,  ...,  0.0256, -0.9058,  1.8682],
         ...,
         [-1.4292, -0.2980, -0.5059,  ..., -0.2350,  0.2973,  1.8906],
         [-1.4292, -0.2980, -0.7362,  ..., -0.2096,  0.0230,  1.8906],
         [-1.4292, -0.2980, -0.2837,  ..., -0.2654, -0.3144,  1.8906]],

        [[ 0.2562,  0.3100, -0.2654,  ...,  1.4200, -0.0231,  1.4399],
         [ 0.1513,  0.0361, -0.0853,  ...,  1.2115,  0.2009,  0.5032],
         [-0.0307,  0.0746,  0.2455,  ...,  0.6168,  0.8053,  0.8045],
         ...,
         [ 0.1693,  0.1902,  0.6899,  ...,  1.4375,  0.9854,  1.0094],
         [ 0.1693,  0.1902,  0.1949,  ...,  1.1635, -0.6814,  1.0094],
         [ 0.1693,  0.1902,  0.7529,  ...,  0.9096, -1.2734,  1.0094]],

        [[-0.9897,  0.0436, -1.5413,  ...,  1.3749, -0.7258, -0.0335],
         [ 0.6870, -0.2684, -1.5700,  ...,  1

In [55]:
print(torch.allclose(y_gt, y_m4, atol=1e-6), y_m4.isnan().any(), y_gt.isnan().any())

True tensor(False, device='cuda:0') tensor(False, device='cuda:0')


#### Look at max and min values of running sums
In each step we use elementwise max/min operations. 

In [38]:
ret_dict['max_pp']

tensor([[-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38],
        [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38],
        [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38],
        ...,
        [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38],
        [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38],
        [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
         -1.0000e+38, -1.0000e+38]], device='cuda:0')

In [39]:
ret_dict['min_pp']

tensor([[1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38],
        [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38],
        [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38],
        ...,
        [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38],
        [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38],
        [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
         1.0000e+38]], device='cuda:0')

In [40]:
ret_dict['max_aa'], torch.max(ret_dict['max_aa'])

(tensor([[-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         ...,
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38]], device='cuda:0'),
 tensor(-1.0000e+38, device='cuda:0'))

In [41]:
ret_dict['min_aa'], torch.min(ret_dict['min_aa'])

(tensor([[1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         ...,
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38]], device='cuda:0'),
 tensor(1.0000e+38, device='cuda:0'))

In [42]:
ret_dict['max_bb'], torch.max(ret_dict['max_bb'])

(tensor([[-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         ...,
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38],
         [-1.0000e+38, -1.0000e+38, -1.0000e+38,  ..., -1.0000e+38,
          -1.0000e+38, -1.0000e+38]], device='cuda:0'),
 tensor(-1.0000e+38, device='cuda:0'))

In [43]:
ret_dict['min_bb'], torch.min(ret_dict['min_bb'])

(tensor([[1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         ...,
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38],
         [1.0000e+38, 1.0000e+38, 1.0000e+38,  ..., 1.0000e+38, 1.0000e+38,
          1.0000e+38]], device='cuda:0'),
 tensor(1.0000e+38, device='cuda:0'))

### understand wkv try 1


In [44]:
# batch_size = 1
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = 0.1 * (1+torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format))
# time_first = 2* (1+torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format))
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = 5-torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)

In [45]:
time_decay, time_first, k, v

(tensor([ 1.3694e+00,  1.8378e+00, -1.3101e+00,  1.4257e+00,  2.5402e-01,
         -1.5066e+00,  1.2010e+00,  3.1924e-01,  2.4941e-01, -1.3954e+00,
          2.7700e-01, -6.3398e-01, -1.1906e+00, -5.8450e-01, -4.5360e-01,
          4.6626e-01, -2.0136e-01, -6.9353e-01,  1.6446e+00,  9.6092e-01,
         -1.2042e-01, -8.5651e-01, -1.1684e+00, -9.3194e-02,  4.4826e-01,
         -1.2083e+00,  4.6775e-01, -6.7637e-01, -1.2976e+00,  7.2480e-02,
         -8.7845e-01, -1.5823e-01,  1.2346e+00, -5.7721e-01, -7.9356e-01,
         -5.8860e-01, -1.8525e+00,  2.2327e+00, -6.0349e-01, -2.0254e+00,
          4.4815e-01,  7.3291e-01, -9.9053e-01, -1.3270e+00,  4.1145e-01,
          6.7139e-01,  4.6784e-01,  4.0972e-01, -1.0177e-01, -1.5421e+00,
         -2.5567e-01,  7.2136e-01, -1.3837e+00, -6.0545e-01,  1.7030e+00,
          4.0703e-02,  5.2037e-01,  8.7496e-01, -1.1545e+00,  1.0461e+00,
         -2.4928e-01,  1.3595e-01, -1.1185e+00, -2.1287e-01, -1.1626e-01,
         -2.1424e-01,  6.4225e-01,  2.

In [46]:
def cuda_forward_understand(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    # debug metrics
    max_pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_pp = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_aa = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_aa = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_bb = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_bb = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    eps = pp
    for t in range(seq_len):
        #! original version
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
        print(f'{t}: pp {pp}, aa {aa}, bb {bb}, y {y_[t]}')
        #? debug metrics
        max_pp = torch.max(max_pp, pp)
        min_pp = torch.min(min_pp, pp)
        max_aa = torch.max(max_aa, aa)
        min_aa = torch.min(min_aa, aa)
        max_bb = torch.max(max_bb, bb)
        min_bb = torch.min(min_bb, bb)

    y = rearrange(y_, 's b e -> b s e')
    return y, {'max_pp': max_pp, 'min_pp': min_pp, 'max_aa': max_aa, 'min_aa': min_aa, 'max_bb': max_bb, 'min_bb': min_bb}

In [47]:
y, ret_dict = cuda_forward_understand(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

0: pp tensor([[ 0.7398, -0.2714, -1.2947,  ..., -0.3909, -0.1185,  1.7720],
        [ 1.2133, -0.8909, -0.9273,  ...,  0.0924,  0.0507,  0.1870],
        [-0.9431,  1.6839,  1.5189,  ..., -0.1875, -0.8080, -1.0245],
        ...,
        [-0.7545, -1.2513, -0.1554,  ..., -0.3396, -0.8264,  0.3319],
        [-1.5769, -0.3666,  1.4427,  ...,  0.1723, -0.3834, -1.4233],
        [-0.5988, -0.0907, -0.4160,  ..., -0.1415,  0.4114,  0.3355]],
       device='cuda:0'), aa tensor([[-1.5429, -0.6767,  1.6157,  ...,  1.8309,  0.6268,  1.9657],
        [ 0.2562,  0.3100, -0.2654,  ...,  1.4200, -0.0231,  1.4399],
        [-0.9897,  0.0436, -1.5413,  ...,  1.3749, -0.7258, -0.0335],
        ...,
        [ 0.6461,  0.4635,  1.4452,  ...,  0.0926, -0.8716, -1.4159],
        [ 0.2610, -1.5064, -0.9183,  ..., -0.5128,  1.0463,  0.1178],
        [ 2.7079,  0.5586,  0.9636,  ...,  0.9917, -0.2886,  1.2457]],
       device='cuda:0'), bb tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1.,

In [48]:
y

tensor([[[-1.5429, -0.6767,  1.6157,  ...,  1.8309,  0.6268,  1.9657],
         [-1.0107,  2.2756,  1.2978,  ...,  0.5438, -0.5255,  1.6704],
         [-1.3397, -0.0462,  0.6959,  ...,  0.0256, -0.9058,  1.8682],
         ...,
         [-1.4292, -0.2980, -0.5059,  ..., -0.2350,  0.2973,  1.8906],
         [-1.4292, -0.2980, -0.7362,  ..., -0.2096,  0.0230,  1.8906],
         [-1.4292, -0.2980, -0.2837,  ..., -0.2654, -0.3144,  1.8906]],

        [[ 0.2562,  0.3100, -0.2654,  ...,  1.4200, -0.0231,  1.4399],
         [ 0.1513,  0.0361, -0.0853,  ...,  1.2115,  0.2009,  0.5032],
         [-0.0307,  0.0746,  0.2455,  ...,  0.6168,  0.8053,  0.8045],
         ...,
         [ 0.1693,  0.1902,  0.6899,  ...,  1.4375,  0.9854,  1.0094],
         [ 0.1693,  0.1902,  0.1949,  ...,  1.1635, -0.6814,  1.0094],
         [ 0.1693,  0.1902,  0.7529,  ...,  0.9096, -1.2734,  1.0094]],

        [[-0.9897,  0.0436, -1.5413,  ...,  1.3749, -0.7258, -0.0335],
         [ 0.6870, -0.2684, -1.5700,  ...,  1

In [49]:
print(torch.allclose(y_gt, y_m4, atol=1e-5))

False
