In [2]:
import math
import functools

import jax
import jax.numpy as jnp 
from jax import random
from jax.tree_util import tree_map

import flax
import flax.linen as nn

import numpy as np

from general.utils import ModelConfig, TrainConfig
from jax_impl.model import *
import optax

%reload_ext autoreload
%autoreload 2

In [3]:
key = random.PRNGKey(42)

model_config = ModelConfig(
    vocab_size=1024, 
    context_window=64, 
    model_type='gpt-nano')
model_config

FrozenInstanceError: cannot assign to field 'config_set'

In [None]:
rx, rp, rd, key = random.split(key, 4)
init_rngs = {'params': rp, 'dropout': rd}

x = random.normal(rx, (10, 4, 48))

attn = Attention(model_config)
params = attn.init(init_rngs, x)

out = attn.apply(params, x, rngs={'dropout': rd})
out.shape

In [None]:
blk = Block(model_config)
params = blk.init(init_rngs, x)

output = blk.apply(params, x, rngs={'dropout': rd})
output.shape

In [None]:
idx = random.randint(rx, (10, 4), 0, maxval=model_config.vocab_size, dtype=jnp.integer)

emb = Embedding(model_config)
params = emb.init(init_rngs, idx)

output = emb.apply(params, idx, rngs={'dropout': rd})
output.shape

In [None]:
gpt = GPT(model_config)
params = gpt.init(init_rngs, idx)

output = gpt.apply(params, idx, rngs={'dropout': rd})
output.shape

In [None]:
from general.dataset import SortDataset
from torch.utils.data import DataLoader

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

def cast(x):
    return np.array(x, dtype=int)

In [None]:
train_dataset = SortDataset('train', num_digits=6, transform=cast)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=numpy_collate)

In [None]:
batch = next(iter(train_loader))
x, y = batch
x.shape

In [None]:
x[1], y[1]

In [None]:
model_config = ModelConfig(
    vocab_size=6, 
    context_window=16, 
    model_type='gpt-nano')

In [None]:
gpt = GPT(model_config)
params = gpt.init(init_rngs, idx)

logits = gpt.apply(params, x, rngs={'dropout': rd})

In [None]:
logits.shape

In [None]:
y.shape

In [None]:
jnp.take_along_axis(logits, y[...,None], axis=-1).shape

In [None]:
jnp.take_along_axis(logits, y[...,None], axis=-1).squeeze().shape

In [None]:
y.shape

In [None]:
logits[y!= -1].shape

In [None]:
y[y!=-1].shape

In [None]:
def loss_fn(logits, labels, ignore_index=-1):
    filter = labels != ignore_index
    labels = labels[filter]
    logits = logits[filter]

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)

    return loss.mean()

In [None]:
loss_fn(logits, y)

In [None]:
from jax_impl.train import create_train_state

In [None]:
train_config = TrainConfig()

In [None]:
state = create_train_state(key, model_config, train_config)

In [None]:
state

In [None]:
from jax_impl.train import train_step

In [None]:
train_step(key, state, batch, model_config)