In [2]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [13]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

In [6]:
%ls

[0m[01;34massets[0m/  dataset.py  model.py                   nanoGPT_singe_file.ipynb  trainer.py
[01;34mdata[0m/    LICENSE     nanoGPT_JAX_JAX_JAX.ipynb  README.md


In [15]:
import importlib

import dataset
import model

importlib.reload(dataset)
importlib.reload(model)

<module 'model' from '/content/nanoGPT-JAX-JAX-JAX/model.py'>

In [9]:
BATCH_SIZE = 8
BLOCK_SIZE = 16

In [16]:
poem_dataset = dataset.Dataset()

  data = jnp.array(_encode(text, self.stoi), dtype=jnp.int64)


In [17]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

T = BLOCK_SIZE
random_key = jax.random.PRNGKey(99)
random_key, random_subkey = jax.random.split(random_key)

model = model.LanguageModel(vocab_size=65, n_embed=48, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens, training=False)
params = params["params"]


For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  key: jax.random.KeyArray


In [None]:
def model_apply(params, inputs):
  dropout_key = jax.random.PRNGKey(0) # TODO need to fix this.
  return model.apply({"params": params}, inputs, False, rngs={'dropout': dropout_key})

model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

def train_step(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch)
  state = state.apply_gradients(grads=grads)
  return state, loss

opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)

In [None]:
for epoch in range(1):
  batch = get_batch()

  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(key=random_key, data=state.step)

  state, loss = train_step(state, batch)
  print("loss", loss, "epoch", epoch) if epoch%100==0 else None

loss 4.3279257 epoch 0


## pmapping

In [None]:
def model_apply(params, inputs):
  dropout_key = jax.random.PRNGKey(0) # TODO need to fix this.
  return model.apply({"params": params}, inputs, False, rngs={'dropout': dropout_key})

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  print("forward pass loss 1", loss)
  loss = jax.lax.pmean(loss, "device")
  print("forward pass loss 2", loss)
  return loss

def train_step(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch)
  print("loss before mean", loss)

  grads = jax.lax.pmean(grads, "device")
  # loss = jax.lax.pmean(loss, "device")

  print("loss after mean", loss)
  state = state.apply_gradients(grads=grads)
  return state, loss

In [None]:
opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
states = jax.device_put_replicated(state, jax.local_devices())

In [None]:
with jax.disable_jit():
  model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0), out_axes=(0))

  opt = optax.adam(learning_rate=0.0001)
  state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
  states = jax.device_put_replicated(state, jax.local_devices())
  train_step_pmap = jax.pmap(train_step, axis_name="device")

  for epoch in range(1):
    inputs, targets = get_batch()
    inputs = jnp.reshape(inputs, [DEVICE_COUNT, -1, inputs.shape[1]])
    targets = jnp.reshape(targets, [DEVICE_COUNT, -1, targets.shape[1]])
    batch = inputs, targets


    states, loss = train_step_pmap(states, batch)
    print("loss", loss, "epoch", epoch) if epoch%100==0 else None

forward pass loss 1 Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[])>with<MapTrace(level=0/1)> with
    val = ShardedDeviceArray([4.6782055, 4.7953796, 4.7818027, 4.407387 , 4.579782 ,
                    4.4532733, 4.446926 , 5.13021  ], dtype=float32)
    shard_axes = {'device': 0}
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7dca9521ea00>, in_tracers=(Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x7dca952bd760; to 'JaxprTracer' at 0x7dca9772b740>], out_avals=[ShapedArray(float32[])], primitive=div, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7dca9782f230>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
forward pass loss 2 Traced<ShapedArray(float32[

In [None]:
def train_step(state, batch):
  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(random_key, data=state.step)

  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch, dropout_key)

  return state, loss