# RWKV 6.X to Triton Port

This notebook is designeed to help while porting RWKV6+ to triton. It first validates the algorithm in torch before porting to triton. Testbed for experiments

In [5]:
import torch
import torch as th

### Helpers

In [6]:
def gen_inputs(B, H, L, K, V): 
    th.manual_seed(17)
    device = "cpu"
    rt = th.randn(B, H, L, K, device=device, requires_grad=True)
    kt = th.randn(B, H, L, K, device=device, requires_grad=True)
    vt = th.randn(B, H, L, V, device=device, requires_grad=True)
    wt = th.randn(B, H, L, K, V, device=device, requires_grad=True)
    ut = th.randn(H, K, device=device, requires_grad=True)
    return rt, kt, vt, wt, ut

## Diagnosis of 18/04/2024
* [ ] [Songlin's code](https://github.com/sustcsonglin/flash-linear-attention/commit/fee90b2e72366a46c60e3ef16431133aa5aced8d) is wrong for the backward pass of `W`. Although it is hidden for normally-distributed values and high number of `V`. 
* [ ] TODO: extend U and W to support matrices. Experiments demonstrate it works wonders. Especially later in training.
    * [This paper: The illusion of State in State Space Models](https://arxiv.org/pdf/2404.08819.pdf) demonstrates that a `matrix_cumprod` would give superior modelling ability. The matrix has to be non-diagonal as otherwise it would be a `cum_matmul` which would still be in `TC0` and not `NC1`

#### rwkv_inner (from gptcore)

In [7]:
import torch as th
import torch.nn.functional as F

def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24,precision_dtype:th.dtype=th.float32):
    """
    expects
    r : (B,H,L,K)
    k : (B,H,L,K)
    v : (B,H,L,V)
    w : (B,H,L,K) or (1,H,L,K)
    u : (1,H,1,K)
    kv_state : (B,H,K,V)
    """
    B,H,L,K = k.size()
    V = v.size(-1)
    T = chunk_len

    if L == 1:
        kv = k.mT @ v
        out = r @ (kv_state + u.mT * kv)
        kv_state = w.mT * kv_state + kv
        return out, kv_state
    else:
        # FIXME - support fast path for non-exact multiples
        # ensure it's an exact multiple
        if L % T != 0:
            T = 1

        N = L // T

        # this has to be done to avoid numerical instability (inf/NaN) when w is used as a divisor up to chunk_length//2 places away (so precision_min_val^(T//2) has to be in fp range)
        # NOTE - this does not account for the impact of the size of R, K so we currently use the chunk_len=32 numbers for chunk_len=24
        assert(precision_dtype == th.float32 or precision_dtype == th.float64)
        if precision_dtype == th.float32:
            precision_min_val = 0.005 # good for fp32 (1.175e-38 ^ (1/16.0) < 0.00426)
        else: #elif precision_dtype == torch.float64:
            precision_min_val = 1e-10 # good for fp64 (1.7e-308 ^ (1/16.0) < 5.8e-20)
        w = w.clamp(precision_min_val)

        # calculate cumulative decay in log space where it won't overflow
        w_log = w.float().log() # (1,H,L,K) or (B,H,L,K)

        # chunked view of w_log
        wc_log = w_log.view(w.size(0),H,N,T,K)
        wc_log_cum = wc_log.cumsum(dim=-2)

        # chunked view of shifted_w_log
        shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1))


        # NOTE - we have to apply the decay weight from TWO ahead.. ONE ahead gets no decay (log==0)
        # pre-applied weights
        # left side is prior chunk (w_inter), right side is current chunk (w_intra)
        # without u...
        # w0   w1   w2   w3   | w4   w5   w6   w7          
        # w1:4 w2:4 w3:4 w4:4 | w4:5 w4:6 w4:7 w4:8
        # with u...
        # w0   w1   w2   w3   | w4   w5   w6   w7          
        # w1:4 w2:4 w3:4 w4:4 | w4:4 w4:5 w4:6 w4:7

        # ws decays the entire current state (representing t-1) to the prior block (t-2)
        ws = wc_log.sum(dim=-2, keepdim=True) # 1HN1K or BHN1K
        # w_inter is the decay to the end of the current block, since it will be applied at the next iteration when current (t) becomes prior (t-1)
        # this formula because e.g. w1:4 = w0:4 - w0:1
        w_inter = ws - wc_log_cum # 1HNTK or BHNTK (w^(T-1) ... w^0)
        # w_intra is the decay from the beginning of the current block (t), since it will be applied to current queries (t) against prior state (representing keys+values up to but not including block t)
        # this formula because e.g. w1:3 = w0:3 - w0
        w_intra = wc_log_cum - wc_log # 1HNTK or BHNTK (w^0 ... w^(T-2))

        ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3)) # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!!
        w_inter = w_inter.exp().to(r.dtype) # 1HNTK or BHNTK
        w_intra = w_intra.exp().to(r.dtype) # 1HNTK or BHNTK

        # chunked view of r, k, v
        r = r.view(B,H,N,T,K) 
        k = k.view(B,H,N,T,K) 
        v = v.view(B,H,N,T,V)
        u = u.unsqueeze(2).to(r.dtype) # (1,H,1,1,K)

        # parallel calculation of all intra-chunk attention contributions
        wc_log_offset = shifted_wc_log_cum[...,T//2:T//2+1,:] # B,H,N,1,K
        r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp() # B,H,N,T,K
        k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp() # B,H,N,T,K
        a = ((r*r_decay) @ (k*k_inv_decay).mT).to(r.dtype).tril(-1) # B,H,N,T,T
        # add u term to attention (NOTE - the tril(-1) above zeroed the diagonal)
        a = a + th.einsum('bhntk,bhntk->bhnt', r, u * k).diag_embed()
        out = a @ v # BHNTV
        # alternate way of adding in u
        # out = out + torch.einsum('bhntk,bhntk,bhntv->bhntv', r, u * k, v) 

        # parallel precalculation of chunked (k*wk).mT@v for use in recurrent state calc below
        wkv = (k * w_inter).mT @ v # BHNKV
        wkv = list(wkv.unbind(dim=-3)) # N x BHKV

        # recurrent calculation of all states
        states = []
        for i in range(N):
            states.append(kv_state)
            kv_state = kv_state * ws[i] + wkv[i] # BHKV
            # equivalent non-precalced version
            #wkv = (k[...,i,:,:] * wk[...,i,:,:]).mT @ v[...,i,:,:]
            #kv_state = kv_state * ws[i] + wkv
        states = th.stack(states, dim=2) # BHNKV       

        # parallel application of all r to states
        out = out + (r * w_intra) @ states # BHNTV
        out = out.view(B,H,L,V)
        return out, kv_state

#### Base Funcs

In [8]:
# -*- coding: utf-8 -*-

import torch

def naive_recurrent_rwkv6_original(q, k, v, w, u, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
        o[:, :, i] = o_i.sum(-2)
        h = h * w_i[..., None] + kv_i
    return o.to(orig_dtype)


def naive_recurrent_rwkv6_bwd_original(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h_i = (h + u[None, ..., None] * kv_i)
        dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
        dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
        dq[:, :, i] = dq_i
        dq_aux[:, :, i] = dq_aux_i
        h = h * w_i[..., None] + kv_i

    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_i = (dh * v_i[..., None, :]).sum(-1)
        dk_aux[:, :, i] = dk_i
        dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
        dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
        dv_i += (dh * k_i[..., None]).sum(-2)

        dk[:, :, i] = dk_i
        dv[:, :, i] = dv_i
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    dw = torch.zeros_like(w)
    for i in range(seq_len-2, -1, -1):
        dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]

    du = du.sum(0)
    return dq, dk, dv, dw, du

#### Demonstrate Songlin's Code is Wrong: 
Songlin's code is wrong! Hidden by a high `|V|`

* `V = 1` and errors era ~1e-2
* `V = 64` and errors era ~1e-9

`Autograd(recurrent_naive_fw)` and `rwkv_inner` match (0 error) but both disagree with `recurrent_naive_bw`

In [9]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 1

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6_original(rt, kt, vt, w_[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
w3_ = -th.exp(wt3)
o3, state3 = rwkv_inner(rt3, kt3, vt3, w3_[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
w2_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_original(rt2, kt2, vt2, w2_[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten(), (wt3.grad - dw).detach().flatten(), (w3_ - w2_).detach().flatten()

#### DOESN'T MATCH!

(tensor([-5.4060e-09, -3.1498e-05, -1.1221e-04, -8.7801e-05,  1.5598e-03,
         -9.9161e-03, -1.2344e-03, -8.6985e-03, -1.6997e-03, -8.5992e-04,
          1.4172e-02, -1.1259e-05, -1.2936e-01, -4.7139e-02, -3.9083e-02,
         -0.0000e+00]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([-5.4060e-09, -9.9698e-06, -3.1177e-05, -2.3574e-05,  3.9917e-04,
         -4.3252e-03, -5.3626e-04, -3.3183e-03, -1.0520e-03, -2.2075e-04,
          1.1902e-02, -9.0152e-07, -9.2706e-02, -1.5038e-02, -2.4199e-02,
         -0.0000e+00]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

In [10]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 64

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6_original(rt, kt, vt, w_[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
w3_ = -th.exp(wt3)
o3, state3 = rwkv_inner(rt3, kt3, vt3, w3_[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
w2_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_original(rt2, kt2, vt2, w2_[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten(), (wt3.grad - dw).detach().flatten(), (w3_ - w2_).detach().flatten()

#### DOESN'T MATCH!

(tensor([-1.4552e-09, -1.4552e-09, -1.4552e-09,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([-1.4552e-09, -1.4552e-09, -1.4552e-09,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

In [9]:
wt.grad.shape, dw.shape

(torch.Size([1, 1, 16, 1, 1]), torch.Size([1, 1, 16, 1, 1]))

#### Manual: Fixing backward implementation¶ - DISCARDED

In [18]:
L = 4
q_ic, k_jc, v_jd, w_tc_prelog, _ = gen_inputs(1, 1, L, 1, 1)
w_tc = -th.exp(w_tc_prelog)

q_ic, k_jc, v_jd, w_tc = map(lambda x: x[0, 0].detach.requires_grad(True), (q_ic, k_jc, v_jd, w_tc))

wcum_jc_buffer = []
for j in range(1, L+1): 
    wcum_jc_buffer.append( w_tc[:j].cumsum(dim=0).exp() )
wcum_jc = th.stack(wcum_jc_buffer)
        
inner_jcd = th.einsum('jc,jc,jd->jcd', wcum_jc, k_jc, v_jd)
wkv_icd = th.cumsum(inner_jcd, dim=0)
oid = th.einsum('ic,icd->id', q_ic, wkv_icd)

#### Fixing backward implementation - TODO

In [43]:
def naive_recurrent_rwkv6_bwd_hypnofix(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        do_i = do[:, :, i]
        kv_i = k_i[..., None] * v_i[..., None, :]
        dq[:, :, i] = ((h * + kv_i * u[None, ..., None]) * do_i[None, :]).sum(dim=-1)
        dq_aux[:, :, i] = (h * do_i[None, :]).sum(-1)
        h = h * w_i[..., None] + kv_i
        
    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_aux[:, :, i] = (dh * v_i[..., None, :]).sum(-1)
        
        d_kv_hiu = (dh + d_kv_i * u[None, ..., None])
        dk[:, :, i] = (d_kv_hiu * v_i[..., None, :]).sum(-1)
        dv[:, :, i] = (d_kv_hiu * k_i[..., :, None]).sum(-2)

        # dk_aux[:, :, i] = (dh * v_i[..., None, :] * k_i[..., None]).sum(-1)
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    # dw = torch.zeros_like(w)
    #for i in range(seq_len-2, -1, -1):
    #    dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
    dw = (dq_aux * q)[:, :, 1:, ...] - (dk_aux * k)[:, :, 0:-1]
    dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
    dw = dw.flip([-2,]).cumsum(dim=-2).flip([-2,])
    
    du = du.sum(0)
    return dq, dk, dv, dw, du

In [45]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 2

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6_original(rt, kt, vt, w_[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
w3_ = -th.exp(wt3)
o3, state3 = rwkv_inner(rt3, kt3, vt3, w3_[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
w2_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_hypnofix(rt2, kt2, vt2, w2_[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten(), (wt3.grad - dw).detach().flatten(), (w3_ - w2_).detach().flatten()

#### DOESN'T MATCH!

(tensor([-7.5242e-09, -7.5242e-09,  4.1360e-06,  2.2635e-06, -1.5978e-04,
         -4.1195e-05, -1.8432e-04, -1.2256e-04, -1.7147e-03, -4.5438e-04,
          5.7485e-03,  2.3668e-03,  3.1560e-04,  2.0305e-04, -8.2964e-05,
         -1.6727e-05, -1.6277e-04, -6.2998e-05, -2.8594e-04, -1.8627e-04,
          1.9118e-02,  1.0362e-02,  2.5464e-02,  1.5361e-02, -3.9735e-03,
         -1.9676e-03, -1.7708e-03, -5.4920e-04,  4.4832e-07,  3.1956e-08,
         -0.0000e+00, -0.0000e+00]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([-7.5242e-09, -7.5242e-09,  2.2635e-06,  2.2635e-06, -4.1195e-05,
         -4.1195e-05, -1.2256e-04, -1.2256e-04, -4.5438e-04, -4.5438e-04,
          2.3668e-03,  2.3668e-03,  2.0305e-04,  2.0305e-04, -1.6727e-05,
         -1.6727e-05, -6.2998e-05, -6.2998e-05, -1.8627e-04, -1.8627e-04,
          1.0362e-02,  1.0362e-02,  1.5361e-02,  1.5361e-02, -1.9676e-03,
 

#### My funcs

In [46]:
def naive_recurrent_rwkv6(q, k, v, w, u, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        # print(f"h.shape: {h.shape}, u.shape: {u[None, ..., None].shape} and kv.shape: {kv_i.shape}")
        h_i = (h + u[None, ..., None] * kv_i)
        o_i = th.einsum('bhc,bhcd->bhd', q_i, h_i)
        o[:, :, i] = o_i.sum(-2)
        h = h * w_i + kv_i
    return o.to(orig_dtype)

In [11]:
def naive_recurrent_rwkv6_bwd(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h_i = (h + u[None, ..., None] * kv_i)
        dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
        dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
        dq[:, :, i] = dq_i
        dq_aux[:, :, i] = dq_aux_i
        h = h * w_i + kv_i

    du = torch.zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_i = (dh * v_i[..., None, :]).sum(-1)
        dk_aux[:, :, i] = dk_i
        dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
        dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
        dv_i += (dh * k_i[..., None]).sum(-2)

        dk[:, :, i] = dk_i
        dv[:, :, i] = dv_i
        dh = dh * w[:, :, i, :, :].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    dw = torch.zeros_like(w)
    for i in range(seq_len-2, -1, -1):
        dw[:, :, i] = dw[:, :, i+1] + (dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i])[..., None]

    du = du.sum(0)
    return dq, dk, dv, dw, du

#### My tests

In [None]:
B, H, L, K, V = 1, 1, 8, 1, 1

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6(rt, kt, vt, w_, ut)
# get grad of nonleaf variable
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()

#### MANUAL BACKWARD PASS
do = grads[0]
rt2, kt2, vt2, wt2, ut2 = gen_inputs()
w_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd(rt2, kt2, vt2, w_, ut2, o, do)

In [None]:
wt.grad - dw

In [None]:
(rt.grad - dq).flatten(), (kt.grad - dk).flatten(), (vt.grad - dv).flatten(), (ut.grad - du).flatten()

## AOB