In [3]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [4]:
%ls

[0m[01;34mnanoGPT-JAX-JAX-JAX[0m/  [01;34msample_data[0m/


In [8]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[CpuDevice(id=0)]

In [9]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

In [9]:
%ls

[0m[01;34massets[0m/  dataset.py  LICENSE   nanoGPT_JAX_JAX_JAX.ipynb  [01;34m__pycache__[0m/  trainer.py
[01;34mdata[0m/    input.txt   model.py  nanoGPT_singe_file.ipynb   README.md


In [10]:
import importlib

import dataset
import model

importlib.reload(dataset)
importlib.reload(model)

<module 'model' from '/content/nanoGPT-JAX-JAX-JAX/model.py'>

In [11]:
from typing import Tuple

import chex
from chex._src import fake
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from dataclasses import dataclass
import importlib
import pdb

# import dataset
# import model

class TrainState(train_state.TrainState):
    key: jax.random.KeyArray

@dataclass
class Config:
    BATCH_SIZE: int = 256
    BLOCK_SIZE: int = 64
    T: int = 64
    n_embed: int = 256
    num_heads: int = 8
    num_layers: int = 6

config = Config()

random_key = jax.random.PRNGKey(99)

# Initialize model
lm_model = model.LanguageModel(vocab_size=65,
                      n_embed=config.n_embed,
                      T=config.BLOCK_SIZE,
                      num_heads=config.num_heads,
                      num_layers=config.num_layers)
sample_block_of_tokens = jnp.ones(shape=(config.T,), dtype=jnp.int32)
output, params = lm_model.init_with_output(jax.random.PRNGKey(99), sample_block_of_tokens, training=False)
params = params["params"]

def model_apply(params, inputs, training, dropout_key):
    return lm_model.apply({"params": params}, inputs, training, rngs={'dropout': dropout_key})

# Vectorize model apply function
model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0, None, None), out_axes=(0))

PER_HOST_BATCH_SIZE = config.BATCH_SIZE // jax.device_count()

# Define forward pass
def forward_pass(params, state, batch, dropout_key):
    inputs, targets = batch
    logits = state.apply_fn(params, inputs, True, dropout_key)

    chex.assert_shape(inputs, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))
    chex.assert_shape(targets, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    loss = loss.mean()
    return loss

# Define training step
def train_step(state, inputs, targets, dropout_key):
    dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)

    batch = inputs, targets

    grad_fn = jax.value_and_grad(forward_pass, argnums=(0))
    loss, grads = grad_fn(state.params, state, batch, dropout_key)

    loss = jax.lax.pmean(loss, axis_name="devices")
    grads = jax.lax.pmean(grads, axis_name="devices")

    state = state.apply_gradients(grads=grads)
    return state, loss

# Initialize optimizer and training state
opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
data = dataset.Dataset(batch_size=config.BATCH_SIZE, block_size=config.BLOCK_SIZE)

# pmap the train_step.
train_step_pmap = jax.jit(jax.pmap(train_step, in_axes=(0, 0, 0, None), out_axes=(0), axis_name="devices"))
states = jax.device_put_replicated(state, jax.local_devices())



  data = jnp.array(_encode(text, self.stoi), dtype=jnp.int64)


In [13]:
# Function to run a training step
# This is an **IMPURE function** for convenience. Don't JIT it.
reload_libs = False
if reload_libs:
  importlib.reload(dataset)
  importlib.reload(model)


# fake_pmap = chex.fake_pmap_and_jit(enable_jit_patching=fake_jit, enable_pmap_patching=fake_pmap)
# fake_pmap.start()
num_epochs = 20
steps_per_epoch = 1 # len(data.train_data) // config.BATCH_SIZE
for epoch in range(num_epochs):
  print("epoch: ", epoch)
  data.create_train_dataset()

  for step in range(steps_per_epoch):
    random_key, random_subkey = jax.random.split(random_key)

    inputs, targets = data.get_batch()
    # pdb.set_trace()

    # create device dimension for minibatch
    inputs = inputs.reshape((jax.device_count(), -1, inputs.shape[-1]))
    targets = targets.reshape((jax.device_count(), -1, targets.shape[-1]))

    # pdb.set_trace()

    states, loss = train_step_pmap(states, inputs, targets, random_subkey)
    print("loss", loss[0], "epoch", epoch) if epoch % 1 == 0 else None

# fake_pmap.stop()

epoch:  0
loss 4.233783 epoch 0


## pmapping

## Verify using flax multihead attention

In [None]:
def compare_attention_outputs(custom_attention, flax_attention, input_shape, num_heads, head_size, rng_key):
    # Create dummy input
    x = jax.random.normal(rng_key, input_shape)

    # Initialize custom attention
    custom_params = custom_attention.init(rng_key, x, training=True)
    custom_output = custom_attention.apply(custom_params, x, training=True, rngs={'dropout': rng_key})

    # Initialize Flax attention
    flax_params = flax_attention.init(rng_key, x, x, x)
    flax_output = flax_attention.apply(flax_params, x, x, x)

    print("custom_output: ", custom_output)
    print("flax_output: ", flax_output)

    # Compare outputs
    return jnp.isclose(custom_output, flax_output, atol=1e-5).all()

In [None]:
rng_key = jax.random.PRNGKey(0)
input_shape = (1, 2, 4)  # (batch_size, sequence_length, feature_size)
num_heads = 4
head_size = 16

In [None]:
# Custom attention
custom_attention = model.MultiHeadAttentionBatch(num_heads=num_heads, head_size=head_size, T=input_shape[1])

# Flax attention
flax_attention = nn.MultiHeadDotProductAttention(num_heads=num_heads, qkv_features=head_size * num_heads, out_features=head_size * num_heads)


In [None]:
result = compare_attention_outputs(custom_attention, flax_attention, input_shape, num_heads, head_size, rng_key)
print("Are the attention outputs close?", result)

custom_output:  [[[ 1.8267553   0.2545625   0.51664734  1.872377    0.
   -0.24741195  0.06325258 -0.4732108   0.97036505  0.
    1.2170275   1.5578686   1.0638489   0.          2.772833
   -0.41109625  0.444793   -0.08247733  0.         -0.18067323
    0.          0.          0.6723873  -0.93943655 -0.3522747
    1.2153784  -3.7089698   1.3073872  -0.6657839  -0.5994085
   -0.33070773 -1.8484493   0.37312767  0.44226554  0.60474485
    2.2404766   0.         -1.8605132  -2.4844682  -0.56995404
   -0.1442299   1.2074916  -0.11788648  2.850931    0.33974466
    2.3744946  -2.746928    0.685969   -0.92724115 -1.0124649
    0.          0.          1.3646483   0.4259958   1.1758763
   -0.8295348   0.3146336   0.38039386 -1.96878    -1.0014266
    0.88716567  1.783647    0.57467306  0.        ]
  [ 0.00777232  0.8190602   2.6580398   1.651423   -0.9469865
    0.48011455 -0.9287533   0.          0.         -2.8874931
    0.60840005  2.0658875   0.35624415  0.          0.
    0.70599437  0.58