In [1]:
import os
COMBINE_BYTES = 130*1024*1024 // 2
os.environ["XLA_FLAGS"] = f"--xla_gpu_simplify_all_fp_conversions --xla_gpu_all_reduce_combine_threshold_bytes={COMBINE_BYTES}"

In [2]:
import jax
import jax.numpy as jnp
from safetensors.numpy import load_file

rwkv = load_file("../../RWKV-LM-deepspeed/RWKV-v4neo/RWKV-4-Pile-430M-20220808-8066.safetensors")

rwkv.keys()

dict_keys(['blocks.0.att.key.weight', 'blocks.0.att.output.weight', 'blocks.0.att.receptance.weight', 'blocks.0.att.time_decay', 'blocks.0.att.time_first', 'blocks.0.att.time_mix_k', 'blocks.0.att.time_mix_r', 'blocks.0.att.time_mix_v', 'blocks.0.att.value.weight', 'blocks.0.ffn.key.weight', 'blocks.0.ffn.receptance.weight', 'blocks.0.ffn.time_mix_k', 'blocks.0.ffn.time_mix_r', 'blocks.0.ffn.value.weight', 'blocks.0.ln0.bias', 'blocks.0.ln0.weight', 'blocks.0.ln1.bias', 'blocks.0.ln1.weight', 'blocks.0.ln2.bias', 'blocks.0.ln2.weight', 'blocks.1.att.key.weight', 'blocks.1.att.output.weight', 'blocks.1.att.receptance.weight', 'blocks.1.att.time_decay', 'blocks.1.att.time_first', 'blocks.1.att.time_mix_k', 'blocks.1.att.time_mix_r', 'blocks.1.att.time_mix_v', 'blocks.1.att.value.weight', 'blocks.1.ffn.key.weight', 'blocks.1.ffn.receptance.weight', 'blocks.1.ffn.time_mix_k', 'blocks.1.ffn.time_mix_r', 'blocks.1.ffn.value.weight', 'blocks.1.ln1.bias', 'blocks.1.ln1.weight', 'blocks.1.ln2.b

In [3]:
print(rwkv["blocks.0.att.receptance.weight"].shape, rwkv["blocks.0.att.time_decay"].shape)

BLK = "blocks."
blks = 1
for k in rwkv.keys():
    if k.startswith(BLK):
        blks = max(int(k.split('.')[1]) + 1, blks)
print(blks)


(1024, 1024) (1024,)
24


In [13]:
# @jax.jit
def AT(x, state, i: int, time_mix_k, time_mix_v, time_mix_r, key, output, receptance, time_decay, time_first, value):
    xk = x * time_mix_k + state[5*i+1].astype(time_mix_k.dtype) * (1 - time_mix_k)
    xv = x * time_mix_v + state[5*i+1].astype(time_mix_v.dtype) * (1 - time_mix_v)
    xr = x * time_mix_r + state[5*i+1].astype(time_mix_r.dtype) * (1 - time_mix_r)
    state = state.at[5*i+1].set(x)

    r = jax.nn.sigmoid(xr @ receptance)
    k = (xk @ key).astype(jnp.float32)
    v = (xv @ value).astype(jnp.float32)

    aa = state[5*i+2]
    bb = state[5*i+3]
    pp = state[5*i+4]
    ww = time_first + k
    p = jnp.maximum(pp, ww)
    e1 = jnp.exp(pp - p)
    e2 = jnp.exp(ww - p)
    a = e1 * aa + e2 * v
    b = e1 * bb + e2
    ww = pp + time_decay
    p = jnp.maximum(ww, k)
    e1 = jnp.exp(ww - p)
    e2 = jnp.exp(k - p)
    state = state.at[5*i+2].set(e1 * aa + e2 * v)
    state = state.at[5*i+3].set(e1 * bb + e2)
    state = state.at[5*i+4].set(p)

    wkv = (a / b).astype(key.dtype)
    return (r * wkv) @ output, state

class Attention():
    def __init__(self, sd, i, dtype=jnp.float32):
        self.dtype = dtype
        self.key = jax.device_put(jnp.transpose(sd[f"blocks.{i}.att.key.weight"].astype(dtype)))
        self.output = jax.device_put(jnp.transpose(sd[f"blocks.{i}.att.output.weight"].astype(dtype) / (2 ** (i // 6))))
        self.receptance = jax.device_put(jnp.transpose(sd[f"blocks.{i}.att.receptance.weight"].astype(dtype)))
        # f32 only
        self.time_decay = -jnp.exp(jax.device_put(sd[f"blocks.{i}.att.time_decay"].astype(jnp.float32)))
        self.time_first = jax.device_put(sd[f"blocks.{i}.att.time_first"].astype(jnp.float32))
        # ---
        self.time_mix_k = jax.device_put(sd[f"blocks.{i}.att.time_mix_k"].astype(dtype).squeeze())
        self.time_mix_v = jax.device_put(sd[f"blocks.{i}.att.time_mix_v"].astype(dtype).squeeze())
        self.time_mix_r = jax.device_put(sd[f"blocks.{i}.att.time_mix_r"].astype(dtype).squeeze())
        self.value = jax.device_put(jnp.transpose(sd[f"blocks.{i}.att.value.weight"].astype(dtype)))

In [11]:
att = Attention(rwkv, 0)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1024,), dtype=att.dtype)
state = jnp.array([0.0]*5*1024, dtype=jnp.float32).reshape((5,1024))
# AT(x, state, 0, att.time_mix_k, att.time_mix_v, att.time_mix_r, att.key, att.output, att.receptance, att.time_decay, att.time_first, att.value)
%timeit AT(x, state, 0, att.time_mix_k, att.time_mix_v, att.time_mix_r, att.key, att.output, att.receptance, att.time_decay, att.time_first, att.value)

AttributeError: 'numpy.ndarray' object has no attribute 't'

In [16]:
# @jax.jit
def FFN(x, state, i: int, key, receptance, time_mix_k, time_mix_r, value):
    xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
    xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
    state = state.at[5*i+0].set(x)

    r = jax.nn.sigmoid(xr @ receptance)
    k = jnp.square(jax.nn.relu(xk @ key))
    # print(i, k, value)
    kv = k @ value

    return r * kv, state

class Ffn():
    def __init__(self, sd, i, dtype=jnp.float32):
        self.dtype = dtype
        self.key = jax.device_put(jnp.transpose(sd[f"blocks.{i}.ffn.key.weight"].astype(dtype)))
        self.receptance = jax.device_put(jnp.transpose(sd[f"blocks.{i}.ffn.receptance.weight"].astype(dtype)))
        self.time_mix_k = jax.device_put(sd[f"blocks.{i}.ffn.time_mix_k"].astype(dtype).squeeze())
        self.time_mix_r = jax.device_put(sd[f"blocks.{i}.ffn.time_mix_r"].astype(dtype).squeeze())
        self.value = jax.device_put(jnp.transpose(sd[f"blocks.{i}.ffn.value.weight"].astype(dtype)) / (2 ** (i // 6)))
        # print(i, self.value)

In [6]:
ffn = Ffn(rwkv, 0)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1024,), dtype=att.dtype)
state = jnp.array([0.0]*5*1024, dtype=jnp.float32).reshape((5,1024))
# FFN(x, state, 0, ffn.key, ffn.receptance, ffn.time_mix_k, ffn.time_mix_r, value)
%timeit FFN(x, state, 0, ffn.key, ffn.receptance, ffn.time_mix_k, ffn.time_mix_r, ffn.value)

NameError: name 'att' is not defined

In [6]:
# @jax.jit
def LN(x, w, b):
    mean = jnp.mean(x)
    v = jnp.var(x)
    o = x - mean
    i = w * jax.lax.rsqrt(v + 1e-5)
    return o * i + b

class Ln():
    def __init__(self, sd, n, dtype=jnp.float32):
        self.dtype = dtype
        self.weight = jax.device_put(sd[f"{n}.weight"].astype(dtype))
        self.bias = jax.device_put(sd[f"{n}.bias"].astype(dtype))

In [18]:
class Block():
    def __init__(self, sd, i, dtype=jnp.float32):
        self.dtype = dtype
        self.ln1 = Ln(sd, f"blocks.{i}.ln1", dtype=jnp.float32)
        self.att = Attention(sd, i, dtype=dtype)
        self.ln2 = Ln(sd, f"blocks.{i}.ln2", dtype=jnp.float32)
        self.ffn = Ffn(sd, i, dtype=dtype)

def BLOCK(x, state, i, block: Block):
    xx = LN(x, block.ln1.weight, block.ln1.bias)
    xx, state = AT(xx, state, i, block.att.time_mix_k, block.att.time_mix_v, block.att.time_mix_r, block.att.key, block.att.output, block.att.receptance, block.att.time_decay, block.att.time_first, block.att.value)
    x = x + xx
    xx = LN(x, block.ln2.weight, block.ln2.bias)
    xx, state = FFN(xx, state, i, block.ffn.key, block.ffn.receptance, block.ffn.time_mix_k, block.ffn.time_mix_r, block.ffn.value)
    return x + xx, state


In [27]:
block = Block(rwkv, 0)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1024,), dtype=block.dtype)
state = jnp.array([0.0]*5*1024, dtype=jnp.float32).reshape((5,1024))
print(BLOCK(x, state, 0, block))
# %timeit BLOCK(x, state, 0, block)
del block, x, state, key

TypeError: dot_general requires contracting dimensions to have the same shape, got (1024,) and (4096,).

In [14]:
class RWKV():
    def __init__(self, sd, blks, dtype=jnp.float32):
        emb = jax.device_put(sd[f"emb.weight"].astype(dtype))
        ln0 = Ln(sd, "blocks.0.ln0", dtype=jnp.float32)
        ln = jax.vmap(LN, (0, None, None), 0)
        self.emb = ln(emb, ln0.weight, ln0.bias)
        self.blocks = [Block(sd, i, dtype=dtype) for i in range(blks)]
        self.ln_out = Ln(sd, "ln_out", dtype=jnp.float32)
        self.head = jax.device_put(sd[f"head.weight"].astype(dtype))
    
    def __call__(self, x, state):
        x = self.emb[x]
        for i,b in enumerate(self.blocks):
            x, state = BLOCK(x, state, i, b)
        x = LN(x, self.ln_out.weight, self.ln_out.bias)
        return self.head @ x, state

In [19]:
from jax.config import config
config.update("jax_debug_nans", True)
config.update("jax_debug_infs", True)
config.update("jax_disable_jit", True)

rwkv_ = RWKV(rwkv, blks, dtype=jnp.float16)
key = jax.random.PRNGKey(0)
state = jnp.array(([0.0]*1024*4 + [-1e30]*1024)*blks, dtype=jnp.float32).reshape((5*blks,1024))
print(rwkv_.emb.shape, state[4+5])
rwkv_(147, state)
# %timeit rwkv_(0, state) 

(50277, 1024) [-1.e+30 -1.e+30 -1.e+30 ... -1.e+30 -1.e+30 -1.e+30]
0 [0. 0. 0. ... 0. 0. 0.] [[ 0.1128   -0.1113   -0.4688   ... -0.08496  -0.2393    0.05493 ]
 [ 0.2207   -0.4395    0.08154  ...  0.3496    0.1865    0.6406  ]
 [-0.1396    0.0238    0.1396   ...  0.2617    0.2139   -0.332   ]
 ...
 [-0.0786   -0.012695 -0.1377   ... -0.1299   -0.3828   -0.2246  ]
 [ 0.3164   -0.1289   -0.4766   ...  0.4238    0.1484    0.1777  ]
 [ 0.02917   0.8125   -0.1416   ... -0.1196   -0.2168   -0.248   ]]
1 [0. 0. 0. ... 0. 0. 0.] [[-3.516e-02 -6.250e-01  3.457e-01 ...  1.099e-02  2.026e-02  1.445e-01]
 [ 1.338e-01  2.124e-02  2.773e-01 ... -2.344e-02 -2.217e-01 -1.855e-01]
 [ 1.934e-01 -3.906e-01  2.734e-01 ...  6.104e-02  1.514e-01  3.555e-01]
 ...
 [ 9.863e-02  1.118e-01 -1.680e-01 ...  4.355e-01 -3.340e-01  1.543e-01]
 [ 9.375e-02 -7.178e-02 -3.066e-01 ... -1.045e-01  8.398e-02 -3.770e-01]
 [ 4.883e-04 -2.988e-01  3.262e-01 ... -2.852e-01  3.223e-01  1.434e-02]]
2 [0. 0. 0. ... 0. 0. 0.] [[

(Array([ -0.3326708 , -15.422455  ,   0.89521873, ...,  -2.268542  ,
         -1.8543549 ,  -0.48138535], dtype=float32),
 Array([[ 0.09390897, -0.0342318 ,  0.02709531, ...,  0.08806008,
          0.07034266, -0.07390796],
        [ 0.21836135, -0.03054377, -0.11271324, ..., -0.00916648,
          0.0862558 ,  0.09088991],
        [ 0.97006464, -0.02521658, -0.43896416, ...,  0.04065698,
          0.51248145,  1.3487475 ],
        ...,
        [-0.7317802 , -1.2414354 , -4.2590995 , ...,  2.505656  ,
          1.2210972 , -1.3080119 ],
        [ 1.        ,  1.        ,  1.        , ...,  1.        ,
          1.        ,  1.        ],
        [ 2.0876906 ,  0.30494115,  4.2117324 , ...,  2.5139866 ,
          4.4783044 ,  2.0531316 ]], dtype=float32))