In [1]:
%load_ext autoreload
%autoreload 2
import torch
from wkv_kernel import WKV, WKVConfig
from einops import rearrange, reduce, repeat


  from .autonotebook import tqdm as notebook_tqdm


## Translate the recurrence formula from CUDA to torch

In [2]:
# Setup cuda kernel
wkv_cuda_kernel = WKV()

Using /system/user/beck/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /system/user/beck/.cache/torch_extensions/py310_cu117/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module wkv...


### Inputs

In [3]:
dtype = torch.float32
device = torch.device('cuda')

In [4]:
# original input shapes from training
batch_size = 12
seq_len = 512
embedding_dim = 512
time_decay = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
time_first = torch.randn(
    (embedding_dim, ), dtype=dtype,
    device=device).to(memory_format=torch.contiguous_format)
k = torch.randn((batch_size, seq_len, embedding_dim),
                dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
v = torch.randn((batch_size, seq_len, embedding_dim),
                device=device,
                dtype=dtype).to(memory_format=torch.contiguous_format)
y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)

In [5]:
# mock up for experiment
# batch_size = 2
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# time_first = torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format)
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)


### CUDA version

In [6]:
wkv_cuda_kernel.wkv_cuda.forward(batch_size, seq_len, embedding_dim, time_decay,
                             time_first, k, v, y_gt)

In [7]:
y_gt

