In [7]:
import jax
import jax.numpy as jnp 
from jax import random

import flax
import flax.linen as nn
from flax.training import checkpoints

import numpy as np

from jax_impl.config import Config
from jax_impl.train import prepare_data, create_train_state, train_epoch, eval

# %reload_ext autoreload
# %autoreload 2

In [2]:
cfg = Config()
print(cfg)

key = random.PRNGKey(42)

train_loader, eval_loader = prepare_data(cfg)

rng, key = random.split(key)
state = create_train_state(rng, cfg)

Config(context_window=128, n_embd=192, n_head=6, n_layer=6, dropout_prob=0.1, vocab_size=64, sequence_len=64, batch_size=64, learning_rate=0.0005)


In [3]:
for e in range(cfg.n_epoch):
    rng, key = random.split(key)
    state = train_epoch(state, train_loader, cfg, rng, e)

epoch  0 | avg_loss=1.0432 | avg_acc=0.3403
epoch  1 | avg_loss=0.5176 | avg_acc=0.5541
epoch  2 | avg_loss=0.2621 | avg_acc=0.7846
epoch  3 | avg_loss=0.1785 | avg_acc=0.8577
epoch  4 | avg_loss=0.1321 | avg_acc=0.8992
epoch  5 | avg_loss=0.1086 | avg_acc=0.9188
epoch  6 | avg_loss=0.0936 | avg_acc=0.9301
epoch  7 | avg_loss=0.0828 | avg_acc=0.9392
epoch  8 | avg_loss=0.0751 | avg_acc=0.9448
epoch  9 | avg_loss=0.0682 | avg_acc=0.9504
epoch 10 | avg_loss=0.0661 | avg_acc=0.9520
epoch 11 | avg_loss=0.0597 | avg_acc=0.9570
epoch 12 | avg_loss=0.0557 | avg_acc=0.9596
epoch 13 | avg_loss=0.0540 | avg_acc=0.9614
epoch 14 | avg_loss=0.0520 | avg_acc=0.9628
epoch 15 | avg_loss=0.0490 | avg_acc=0.9649
epoch 16 | avg_loss=0.0477 | avg_acc=0.9657
epoch 17 | avg_loss=0.0440 | avg_acc=0.9684
epoch 18 | avg_loss=0.0418 | avg_acc=0.9701
epoch 19 | avg_loss=0.0420 | avg_acc=0.9700
epoch 20 | avg_loss=0.0402 | avg_acc=0.9714
epoch 21 | avg_loss=0.0389 | avg_acc=0.9724
epoch 22 | avg_loss=0.0388 | avg

In [12]:
checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch)

'checkpoint/checkpoint_96'

In [None]:
# more epochs

# n_more = 32
# for e in range(n_more):
#     rng, key = random.split(key)
#     state = train_epoch(state, train_loader, cfg, rng, e)
# checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch + n_more)

In [4]:
print(f"train dataset accuracy = {eval(state, train_loader, cfg):.3f}")
print(f"eval dataset accuracy = {eval(state, eval_loader, cfg):.3f}")

train dataset accuracy = 0.992
eval dataset accuracy = 0.991


In [5]:
from jax_impl.infer import generate

rng, key = random.split(key)
x = random.randint(rng, (1, cfg.sequence_len), 0, maxval=cfg.vocab_size)

output = generate(cfg, state.params, x, cfg.sequence_len)[:, -cfg.sequence_len:]

print(f"input = {x}")
print(f"generated sequence = {output}")

input = [[ 1  2 57  6 37 60 43 26 12 61 13 31  5  9 35 14 63 23 41 62 10 33 16 52
  50 57 37 50 31  4 40 49  1 12 13 52 60 30 26  9 61  6 40 34 14 24 20 50
   0 59  2 39 29 11 43 12 22 48 49 22  4 40 61 47]]
generated sequence = [[ 0  1  1  2  2  4  4  5  6  6  9  9 10 11 12 12 12 13 13 14 14 16 20 22
  22 23 24 26 26 29 30 31 31 33 34 35 37 37 39 40 40 40 41 43 43 47 48 49
  49 50 50 50 52 52 57 57 59 60 60 61 61 61 62 63]]


In [6]:
assert (jnp.sort(x) == output).all()