In [1]:
import math
import functools

import jax
import jax.numpy as jnp 
from jax import random
from jax.tree_util import tree_map, tree_flatten, tree_leaves, tree_unflatten

import flax
import flax.linen as nn

import numpy as np

from jax_impl.model import GPT
import optax

from jax_impl.config import Config
from jax_impl.train import create_train_state, train_step, create_weight_decay_mask

%reload_ext autoreload
%autoreload 2

In [2]:
from dataset import SortDataset
from torch.utils.data import DataLoader

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

def cast(x):
    return np.array(x, dtype=int)

In [3]:
cfg = Config()
key = random.PRNGKey(42)

train_dataset = SortDataset('train', length=cfg.sequence_len, num_digits=cfg.vocab_size, transform=cast)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=numpy_collate)

In [4]:
len(train_loader)

157

In [5]:
state = create_train_state(key, cfg)
j_train_step = jax.jit(
    train_step, 
    static_argnames='cfg', 
    donate_argnums=(0, )) # old state could be donated

# j_train_step = jax.jit(
#     functools.partial(train_step, cfg=cfg, dropout_rng=key), 
#     donate_argnums=(0, ))

for e in range(cfg.n_epoch):
    for i, batch in enumerate(train_loader):
        state, metrics = j_train_step(state, batch, cfg, key)
        # state, metrics = j_train_step(state, batch)

        if i % 100 == 0:
            print(f"epoch{e}-iter:{i:3d} | loss:{metrics['loss']:.3f} | acc:{metrics['acc']:.3f}")


epoch0-iter:  0 | loss:2.409 | acc:0.182
epoch0-iter:100 | loss:0.259 | acc:0.852
epoch1-iter:  0 | loss:0.141 | acc:0.911
epoch1-iter:100 | loss:0.065 | acc:0.979
epoch2-iter:  0 | loss:0.086 | acc:0.943
epoch2-iter:100 | loss:0.032 | acc:0.979
epoch3-iter:  0 | loss:0.028 | acc:0.984
epoch3-iter:100 | loss:0.021 | acc:0.995
epoch4-iter:  0 | loss:0.024 | acc:0.984
epoch4-iter:100 | loss:0.026 | acc:0.984
epoch5-iter:  0 | loss:0.023 | acc:0.992
epoch5-iter:100 | loss:0.011 | acc:0.995
epoch6-iter:  0 | loss:0.021 | acc:0.987
epoch6-iter:100 | loss:0.014 | acc:0.992
epoch7-iter:  0 | loss:0.015 | acc:0.990
epoch7-iter:100 | loss:0.016 | acc:0.990
epoch8-iter:  0 | loss:0.016 | acc:0.990
epoch8-iter:100 | loss:0.016 | acc:0.984
epoch9-iter:  0 | loss:0.011 | acc:0.992
epoch9-iter:100 | loss:0.012 | acc:0.995
epoch10-iter:  0 | loss:0.009 | acc:0.995
epoch10-iter:100 | loss:0.004 | acc:0.997
epoch11-iter:  0 | loss:0.005 | acc:0.997
epoch11-iter:100 | loss:0.011 | acc:0.990
