In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import model

In [None]:
def _forward(x):
    a = model.RWKV(num_layers=12, vocab_size=50277, n_embd=1024, dim_att=1024, dim_ffn=1024*4)
    return a(x)

policy = jmp.get_policy('p=f16,c=f16,o=f16')
policy_wkv = jmp.get_policy('p=f32,c=f32,o=f16')
hk.mixed_precision.set_policy(_forward, policy)
hk.mixed_precision.set_policy(model.WKV, policy_wkv)

a = hk.transform(_forward)

dummy_x = jnp.ones((1024), dtype=int)
rng_key = jax.random.PRNGKey(42)

i = a.init(rng=rng_key, x=dummy_x)

i

In [2]:
jnp.arange(5), jnp.array([0,1,2,3,4])

(Array([0, 1, 2, 3, 4], dtype=int32), Array([0, 1, 2, 3, 4], dtype=int32))

In [3]:
# why does it need `,`
jnp.arange(5)[jnp.newaxis,:]

Array([[0, 1, 2, 3, 4]], dtype=int32)

In [4]:
jnp.array([[1, 2, 3], [4, 5, 6]]).T

Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

In [9]:
# Time decay/first are monotonous arrays
a = jnp.array([1,2,3])
a = a[:,jnp.newaxis]
aa = jnp.expand_dims(a, 1)
aa, a, a.shape

(Array([[1],
        [2],
        [3]], dtype=int32),
 Array([[1],
        [2],
        [3]], dtype=int32),
 (3, 1))

In [10]:
T = 3 + 1
b = jnp.arange(-(T-2), 1)[jnp.newaxis, :]
# this should be ok for the operation?
matrix = jnp.concatenate([a*b, a], axis=1)
a, b, a*b, matrix 

(Array([[1],
        [2],
        [3]], dtype=int32),
 Array([[-2, -1,  0]], dtype=int32),
 Array([[-2, -1,  0],
        [-4, -2,  0],
        [-6, -3,  0]], dtype=int32),
 Array([[-2, -1,  0,  1],
        [-4, -2,  0,  2],
        [-6, -3,  0,  3]], dtype=int32))

In [7]:
@jax.jit
def wkv(w, u, k, v):
    T, C = k.shape
    time_curve = jnp.arange(-T+2, 1)[jnp.newaxis, ...]
    k, v = map(jnp.array, [[k], [v]])
    w = -jnp.exp(w)
    ek = jnp.exp(k.transpose((0, 2, 1)))
    ekv = ek * v.transpose((0, 2, 1))
    ew_time = jnp.expand_dims(jnp.exp(w), 1) * time_curve
    time_w = jnp.concatenate([ew_time, jnp.expand_dims(u, 1)], axis=1)
    w = jnp.expand_dims(jnp.exp(time_w), 1)

    # print(time_w.shape, ew_time.shape, time_w, ew_time)
    # print(ew_time.shape, time_w.shape, w.shape, ekv.shape, ek.shape)

    def pad(x): return jnp.pad(x, [(0, 0), (0, 0), (T-1, 0)])

    wkv = jax.lax.conv_general_dilated(pad(ekv), w, (1,), [(
        0, 0)], dimension_numbers=('NCW', 'OIW', 'NCW'), feature_group_count=C)
    wk = jax.lax.conv_general_dilated(pad(ek), w, (1,), [(
        0, 0)], dimension_numbers=('NCW', 'OIW', 'NCW'), feature_group_count=C)
    return (wkv / wk).transpose(0, 2, 1)[0].T

%timeit wkv(jnp.array([1. , 2., 3.]), jnp.array([0.1, 0.2, 0.3]), jnp.array([[11, 22, 33], [44, 55, 66], [77, 88, 99]], dtype=jnp.float32), jnp.array([[11, 22, 33], [44, 55, 66], [77, 88, 99]], dtype=jnp.float32))

564 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit model.WKV(jnp.array([1. , 2., 3.]), jnp.array([0.1, 0.2, 0.3]), jnp.array([[11, 22, 33], [44, 55, 66], [77, 88, 99]], dtype=jnp.float32), jnp.array([[11, 22, 33], [44, 55, 66], [77, 88, 99]], dtype=jnp.float32))

556 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
