In [1]:
from flax import linen as nn
import jax
from jax import numpy as jnp
from jax import random
import optax


vocab_size = 64
batch_size = 128
block_size = 2048
n_embd = 2048


class Model(nn.Module):

    @nn.compact
    def __call__(self, x):
        e = nn.Embed(vocab_size, n_embd)(x)
        logits = nn.Dense(vocab_size)(e)
        return logits

    def loss(self, params, x, labels):
        logits = self.apply(params, x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    

class ShapePrinter:
  def __init__(self, v): self.v = v
  def __repr__(self): return jax.api_util.shaped_abstractify(self.v).str_short()


# Initialize

In [None]:
model = Model()
prng_key = random.key(0)
x = jnp.zeros((1, 1), dtype=int)
params = model.init(prng_key, x)

%timeit model.apply(params, x).block_until_ready()

In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    model.apply(params, x).block_until_ready()

In [None]:
print(jax.tree.map(ShapePrinter, params))

# Train

In [None]:
max_iters = 1

optimizer = optax.adamw(learning_rate=0)
optimizer_state = optimizer.init(params)

grad_loss = jax.grad(model.loss, argnums=0)

x = random.randint(prng_key, (batch_size, block_size), 0, vocab_size)
labels = random.randint(prng_key, (batch_size, block_size), 0, vocab_size)

for i in range(max_iters):
    grad = grad_loss(params, x, labels)
    updates, optimizer_state = optimizer.update(grad, optimizer_state, params)
    params = optax.apply_updates(params, updates)

%timeit model.apply(params, x).block_until_ready()

In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    model.apply(params, x).block_until_ready()

In [None]:
print(jax.tree.map(ShapePrinter, params))