In [None]:
from flax import linen as nn
import jax
from jax import numpy as jnp
from jax import random
import optax


vocab_size = 64
batch_size = 8
block_size = 2048
n_embd = 2048
head_size = 2048


class QKV(nn.Module):

    @nn.compact
    def __call__(self, x, training=True, kvcache=None):
        kvcache_enabled = kvcache is not None
        inference_with_kvcache_enabled = not training and kvcache_enabled
        if inference_with_kvcache_enabled:
            x = x[:, -1:, :]

        q = nn.Dense(head_size, use_bias=False)(x)
        k = nn.Dense(head_size, use_bias=False)(x)
        v = nn.Dense(head_size, use_bias=False)(x)

        if inference_with_kvcache_enabled:
            if kvcache:
                k = jnp.concatenate((kvcache['k'], k), axis=1)[:, -block_size:, :]
                v = jnp.concatenate((kvcache['v'], v), axis=1)[:, -block_size:, :]
            kvcache['k'] = k
            kvcache['v'] = v

        return q, k, v


class Head(nn.Module):

    @nn.compact
    def __call__(self, context, training=True, kvcache=None):
        context_block = context[:, -block_size:]
        tok_embed = nn.Embed(vocab_size, n_embd)(context_block)
        pos = jnp.arange(0, context_block.shape[1])
        pos_embed = nn.Embed(block_size, n_embd)(pos)
        x = tok_embed + pos_embed

        _kvcache = kvcache
        if not training and kvcache is not None:
            if 'qkv' not in kvcache:
                kvcache['qkv'] = {}
            _kvcache = kvcache['qkv']
        q, k, v = QKV()(x, training=training, kvcache=_kvcache)

        wei = q @ jnp.transpose(k, (0, -1, -2))
        wei = nn.softmax(wei, axis=-1)
        out = wei @ v
        logits = nn.Dense(vocab_size)(out)
        return logits

    def loss(self, params, context, labels, training):
        logits = self.apply(params, context, training=training)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

    def generate(self, prng_key, params, context, max_new_tokens=1, use_kvcache=True):
        kvcache = {} if use_kvcache else None
        for _ in range(max_new_tokens):
            logits = self.apply(params, context, training=False, kvcache=kvcache)[:, -1, :]
            prng_key, prng_subkey = random.split(prng_key)
            context_next = random.categorical(prng_subkey, logits)
            context = jnp.concatenate((context, jnp.reshape(context_next, (-1, 1))), axis=1)
        return context


In [2]:
model = Head()
prng_key = random.key(0)
context = jnp.zeros((1, 1), dtype=int)
params = model.init(prng_key, context)

# Train

In [3]:
max_iters = 5

optimizer = optax.adamw(learning_rate=1e-3)
params = model.init(prng_key, jnp.zeros((1, block_size), dtype=int), training=True)
optimizer_state = optimizer.init(params)

grad_loss = jax.grad(model.loss, argnums=0)

context = random.randint(prng_key, (batch_size, block_size), 0, vocab_size)
labels = random.randint(prng_key, (batch_size, block_size), 0, vocab_size)

for i in range(max_iters):
    print(f"Iteration: {i}")
    grad = grad_loss(params, context, labels, training=True)
    updates, optimizer_state = optimizer.update(grad, optimizer_state, params)
    params = optax.apply_updates(params, updates)

# Serve

In [None]:
from functools import partial

def run_timeit(f, x):
    result = %timeit -o f(x)
    return result

max_new_tokens = [100, 200, 300, 400]

_ = model.generate(prng_key, params, context, 1, use_kvcache=False)
generator = lambda n: model.generate(prng_key, params, context, n, use_kvcache=False)
timeits_nocache = [run_timeit(generator, n) for n in max_new_tokens]

_ = model.generate(prng_key, params, context, 1, use_kvcache=True)
generator = lambda n: model.generate(prng_key, params, context, n, use_kvcache=True)
timeits_kvcache = [run_timeit(generator, n) for n in max_new_tokens]



In [None]:
import matplotlib.pyplot as plt

plt.plot(max_new_tokens, [t.average for t in timeits_nocache])
plt.plot(max_new_tokens, [t.average for t in timeits_kvcache])

# TODO
1. If I run the training before serving, the serving will take an extremely long time. Why?
1. Starting input context may be [0, 1, 2].
1. What happens when there are multiple layers?
1. How to simplify code? For example, can we remove `if training`, not use `block_size` or `T`, and avoid using `if`?