In [1]:
import jax
import optax
import flax
from flax.training.train_state import TrainState

from functools import partial
from tqdm.auto import tqdm
from datasets import TextDataset
import models
import training

In [2]:
rng_key = jax.random.key(0)
context_len = 8
batch_size = 32

dataset = TextDataset(data_path="shakespeare.txt")
model = models.BigramLM(vocab_size=len(dataset.tokenizer.vocab))

In [3]:
optimization_step = jax.jit(partial(training.optimization_step, loss_fn=training.logit_prediction_loss))
get_batch = jax.jit(partial(dataset.get_batch, batch_size=batch_size, context_len=context_len))
generate_token = jax.jit(partial(model.apply, method=model.generate_token))
def generate_text(params, prompt: str, length=500, rng_key=jax.random.key(0)):
    context = dataset.tokenizer.encode(prompt)
    print("\033[94m", dataset.tokenizer.decode(context), "\033[0m", end="")
    for sub_rng in jax.random.split(rng_key, length):
        next_token, context = generate_token(params, context, sub_rng)
        print(dataset.tokenizer.decode(next_token[None]), end="")

losses = []
train_state = TrainState.create(
    apply_fn=model.apply,
    params=model.init(rng_key, dataset.sample(context_len, rng_key)),
    tx=optax.adam(3e-4),
)
for epoch_rng_key in tqdm(jax.random.split(rng_key, 10)):
    for batch_rng_key in tqdm(jax.random.split(epoch_rng_key, 10000), leave=False):
        x, y = get_batch(rng_key=batch_rng_key)
        train_state, loss_value = optimization_step(train_state, x, y)
        losses.append(loss_value)
print(f"Loss: {sum(losses) / len(losses)}\nGeneration test: ")
generate_text(train_state.params, prompt="To be or", rng_key=rng_key)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.507021188735962
Generation test: 
[94m To be or [0mou ave t
WAS:
masun,
CIO:
Pe hom thakice te wosthatth'de blanovecing by:
Motinins; nisureilt:
NDR:
Whe e


LO: ng t abld meltoforrear,
Welds twhes tine whe, w, in.
Wiselal oromes? r, hiak yo? stin:
CES:
Lur mamarth omyobrd. Bate;

SAngall hean,
Wh bs thertiaulonke-han heeread IXE:

Mad bayin th, ant.

Hin sthil;
Ses te de s fisorttrrmbr w, thitoro.

He myooustintens manth tom y wo' s ham kerispavilinon IRUSo.
Thed ELUShy clyove?
t pid llleeral in:
As?
S:
II ol veaveay! as tharule Ifordoused prmu