In [1]:
import jax
import jax.numpy as jnp 
from jax import random

import flax
import flax.linen as nn
from flax.training import checkpoints

import numpy as np

from jax_impl.config import Config
from jax_impl.train import prepare_data, create_train_state, train_epoch, eval

%reload_ext autoreload
%autoreload 2

In [2]:
cfg = Config()
print(cfg)

key = random.PRNGKey(42)

train_loader, eval_loader = prepare_data(cfg)

rng, key = random.split(key)
state = create_train_state(rng, cfg)

Config(context_window=128, n_embd=192, n_head=6, n_layer=6, dropout_prob=0.1, vocab_size=64, sequence_len=64, batch_size=64, learning_rate=0.0005)


In [3]:
trained_checkpoint = checkpoints.latest_checkpoint('./checkpoint/')

if trained_checkpoint:
    state = checkpoints.restore_checkpoint(trained_checkpoint, state)
else: 
    for e in range(cfg.n_epoch):
        rng, key = random.split(key)
        state = train_epoch(state, train_loader, cfg, rng, e)
        
        checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch, overwrite=True)

In [4]:
# more epochs

# n_more = 32
# for e in range(n_more):
#     rng, key = random.split(key)
#     state = train_epoch(state, train_loader, cfg, rng, e)
# checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch + n_more)

In [5]:
print(f"train dataset accuracy = {eval(state, train_loader, cfg):.3f}")
print(f"eval dataset accuracy = {eval(state, eval_loader, cfg):.3f}")

train dataset accuracy = 0.996
eval dataset accuracy = 0.996


In [6]:
from jax_impl.infer import generate

rng, key = random.split(key)
x = random.randint(rng, (1, cfg.sequence_len), 0, maxval=cfg.vocab_size)

output = generate(cfg, state.params, x, cfg.sequence_len)[:, -cfg.sequence_len:]

print(f"input = {x}")
print(f"generated sequence = {output}")

input = [[30 52 16 28  7 50 61 25 34 45 47 53 55 38 52 19  9  1  1 46 10  9 57  9
  16 59 56 57 26 60 61 24  3 46 55 51 25 16 29 48 17  5  9 26 56 23 55  0
  24 42  1 36 36 37 51 24 38 46  1 14 33  0 56 20]]
generated sequence = [[ 0  0  1  1  1  1  3  5  7  9  9  9  9 10 14 16 16 16 17 19 20 23 24 24
  24 25 25 26 26 26 28 29 30 33 34 36 36 37 38 38 42 45 46 46 46 47 48 50
  51 51 52 52 53 55 55 55 56 56 56 57 59 60 61 61]]


In [7]:
assert (jnp.sort(x) == output).all()

AssertionError: 