JAX implementation of Andrej Karpathy's [nanoGPT Colab](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)

In [None]:
import wget
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filename = wget.download(url, '/tmp/tinyshakespeare.txt')
filename

In [None]:
with open(filename, "r", encoding="utf-8") as f:
    text = f.read()

print("length of dataset in characters: ", len(text))
print(text)

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)

In [None]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode("hii there"))
print(decode(encode("hii there")))

In [None]:
from jax import numpy as jnp
data = jnp.array(encode(text))
print(data.shape)
data

In [None]:
n = int(0.9 * len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8
train_data[:block_size + 1]

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

In [None]:
from jax import random

seed = 1701
prng_key = random.key(seed)

batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(prng_key, split):
    data = train_data if split == 'train' else val_data
    ix = random.randint(prng_key, (batch_size,), 0, len(data) - block_size)
    x = jnp.stack([data[i:i + block_size] for i in ix])
    y = jnp.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y

prng_key, prng_subkey = random.split(prng_key)
xb, yb = get_batch(prng_subkey, 'train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

In [None]:
import flax
import flax.linen as nn
import jax
import optax

class BigramLanguageModel(nn.Module):

    @nn.compact
    def __call__(self, idx):
        logits = nn.Embed(vocab_size, vocab_size)(idx)
        return logits

    def loss(self, embeddings, idx, labels):
        # idx and labels are both (B, T) tensor of integers
        logits = self.apply(embeddings, idx) # (B, T, C)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

    def generate(self, prng_key, embeddings, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits = self.apply(embeddings, idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            prng_key, prng_subkey = random.split(prng_key)
            # sample from the distribution
            idx_next = random.categorical(prng_key, logits) # (B, 1)
            # append sampled index to the running sequence
            idx = jnp.concatenate((idx, jnp.reshape(idx_next, (-1, 1))), axis=1) # (B, T+1)
        return idx

m = BigramLanguageModel()
prng_key, prng_subkey = random.split(prng_key)
embeddings = m.init(prng_subkey, jnp.array(0))
logits = m.apply(embeddings, xb)
print(logits.shape)
loss = m.loss(embeddings, xb, yb)
print(loss)

prng_key, prng_subkey = random.split(prng_key)
print(decode(m.generate(prng_subkey, embeddings, idx=jnp.zeros((1, 1), dtype=int), max_new_tokens=100)[0].tolist()))


In [None]:
grad_loss = jax.grad(m.loss, argnums=0)

In [None]:
optimizer = optax.adamw(learning_rate=1e-3)

In [None]:
batch_size = 32

optimizer_state = optimizer.init(embeddings)

for steps in range(100): # increase number of steps for good results...
    prng_key, prng_subkey = random.split(prng_key)
    # sample a batch of data
    xb, yb = get_batch(prng_subkey, 'train')

    grad = grad_loss(embeddings, xb, yb)
    updates, optimizer_state = optimizer.update(grad, optimizer_state, embeddings)
    embeddings = optax.apply_updates(embeddings, updates)

loss = m.loss(embeddings, xb, yb)
print(loss)

In [None]:
prng_key, prng_subkey = random.split(prng_key)
print(decode(m.generate(prng_key, embeddings, idx = jnp.zeros((1, 1), dtype=int), max_new_tokens=100)[0].tolist()))

# The mathentical trick in self-attention

In [None]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
a = jnp.tril(jnp.ones((3, 3)))
a = a / jnp.sum(a, 1, keepdims=True)

prng_key, prng_subkey = random.split(prng_key)
b = random.randint(prng_subkey, (3, 2), 0, 10).astype(float)

c = a @ b

print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

In [None]:
# consider the following toy example:

prng_key, prng_subkey = random.split(prng_key)
B, T, C = 4, 8, 2 # batch, time, channels
x = random.normal(prng_subkey, (B, T, C))
x.shape

In [None]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = jnp.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1] # (t,C)
        xbow = xbow.at[b, t].set(jnp.mean(xprev, 0))

In [None]:
# version 2: using matrix multiply for a weighted aggregation
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / wei.sum(1, keepdims=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
jnp.allclose(xbow, xbow2)

In [None]:
# version 3: use Softmax
tril = jnp.tril(jnp.ones((T, T)))
wei = jnp.zeros((T, T))
wei = jnp.where(tril == 0, float('-inf'), wei)
wei = nn.softmax(wei, axis=-1)
xbow3 = wei @ x
jnp.allclose(xbow, xbow3)

In [None]:
# version 4: self-attention!
prng_key, prng_subkey = random.split(prng_key)
B, T, C = 4, 8, 32 # batch, time, channels
x = random.normal(prng_subkey, (B, T, C))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Dense(head_size, use_bias=False)
query = nn.Dense(head_size, use_bias=False)
value = nn.Dense(head_size, use_bias=False)
# Initialize
prng_key, prng_subkey = random.split(prng_key)
key_vars = key.init(prng_subkey, x)
query_vars = query.init(prng_subkey, x)
value_vars = value.init(prng_subkey, x)
k = key.apply(key_vars, x) # (B, T, 16)
q = query.apply(query_vars, x) # (B, T, 16)
wei = q @ jnp.transpose(k, (0, -1, -2)) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = jnp.tril(jnp.ones((T, T)))
#wei = jnp.zeros((T, T))
wei = jnp.where(tril==0, float('-inf'), wei)
wei = nn.softmax(wei, axis=-1)

v = value.apply(value_vars, x)
out = wei @ v
#out = wei @ x

out.shape

In [None]:
wei[0]

---

# Full finished code, for reference
You may want to refer directly to the git repo instead though.

In [None]:
from flax import linen as nn
from functools import partial
from jax import numpy as jnp
from jax import random
import optax
import wget


# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

prng_key = random.key(1337)

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filename = wget.download(url, '/tmp/tinyshakespeare.txt')
with open(filename, "r", encoding="utf-8") as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = jnp.array(encode(text))
n = int(0.9 * len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

def get_batch(prng_key, split):
    data = train_data if split == 'train' else val_data
    ix = random.randint(prng_key, (batch_size,), 0, len(data) - block_size)
    x = jnp.stack([data[i:i + block_size] for i in ix])
    y = jnp.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


def estimate_loss(prng_key, params):
    out = {}
    for split in ["train", "val"]:
        losses = jnp.zeros((eval_iters,))
        for k in range(eval_iters):
            prng_key, prng_subkey = random.split(prng_key)
            X, Y = get_batch(prng_subkey, split)
            loss = model.loss(params, X, Y, training=split=="train")
            losses = losses.at[k].set(loss)
        out[split] = losses.mean()
    return out


class Head(nn.Module):
    """ one head of self-attention """
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        B, T, C = x.shape
        k = nn.Dense(self.head_size, use_bias=False)(x)   # (B,T,C)
        q = nn.Dense(self.head_size, use_bias=False)(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ jnp.transpose(k, (0, -1, -2)) * C ** -0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = jnp.where(jnp.tril(jnp.ones((T, T))) == 0, float('-inf'), wei) # (B, T, T)
        wei = nn.softmax(wei, axis=-1) # (B, T, T)
        wei = nn.Dropout(rate=dropout, deterministic=not training)(wei)
        # perform the weighted aggregation of the values
        v = nn.Dense(self.head_size, use_bias=False)(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        out = jnp.concatenate([Head(self.head_size)(x, training=training) for _ in range(self.num_heads)], axis=-1)
        out = nn.Dense(n_embd)(out)
        out = nn.Dropout(rate=dropout, deterministic=not training)(out)
        return out

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    @nn.compact
    def __call__(self, x, training: bool):
        return nn.Sequential([
            nn.Dense(4 * n_embd),
            nn.relu,
            nn.Dense(n_embd),
            nn.Dropout(dropout, deterministic=not training),
        ])(x)


class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    n_embd: int
    n_head: int

    @nn.compact
    def __call__(self, x, training: bool):
        head_size = self.n_embd // self.n_head
        x = nn.LayerNorm()(x)
        x = x + MultiHeadAttention(self.n_head, head_size)(x, training=training)
        x = nn.LayerNorm()(x)
        x = x + FeedForward()(x, training=training)
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    @nn.compact
    def __call__(self, idx, training: bool):
        _, T = idx.shape
        # idx and labels are both (B, T) tensor of integers
        tok_emb = nn.Embed(vocab_size, n_embd)(idx)
        pos_emb = nn.Embed(block_size, n_embd)(jnp.arange(0, T))
        x = tok_emb + pos_emb
        x = nn.Sequential([partial(Block(n_embd, n_head), training=training) for _ in range(n_layer)])(x) # (B, T, C)
        x = nn.LayerNorm()(x) # (B, T, C)
        logits = nn.Dense(vocab_size)(x) # (B, T, vocab_size)
        return logits

    def loss(self, params, idx, labels, training: bool):
        logits = self.apply(params, idx, training=training) # (B, T, C)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

    def generate(self, prng_key, params, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits = self.apply(params, idx_cond, training=False)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            prng_key, prng_subkey = random.split(prng_key)
            # sample from the distribution
            idx_next = random.categorical(prng_subkey, logits) # (B, 1)
            # append sampled index to the running sequence
            idx = jnp.concatenate((idx, jnp.reshape(idx_next, (-1, 1))), axis=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
# print the number of parameters in the model
prng_key, prng_subkey = random.split(prng_key)
print(model.tabulate(
    prng_subkey,
    jnp.zeros((batch_size, block_size), dtype=int),
    training=True,
    compute_flops=True,
    compute_vjp_flops=True,
))
# print(jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

grad_loss = jax.grad(model.loss, argnums=0)

optimizer = optax.adamw(learning_rate=1e-3)
params = model.init(prng_subkey, xb, training=True)
optimizer_state = optimizer.init(params)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        prng_key, prng_subkey = random.split(prng_key)
        losses = estimate_loss(prng_subkey, params)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    prng_key, prng_subkey = random.split(prng_key)
    xb, yb = get_batch(prng_subkey, 'train')

    grad = grad_loss(params, xb, yb, training=True)
    updates, optimizer_state = optimizer.update(grad, optimizer_state, params)
    params = optax.apply_updates(params, updates)


# generate from the model
prng_key, prng_subkey = random.split(prng_key)
context = jnp.zeros((1, 1), dtype=int)
print(decode(model.generate(prng_subkey, params, context, max_new_tokens=100)[0].tolist()))

# TODO
1. Should JAX equivalent of `register_buffer` be used for `tril`?
1. Add running on device.

# Example: KV cache for single head

In [None]:
vocab_size = 4096
batch_size = 128
block_size = 2048
n_embd = 2048
n_head = 1

head_size = n_embd // n_head
head_size

In [None]:
class QKV(nn.Module):

    @nn.compact
    def __call__(self, x, training=True, kvcache=None):
        if not training:
            x = x[:, -1:, :]

        q = nn.Dense(head_size, use_bias=False)(x)
        k = nn.Dense(head_size, use_bias=False)(x)
        v = nn.Dense(head_size, use_bias=False)(x)

        if not training:
            if kvcache:
                k = jnp.concatenate((kvcache['k'], k), axis=1)[:, -block_size:, :]
                v = jnp.concatenate((kvcache['v'], v), axis=1)[:, -block_size:, :]
            kvcache['k'] = k
            kvcache['v'] = v

        return q, k, v


class Transformer(nn.Module):

    @nn.compact
    def __call__(self, context, training=True, kvcache=None):
        context_block = context[:, -block_size:]
        pos = jnp.arange(0, context_block.shape[1])

        tok_embed = nn.Embed(vocab_size, n_embd)(context_block)
        pos_embed = nn.Embed(block_size, n_embd)(pos)
        x = tok_embed + pos_embed

        if not training:
            if 'qkv' not in kvcache:
                kvcache['qkv'] = {}
            kvcache = kvcache['qkv']
        q, k, v = QKV()(x, training=training, kvcache=kvcache)

        B, T, C = x.shape
        wei = q @ jnp.transpose(k, (0, -1, -2)) * C ** -0.5
        wei = nn.softmax(wei, axis=-1)
        out = wei @ v

        logits = nn.Dense(vocab_size)(out)

        logits = logits[:, -1, :]
        return logits


    def generate(self, prng_key, params, context, max_new_tokens=1):
        kvcache = {}
        for _ in range(max_new_tokens):
            logits = self.apply(params, context, training=False, kvcache=kvcache)
            prng_key, prng_subkey = random.split(prng_key)
            context_next = random.categorical(prng_subkey, logits)
            context = jnp.concatenate((context, jnp.reshape(context_next, (-1, 1))), axis=1)
        return context

In [None]:

transformer = Transformer()
params = transformer.init(prng_key, context)

In [None]:
context = transformer.generate(random.key(0), params, jnp.zeros((1, 1), dtype=int), 10)
context