In [1]:
import math
import functools

import jax
import jax.numpy as jnp 
from jax import random
from jax.tree_util import tree_map

import flax
import flax.linen as nn

import numpy as np

from jax_impl.model import GPT
import optax

from jax_impl.config import Config

%reload_ext autoreload
%autoreload 2

In [2]:
key = random.PRNGKey(42)

In [3]:
rp, rd, key = random.split(key, 3)
init_rngs = {'params': rp, 'dropout': rd}

In [4]:
from dataset import SortDataset
from torch.utils.data import DataLoader

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

def cast(x):
    return np.array(x, dtype=int)

In [6]:
cfg = Config()
cfg

Config(context_window=64, n_embd=48, n_head=3, n_layer=3, dropout_prob=0.1, vocab_size=8, sequence_len=6, batch_size=32, learning_rate=0.0005, max_iter=2000)

In [7]:
train_dataset = SortDataset('train', length=cfg.sequence_len, num_digits=cfg.vocab_size, transform=cast)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=numpy_collate)

In [8]:
batch = next(iter(train_loader))
x, y = batch
x.shape

(32, 11)

In [9]:
x[1], y[1]

(array([4, 2, 1, 0, 5, 0, 0, 0, 1, 2, 4]),
 array([-1, -1, -1, -1, -1,  0,  0,  1,  2,  4,  5]))

In [10]:
gpt = GPT(cfg)
params = gpt.init(init_rngs, x)

logits = gpt.apply(params, x, rngs={'dropout': rd})

In [11]:
logits.shape

(32, 11, 8)

In [12]:
y.shape

(32, 11)

In [15]:
from jax_impl.train import create_train_state

In [16]:
state = create_train_state(key, cfg)

In [17]:
state

TrainState(step=0, apply_fn=<bound method Module.apply of GPT(
    # attributes
    cfg = Config(context_window=64, n_embd=48, n_head=3, n_layer=3, dropout_prob=0.1, vocab_size=8, sequence_len=6, batch_size=32, learning_rate=0.0005, max_iter=2000)
)>, params=FrozenDict({
    params: {
        Embedding_0: {
            Embed_0: {
                embedding: DeviceArray([[ 0.01243785, -0.02221239,  0.01881308, ..., -0.05130462,
                              -0.01304357, -0.00700958],
                             [ 0.0071265 ,  0.02709752, -0.03665798, ...,  0.03381928,
                              -0.0296815 ,  0.00049163],
                             [-0.00014662, -0.01876912,  0.00307894, ..., -0.02393069,
                              -0.00483473, -0.01081953],
                             ...,
                             [-0.01256307, -0.02156094, -0.00831418, ..., -0.00836381,
                               0.01899023,  0.04095865],
                             [-0.00976414,  0.0

In [23]:
from jax_impl.train import train_step

In [24]:
train_step(key, state, batch, cfg)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.