In [1]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [2]:
%ls

[0m[01;34msample_data[0m/


In [3]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

In [14]:
%ls

[0m[01;34massets[0m/  dataset.py  model.py                   nanoGPT_singe_file.ipynb  README.md
[01;34mdata[0m/    LICENSE     nanoGPT_JAX_JAX_JAX.ipynb  nanoGPT_train.ipynb       trainer.py


In [16]:
import importlib

import dataset
import model

importlib.reload(dataset)
importlib.reload(model)

<module 'model' from '/content/nanoGPT-JAX-JAX-JAX/model.py'>

In [None]:
from typing import Tuple

import chex
from chex._src import fake
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from dataclasses import dataclass
import importlib
import pdb

# import dataset
# import model

class TrainState(train_state.TrainState):
    key: jax.random.KeyArray

@dataclass
class Config:
    BATCH_SIZE: int = 256
    BLOCK_SIZE: int = 64
    T: int = 64
    n_embed: int = 256
    num_heads: int = 8
    num_layers: int = 6

config = Config()

random_key = jax.random.PRNGKey(99)

# Initialize model
lm_model = model.LanguageModel(vocab_size=65,
                      n_embed=config.n_embed,
                      T=config.BLOCK_SIZE,
                      num_heads=config.num_heads,
                      num_layers=config.num_layers)
sample_block_of_tokens = jnp.ones(shape=(config.T,), dtype=jnp.int32)
output, params = lm_model.init_with_output(jax.random.PRNGKey(99), sample_block_of_tokens, training=False)
params = params["params"]

def model_apply(params, inputs, training, dropout_key):
    return lm_model.apply({"params": params}, inputs, training, rngs={'dropout': dropout_key})

# Vectorize model apply function
model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0, None, None), out_axes=(0))

PER_HOST_BATCH_SIZE = config.BATCH_SIZE // jax.device_count()

# Define forward pass
def forward_pass(params, state, batch, dropout_key):
    inputs, targets = batch
    logits = state.apply_fn(params, inputs, True, dropout_key)

    chex.assert_shape(inputs, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))
    chex.assert_shape(targets, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    loss = loss.mean()
    return loss

# Define training step
def train_step(state, inputs, targets, dropout_key):
    dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)

    batch = inputs, targets

    grad_fn = jax.value_and_grad(forward_pass, argnums=(0))
    loss, grads = grad_fn(state.params, state, batch, dropout_key)

    loss = jax.lax.pmean(loss, axis_name="devices")
    grads = jax.lax.pmean(grads, axis_name="devices")

    state = state.apply_gradients(grads=grads)
    return state, loss

# Initialize optimizer and training state
opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
data = dataset.Dataset(batch_size=config.BATCH_SIZE, block_size=config.BLOCK_SIZE)

# pmap the train_step.
train_step_pmap = jax.jit(jax.pmap(train_step, in_axes=(0, 0, 0, None), out_axes=(0), axis_name="devices"))
states = jax.device_put_replicated(state, jax.local_devices())



In [None]:
# Function to run a training step
# This is an **IMPURE function** for convenience. Don't JIT it.
reload_libs = False
if reload_libs:
  importlib.reload(dataset)
  importlib.reload(model)


# fake_pmap = chex.fake_pmap_and_jit(enable_jit_patching=fake_jit, enable_pmap_patching=fake_pmap)
# fake_pmap.start()
num_epochs = 1 # 20
steps_per_epoch = 1 # len(data.train_data) // config.BATCH_SIZE
for epoch in range(num_epochs):
  print("epoch: ", epoch)
  data.create_train_dataset()

  for step in range(steps_per_epoch):
    random_key, random_subkey = jax.random.split(random_key)

    inputs, targets = data.get_batch()

    # create device dimension for minibatch
    inputs = inputs.reshape((jax.device_count(), -1, inputs.shape[-1]))
    targets = targets.reshape((jax.device_count(), -1, targets.shape[-1]))

    states, loss = train_step_pmap(states, inputs, targets, random_subkey)
    print("loss", loss[0], "epoch", epoch) if epoch % 510 == 0 else None

# fake_pmap.stop()

epoch:  0
loss 3.9624007 epoch 0
loss 3.88149 epoch 0
loss 3.8095543 epoch 0
loss 3.7548494 epoch 0
loss 3.7914453 epoch 0
loss 3.7002654 epoch 0
loss 3.6340194 epoch 0
loss 3.6558213 epoch 0
loss 3.6011715 epoch 0
loss 3.61494 epoch 0
loss 3.6759403 epoch 0
loss 3.595214 epoch 0
loss 3.6045206 epoch 0
loss 3.6529965 epoch 0


In [None]:
T = config.BLOCK_SIZE

state = jax.tree_map(lambda x: x[0], states)

state_apply_jit = jax.jit(state.apply_fn)

context = jnp.tile(jnp.array([52], dtype=jnp.int32), T)
context = context[None, -T:]
key = jrand.PRNGKey(99)

for _ in range(100):
  next_token_logits = state_apply_jit({"params": state.params}, context[:, -T:])

  key, split_key = jrand.split(key)
  new_token = jax.random.categorical(key, next_token_logits[:, -1, :], axis=-1, shape=(1, 1))

  context = jnp.concatenate([context, new_token], axis=1)


print(context.tolist()[0])