In [3]:
import jax
import jax.numpy as jnp 
from jax import random

import flax
import flax.linen as nn
from flax.training import checkpoints

import numpy as np

from jax_impl.config import Config
from jax_impl.train import prepare_data, create_train_state, train_epoch, eval

%reload_ext autoreload
%autoreload 2

In [4]:
cfg = Config()
print(cfg)

key = random.PRNGKey(42)

train_loader, eval_loader = prepare_data(cfg)

rng, key = random.split(key)
state = create_train_state(rng, cfg)

Config(context_window=128, n_embd=192, n_head=6, n_layer=6, dropout_prob=0.1, vocab_size=64, sequence_len=64, batch_size=64, learning_rate=0.0005)


In [5]:
trained_checkpoint = checkpoints.latest_checkpoint('./checkpoint/')

if trained_checkpoint:
    state = checkpoints.restore_checkpoint(trained_checkpoint, state)
else: 
    for e in range(cfg.n_epoch):
        rng, key = random.split(key)
        state = train_epoch(state, train_loader, cfg, rng, e)
        
        checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch, overwrite=True)

In [6]:
# more epochs

# n_more = 32
# for e in range(n_more):
#     rng, key = random.split(key)
#     state = train_epoch(state, train_loader, cfg, rng, e)
# checkpoints.save_checkpoint('./checkpoint', state, step=cfg.n_epoch + n_more)

In [7]:
print(f"train dataset accuracy = {eval(state, train_loader, cfg):.3f}")
print(f"eval dataset accuracy = {eval(state, eval_loader, cfg):.3f}")

train dataset accuracy = 0.996
eval dataset accuracy = 0.996


In [16]:
from jax_impl.infer import generate

rng, key = random.split(key)
x = random.randint(rng, (1, cfg.sequence_len), 0, maxval=cfg.vocab_size)

output = generate(cfg, state.params, x, cfg.sequence_len)[:, -cfg.sequence_len:]

print(f"input = {x}")
print(f"generated sequence = {output}")

input = [[59 61 58 32 53 24 11 43 27 33 19  0 23  6 32 27 15 31 19  7 14 53 50  9
  31  1 45 16 12 39 27 18 22 38 59 29  7 38 31  7 22 43 34 37 20  2 57 18
  60 45 17 25 41 12  1 33 62 50 22 44 41 14 20 14]]
generated sequence = [[ 0  1  1  2  6  7  7  7  9 11 12 12 14 14 14 15 16 17 18 18 19 19 20 20
  22 22 22 23 24 25 27 27 27 29 31 31 31 32 32 33 33 34 37 38 38 39 41 41
  43 43 44 45 45 50 50 53 53 57 58 59 59 60 61 62]]


In [18]:
assert (jnp.sort(x) == output).all()