In [10]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [9]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [11]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

In [13]:
%ls

[0m[01;34massets[0m/  dataset.py  model.py                   nanoGPT_singe_file.ipynb  trainer.py
[01;34mdata[0m/    LICENSE     nanoGPT_JAX_JAX_JAX.ipynb  README.md


In [14]:
import importlib

import dataset
import model

importlib.reload(dataset)
importlib.reload(model)

<module 'model' from '/content/nanoGPT-JAX-JAX-JAX/model.py'>

In [34]:
import trainer
importlib.reload(trainer)

<module 'trainer' from '/content/nanoGPT-JAX-JAX-JAX/trainer.py'>

In [35]:
trainer.run_train_step()

epoch:  0




loss 4.8114243 epoch 0
loss 4.7046213 epoch 0
loss 4.775199 epoch 0
loss 4.7114058 epoch 0
loss 4.796525 epoch 0
loss 4.748788 epoch 0
loss 4.657966 epoch 0
loss 4.6633005 epoch 0
loss 4.558431 epoch 0
loss 4.59794 epoch 0
loss 4.586134 epoch 0
loss 4.82765 epoch 0
loss 4.6251516 epoch 0
loss 4.699952 epoch 0
loss 4.5693865 epoch 0
loss 4.636324 epoch 0
loss 4.5907664 epoch 0
loss 4.505391 epoch 0
loss 4.536727 epoch 0
loss 4.671359 epoch 0
loss 4.4616547 epoch 0
loss 4.5774136 epoch 0
loss 4.526718 epoch 0
loss 4.4867578 epoch 0
loss 4.617264 epoch 0
loss 4.442016 epoch 0
loss 4.401455 epoch 0
loss 4.75319 epoch 0
loss 4.6726503 epoch 0
loss 4.459524 epoch 0
loss 4.375473 epoch 0
loss 4.552512 epoch 0
loss 4.3940134 epoch 0
loss 4.3297033 epoch 0
loss 4.4734635 epoch 0
loss 4.402542 epoch 0
loss 4.457515 epoch 0
loss 4.364932 epoch 0
loss 4.246636 epoch 0
loss 4.3672915 epoch 0
loss 4.479223 epoch 0
loss 4.3274307 epoch 0
loss 4.3152285 epoch 0
loss 4.370915 epoch 0
loss 4.1599874 epo

KeyboardInterrupt: 

## pmapping

## Verify using flax multihead attention

In [None]:
def compare_attention_outputs(custom_attention, flax_attention, input_shape, num_heads, head_size, rng_key):
    # Create dummy input
    x = jax.random.normal(rng_key, input_shape)

    # Initialize custom attention
    custom_params = custom_attention.init(rng_key, x, training=True)
    custom_output = custom_attention.apply(custom_params, x, training=True, rngs={'dropout': rng_key})

    # Initialize Flax attention
    flax_params = flax_attention.init(rng_key, x, x, x)
    flax_output = flax_attention.apply(flax_params, x, x, x)

    print("custom_output: ", custom_output)
    print("flax_output: ", flax_output)

    # Compare outputs
    return jnp.isclose(custom_output, flax_output, atol=1e-5).all()

In [None]:
rng_key = jax.random.PRNGKey(0)
input_shape = (1, 2, 4)  # (batch_size, sequence_length, feature_size)
num_heads = 4
head_size = 16

In [None]:
# Custom attention
custom_attention = model.MultiHeadAttentionBatch(num_heads=num_heads, head_size=head_size, T=input_shape[1])

# Flax attention
flax_attention = nn.MultiHeadDotProductAttention(num_heads=num_heads, qkv_features=head_size * num_heads, out_features=head_size * num_heads)


In [None]:
result = compare_attention_outputs(custom_attention, flax_attention, input_shape, num_heads, head_size, rng_key)
print("Are the attention outputs close?", result)

custom_output:  [[[ 1.8267553   0.2545625   0.51664734  1.872377    0.
   -0.24741195  0.06325258 -0.4732108   0.97036505  0.
    1.2170275   1.5578686   1.0638489   0.          2.772833
   -0.41109625  0.444793   -0.08247733  0.         -0.18067323
    0.          0.          0.6723873  -0.93943655 -0.3522747
    1.2153784  -3.7089698   1.3073872  -0.6657839  -0.5994085
   -0.33070773 -1.8484493   0.37312767  0.44226554  0.60474485
    2.2404766   0.         -1.8605132  -2.4844682  -0.56995404
   -0.1442299   1.2074916  -0.11788648  2.850931    0.33974466
    2.3744946  -2.746928    0.685969   -0.92724115 -1.0124649
    0.          0.          1.3646483   0.4259958   1.1758763
   -0.8295348   0.3146336   0.38039386 -1.96878    -1.0014266
    0.88716567  1.783647    0.57467306  0.        ]
  [ 0.00777232  0.8190602   2.6580398   1.651423   -0.9469865
    0.48011455 -0.9287533   0.          0.         -2.8874931
    0.60840005  2.0658875   0.35624415  0.          0.
    0.70599437  0.58