Download shakespeare text [here](https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt) and place in root-level data folder

In [114]:
import jax
import flax.nnx as nnx
from jax import random
from genlearn.data_utils.shakespeare import get_shakespeare_dataset, decode, encode
import grain
import jax.numpy as jnp

In [None]:
X = get_shakespeare_dataset()
print(X.shape)

(32777, 128)


### DataLoader Setup

In [31]:
# Grain doesn't natively support an in-memory jax array like we have here - quick utility class
from typing import SupportsIndex
class ArraySource(grain.sources.RandomAccessDataSource):
    def __init__(self, array: jax.Array):
        self._array = array

    def __len__(self) -> int:
        return self._array.shape[0]

    def __getitem__(self, i: SupportsIndex) -> jax.Array:
        return self._array[i]

In [101]:
n_epochs = 10
data_seed = 0
batch_size = 32
n_vocab = 95 # X.max() - X.min()

sampler = grain.samplers.IndexSampler(
    num_records=X.shape[0], num_epochs=n_epochs, shuffle=True, seed=data_seed
)
data_loader = grain.DataLoader(
    sampler=sampler,
    operations=[grain.transforms.Batch(batch_size=batch_size)],
    data_source=ArraySource(X),
)

In [None]:
# TODO: add causal masking
# TODO: add skip connections at the right places this time :) 
# TODO: what is layernorm btw

In [189]:
# now define the model itself. First, the embedding
class Embed(nnx.Module): 

    def __init__(self, n_vocab: int, rngs: nnx.Rngs, embed_dim: int=2):
        self.embed_lookup = nnx.Param(rngs.uniform(shape=(n_vocab, embed_dim)))
        self.n_vocab = n_vocab
    
    @nnx.jit
    def __call__(self, x: jax.Array) -> jax.Array:
        # embed
        z = nnx.one_hot(x, self.n_vocab) @ self.embed_lookup
        return z
        # add positional encoding
        # z = z + posenc shit

    def _add_posenc(self, x: jax.Array) -> jax.Array: 
        r_te = jax.lax.cond()


x_bt = nnx.Rngs(0).choice(jnp.arange(95), (32, 128))
emb = Embed(n_vocab=95, embed_dim=2, rngs=nnx.Rngs(0))
x_bte =emb(x_bt) 
assert x_bte.shape == (32, 128, 2)

In [190]:
L = 30
D = x_bte.shape[2] # embedding dim

@jax.jit
@jax.vmap  # over batch dimension
@jax.vmap  # over sequence position
@jax.vmap  # over embedding dimension
def add_posenc(x: jax.Array, n: jax.Array, i: jax.Array) -> jax.Array:
    'Scalar implementation of positional encoding'
    return jax.lax.cond(
        i % 2 == 0,  # if i is even
        lambda x: x + jnp.sin(n / (L ** (i / D))),
        lambda x: x + jnp.cos(n / (L ** ((i - 1) / D))),
        x,
    )


n_grid = jax.lax.broadcasted_iota(jnp.float32, shape=x_bte.shape, dimension=1)
i_grid = jax.lax.broadcasted_iota(jnp.float32, shape=x_bte.shape, dimension=2)

posenc = add_posenc(x_bte, n_grid, i_grid) - x_bte

In [156]:
# Now a single attention layer
class AttentionLayer(nnx.Module):
    "Implements scaled dot product attention with a nifty einsum calc"

    def __init__(
        self,
        rngs: nnx.Rngs,
        n_heads: int,
        d_head: int,
        d_embed: int,
        dtype: jax.typing.DTypeLike = jnp.float32,
    ):
        self.n_heads = n_heads
        self.d_head = d_head
        self.d_embed = d_embed

        def init_attn_weight(*shape: int):
            return nnx.Param(nnx.initializers.kaiming_normal()(rngs(), shape=shape, dtype=dtype))

        # fmt: off
        self.WQ_hed = init_attn_weight(n_heads, d_embed, d_head)
        self.WK_hed = init_attn_weight(n_heads, d_embed, d_head)
        self.WH_hee = init_attn_weight(n_heads, d_embed, d_embed)

        self.layernorm = nnx.LayerNorm(self.d_embed, rngs=rngs)
        self.out = nnx.Linear(self.d_embed, self.d_embed, rngs=rngs)
        self.residual_strength = nnx.Param(jnp.array(0, dtype=dtype))
        # fmt: on

    def __call__(self, X_bte: jax.Array) -> jax.Array:
        Q_hbtd = jnp.einsum("bti, hid -> hbtd", X_bte, self.WQ_hed)
        K_hbtd = jnp.einsum("bti, hid -> hbtd", X_bte, self.WK_hed)
        A_hbtb = jnp.einsum("hitj, hktj -> hitk", Q_hbtd, K_hbtd)
        # softmax over the D dimension + division by sqrt(D)
        A_hbtb = nnx.softmax(A_hbtb / jnp.sqrt(self.d_head), axis=-1)
        # multi-head attention incl. a sum over heads (from Bishop 2025)
        Y_bte = jnp.einsum("hbti, itj, hje -> bte", A_hbtb, X_bte, self.WH_hee)
        Y_bte = self.layernorm(Y_bte)
        Y_bte = nnx.relu(self.out(Y_bte))
        # add skip connection
        return Y_bte + self.residual_strength * X_bte 


attn = AttentionLayer(nnx.Rngs(0), n_heads=4, d_head=2, d_embed=2)
assert attn(x_bte).shape == x_bte.shape  # so we can stack and add residual connections
nnx.display(attn)

In [None]:
class Transformer(nnx.Module):

    def __init__(self):
        