In [1]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import haiku as hk
from einops import rearrange, reduce, repeat, einsum
from functools import partial
from dataclasses import dataclass

In [2]:
@dataclass
class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    sliding_window: int
    norm_eps: float
    vocab_size: int
    max_batch_size: int = 0

In [3]:
key_seq = hk.PRNGSequence(42)
nk = lambda : next(key_seq)

In [4]:
_ = """
args = ModelArgs(
    dim = 4096,
    n_layers = 32,
    head_dim = 128,
    hidden_dim = 14336,
    n_heads = 32,
    n_kv_heads = 8,
    norm_eps = 1e-05,
    sliding_window = 4096,
    vocab_size = 32000,
    max_batch_size = 4
)
"""

In [5]:
from jax._src.lax.control_flow.loops import _interleave


def precompute_freqs_cis(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[:(dim // 2)] / dim))
    t = jnp.arange(0, end)
    freqs = jnp.outer(t, freqs).astype(jnp.float32)
    return jax.lax.complex(jnp.cos(freqs), jnp.sin(freqs))


def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = xq.astype(jnp.float32)
    xq_ = jax.lax.complex(xq_[..., ::2], xq_[..., 1::2]) * freqs_cis
    xq_ = _interleave(jnp.real(xq_), jnp.imag(xq_), -1)
    xk_ = xk.astype(jnp.float32)
    xk_ = jax.lax.complex(xk_[..., ::2], xk_[..., 1::2]) * freqs_cis
    xk_ = _interleave(jnp.real(xk_), jnp.imag(xk_), -1)
    return xq_.astype(xq.dtype), xk_.astype(xk.dtype)


def get_read_idxs(i, window):    
    return jnp.arange(i - min(window, i + 1) + 1, i + 1) % window 


class Attention(hk.Module):
    def __init__(self, args, name=None):
        super().__init__(name)
        self.args = args
        self.scale = args.head_dim ** -0.5

    def __call__(self, x, freqs_cis, positions, read_idxs, cache, mask, use_cache):
        # assert False
        args = self.args
        wq = hk.Linear(args.n_heads * args.head_dim, with_bias=False, name='wq')
        wk = hk.Linear(args.n_kv_heads * args.head_dim, with_bias=False, name='wk')
        wv = hk.Linear(args.n_kv_heads * args.head_dim, with_bias=False, name='wv')
        wo = hk.Linear(args.dim, with_bias=False, name='wo')

        b, _, _ = x.shape
        # reshape q into groups g for GQA
        q = rearrange(wq(x), 'b l (g nkv dh) -> b g nkv l dh', dh=args.head_dim, nkv=args.n_kv_heads)
        k = rearrange(wk(x), 'b l (nkv dh) -> b nkv l dh', dh=args.head_dim)
        v = rearrange(wv(x), 'b l (nkv dh) -> b nkv l dh', dh=args.head_dim)
        q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        # update cache
        write_idxs = positions[-args.sliding_window:] % args.sliding_window
        k_cache, v_cache = cache
        k_cache = k_cache.at[:b, :, write_idxs].set(k[:, :, -args.sliding_window:])
        v_cache = v_cache.at[:b, :, write_idxs].set(v[:, :, -args.sliding_window:])
    
        k_ = k_cache[:b, :, read_idxs]
        v_ = v_cache[:b, :, read_idxs]

        attention = partial(self.attention, q=q)
        out = lax.cond(
            use_cache, lambda : attention(k=k_, v=v_, mask=0), lambda : attention(k=k, v=v, mask=mask)
        )

        # out proj and updated cache
        return wo(out), (k_cache, v_cache)

    def attention(self, q, k, v, mask):
        scores = einsum(q, k, 'b g h i k, b h j k -> b g h i j') * self.scale
        sfmx = jax.nn.softmax(scores + mask, axis=-1).astype(q.dtype)
        heads = einsum(sfmx, v, 'b g h i k, b h k j -> b g h i j')
        return rearrange(heads, 'b g nkv l dh -> b l (g nkv dh)')

In [6]:
class FeedForward(hk.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def __call__(self, x):
        args = self.args
        w1 = hk.Linear(args.hidden_dim, with_bias=False, name='w1')
        w2 = hk.Linear(args.dim, with_bias=False, name='w2')
        w3 = hk.Linear(args.hidden_dim, with_bias=False, name='w3')
        return w2(jax.nn.silu(w1(x)) * w3(x))

In [7]:
class TransformerBlock(hk.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def __call__(self, x, freqs_cis, positions, read_idxs, cache, mask, use_cache):
        args = self.args
        attention = Attention(args)
        feed_forward = FeedForward(args)
        attention_norm = hk.RMSNorm(-1, eps=args.norm_eps)
        ffn_norm = hk.RMSNorm(-1, eps=args.norm_eps)

        att, cache = attention(attention_norm(x), freqs_cis, positions, read_idxs, cache, mask, use_cache)
        h = x + att
        return h + feed_forward(ffn_norm(h)), cache

In [8]:
class Transformer(hk.Module):
    def __init__(self, args, max_seq_len):
        super().__init__()
        self.args = args
        self.freqs_cis = precompute_freqs_cis(args.head_dim, max_seq_len)

    def __call__(self, input_ids, positions, read_idxs, cache, use_cache=False):
        args = self.args
        norm = hk.RMSNorm(-1, eps=args.norm_eps)
        tok_embeddings = hk.Embed(args.vocab_size, args.dim)
        output = hk.Linear(args.vocab_size, with_bias=False)
        layers = [TransformerBlock(args) for _ in range(args.n_layers)]

        b, seq_len = input_ids.shape

        h = tok_embeddings(input_ids)
        freqs_cis = self.freqs_cis[positions]

        tensor = jnp.ones((seq_len, seq_len), dtype=h.dtype)
        mask = jnp.tril(tensor, k=0).astype(h.dtype)
        mask = jnp.triu(mask, k=-self.args.sliding_window)
        mask = jnp.log(mask)

        k_cache, v_cache = cache

        for i, layer in enumerate(layers):
            kh, vh = k_cache[:b, i], v_cache[:b, i]
            h, (kh, vh) = layer(h, freqs_cis, positions, read_idxs, (kh, vh), mask, use_cache)
            k_cache = k_cache.at[:b, i].set(kh)
            v_cache = v_cache.at[:b, i].set(vh)
        
        return output(norm(h)).astype(jnp.float32), (k_cache, v_cache)

In [9]:
seq_len = 8
max_seq_len = 512

In [10]:
my_args = ModelArgs(
    dim=12 * 6,
    n_layers=4,
    head_dim=6,
    hidden_dim=int(12 * 6 * 3.5),
    n_heads=12,
    n_kv_heads=4,
    sliding_window=8,
    norm_eps=1e-5,
    vocab_size=10000,
    max_batch_size=4
)

In [11]:
@hk.transform
def f(input_ids, positions, read_idxs, cache, use_cache):
    return Transformer(my_args, max_seq_len)(input_ids, positions, read_idxs, cache, use_cache)

input_ids = jax.random.randint(nk(), (1, seq_len), 0, my_args.vocab_size)
positions = jnp.arange(input_ids.shape[1])
cache = (
    jnp.zeros((my_args.max_batch_size, my_args.n_layers, my_args.n_kv_heads, my_args.sliding_window, my_args.head_dim)),
    jnp.zeros((my_args.max_batch_size, my_args.n_layers, my_args.n_kv_heads, my_args.sliding_window, my_args.head_dim))
)
mask = 0
use_cache = True

In [12]:
rd_idx = partial(get_read_idxs, window=my_args.sliding_window)
params = f.init(nk(), input_ids, positions, rd_idx(seq_len - 1), cache, use_cache)

TypeError: 'TransformerBlock' object is not callable

In [None]:
fn = f.apply
jit_fn = jax.jit(fn)

In [None]:
%timeit out, (kh, vh) = fn(params, nk(), input_ids, positions, rd_idx(seq_len - 1), cache, use_cache)

In [None]:
%timeit out, (kh, vh) = jit_fn(params, nk(), input_ids, positions, rd_idx(seq_len - 1), cache, use_cache)