tensor([[[ 0.3675,  0.3255,  1.3535,  ..., -1.0765,  0.3313, -1.2623],
         [ 0.3631,  0.2273,  0.2391,  ..., -0.2573,  0.2292, -1.1041],
         [ 0.4852,  0.1432, -0.6903,  ...,  0.0893,  0.2245,  0.4664],
         ...,
         [-0.5824,  0.1415, -1.5711,  ...,  1.4796, -0.1994,  0.0057],
         [-0.0897,  0.0640, -0.1779,  ..., -0.7080,  0.1719,  1.2882],
         [ 0.1788, -0.1008,  1.0149,  ..., -1.0478, -0.6345,  1.1104]],

        [[ 0.4482,  1.8796,  1.0367,  ...,  0.1798,  0.1110, -0.6689],
         [ 0.3965,  1.0164,  0.6041,  ...,  0.4157, -0.1279,  0.2094],
         [-0.0484,  0.4225, -0.0105,  ...,  0.6572, -0.2714, -1.1611],
         ...,
         [-0.4828, -0.5803, -0.0811,  ..., -0.2196,  0.2380, -0.5591],
         [ 0.5138, -0.5213,  0.5444,  ..., -0.0929,  0.5993,  0.4539],
         [-0.4556, -0.4740,  0.4532,  ..., -0.3167, -0.3081,  0.1459]],

        [[-0.3431, -1.2827, -0.1383,  ..., -0.6697,  1.0710,  1.2622],
         [-0.3424, -1.3090, -0.0769,  ..., -0

### Reproduced torch version 1

In [8]:
# ChatGPT output:

# ww = u_timefirst[c] + k[b, i, c]
# p = max(pp, ww)
# e1 = exp(pp - p)
# e2 = exp(ww - p)
# y[b, i, c] = (e1 * aa + e2 * v[b, i, c]) / (e1 * bb + e2)
# ww = w_timedecay[c] + pp
# p = max(ww, k[b, i, c])
# e1 = exp(ww - p)
# e2 = exp(k[b, i, c] - p)
# aa = e1 * aa + e2 * v[b, i, c]
# bb = e1 * bb + e2
# pp = p


In [9]:
def cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    for b in range(batch_size):
        for c in range(embedding_dim):
            pp = MIN_VAL
            aa = 0
            bb = 0
            for i in range(seq_len):
                # ii = i * embedding_dim + c
                kk = k[b, i, c]
                vv = v[b, i, c]
                ww = time_first[c] + kk
                p = torch.tensor(max(pp, ww))
                e1 = torch.exp(pp - p)
                e2 = torch.exp(ww - p)
                new_y = (e1 * aa + e2 * vv) / (e1 * bb + e2)
                y[b, i, c] = new_y
                ww = time_decay[c] + pp
                p = torch.tensor(max(ww, kk))
                e1 = torch.exp(ww - p)
                e2 = torch.exp(kk - p) # uses current key
                aa = e1 * aa + e2 * vv
                bb = e1 * bb + e2
                pp = p

    return y

In [10]:
# y_m = cuda_forward_mock(batch_size, seq_len, embedding_dim, time_decay,
#                         time_first, k, v, y)
# y_m

### torch version v3 (own impl)

In [11]:
def cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    for t in range(seq_len):
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
    y = rearrange(y_, 's b e -> b s e')
    return y

In [12]:
y_m3 = cuda_forward_mock3(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [13]:
y_m3

tensor([[[ 0.3675,  0.3255,  1.3535,  ..., -1.0765,  0.3313, -1.2623],
         [ 0.3631,  0.2273,  0.2391,  ..., -0.2573,  0.2292, -1.1041],
         [ 0.4852,  0.1432, -0.6903,  ...,  0.0893,  0.2245,  0.4664],
         ...,
         [-0.5824,  0.1415, -1.5711,  ...,  1.4796, -0.1994,  0.0057],
         [-0.0897,  0.0640, -0.1779,  ..., -0.7080,  0.1719,  1.2882],
         [ 0.1788, -0.1008,  1.0149,  ..., -1.0478, -0.6345,  1.1104]],

        [[ 0.4482,  1.8796,  1.0367,  ...,  0.1798,  0.1110, -0.6689],
         [ 0.3965,  1.0164,  0.6041,  ...,  0.4157, -0.1279,  0.2094],
         [-0.0484,  0.4225, -0.0105,  ...,  0.6572, -0.2714, -1.1611],
         ...,
         [-0.4828, -0.5803, -0.0811,  ..., -0.2196,  0.2380, -0.5591],
         [ 0.5138, -0.5213,  0.5444,  ..., -0.0929,  0.5993,  0.4539],
         [-0.4556, -0.4740,  0.4532,  ..., -0.3167, -0.3081,  0.1459]],

        [[-0.3431, -1.2827, -0.1383,  ..., -0.6697,  1.0710,  1.2622],
         [-0.3424, -1.3090, -0.0769,  ..., -0

In [14]:
print(torch.allclose(y_gt, y_m3, atol=1e-6)), print(torch.allclose(y_gt, y_m3, atol=1e-7))

True
False


(None, None)

In [15]:
y_gt[:2, :2, :2], y_m3[:2, :2, :2]

(tensor([[[0.3675, 0.3255],
          [0.3631, 0.2273]],
 
         [[0.4482, 1.8796],
          [0.3965, 1.0164]]], device='cuda:0'),
 tensor([[[0.3675, 0.3255],
          [0.3631, 0.2273]],
 
         [[0.4482, 1.8796],
          [0.3965, 1.0164]]], device='cuda:0'))

#### Single computation steps

In [16]:
y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
            device=device).to(memory_format=torch.contiguous_format)

In [17]:
batch_size, seq_len, embedding_dim

(12, 512, 512)

In [18]:
k.shape, v.shape, y.shape, time_decay.shape, time_first.shape

(torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([512]),
 torch.Size([512]))

In [19]:
k_ = rearrange(k, 'b s e -> s b e')
v_ = rearrange(v, 'b s e -> s b e')
y_ = rearrange(y, 'b s e -> s b e')
tf = repeat(time_first, 'e -> b e', b=batch_size)
td = repeat(time_decay, 'e -> b e', b=batch_size)

In [20]:
k_.shape, v_.shape, y.shape, td.shape, tf.shape

(torch.Size([512, 12, 512]),
 torch.Size([512, 12, 512]),
 torch.Size([12, 512, 512]),
 torch.Size([12, 512]),
 torch.Size([12, 512]))

In [21]:
MIN_VAL = -1e-38
aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)

In [22]:
t = 0

In [23]:
ww = tf + k_[t]


In [24]:
p = torch.max(pp, ww)

In [25]:
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)

In [26]:
# y[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)

In [27]:
ww = td + pp

In [28]:
ww

tensor([[-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982],
        [-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982],
        [-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982],
        ...,
        [-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982],
        [-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982],
        [-1.6768, -0.0923, -1.0908,  ..., -1.0797, -0.4115, -2.0982]],
       device='cuda:0')

In [29]:
p = torch.max(ww, k_[t])

In [30]:
e1 = torch.exp(ww - p)
e2 = torch.exp(k_[t] - p)

In [31]:
aa = e1 * aa + e2 * v_[t]
bb = e1 * bb + e2

In [32]:
pp = p

### torch version v4 (own impl, simplify)

In [33]:
def cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    # debug metrics
    max_pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_pp = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_aa = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_aa = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_bb = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_bb = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    eps = pp
    for t in range(seq_len):
        # #! original version
        # ww = tf + k_[t]
        # p = torch.max(pp, ww)
        # e1 = torch.exp(pp - p)
        # e2 = torch.exp(ww - p)
        # y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        # ww = td + pp
        # p = torch.max(ww, k_[t])
        # e1 = torch.exp(ww - p)
        # e2 = torch.exp(k_[t] - p)
        # aa = e1 * aa + e2 * v_[t]
        # bb = e1 * bb + e2
        # pp = p

        # #? debug metrics
        # max_pp = torch.max(max_pp, pp)
        # min_pp = torch.min(min_pp, pp)
        # max_aa = torch.max(max_aa, aa)
        # min_aa = torch.min(min_aa, aa)
        # max_bb = torch.max(max_bb, bb)
        # min_bb = torch.min(min_bb, bb)
        
        #! version markus (v1) -> yields nan
        # eps_next = torch.max(td, eps)
        # y_[t] = (aa * torch.exp(eps-eps_next) + v_[t] * torch.exp(tf + k_[t])) / (bb * torch.exp(eps) + torch.exp(tf + k_[t]))

        # aa = torch.exp(-eps_next)*(aa * torch.exp(td+eps) + v_[t] * torch.exp(k_[t]))
        # bb = torch.exp(-eps_next)*(bb * torch.exp(td+eps) + torch.exp(k_[t]))
        # eps = eps_next
        
        #! version markus (v2) -> yields nan
        y_[t] = (aa * torch.exp(eps) + v_[t] * torch.exp(tf + k_[t])) / (bb * torch.exp(eps) + torch.exp(tf + k_[t]))
        eps_next = torch.max(td+eps, k_[t])

        aa = torch.exp(-eps_next)*(aa * torch.exp(td+eps) + v_[t] * torch.exp(k_[t]))
        bb = torch.exp(-eps_next)*(bb * torch.exp(td+eps) + torch.exp(k_[t]))
        eps = eps_next

        #! original version (without p)
        # ww = tf + k_[t]
        # p = torch.max(pp, ww)
        # e1 = torch.exp(pp)
        # e2 = torch.exp(ww)
        # y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        # ww = td + pp
        # p = torch.max(ww, k_[t])
        # e1 = torch.exp(ww)
        # e2 = torch.exp(k_[t])
        # aa = e1 * aa + e2 * v_[t]
        # bb = e1 * bb + e2
        # pp = p

    y = rearrange(y_, 's b e -> b s e')
    return y, {'max_pp': max_pp, 'min_pp': min_pp, 'max_aa': max_aa, 'min_aa': min_aa, 'max_bb': max_bb, 'min_bb': min_bb}

In [34]:
y_m4, ret_dict = cuda_forward_mock4(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

In [35]:
y_m4

tensor([[[ 0.3675,  0.3255,  1.3535,  ..., -1.0765,  0.3313, -1.2623],
         [ 0.3631,  0.2273,  0.2391,  ..., -0.2573,  0.2292, -1.1041],
         [ 0.4852,  0.1432, -0.6903,  ...,  0.0893,  0.2245,  0.4664],
         ...,
         [-0.5824,  0.1415, -1.5711,  ...,  1.4796, -0.1994,  0.0057],
         [-0.0897,  0.0640, -0.1779,  ..., -0.7080,  0.1719,  1.2882],
         [ 0.1788, -0.1008,  1.0149,  ..., -1.0478, -0.6345,  1.1104]],

        [[ 0.4482,  1.8796,  1.0367,  ...,  0.1798,  0.1110, -0.6689],
         [ 0.3965,  1.0164,  0.6041,  ...,  0.4157, -0.1279,  0.2094],
         [-0.0484,  0.4225, -0.0105,  ...,  0.6572, -0.2714, -1.1611],
         ...,
         [-0.4828, -0.5803, -0.0811,  ..., -0.2196,  0.2380, -0.5591],
         [ 0.5138, -0.5213,  0.5444,  ..., -0.0929,  0.5993,  0.4539],
         [-0.4556, -0.4740,  0.4532,  ..., -0.3167, -0.3081,  0.1459]],

        [[-0.3431, -1.2827, -0.1383,  ..., -0.6697,  1.0710,  1.2622],
         [-0.3424, -1.3090, -0.0769,  ..., -0

In [44]:
print(torch.allclose(y_gt, y_m4, atol=1e-6), y_m4.isnan().any(), y_gt.isnan().any())

False tensor(True, device='cuda:0') tensor(False, device='cuda:0')


In [None]:
#! markus version
# B = 1
# T = 512
# C = 1

# _w = torch.randn((C))
# _u = torch.randn((C))
# _k = torch.randn((B, T, C))
# _v = torch.randn((B, T, C))

# min_value = torch.tensor(-1e38).float()

# w = -torch.exp(_w.float())
# u = _u.float()
# k = _k.float()
# v = _v.float()
# y = torch.empty(B, T, C)


# for c in range(C):
#     eps = min_value
#     aa = 0
#     bb = 0
#     for t in range(T):
#         kk = k[0,t,c]
#         vv = v[0,t,c]

#         y[0,t,c] = (torch.exp(eps) * aa + torch.exp(u[c] + kk) * vv) / (torch.exp(eps) * bb + torch.exp(u[c] + kk))

#         eps_next = torch.max(w[c] + eps, kk)

#         aa = torch.exp(-eps_next) * (torch.exp(w[c] + eps) * aa + torch.exp(kk) * vv)
#         bb = torch.exp(-eps_next) * (torch.exp(w[c] + eps) * bb + torch.exp(kk))
#         eps = eps_next

# display(any(y.view(-1).isnan().tolist()))

#### Look at max and min values of running sums
In each step we use elementwise max/min operations. 

In [37]:
ret_dict['max_pp']

tensor([[962.1312, 420.1057, 299.8172,  ...,   3.4283, 567.4110,   3.1173],
        [963.3484, 420.1945, 299.9291,  ...,   2.6842, 567.6777,   3.2462],
        [963.5245, 421.1901, 299.7128,  ...,   3.0347, 567.0105,   3.5679],
        ...,
        [961.6948, 419.0993, 299.4191,  ...,   2.8203, 566.3354,   3.6480],
        [963.6456, 420.1425, 299.1628,  ...,   3.1260, 565.8119,   2.1868],
        [964.7355, 420.5218, 300.3070,  ...,   2.9727, 566.9648,   2.5999]],
       device='cuda:0')

In [38]:
ret_dict['min_pp']

tensor([[-1.3030, -1.8639,  0.5374,  ..., -2.2211,  0.4580, -0.7061],
        [ 0.1641, -0.2136, -2.0325,  ..., -2.5894,  0.7247, -0.9738],
        [ 0.3401,  0.7821,  0.8226,  ..., -2.2132,  0.0575, -0.4266],
        ...,
        [-1.6317, -1.3087, -0.3186,  ..., -2.2139, -0.6175, -0.6660],
        [-0.6776, -0.2841,  0.2726,  ..., -2.2243, -1.1411, -0.3897],
        [ 1.5512,  0.1138,  1.4169,  ..., -3.0105,  0.0118, -0.6453]],
       device='cuda:0')

In [39]:
ret_dict['max_aa'], torch.max(ret_dict['max_aa'])

(tensor([[ 0.3175,  0.7887, -0.0513,  ...,  3.2844, -2.4227,  4.3941],
         [-1.1903, -0.6233,  1.1964,  ...,  3.4967, -0.0087,  4.6466],
         [ 0.7404,  0.5722,  2.1244,  ...,  3.3701,  0.3867,  3.8458],
         ...,
         [ 0.3901,  0.3522, -0.3350,  ...,  3.5138,  1.4386,  3.9051],
         [-0.1157,  1.1631,  2.0237,  ...,  2.9620,  0.5157,  4.9913],
         [-0.0926,  2.0300,  1.2441,  ...,  3.3179, -0.5801,  2.9114]],
        device='cuda:0'),
 tensor(13.5066, device='cuda:0'))

In [40]:
ret_dict['min_aa'], torch.min(ret_dict['min_aa'])

(tensor([[-0.1641,  0.1188, -0.8389,  ..., -4.0240, -2.4821, -3.7610],
         [-1.2051, -1.0781,  0.7027,  ..., -2.9647, -0.2947, -4.1991],
         [ 0.6287,  0.3607,  1.0143,  ..., -2.7821, -0.2213, -4.3683],
         ...,
         [-0.0684, -0.6464, -0.8603,  ..., -4.2799,  1.0761, -3.8674],
         [-1.6427, -0.0430,  1.7798,  ..., -3.2646, -0.2280, -3.8984],
         [-0.1059,  1.6882,  0.3863,  ..., -3.2564, -0.7386, -5.3694]],
        device='cuda:0'),
 tensor(-7.8673, device='cuda:0'))

In [41]:
ret_dict['max_bb'], torch.max(ret_dict['max_bb'])

(tensor([[1.8313, 1.9966, 2.3150,  ..., 2.8648, 1.1176, 8.9634],
         [1.0387, 1.8036, 1.6067,  ..., 2.7565, 1.3202, 9.2702],
         [1.2204, 1.3959, 2.5086,  ..., 2.6329, 1.4677, 8.5003],
         ...,
         [2.0314, 2.1810, 1.7666,  ..., 2.9097, 1.6132, 8.1550],
         [1.3392, 2.9483, 2.1127,  ..., 2.5649, 2.2604, 8.0671],
         [1.0262, 2.2889, 1.9205,  ..., 2.6718, 1.2837, 9.3124]],
        device='cuda:0'),
 tensor(59.7912, device='cuda:0'))

In [42]:
ret_dict['min_bb'], torch.min(ret_dict['min_bb'])

(tensor([[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0'),
 tensor(1., device='cuda:0'))

### understand wkv try 1


In [81]:
# batch_size = 1
# seq_len = 4
# embedding_dim = 2
# entries = batch_size * seq_len * embedding_dim
# time_decay = 0.1 * (1+torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format))
# time_first = 2* (1+torch.arange(embedding_dim, dtype=dtype,
#                           device=device).reshape(embedding_dim).to(
#                               memory_format=torch.contiguous_format))
# k = torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# v = 5-torch.arange(entries, dtype=dtype, device=device).reshape(
#     batch_size, seq_len,
#     embedding_dim).to(memory_format=torch.contiguous_format)
# y_gt = torch.empty(batch_size, seq_len, embedding_dim, dtype=dtype,
#                 device=device).to(memory_format=torch.contiguous_format)

In [82]:
time_decay, time_first, k, v

(tensor([0.1000, 0.2000], device='cuda:0'),
 tensor([2., 4.], device='cuda:0'),
 tensor([[[0., 1.],
          [2., 3.],
          [4., 5.],
          [6., 7.]]], device='cuda:0'),
 tensor([[[ 5.,  4.],
          [ 3.,  2.],
          [ 1.,  0.],
          [-1., -2.]]], device='cuda:0'))

In [83]:
def cuda_forward_understand(batch_size, seq_len, embedding_dim, time_decay,
                      time_first, k, v, y):
    y = torch.zeros(batch_size, seq_len, embedding_dim, dtype=dtype,
                device=device).to(memory_format=torch.contiguous_format)
    MIN_VAL = -1e38
    # reshape inputs
    k_ = rearrange(k, 'b s e -> s b e')
    v_ = rearrange(v, 'b s e -> s b e')
    y_ = rearrange(y, 'b s e -> s b e')
    tf = repeat(time_first, 'e -> b e', b=batch_size)
    td = repeat(time_decay, 'e -> b e', b=batch_size)
    # running sums
    aa = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    bb = torch.zeros(batch_size, embedding_dim, dtype=dtype, device=device)
    pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    # debug metrics
    max_pp = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_pp = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_aa = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_aa = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    max_bb = torch.full((batch_size, embedding_dim), MIN_VAL, dtype=dtype, device=device)
    min_bb = torch.full((batch_size, embedding_dim), -MIN_VAL, dtype=dtype, device=device)
    eps = pp
    for t in range(seq_len):
        #! original version
        ww = tf + k_[t]
        p = torch.max(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y_[t] = (e1 * aa + e2 * v_[t]) / (e1 * bb + e2)
        ww = td + pp
        p = torch.max(ww, k_[t])
        e1 = torch.exp(ww - p)
        e2 = torch.exp(k_[t] - p)
        aa = e1 * aa + e2 * v_[t]
        bb = e1 * bb + e2
        pp = p
        print(f'{t}: pp {pp}, aa {aa}, bb {bb}, y {y_[t]}')
        #? debug metrics
        max_pp = torch.max(max_pp, pp)
        min_pp = torch.min(min_pp, pp)
        max_aa = torch.max(max_aa, aa)
        min_aa = torch.min(min_aa, aa)
        max_bb = torch.max(max_bb, bb)
        min_bb = torch.min(min_bb, bb)

    y = rearrange(y_, 's b e -> b s e')
    return y, {'max_pp': max_pp, 'min_pp': min_pp, 'max_aa': max_aa, 'min_aa': min_aa, 'max_bb': max_bb, 'min_bb': min_bb}

In [84]:
y, ret_dict = cuda_forward_understand(batch_size, seq_len, embedding_dim, time_decay, time_first, k, v, y_gt)

0: pp tensor([[0., 1.]], device='cuda:0'), aa tensor([[5., 4.]], device='cuda:0'), bb tensor([[1., 1.]], device='cuda:0'), y tensor([[5., 4.]], device='cuda:0')
1: pp tensor([[2., 3.]], device='cuda:0'), aa tensor([[3.7478, 2.6612]], device='cuda:0'), bb tensor([[1.1496, 1.1653]], device='cuda:0'), y tensor([[3.0360, 2.0049]], device='cuda:0')
2: pp tensor([[4., 5.]], device='cuda:0'), aa tensor([[1.5606, 0.4399]], device='cuda:0'), bb tensor([[1.1719, 1.1926]], device='cuda:0'), y tensor([[1.0466, 0.0066]], device='cuda:0')
3: pp tensor([[6., 7.]], device='cuda:0'), aa tensor([[-0.7666, -1.9273]], device='cuda:0'), bb tensor([[1.1753, 1.1971]], device='cuda:0'), y tensor([[-0.9510, -1.9930]], device='cuda:0')


In [85]:
y

tensor([[[ 5.0000,  4.0000],
         [ 3.0360,  2.0049],
         [ 1.0466,  0.0066],
         [-0.9510, -1.9930]]], device='cuda:0')

In [None]:
print(torch.allclose(y_gt, y_m4, atol=1e-5))