In [1]:
import math
import functools

import jax
import jax.numpy as jnp 
from jax import random
from jax.tree_util import tree_map

import flax
import flax.linen as nn

import numpy as np

from jax_impl.model import GPT
import optax

from jax_impl.config import Config
from jax_impl.train import create_train_state, train_step

%reload_ext autoreload
%autoreload 2

In [2]:
from dataset import SortDataset
from torch.utils.data import DataLoader

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

def cast(x):
    return np.array(x, dtype=int)

In [3]:
cfg = Config()
key = random.PRNGKey(42)

train_dataset = SortDataset('train', length=cfg.sequence_len, num_digits=cfg.vocab_size, transform=cast)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=numpy_collate)

In [4]:
len(train_loader)

157

In [6]:
state = create_train_state(key, cfg)
j_train_step = jax.jit(train_step, static_argnames='cfg', donate_argnums=(0, )) # old state could be donated

for e in range(cfg.n_epoch):
    for i, batch in enumerate(train_loader):
        state, metrics = j_train_step(state, batch, cfg, key)

        if i % 15 == 0:
            print(f"epoch{e}-iter:{i:3d} | loss:{metrics['loss']:.3f} | acc:{metrics['acc']:.3f}")


epoch0-iter:  0 | loss:2.558 | acc:0.128
epoch0-iter: 15 | loss:0.682 | acc:0.503
epoch0-iter: 30 | loss:0.607 | acc:0.581
epoch0-iter: 45 | loss:0.481 | acc:0.688
epoch0-iter: 60 | loss:0.455 | acc:0.703
epoch0-iter: 75 | loss:0.375 | acc:0.760
epoch0-iter: 90 | loss:0.331 | acc:0.789
epoch0-iter:105 | loss:0.280 | acc:0.839
epoch0-iter:120 | loss:0.208 | acc:0.891
epoch0-iter:135 | loss:0.186 | acc:0.891
epoch0-iter:150 | loss:0.154 | acc:0.938
epoch1-iter:  0 | loss:0.132 | acc:0.922
epoch1-iter: 15 | loss:0.110 | acc:0.953
epoch1-iter: 30 | loss:0.098 | acc:0.943
epoch1-iter: 45 | loss:0.101 | acc:0.953
epoch1-iter: 60 | loss:0.104 | acc:0.940
epoch1-iter: 75 | loss:0.131 | acc:0.924
epoch1-iter: 90 | loss:0.078 | acc:0.961
epoch1-iter:105 | loss:0.095 | acc:0.966
epoch1-iter:120 | loss:0.052 | acc:0.974
epoch1-iter:135 | loss:0.064 | acc:0.969
epoch1-iter:150 | loss:0.051 | acc:0.969
epoch2-iter:  0 | loss:0.058 | acc:0.971
epoch2-iter: 15 | loss:0.063 | acc:0.964
epoch2-iter: 30 

In [7]:
state.params

FrozenDict({
    params: {
        Block_0: {
            Attention_0: {
                Dense_0: {
                    bias: DeviceArray([-1.94505267e-02,  3.70444320e-02,  7.44959489e-02,
                                 -2.32661236e-02, -1.09914191e-01, -1.61355454e-02,
                                 -8.61940607e-02,  1.20265014e-01, -5.63694350e-02,
                                  1.53210863e-01, -1.11941472e-02, -1.84274465e-02,
                                 -2.30767787e-03,  5.95009886e-02,  4.72199991e-02,
                                  5.01606353e-02,  3.22466083e-02,  4.72029746e-02,
                                  3.63318063e-02,  1.28932735e-02, -1.31842587e-02,
                                  3.43464576e-02,  1.14941411e-01, -1.39205735e-02,
                                  1.79031268e-02,  8.01579878e-02,  3.10795158e-02,
                                  2.25903429e-02,  3.64684612e-02,  7.50186518e-02,
                                 -3.82116474e-02,  4.8