# RWKV 6.X to Triton Port

This notebook is designeed to help while porting RWKV6+ to triton. It first validates the algorithm in torch before porting to triton. Testbed for experiments

In [828]:
import torch
import torch as th

### Helpers

In [829]:
def gen_inputs(B, H, L, K, V): 
    th.manual_seed(17)
    device = "cpu"
    rt = th.randn(B, H, L, K, device=device, requires_grad=True)
    kt = th.randn(B, H, L, K, device=device, requires_grad=True)
    vt = th.randn(B, H, L, V, device=device, requires_grad=True)
    wt = -2 + 1e-1 * th.randn(B, H, L, K, V, device=device)
    wt = -th.exp(wt)
    wt.requires_grad = True
    ut = th.randn(H, K, device=device, requires_grad=True)
    return rt, kt, vt, wt, ut

## Diagnosis of 18/04/2024
* [ ] [Songlin's code](https://github.com/sustcsonglin/flash-linear-attention/commit/fee90b2e72366a46c60e3ef16431133aa5aced8d) is wrong for the backward pass of `W`. Although it is hidden for normally-distributed values and high number of `V`. 
* [ ] TODO: extend U and W to support matrices. Experiments demonstrate it works wonders. Especially later in training.
    * [This paper: The illusion of State in State Space Models](https://arxiv.org/pdf/2404.08819.pdf) demonstrates that a `matrix_cumprod` would give superior modelling ability. The matrix has to be non-diagonal as otherwise it would be a `cum_matmul` which would still be in `TC0` and not `NC1`

#### rwkv_inner (from gptcore)

In [130]:
import torch as th
import torch.nn.functional as F

def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24,precision_dtype:th.dtype=th.float32):
    """
    expects
    r : (B,H,L,K)
    k : (B,H,L,K)
    v : (B,H,L,V)
    w : (B,H,L,K) or (1,H,L,K)
    u : (1,H,1,K)
    kv_state : (B,H,K,V)
    """
    B,H,L,K = k.size()
    V = v.size(-1)
    T = chunk_len

    if L == 1:
        kv = k.mT @ v
        out = r @ (kv_state + u.mT * kv)
        kv_state = w.mT * kv_state + kv
        return out, kv_state
    else:
        # FIXME - support fast path for non-exact multiples
        # ensure it's an exact multiple
        if L % T != 0:
            T = 1

        N = L // T

        # this has to be done to avoid numerical instability (inf/NaN) when w is used as a divisor up to chunk_length//2 places away (so precision_min_val^(T//2) has to be in fp range)
        # NOTE - this does not account for the impact of the size of R, K so we currently use the chunk_len=32 numbers for chunk_len=24
        assert(precision_dtype == th.float32 or precision_dtype == th.float64)
        if precision_dtype == th.float32:
            precision_min_val = 0.005 # good for fp32 (1.175e-38 ^ (1/16.0) < 0.00426)
        else: #elif precision_dtype == torch.float64:
            precision_min_val = 1e-10 # good for fp64 (1.7e-308 ^ (1/16.0) < 5.8e-20)
        w = w.clamp(precision_min_val)

        # calculate cumulative decay in log space where it won't overflow
        w_log = w.float().log() # (1,H,L,K) or (B,H,L,K)

        # chunked view of w_log
        wc_log = w_log.view(w.size(0),H,N,T,K)
        wc_log_cum = wc_log.cumsum(dim=-2)

        # chunked view of shifted_w_log
        shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1))


        # NOTE - we have to apply the decay weight from TWO ahead.. ONE ahead gets no decay (log==0)
        # pre-applied weights
        # left side is prior chunk (w_inter), right side is current chunk (w_intra)
        # without u...
        # w0   w1   w2   w3   | w4   w5   w6   w7          
        # w1:4 w2:4 w3:4 w4:4 | w4:5 w4:6 w4:7 w4:8
        # with u...
        # w0   w1   w2   w3   | w4   w5   w6   w7          
        # w1:4 w2:4 w3:4 w4:4 | w4:4 w4:5 w4:6 w4:7

        # ws decays the entire current state (representing t-1) to the prior block (t-2)
        ws = wc_log.sum(dim=-2, keepdim=True) # 1HN1K or BHN1K
        # w_inter is the decay to the end of the current block, since it will be applied at the next iteration when current (t) becomes prior (t-1)
        # this formula because e.g. w1:4 = w0:4 - w0:1
        w_inter = ws - wc_log_cum # 1HNTK or BHNTK (w^(T-1) ... w^0)
        # w_intra is the decay from the beginning of the current block (t), since it will be applied to current queries (t) against prior state (representing keys+values up to but not including block t)
        # this formula because e.g. w1:3 = w0:3 - w0
        w_intra = wc_log_cum - wc_log # 1HNTK or BHNTK (w^0 ... w^(T-2))

        ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3)) # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!!
        w_inter = w_inter.exp().to(r.dtype) # 1HNTK or BHNTK
        w_intra = w_intra.exp().to(r.dtype) # 1HNTK or BHNTK

        # chunked view of r, k, v
        r = r.view(B,H,N,T,K) 
        k = k.view(B,H,N,T,K) 
        v = v.view(B,H,N,T,V)
        u = u.unsqueeze(2).to(r.dtype) # (1,H,1,1,K)

        # parallel calculation of all intra-chunk attention contributions
        wc_log_offset = shifted_wc_log_cum[...,T//2:T//2+1,:] # B,H,N,1,K
        r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp() # B,H,N,T,K
        k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp() # B,H,N,T,K
        a = ((r*r_decay) @ (k*k_inv_decay).mT).to(r.dtype).tril(-1) # B,H,N,T,T
        # add u term to attention (NOTE - the tril(-1) above zeroed the diagonal)
        a = a + th.einsum('bhntk,bhntk->bhnt', r, u * k).diag_embed()
        out = a @ v # BHNTV
        # alternate way of adding in u
        # out = out + torch.einsum('bhntk,bhntk,bhntv->bhntv', r, u * k, v) 

        # parallel precalculation of chunked (k*wk).mT@v for use in recurrent state calc below
        wkv = (k * w_inter).mT @ v # BHNKV
        wkv = list(wkv.unbind(dim=-3)) # N x BHKV

        # recurrent calculation of all states
        states = []
        for i in range(N):
            states.append(kv_state)
            kv_state = kv_state * ws[i] + wkv[i] # BHKV
            # equivalent non-precalced version
            #wkv = (k[...,i,:,:] * wk[...,i,:,:]).mT @ v[...,i,:,:]
            #kv_state = kv_state * ws[i] + wkv
        states = th.stack(states, dim=2) # BHNKV       

        # parallel application of all r to states
        out = out + (r * w_intra) @ states # BHNTV
        out = out.view(B,H,L,V)
        return out, kv_state

#### Base Funcs

In [830]:
# -*- coding: utf-8 -*-

import torch

def naive_recurrent_rwkv6_original(q, k, v, w, u, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
        o[:, :, i] = o_i.sum(-2)
        h = h * w_i[..., None] + kv_i
    return o.to(orig_dtype)


def naive_recurrent_rwkv6_bwd_original(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h_i = (h + u[None, ..., None] * kv_i)
        dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
        dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
        dq[:, :, i] = dq_i
        dq_aux[:, :, i] = dq_aux_i
        h = h * w_i[..., None] + kv_i

    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_i = (dh * v_i[..., None, :]).sum(-1)
        dk_aux[:, :, i] = dk_i
        dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
        dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
        dv_i += (dh * k_i[..., None]).sum(-2)

        dk[:, :, i] = dk_i
        dv[:, :, i] = dv_i
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    dw = torch.zeros_like(w)
    for i in range(seq_len-2, 0, -1):
        dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]

    du = du.sum(0)
    return dq, dk, dv, dw, du

#### Demonstrate Songlin's Code is Wrong: 
Songlin's code is wrong! Hidden by a high `|V|`

* `V = 1` and errors era ~1e-2
* `V = 64` and errors era ~1e-9

`Autograd(recurrent_naive_fw)` and `rwkv_inner` match (0 error) but both disagree with `recurrent_naive_bw`

In [833]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 1

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
o = naive_recurrent_rwkv6_original(rt, kt, vt, wt[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
o3, state3 = rwkv_inner(rt3, kt3, vt3, wt3[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_original(rt2, kt2, vt2, wt2[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]

# (wt.grad - dw).flatten(), (wt.grad - wt3.grad).flatten()
(dw - wt3.grad).flatten() # (rt.grad - dq).flatten() 

#### DOESN'T MATCH!

tensor([ 0.0000e+00,  1.0814e-04, -9.6143e-04,  4.2107e-03, -2.9855e-02,
        -5.6489e-02, -5.2342e-02,  6.2052e-02,  5.8009e-02,  5.6197e-02,
        -3.9192e-02,  2.2172e-01,  4.5362e-01,  2.9285e-01,  1.9845e-01,
         0.0000e+00], grad_fn=<ViewBackward0>)

In [141]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 32, 2, 2

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6_original(rt, kt, vt, w_[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
w3_ = -th.exp(wt3)
o3, state3 = rwkv_inner(rt3, kt3, vt3, w3_[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
w2_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_original(rt2, kt2, vt2, w2_[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten(), (wt3.grad - dw).detach().flatten(), (w3_ - w2_).detach().flatten()

#### DOESN'T MATCH!

(tensor([ 2.9104e-09,  2.9104e-09, -6.7521e-09, -6.7521e-09, -4.7104e-07,
         -1.5335e-07, -2.3590e-04, -8.0526e-05,  5.6382e-04,  2.2494e-04,
          6.9353e-04,  2.4993e-04, -1.8924e-03, -6.7232e-04,  5.9796e-04,
          2.2419e-04, -2.4683e-03, -7.2213e-04,  1.1753e-02,  3.9835e-03,
         -8.7115e-04, -2.8660e-04,  3.4635e-03,  1.1052e-03,  6.2721e-04,
          2.2001e-04,  2.4188e-03,  8.6338e-04, -4.9119e-04, -1.4935e-04,
          7.1844e-04,  2.2646e-04, -1.1038e-02, -3.6537e-03,  1.0063e-03,
          3.1321e-04, -1.1371e-03, -3.4617e-04, -4.1771e-04, -1.4600e-04,
         -1.7502e-03, -5.3780e-04, -2.2888e-04, -7.8782e-05, -8.0950e-05,
         -2.5472e-05, -1.9536e-03, -6.3890e-04,  7.7323e-03,  2.8001e-03,
         -4.6986e-02, -1.6366e-02, -4.9570e-04, -1.6514e-04,  3.0733e-03,
          9.5976e-04, -7.9752e-04, -2.5922e-04, -3.2619e-04, -1.0276e-04,
          1.8081e-03,  5.8611e-04,  3.0957e-04,  8.5090e-05,  1.8494e-03,
          6.8724e-04,  1.7429e-03,  5.

In [134]:
wt.grad.shape, dw.shape

(torch.Size([1, 1, 16, 1, 64]), torch.Size([1, 1, 16, 1, 1]))

#### Manual: Fixing backward implementation¶ - DISCARDED

In [18]:
L = 4
q_ic, k_jc, v_jd, w_tc_prelog, _ = gen_inputs(1, 1, L, 1, 1)
w_tc = -th.exp(w_tc_prelog)

q_ic, k_jc, v_jd, w_tc = map(lambda x: x[0, 0].detach.requires_grad(True), (q_ic, k_jc, v_jd, w_tc))

wcum_jc_buffer = []
for j in range(1, L+1): 
    wcum_jc_buffer.append( w_tc[:j].cumsum(dim=0).exp() )
wcum_jc = th.stack(wcum_jc_buffer)
        
inner_jcd = th.einsum('jc,jc,jd->jcd', wcum_jc, k_jc, v_jd)
wkv_icd = th.cumsum(inner_jcd, dim=0)
oid = th.einsum('ic,icd->id', q_ic, wkv_icd)

#### Fixing backward implementation - algorithmically

In [142]:
def naive_recurrent_rwkv6_bwd_hypnofix_correct(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        do_i = do[:, :, i]
        kv_i = k_i[..., None] * v_i[..., None, :]
        dq[:, :, i] = ((h * + kv_i * u[None, ..., None]) * do_i[None, :]).sum(dim=-1)
        dq_aux[:, :, i] = (h * do_i[None, :]).sum(-1)
        h = h * w_i[..., None] + kv_i
        
    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dw = torch.zeros_like(w)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        kvi = k_i[..., None] * v_i[..., None, :]
        du_i = (d_kv_i * kvi).sum(-1)
        du += du_i
        
        d_kv_hiu = (dh + d_kv_i * u[None, ..., None])
        dk[:, :, i] = (d_kv_hiu * v_i[..., None, :]).sum(-1)
        dv[:, :, i] = (d_kv_hiu * k_i[..., :, None]).sum(-2)
        
        # backward on pure W
        # back_w += w[:, :, i]
        # dw[:, :, i] = dw[:, :, i+1] + (q[:, :, i] * back_w * kvi).sum(-1)

        # since U, W vector, collapse si into C
        if seq_len - i - 1: 
            wcum = th.zeros_like(w[:,:, i+1:])
            for t in range(seq_len - i - 1): 
                wcum[:, :, t] = w[:, :, i:-1].sum(dim=2)
            # c,c,d -> cd -> c
            si = (do[:, :, [i], None, :] * q[:, :, i+1:, ..., None] * k[:, :, i+1:, ..., None] * v[:, :, i+1:, ..., None, :]).sum(-1)
            dw[:, :, i] = (wcum.exp() * si).sum(2)
            
        # dk_aux[:, :, i] = (dh * v_i[..., None, :] * k_i[..., None]).sum(-1)
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i
    
    du = du.sum(0)
    return dq, dk, dv, dw, du

#### Fixing backward implementation

In [143]:
def naive_recurrent_rwkv6_bwd_hypnofix(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        do_i = do[:, :, i]
        kv_i = k_i[..., None] * v_i[..., None, :]
        dq[:, :, i] = ((h * + kv_i * u[None, ..., None]) * do_i[None, :]).sum(dim=-1)
        dq_aux[:, :, i] = (h * do_i[None, :]).sum(-1)
        h = h * w_i[..., None] + kv_i
        
    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dw = torch.zeros_like(w)
    dv = torch.zeros_like(v)

    dwh = u.new_ones(batch_size, n_heads, d_head_k)
    
    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        kvi = k_i[..., None] * v_i[..., None, :]
        du_i = (d_kv_i * kvi).sum(-1)
        du += du_i
        
        d_kv_hiu = (dh + d_kv_i * u[None, ..., None])
        dk[:, :, i] = (d_kv_hiu * v_i[..., None, :]).sum(-1)
        dv[:, :, i] = (d_kv_hiu * k_i[..., :, None]).sum(-2)
        
        # backward on pure W
        # back_w += w[:, :, i]
        # dw[:, :, i] = dw[:, :, i+1] + (q[:, :, i] * back_w * kvi).sum(-1)
        
        if i < seq_len-1: 
            wwh *= wiexp
            # dwh = w[:, :, i:-1].sum(dim=2).exp()
            dw[:, :, i] = dwh * si
        
        # si = (q[:, :, i, ..., None] * kvi).sum(-1)  
        si = (do[:, :, i, None, :] * q[:, :, i, ..., None] * kvi).sum(-1)     
        
        # dk_aux[:, :, i] = (dh * v_i[..., None, :] * k_i[..., None]).sum(-1)
        wiexp = w[:, :, i, :].exp()
        dh = dh * wiexp[..., None] + d_kv_i
    
    du = du.sum(0)
    return dq, dk, dv, dw, du

In [144]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 2, 2

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6_original(rt, kt, vt, w_[..., 0], ut[..., 0])
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### Autograd rwkv_inner
rt3, kt3, vt3, wt3, ut3 = gen_inputs(B, H, L, K, V)
w3_ = -th.exp(wt3)
o3, state3 = rwkv_inner(rt3, kt3, vt3, w3_[..., 0], ut3.view(1, H, 1, K), th.zeros(B, H, K, V))
o3.mean().backward()

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
w2_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd_hypnofix_correct(rt2, kt2, vt2, w2_[..., 0], ut2[..., 0], o, do)
dw = dw[..., None]
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten(), (wt3.grad - dw).detach().flatten(), (w3_ - w2_).detach().flatten()

#### DOESN'T MATCH!

(tensor([ 8.8016e-16,  8.8016e-16, -1.1330e-14, -1.1330e-14,  1.1148e-05,
          9.4622e-15,  2.5693e-04, -1.0575e-13, -7.0574e-05,  5.8352e-14,
          5.6913e-04, -7.8122e-13,  2.0061e-03,  6.5811e-17, -1.4693e-03,
         -5.2914e-12, -1.5169e-03,  1.5346e-12, -9.8088e-04, -4.6428e-11,
         -3.2685e-03, -1.9871e-11, -1.5882e-02, -2.9842e-10,  1.3238e-02,
         -2.0407e-10, -1.4260e-02, -1.0307e-09, -2.0520e-03, -2.7139e-09,
         -2.1849e-03,  9.5263e-10, -1.0940e-02, -1.8337e-08, -7.8895e-03,
          1.1356e-08,  1.7036e-05,  1.0656e-07, -1.3284e-03, -3.9607e-07,
          1.5412e-02,  6.4054e-07,  6.4981e-04, -2.3626e-06, -7.5737e-04,
          1.1148e-06, -1.8389e-03, -1.8949e-05, -3.8423e-03,  1.1038e-05,
          1.0049e-02, -9.4611e-06, -1.6737e-03,  4.2742e-04, -9.3307e-03,
          9.3694e-04,  1.4482e-03,  6.9880e-04,  3.3894e-03,  2.1426e-03,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

#### My funcs

In [46]:
def naive_recurrent_rwkv6(q, k, v, w, u, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        # print(f"h.shape: {h.shape}, u.shape: {u[None, ..., None].shape} and kv.shape: {kv_i.shape}")
        h_i = (h + u[None, ..., None] * kv_i)
        o_i = th.einsum('bhc,bhcd->bhd', q_i, h_i)
        o[:, :, i] = o_i.sum(-2)
        h = h * w_i + kv_i
    return o.to(orig_dtype)

In [11]:
def naive_recurrent_rwkv6_bwd(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h_i = (h + u[None, ..., None] * kv_i)
        dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
        dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
        dq[:, :, i] = dq_i
        dq_aux[:, :, i] = dq_aux_i
        h = h * w_i + kv_i

    du = torch.zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_i = (dh * v_i[..., None, :]).sum(-1)
        dk_aux[:, :, i] = dk_i
        dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
        dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
        dv_i += (dh * k_i[..., None]).sum(-2)

        dk[:, :, i] = dk_i
        dv[:, :, i] = dv_i
        dh = dh * w[:, :, i, :, :].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    dw = torch.zeros_like(w)
    for i in range(seq_len-2, -1, -1):
        dw[:, :, i] = dw[:, :, i+1] + (dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i])[..., None]

    du = du.sum(0)
    return dq, dk, dv, dw, du

#### My tests

In [108]:
B, H, L, K, V = 1, 1, 8, 1, 1

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
w_ = -th.exp(wt)
o = naive_recurrent_rwkv6(rt, kt, vt, w_, ut)
# get grad of nonleaf variable
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()

#### MANUAL BACKWARD PASS
do = grads[0]
rt2, kt2, vt2, wt2, ut2 = gen_inputs()
w_ = -th.exp(wt2)
dq, dk, dv, dw, du = naive_recurrent_rwkv6_bwd(rt2, kt2, vt2, w_, ut2, o, do)

TypeError: gen_inputs() missing 5 required positional arguments: 'B', 'H', 'L', 'K', and 'V'

In [None]:
wt.grad - dw

In [None]:
(rt.grad - dq).flatten(), (kt.grad - dk).flatten(), (vt.grad - dv).flatten(), (ut.grad - du).flatten()

### Retry Inputs without UWMat

In [839]:
def gen_inputs(B, H, L, K, V): 
    th.manual_seed(17)
    device = "cpu"
    rt = th.randn(B, H, L, K, device=device, requires_grad=True)
    kt = th.randn(B, H, L, K, device=device, requires_grad=True)
    vt = th.randn(B, H, L, V, device=device, requires_grad=True)
    wt = -3 + 1e-1 * th.randn(B, H, L, K, device=device)
    wt = -th.exp(wt)
    wt.requires_grad = True
    ut = th.randn(H, K, device=device, requires_grad=True)
    return rt, kt, vt, wt, ut

### 19/04/2024 Simplify without U

In [330]:
import torch

def naive_recurrent(q, k, v, w, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w = map(lambda x: x.float(), (q, k, v, w))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h = h + w_i[..., None] * kv_i
        o_i = h * q_i[..., None]
        o[:, :, i] = o_i.sum(-2)
        
    return o.to(orig_dtype)


def get_dw(q, k, v, w, do): 
    """ Just test gradient for w """
    B, H, L, K = q.shape
    V = v.shape[-1]
    
    dw = torch.zeros_like(w)
    si = torch.zeros(B, H, K, V)
    wsum = w.sum(-2)[..., None] # sum along seqlen
    for i in range(L-1, -1, -1): 
        do_i = do[:, :, i, ..., None, :]
        q_i = q[:, :, i, :, None]
        k_i = k[:, :, i, :, None] 
        v_i = v[:, :, i, ..., None, :] 
        w_i =  w[:, :, i, ..., None]

        si += q_i * k_i * v_i * wsum.exp()
        wsum -= w_i
        dw[:, :, i] = (do_i * si).sum(dim=-2)
    return dw

def gen_inputs(B, H, L, K, V): 
    th.manual_seed(17)
    device = "cpu"
    rt = th.randn(B, H, L, K, device=device, requires_grad=True)
    kt = th.randn(B, H, L, K, device=device, requires_grad=True)
    vt = th.randn(B, H, L, V, device=device, requires_grad=True)
    wt = -3 + 1e-1 * th.randn(B, H, L, K, device=device)
    wt = -th.exp(wt)
    wt.requires_grad = True
    ut = th.randn(H, K, device=device, requires_grad=True)
    return rt, kt, vt, wt, ut

In [263]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 1, 1, 1

#### AUTOGRAD FORWARD PASS
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)
o = naive_recurrent(rt, kt, vt, wt)
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

#### MANUAL BACKWARD PASS
rt2, kt2, vt2, wt2, ut2 = gen_inputs(B, H, L, K, V)
dw = get_dw(rt2, kt2, vt2, wt2, do)[None]
print(dw.shape)
(wt.grad - dw).detach().flatten(), (w_ - w2_).detach().flatten()

#### DOESN'T MATCH!

torch.Size([1, 1, 1, 1, 1])


(tensor([-9.3132e-10]), tensor([0.5013]))

In [261]:
wt.shape

torch.Size([1, 1, 2, 1])

### 19/04/2024 Very simple test (just L=1) - Correct

In [331]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 1, 1, 1

#### AUTOGRAD FORWARD PASS
q, k, v, w, _ = gen_inputs(B, H, L, K, V)

o = th.einsum('bhic,bhic,bhic,bhid->bhid', q, w.exp(), k, v)
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]
w.grad, do

(tensor([[[[-0.0107]]]]), tensor([[[[1.]]]]))

In [332]:
w.grad - o

tensor([[[[-9.3132e-10]]]], grad_fn=<SubBackward0>)

### 19/04/2024 Very simple test (just L=2) - ???

In [333]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 2, 1, 1

#### AUTOGRAD FORWARD PASS
q, k, v, w, _ = gen_inputs(B, H, L, K, V)

s1 = th.einsum('bhc,bhc,bhd->bhcd', w[:, :, 0].exp(), k[:, :, 0], v[:, :, 0])
o1 = th.einsum('bhc,bhcd->bhd', q[:, :, 0], s1)
s2 = s1 + th.einsum('bhc,bhc,bhd->bhd', w[:, :, :2].sum(dim=2).exp(), k[:, :, 1], v[:, :, 1])
o2 = th.einsum('bhc,bhcd->bhd', q[:, :, 1], s2)
o = th.stack([o1, o2], dim=2)

grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]
w.grad, do

(tensor([[[[-0.0227],
           [-0.0229]]]]),
 tensor([[[[0.5000],
           [0.5000]]]]))

In [346]:
dw2 = th.einsum('bhd,bhcd->bhc', do[:, :, 1], th.einsum('bhc,bhc,bhc,bhd->bhcd', q[:, :, 1], w[:, :, :2].sum(dim=2).exp(), k[:, :, 1], v[:, :, 1]))
dw1 = do[:, :, 0] * o[:, :, 0] + do[:, :, 1] * o[:, :, 1]
dw = th.stack((dw1, dw2), dim=2)

In [347]:
w.grad - dw

tensor([[[[-1.8626e-09],
          [ 0.0000e+00]]]], grad_fn=<SubBackward0>)

### 19/04/2024 Very simple test (general L)

In [433]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 4

#### AUTOGRAD FORWARD PASS
q, k, v, w, _ = gen_inputs(B, H, L, K, V)

o = []
si = th.zeros(B, H, K, V)
for i in range(L): 
    si = si + th.einsum('bhc,bhc,bhd->bhcd', w[:, :, :i+1].sum(dim=2).exp(), k[:, :, i], v[:, :, i])
    oi = th.einsum('bhc,bhcd->bhd', q[:, :, i], si)
    o.append(oi)
o = th.stack(o, dim=2)

grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]
w.grad.flatten(), do.flatten()

(tensor([ 0.0323,  0.0327,  0.0209,  0.0894,  0.0785,  0.0218,  0.0019, -0.0556,
         -0.0575, -0.0579, -0.0265,  0.0239,  0.0093,  0.0209, -0.0112, -0.0085]),
 tensor([0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156, 0.0156,
         0.0156]))

In [434]:
dw = th.zeros_like(w)
doq = th.einsum('bhid,bhic->bhicd', do, q)
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v)
for i in range(L): 
    doq_wkv = (doq[:, :, i:] * wkv[:, :, [i]]).sum(dim=-1).sum(2)
    dw[:, :, :i+1] += doq_wkv

In [435]:
w.grad - dw

tensor([[[[-2.2352e-08],
          [-2.2352e-08],
          [-2.0489e-08],
          [-2.2352e-08],
          [-2.9802e-08],
          [-2.0489e-08],
          [-8.3819e-09],
          [-3.7253e-09],
          [-3.7253e-09],
          [ 0.0000e+00],
          [ 1.8626e-09],
          [-1.8626e-09],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 9.3132e-10],
          [ 9.3132e-10]]]], grad_fn=<SubBackward0>)

### 19/04/2024 Toy Algo for Backward - purely on W

In [438]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 4

#### AUTOGRAD FORWARD PASS
q, k, v, w, _ = gen_inputs(B, H, L, K, V)

o = []
si = th.zeros(B, H, K, V)
for i in range(L): 
    si = si + th.einsum('bhc,bhc,bhd->bhcd', w[:, :, :i+1].sum(dim=2).exp(), k[:, :, i], v[:, :, i])
    oi = th.einsum('bhc,bhcd->bhd', q[:, :, i], si)
    o.append(oi)
o = th.stack(o, dim=2)
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

In [440]:
# 1 CUMSUM
dw = th.zeros_like(w)
doq = th.einsum('bhid,bhic->bhicd', do, q)
doq = doq.flip((2,)).cumsum(dim=2).flip((2,))
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v) # h (guardar durante fw)
for i in range(L): 
    doq_wkv = (doq[:, :, [i]] * wkv[:, :, [i]]).sum(dim=-1).sum(2)
    dw[:, :, :i+1] += doq_wkv
w.grad - dw

tensor([[[[ 7.4506e-09],
          [ 7.4506e-09],
          [ 1.8626e-09],
          [-2.2352e-08],
          [ 0.0000e+00],
          [-5.5879e-09],
          [ 9.3132e-10],
          [ 3.7253e-09],
          [ 3.7253e-09],
          [ 0.0000e+00],
          [ 1.8626e-09],
          [-5.5879e-09],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 9.3132e-10],
          [ 9.3132e-10]]]], grad_fn=<SubBackward0>)

In [388]:
# 2 cumsums, no loop
dw = th.zeros_like(w)
doq = th.einsum('bhid,bhic->bhicd', do, q)
doq = doq.flip((2,)).cumsum(dim=2).flip((2,))
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v) # h (guardar durante fw)
dw = (doq * wkv).sum(dim=-1)
dw = dw.flip((2,)).cumsum(dim=2).flip((2,))
dw

tensor([[[[ 0.0323],
          [ 0.0327],
          [ 0.0209],
          [ 0.0894],
          [ 0.0785],
          [ 0.0218],
          [ 0.0019],
          [-0.0556],
          [-0.0575],
          [-0.0579],
          [-0.0265],
          [ 0.0239],
          [ 0.0093],
          [ 0.0209],
          [-0.0112],
          [-0.0085]]]], grad_fn=<FlipBackward0>)

### 20/04/2024 Toy Algo for Backward - W and U

In [445]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 16, 1, 4

#### AUTOGRAD FORWARD PASS
q, k, v, w, _ = gen_inputs(B, H, L, K, V)

o = []
si = th.zeros(B, H, K, V)
for i in range(L): 
    osi = si + th.einsum('hc,bhc,bhd->bhcd', u, k[:, :, i], v[:, :, i])
    oi = th.einsum('bhc,bhcd->bhd', q[:, :, i], osi)
    si = si + th.einsum('bhc,bhc,bhd->bhcd', w[:, :, :i+1].sum(dim=2).exp(), k[:, :, i], v[:, :, i])
    o.append(oi)
o = th.stack(o, dim=2)
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]

In [446]:
# NO CUMSUMS
dw = th.zeros_like(w)
doq = th.einsum('bhid,bhic->bhicd', do, q)
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v)
# doq1,wkv0
for i in range(L-1): 
    doq_wkv = (doq[:, :, i+1:] * wkv[:, :, [i]]).sum(dim=-1).sum(2)
    dw[:, :, :i+1] += doq_wkv

In [447]:
(w.grad - dw).flatten()

tensor([-1.1176e-08, -2.2352e-08, -1.8626e-08, -7.4506e-09, -7.4506e-09,
        -3.7253e-09, -9.3132e-10, -3.7253e-09,  0.0000e+00,  3.7253e-09,
        -3.7253e-09, -1.8626e-09,  0.0000e+00,  0.0000e+00,  2.3283e-10,
         0.0000e+00], grad_fn=<ViewBackward0>)

In [448]:
# 1 reverse CUMSUM
dw = th.zeros_like(w)
doq = th.einsum('bhid,bhic->bhicd', do, q)
doq = doq.flip((2,)).cumsum(dim=2).flip((2,))
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v) # h (guardar durante fw)
for i in range(L-1): 
    doq_wkv = (doq[:, :, [i+1]] * wkv[:, :, [i]]).sum(dim=-1).sum(2)
    dw[:, :, :i+1] += doq_wkv
(w.grad - dw).flatten()

tensor([-3.3528e-08, -4.4703e-08, -4.0978e-08, -3.7253e-08, -2.2352e-08,
        -1.1176e-08, -9.3132e-09,  3.7253e-09,  7.4506e-09,  3.7253e-09,
         3.7253e-09, -1.8626e-09, -9.3132e-10,  0.0000e+00,  2.3283e-10,
         0.0000e+00], grad_fn=<ViewBackward0>)

In [451]:
doq = th.einsum('bhid,bhic->bhicd', do, q)
doq = doq.flip((2,)).cumsum(dim=2).flip((2,))
doq = F.pad(doq, (0,0,0,0,-1,1))
wkv = th.einsum('bhic,bhic,bhid->bhicd', w.cumsum(dim=2).exp(), k, v) # h (guardar durante fw)
dw = (doq * wkv).sum(dim=-1)
dw = dw.flip((2,)).cumsum(dim=2).flip((2,))
(w.grad - dw).flatten()

tensor([-3.7253e-08, -4.0978e-08, -4.0978e-08, -2.9802e-08, -2.2352e-08,
        -1.3039e-08, -1.0245e-08,  3.7253e-09,  3.7253e-09,  3.7253e-09,
         1.8626e-09, -1.8626e-09, -9.3132e-10,  0.0000e+00,  2.3283e-10,
         0.0000e+00], grad_fn=<ViewBackward0>)

In [404]:
w.grad - dw

### 20/04/2024 Incorporate algorithm into functions: 
* [x] Our toy fw pass was wrong: `W` in log space and direct application to the state, etc
* [x] Delta with main `FW` function solved
* [ ] Need to adjust our backward 

#### Base forward func

In [499]:
import torch as th
import torch.nn.functional as F


def simple_inner(q, k, v, w, u): 
    o = []
    si = th.zeros(B, H, K, V)
    for i in range(L): 
        kv = th.einsum('bhc,bhd->bhcd', k[:, :, i], v[:, :, i])
        osi = si + th.einsum('hc,bhcd->bhcd', u, kv)
        oi = th.einsum('bhc,bhcd->bhd', q[:, :, i], osi)
        si = si * w[:, :, i, :, None].exp() + kv
        o.append(oi)
    o = th.stack(o, dim=2)
    return o

def songlin_inner(q, k, v, w, u, initial_state=None, output_final_state=False):
    orig_dtype = q.dtype
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    o = torch.zeros_like(v)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :]
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
        o[:, :, i] = o_i.sum(-2)
        h = h * w_i[..., None] + kv_i
    return o.to(orig_dtype)

#### Tests

In [851]:
# implementation is wrong. hidden by high V (ex. 64)
B, H, L, K, V = 1, 1, 32, 1, 1

#### AUTOGRAD FORWARD PASS
q, k, v, w, u = gen_inputs(B, H, L, K, V)
o = songlin_inner(q, k, v, w, u, th.zeros(B, H, K, V))
grads = []
o.register_hook(lambda d:grads.append(d))
o.mean().backward()
do = grads[0]
w.grad.flatten()

tensor([ 0.0000e+00, -1.5418e-04, -5.7530e-03, -5.2927e-03, -4.1082e-02,
        -1.6820e-01, -2.4437e-01, -3.0051e-01, -2.6155e-01, -1.6486e-01,
        -1.3501e-01, -2.1993e-01, -1.3706e-01, -1.4047e-01, -6.5149e-02,
        -3.5055e-02, -3.1937e-02, -2.4309e-02, -6.8515e-02, -6.6513e-02,
        -9.0482e-02, -7.5956e-02,  2.1371e-02,  3.1582e-02,  4.1205e-02,
        -2.3560e-02, -1.6657e-01,  4.0461e-02,  1.0806e-01,  2.4641e-02,
         5.8453e-02,  0.0000e+00])

In [852]:
q2, k2, v2, w2, u2 = gen_inputs(B, H, L, K, V)
omine = simple_inner(q2, k2, v2, w2, u2)

In [853]:
grads2 = []
omine.register_hook(lambda d:grads2.append(d))
omine.mean().backward()
do = grads2[0]
w2.grad.flatten()

tensor([ 0.0000e+00, -1.5418e-04, -5.7530e-03, -5.2927e-03, -4.1082e-02,
        -1.6820e-01, -2.4437e-01, -3.0051e-01, -2.6155e-01, -1.6486e-01,
        -1.3501e-01, -2.1993e-01, -1.3706e-01, -1.4047e-01, -6.5149e-02,
        -3.5055e-02, -3.1937e-02, -2.4309e-02, -6.8515e-02, -6.6513e-02,
        -9.0482e-02, -7.5956e-02,  2.1371e-02,  3.1582e-02,  4.1205e-02,
        -2.3560e-02, -1.6657e-01,  4.0461e-02,  1.0806e-01,  2.4641e-02,
         5.8453e-02,  0.0000e+00])

#### DW By hand (if L=3)

In [673]:
dw = th.zeros_like(w)
dw[:, :, 1] += do[:, :, 2] * q[:, :, 2] * k[:, :, 0] * v[:, :, 0] * w[:, :, 1].exp()
dw

tensor([[[[ 0.0000],
          [-0.0045],
          [ 0.0000]]]], grad_fn=<CopySlices>)

#### DW By hand (if L=4)

In [696]:
dw = th.zeros_like(w)
dw[:, :, 1] += do[:, :, 2] * q[:, :, 2] * k[:, :, 0] * v[:, :, 0] * w[:, :, 1].exp()
dw[:, :, 1] += do[:, :, 3] * q[:, :, 3] * k[:, :, 0] * v[:, :, 0] * w[:, :, 1].exp() * w[:, :, 2].exp()
dw[:, :, 2] += do[:, :, 3] * q[:, :, 3] * k[:, :, 0] * v[:, :, 0] * w[:, :, 1].exp() * w[:, :, 2].exp()
dw[:, :, 2] += do[:, :, 3] * q[:, :, 3] * k[:, :, 1] * v[:, :, 1] * w[:, :, 2].exp()
dw.flatten()

tensor([ 0.0000, -0.0023, -0.0026,  0.0000], grad_fn=<ViewBackward0>)

#### DW By hand (if L=5)

In [690]:
dw = th.zeros_like(w)
wexp = w.exp()
# w2
dw[:, :, [1]]       += (do[:, :, 2] * q[:, :, 2] * k[:, :, 0] * v[:, :, 0] * wexp[:, :, 1])[:, :, None]
dw[:, :, [1, 2]]    += (do[:, :, 3] * q[:, :, 3] * k[:, :, 0] * v[:, :, 0] * wexp[:, :, 1] * wexp[:, :, 2])[:, :, None]
dw[:, :, [1, 2, 3]] += (do[:, :, 4] * q[:, :, 4] * k[:, :, 0] * v[:, :, 0] * wexp[:, :, 1] * wexp[:, :, 2] * wexp[:, :, 3])[:, :, None]
# w3
dw[:, :, [2]]    += (do[:, :, 3] * q[:, :, 3] * k[:, :, 1] * v[:, :, 1] * wexp[:, :, 2])[:, :, None]
dw[:, :, [2, 3]] += (do[:, :, 4] * q[:, :, 4] * k[:, :, 1] * v[:, :, 1] * wexp[:, :, 2] * wexp[:, :, 3])[:, :, None]
# dw4 
dw[:, :, [3]]    += (do[:, :, 4] * q[:, :, 4] * k[:, :, 2] * v[:, :, 2] * wexp[:, :, 3])[:, :, None]
dw.flatten()

tensor([ 0.0000, -0.0063,  0.1960, -0.0139,  0.0000], grad_fn=<ViewBackward0>)

#### DW for L

In [660]:
dw = th.zeros_like(w)
# el primer w multiplica el estado de 0s y no hace nada. 
# el ultimo multiplica un estado que no se usa y no hace nada. 
for i in range(1, L-1): 
    for j in range(i, L-1): 
        dw[:, :, i:j+1] += (do[:, :, j+1] * q[:, :, j+1] * k[:, :, i-1] * v[:, :, i-1] * w[:, :, i:j+1].sum(dim=2).exp())[:, :, None]
dw.flatten()

tensor([ 0.0000, -0.0063,  0.1960, -0.0139,  0.0000], grad_fn=<ViewBackward0>)

#### DW for L: parallel?

In [854]:
dw = th.zeros_like(w)
# el primer w multiplica el estado de 0s y no hace nada. 
# el ultimo multiplica un estado que no se usa y no hace nada. 
doq = th.einsum('bhid,bhic->bhicd', do, q)
kv_ = th.einsum('bhic,bhid->bhicd', k, v)
for i in range(1, L-1): 
    for j in range(i, L-1): 
        wcum = w[:, :, i:j+1].sum(dim=2).exp()
        delta = th.einsum('bhcd,bhcd,bhc->bhc', doq[:, :, j+1], kv_[:, :, i-1], wcum )
        dw[:, :, i:j+1] += delta[:, :, None]
(w2.grad - dw).flatten()

tensor([ 0.0000e+00, -1.4552e-11,  0.0000e+00,  1.3970e-09,  0.0000e+00,
        -1.4901e-08, -2.9802e-08, -2.9802e-08, -8.9407e-08, -1.0431e-07,
         0.0000e+00,  1.4901e-08, -4.4703e-08,  4.4703e-08,  7.4506e-09,
        -1.4901e-08, -7.4506e-09, -1.6764e-08,  2.9802e-08,  0.0000e+00,
        -2.2352e-08, -1.4901e-08,  3.7253e-09,  0.0000e+00,  0.0000e+00,
        -1.3039e-08, -4.4703e-08,  3.7253e-09,  2.2352e-08,  3.7253e-09,
         3.7253e-09,  0.0000e+00], grad_fn=<ViewBackward0>)

##### Move cumsum outside loop

In [855]:
dw = th.zeros_like(w)
# el primer w multiplica el estado de 0s y no hace nada. 
# el ultimo multiplica un estado que no se usa y no hace nada. 
doq = th.einsum('bhid,bhic->bhicd', do, q)
kv_ = th.einsum('bhic,bhid->bhicd', k, v)
wcum = w.cumsum(dim=2)
for i in range(1, L-1): 
    for j in range(i, L-1): 
        delta = th.einsum('bhcd,bhcd,bhc->bhc', doq[:, :, j+1], kv_[:, :, i-1], wcum[:, :, j].exp() )
        dw[:, :, i:j+1] += delta[:, :, None]
    wcum[:, :, i:] -= wcum[:, :, [i]]
dw.flatten()

tensor([ 0.0000e+00, -1.4508e-04, -5.7476e-03, -5.2867e-03, -4.1076e-02,
        -1.6819e-01, -2.4436e-01, -3.0050e-01, -2.6155e-01, -1.6485e-01,
        -1.3500e-01, -2.1992e-01, -1.3706e-01, -1.4047e-01, -6.5147e-02,
        -3.5055e-02, -3.1936e-02, -2.4308e-02, -6.8514e-02, -6.6512e-02,
        -9.0481e-02, -7.5955e-02,  2.1371e-02,  3.1582e-02,  4.1205e-02,
        -2.3560e-02, -1.6657e-01,  4.0460e-02,  1.0806e-01,  2.4641e-02,
         5.8453e-02,  0.0000e+00], grad_fn=<ViewBackward0>)

### 20/04/2024 Check songlin's code for clues - hers is parallel and ours is hard to make fast 
* [x] Implemented
* [x] Check incorrect
* [ ] TODO: understand why incorrect
* [ ] See if fixable

In [724]:
def get_qaux(q, k, v, do):
    dq_aux = th.zeros(B, H, L, K)
    h = th.zeros(B, H, K, V)
    for i in range(L):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        dq_aux[:, :, i] = (do[:, :, i, None, :] * h).sum(-1)
        h = h * w_i[..., None] + kv_i
    return dq_aux

def get_kaux(q, k, v, do): 
    dk_aux = th.zeros(B, H, L, K)
    dh = th.zeros(B, H, K, V)
    for i in range(L-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        dk_aux[:, :, i] = (dh * v_i[..., None, :]).sum(-1)
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i
    return dk_aux



In [731]:
dw = th.zeros_like(w)
dk_aux = get_kaux(q, k, v, do)
dq_aux = get_qaux(q, k, v, do)
for i in range(L-2, -1, -1): 
    dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
dw.flatten()

tensor([-6.9849e-10, -2.2847e-03, -2.5597e-03,  0.0000e+00],
       grad_fn=<ViewBackward0>)

### 20/04/2024 Go over backward pass again
* [x] Expand as sum of parallelizable terms
* [x] Get recurrence relation
* [x] Assimilate to Songlin's notation: **But can't be! Her code is numerically wrong with just L=4 (ERR `~1e-4`) and our code was always correct: `Err ~ 1-8 on l=32`**
* [ ] Rewrite songlin's code to to match our notation 

In [856]:
def get_qaux_mine(q, k, v, do):
    dq_aux = th.zeros(B, H, L, K)
    h = th.zeros(B, H, K, V)
    for i in range(L):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        dq_aux[:, :, i] = (do[:, :, i, None, :] * h).sum(-1)
        h = h * w_i[..., None] + kv_i
        
    return dq_aux

def get_kaux_mine(q, k, v, do): 
    dk_aux = th.zeros(B, H, L, K)
    dh = th.zeros(B, H, K, V)
    hstr = ""
    for i in range(L-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        v_i = v[:, :, i]
        dk_aux[:, :, i] = (dh * v_i[..., None, :]).sum(-1)
        # print(f"dkaux_{i}= ({hstr}) * v_{i}")
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i
        # hstr = f"({hstr}) * w_{i} + do_{i} * o_{i}"
        # print(hstr, dh)
    return dk_aux



#### Dev version
* WARNING! Skip the 0 update for weird numerical precision errors

In [857]:
dk_aux = get_kaux_mine(q, k, v, do)
dq_aux = get_qaux_mine(q, k, v, do)
dw = th.zeros_like(w)
# WARNING! Skip the 0 update for weird numerical precision errors
for i in range(L-2, 0, -1): 
    dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
(w2.grad - dw).flatten()

tensor([0.0000e+00, 1.2048e-07, 1.2014e-07, 1.2107e-07, 1.2293e-07, 1.1921e-07,
        1.0431e-07, 8.9407e-08, 5.9605e-08, 4.4703e-08, 2.9802e-08, 2.9802e-08,
        1.4901e-08, 1.4901e-08, 1.4901e-08, 1.8626e-08, 1.8626e-08, 1.8626e-08,
        2.2352e-08, 2.2352e-08, 2.2352e-08, 1.4901e-08, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 3.7253e-09, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.8626e-09,
        0.0000e+00, 0.0000e+00], grad_fn=<ViewBackward0>)

In [849]:
w.grad - dw

tensor([[[[ 0.0000e+00],
          [-2.3283e-09],
          [ 0.0000e+00],
          [ 0.0000e+00],
          [ 0.0000e+00]]]], grad_fn=<SubBackward0>)

In [850]:
# dw2 = dq_aux[:, :, 3] * q[:, :, 3] - dk_aux[:, :, 2] * k[:, :, 2]
# dw1 = dw2 + dq_aux[:, :, 2] * q[:, :, 2] - dk_aux[:, :, 1] * k[:, :, 1]
# dw2, dw1

#### Reduce numerical errors! - Use cumsums

In [880]:
dk_aux = get_kaux_mine(q, k, v, do)
dq_aux = get_qaux_mine(q, k, v, do)
dw = th.zeros_like(w)
# WARNING! Skip the 0 update for weird numerical precision errors
delta = (dq_aux[:, :, 1:] * q[:, :, 1:] - dk_aux[:, :, :-1] * k[:, :, :-1]).flip((2,)).cumsum(dim=2).flip((2,))
dw = F.pad(delta, (0, 0, 0, 1))
dw[:, :, 0] = 0.
(w2.grad - dw).flatten()

tensor([ 0.0000e+00,  6.8729e-08,  6.8452e-08,  6.8918e-08,  7.0781e-08,
         5.9605e-08,  5.9605e-08,  5.9605e-08,  2.9802e-08,  2.9802e-08,
         2.9802e-08,  2.9802e-08,  1.4901e-08,  1.4901e-08,  1.4901e-08,
         1.1176e-08,  1.1176e-08,  9.3132e-09,  1.4901e-08,  7.4506e-09,
         0.0000e+00,  0.0000e+00, -1.3039e-08, -1.1176e-08, -1.1176e-08,
        -5.5879e-09,  0.0000e+00, -3.7253e-09,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00], grad_fn=<ViewBackward0>)

### 20/04/2024

####  Func

In [873]:
def songlin_bw(q, k, v, w, u, o, do, initial_state=None, output_final_state=False):
    q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    dq = torch.zeros_like(q)
    dq_aux = torch.zeros_like(q)

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        w_i = w[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h_i = (h + u[None, ..., None] * kv_i)
        dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
        dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
        dq[:, :, i] = dq_i
        dq_aux[:, :, i] = dq_aux_i
        h = h * w_i[..., None] + kv_i

    du = u.new_zeros(batch_size, n_heads, d_head_k)
    dh = torch.zeros_like(h)
    dk = torch.zeros_like(k)
    dk_aux = torch.zeros_like(k)
    dv = torch.zeros_like(v)

    for i in range(seq_len-1, -1, -1):
        d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
        k_i = k[:, :, i]
        v_i = v[:, :, i]
        du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
        du += du_i
        dk_i = (dh * v_i[..., None, :]).sum(-1)
        dk_aux[:, :, i] = dk_i
        dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
        dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
        dv_i += (dh * k_i[..., None]).sum(-2)

        dk[:, :, i] = dk_i
        dv[:, :, i] = dv_i
        dh = dh * w[:, :, i, :, None].exp() + d_kv_i

    # dw = q * dq_aux - k * dk_aux
    # dw = torch.zeros_like(w)
    # for i in range(seq_len-2, 0, -1):
    #     dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
    # delta = reverse cumsum of q*q_aux - k*k_aux
    delta = (dq_aux[:, :, 1:] * q[:, :, 1:] - dk_aux[:, :, :-1] * k[:, :, :-1]).flip((2,)).cumsum(dim=2).flip((2,))
    dw = F.pad(delta, (0, 0, 0, 1))

    du = du.sum(0)
    return dq, dk, dv, dw, du

#### Test

In [874]:
ddq, ddk, ddv, ddw, ddu = songlin_bw(q, k, v, w, u, o, do, initial_state=None, output_final_state=False)

In [876]:
(w2.grad - ddw).flatten()

tensor([ 6.8743e-08,  6.8729e-08,  6.8452e-08,  6.8918e-08,  7.0781e-08,
         5.9605e-08,  5.9605e-08,  5.9605e-08,  2.9802e-08,  2.9802e-08,
         2.9802e-08,  2.9802e-08,  1.4901e-08,  1.4901e-08,  1.4901e-08,
         1.1176e-08,  1.1176e-08,  9.3132e-09,  1.4901e-08,  7.4506e-09,
         0.0000e+00,  0.0000e+00, -1.3039e-08, -1.1176e-08, -1.1176e-08,
        -5.5879e-09,  0.0000e+00, -3.7253e-09,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00], grad_fn=<ViewBackward0>)

### 20/04/2024 Error was a simple padding by 1!

## AOB

### RWKV6Plus test

In [10]:
import torch 
import torch.nn.functional as F

def rwkv_inner_v6plus(
    r,
    k,
    v,
    w,
    u,
    kv_state,
    chunk_len: int = 24,
    precision_dtype: torch.dtype = torch.float32,
    precision_min_val: float = 0.005,
):
    """
    expects
    r : (B,H,L,K)
    k : (B,H,L,K)
    v : (B,H,L,V)
    w : (B,H,L,K) or (1,H,L,K)
    u : (1,H,1,K,V)
    kv_state : (B,H,K,V)
    """
    B, H, L, K = k.size()
    V = v.size(-1)
    T = chunk_len

    if L == 1:
        kv = k.mT @ v
        out = r @ (kv_state + u.mT * kv)
        kv_state = w.mT * kv_state + kv
        return out, kv_state
    else:
        # FIXME - support fast path for non-exact multiples
        # ensure it's an exact multiple
        if L % T != 0:
            T = 1

        N = L // T

        w = w.clamp(precision_min_val)

        # calculate cumulative decay in log space where it won't overflow
        w_log = w.float().log()  # (1,H,L,K) or (B,H,L,K)

        # chunked view of w_log
        wc_log = w_log.view(w.size(0), H, N, T, K)
        wc_log_cum = wc_log.cumsum(dim=-2)

        # chunked view of shifted_w_log
        shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1))

        # NOTE - we have to apply the decay weight from TWO ahead.. ONE ahead gets no decay (log==0)
        # pre-applied weights
        # left side is prior chunk (w_inter), right side is current chunk (w_intra)
        # without u...
        # w0   w1   w2   w3   | w4   w5   w6   w7
        # w1:4 w2:4 w3:4 w4:4 | w4:5 w4:6 w4:7 w4:8
        # with u...
        # w0   w1   w2   w3   | w4   w5   w6   w7
        # w1:4 w2:4 w3:4 w4:4 | w4:4 w4:5 w4:6 w4:7

        # ws decays the entire current state (representing t-1) to the prior block (t-2)
        ws = wc_log.sum(dim=-2, keepdim=True)  # 1HN1K or BHN1K
        # w_inter is the decay to the end of the current block, since it will be applied at the next iteration when current (t) becomes prior (t-1)
        # this formula because e.g. w1:4 = w0:4 - w0:1
        w_inter = ws - wc_log_cum  # 1HNTK or BHNTK (w^(T-1) ... w^0)
        # w_intra is the decay from the beginning of the current block (t), since it will be applied to current queries (t) against prior state (representing keys+values up to but not including block t)
        # this formula because e.g. w1:3 = w0:3 - w0
        w_intra = wc_log_cum - wc_log  # 1HNTK or BHNTK (w^0 ... w^(T-2))

        ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3))  # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!!
        w_inter = w_inter.exp().to(r.dtype)  # 1HNTK or BHNTK
        w_intra = w_intra.exp().to(r.dtype)  # 1HNTK or BHNTK

        # precompute u contrib
        u = u.squeeze(0).squeeze(-3).to(r.dtype)
        uterm_out = torch.einsum('bhic,hcd,bhic,bhid->bhid', r, u, k, v)

        # chunked view of r, k, v
        r = r.view(B, H, N, T, K)
        k = k.view(B, H, N, T, K)
        v = v.view(B, H, N, T, V)

        # parallel calculation of all intra-chunk attention contributions
        wc_log_offset = shifted_wc_log_cum[..., T // 2 : T // 2 + 1, :]  # B,H,N,1,K
        r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp()  # B,H,N,T,K
        k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp()  # B,H,N,T,K
        a = ((r * r_decay) @ (k * k_inv_decay).mT).to(r.dtype).tril(-1)  # B,H,N,T,T
        # add u term to attention (NOTE - the tril(-1) above zeroed the diagonal)
        # a = a + torch.einsum("bhntk,bhntk->bhnt", r, u * k).diag_embed()
        out = a @ v  # BHNTV
        # alternate way of adding in u
        # out = out + torch.einsum('bhntk,bhntk,bhntv->bhntv', r, u * k, v)

        # parallel precalculation of chunked (k*wk).mT@v for use in recurrent state calc below
        wkv = (k * w_inter).mT @ v  # BHNKV
        wkv = list(wkv.unbind(dim=-3))  # N x BHKV

        # recurrent calculation of all states
        states = []
        for i in range(N):
            states.append(kv_state)
            kv_state = kv_state * ws[i] + wkv[i]  # BHKV
            # equivalent non-precalced version
            # wkv = (k[...,i,:,:] * wk[...,i,:,:]).mT @ v[...,i,:,:]
            # kv_state = kv_state * ws[i] + wkv
        states = torch.stack(states, dim=2)  # BHNKV

        # parallel application of all r to states
        out = out + (r * w_intra) @ states  # BHNTV
        out = out.view(B, H, L, V)

        out = out + uterm_out

        return out, kv_state

In [11]:
import torch as th

def gen_inputs(B, H, L, K, V): 
    th.manual_seed(17)
    device = "cpu"
    rt = th.randn(B, H, L, K, device=device, requires_grad=True)
    kt = th.randn(B, H, L, K, device=device, requires_grad=True)
    vt = th.randn(B, H, L, V, device=device, requires_grad=True)
    wt = -2 + 1e-1 * th.randn(B, H, L, K, device=device)
    wt = -th.exp(wt)
    wt.requires_grad = True
    ut = th.randn(H, K, V, device=device, requires_grad=True)
    return rt, kt, vt, wt, ut

In [12]:
B, H, L, K, V = 1, 2, 8, 4, 4
rt, kt, vt, wt, ut = gen_inputs(B, H, L, K, V)

In [13]:
rwkv_inner_v6plus(rt, kt, vt, wt, ut, torch.zeros(B,H,K,V))

(tensor([[[[-0.0311,  0.1554, -0.0575, -0.1376],
           [ 1.1602,  0.1685,  0.0585, -0.2987],
           [ 0.6814, -0.1242, -0.8078, -0.6277],
           [ 2.6293, -1.1666,  0.0436,  0.4222],
           [-1.5961, -0.9543,  0.3645,  1.4839],
           [-0.3761, -0.6710,  0.2651,  0.1544],
           [ 1.2655,  1.7080,  2.6342,  2.4107],
           [-1.0588,  0.7008,  0.8531,  0.7806]],
 
          [[ 0.0457,  0.1455, -0.3365, -0.5071],
           [ 0.1446, -2.2181, -1.6158,  2.7262],
           [-1.0139,  2.1165,  6.7032, -0.8199],
           [-0.9424,  2.4368,  2.2276,  0.7987],
           [ 0.6797,  1.7517,  1.6228,  0.1750],
           [-0.2344,  0.1294, -1.1555,  0.5780],
           [-2.9541, -0.0378,  9.9134, -3.1395],
           [-1.6844, -0.3199,  3.7610,  1.7704]]]], grad_fn=<AddBackward0>),
 tensor([[[[-1.8724e-01, -2.7551e+00,  2.1327e+00, -4.3973e-02],
           [ 3.5296e-02,  3.3468e-01, -2.7179e-01, -2.5437e-03],
           [-1.2411e-01, -1.7561e+00,  1.3643e+00, -2.5