In [2]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-03-13 18:33:37--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-03-13 18:33:37 (20.7 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

In [5]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [6]:
# Below would result in a minibatch size of 32.
BATCH_SIZE = 8 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 16 # what is the maximum context length for predictions?

In [7]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


In [14]:
class SingleExampleMultiHeadSelfAttention(nn.Module):
    num_heads: int
    projection_dim: int

    @nn.compact
    def __call__(self, x):
        # x shape: (seq_len, projection_dim)
        query = nn.Dense(self.projection_dim)(x)
        key = nn.Dense(self.projection_dim)(x)
        value = nn.Dense(self.projection_dim)(x)

        query = query.reshape((self.num_heads, self.projection_dim // self.num_heads))
        key = key.reshape((self.num_heads, self.projection_dim // self.num_heads))
        value = value.reshape((self.num_heads, self.projection_dim // self.num_heads))

        attention_scores = jnp.einsum('hd,Hd->hH', query, key) / jnp.sqrt(self.projection_dim)
        attention_weights = nn.softmax(attention_scores, axis=-1)
        attention_output = jnp.einsum('hH,Hd->hd', attention_weights, value)
        attention_output = attention_output.reshape((self.num_heads * (self.projection_dim // self.num_heads)))

        output = nn.Dense(self.projection_dim)(attention_output)
        return output

# Vectorize the single-example module to handle batches
# batch_multi_head_self_attention = jax.vmap(SingleExampleMultiHeadSelfAttention(num_heads=8, projection_dim=64), in_axes=0, out_axes=0)

In [24]:
class SingleHeadAttention(nn.Module):
  head_size: int # token_info_size; how much (emb dim) info each token emits for keys, queries, values.
  T: int # block size; number of tokens in a block

  def setup(self):
    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output head_size
    self.key_layer = nn.Dense(self.head_size, use_bias=False)
    self.query_layer = nn.Dense(self.head_size, use_bias=False)
    self.value_layer = nn.Dense(self.head_size, use_bias=False)

    self.dropout = nn.Dropout(rate=0.2)



  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    """Accepts a block of tokens with info channels, like (8, 65)."""
    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    # input: (T, info channels)
    # output: (T, head_size)
    keys = self.key_layer(block_of_tokens_with_info_channels)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # chanel info size
    C = int(block_of_tokens_with_info_channels.shape[-1])

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, head_size) * (head_size, T) == (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)

    attention_values = jnp.dot(wei, values) # (T, T) * (T, head_size))

    attention_values = self.dropout(attention_values, deterministic=not training)

    return attention_values # (T, head_size)


class MultiHeadAttention(nn.Module):
  num_heads: int
  head_size: int # head_size * num_heads is the final embedding dimension you get, after concatenating from all heads
  T: int

  def setup(self):
    self.heads = [
        SingleHeadAttention(head_size=self.head_size, T=self.T) for _ in range(self.num_heads)
    ]

    final_output_size = self.num_heads * self.head_size
    self.projection = nn.Dense(features=final_output_size)

    self.dropout = nn.Dropout(rate=0.2)

  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    out_from_each_head = jnp.array([h(block_of_tokens_with_info_channels, training) for h in self.heads])

    # You just run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(out_from_each_head, axis=-1)
    # print("[ntn99] out_from_all_heads concatenated shape: ", out_from_all_heads.shape)

    projection =  self.projection(out_from_all_heads)

    return self.dropout(projection, deterministic=not training)

class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    # Attention paper uses 4 times token_info_size when doing linear transformation
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x, training: bool):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x

class TransformerEncoderBlock(nn.Module):
  num_heads: int
  output_size: int # final_token_info_size; head_size * num_heads is the final embedding dimension you get, after concatenating from all heads
  T: int

  def setup(self):
    # communication.
    self.head_size = self.output_size // self.num_heads  # each SingleAttentionHead will produce head_size worth of info for key, value, querie. You concatenate all of them to get the final n_embed.
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   head_size = self.head_size,
                                                   T=self.T)

    # computation.
    self.computation_layer = FeedForward(output_size=self.output_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

    self.dropout = nn.Dropout(rate=0.2)

  def __call__(self, x, training: bool):
    x = x + self.self_attention_heads(self.ln1(x), training)

    x = x + self.computation_layer(self.ln2(x), training)

    x = self.dropout(x, deterministic=not training)
    return x

In [25]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    self.num_blocks = 4
    self.blocks = [
        TransformerEncoderBlock(num_heads=4,
                                output_size=self.n_embed,
                                T=self.T) for _ in range(self.num_blocks)
    ]
    self.ln = nn.LayerNorm()
    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array, training: bool):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""
    # generate emb for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    # get token positions.

    # num_pos = block_of_tokens.shape[0]
    num_pos = T
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    # x = self.self_attention_heads(x)
    # x = self.blocks(x)(training)
    for i in range(self.num_blocks):
      x = self.blocks[i](x, training)

    x = self.ln(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


In [26]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

T = BLOCK_SIZE
random_key = jax.random.PRNGKey(99)
random_key, random_subkey = jax.random.split(random_key)

model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens, training=False)
params = params["params"]


In [12]:
def model_apply(params, inputs):
  dropout_key = jax.random.PRNGKey(0) # TODO need to fix this.
  return model.apply({"params": params}, inputs, False, rngs={'dropout': dropout_key})

model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

def train_step(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch)
  state = state.apply_gradients(grads=grads)
  return state, loss

opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)

In [13]:
for epoch in range(1):
  batch = get_batch()

  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(key=random_key, data=state.step)

  state, loss = train_step(state, batch)
  print("loss", loss, "epoch", epoch) if epoch%100==0 else None

loss 4.602925 epoch 0


## pmapping

In [None]:
def model_apply(params, inputs):
  dropout_key = jax.random.PRNGKey(0) # TODO need to fix this.
  return model.apply({"params": params}, inputs, False, rngs={'dropout': dropout_key})

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  print("forward pass loss 1", loss)
  loss = jax.lax.pmean(loss, "device")
  print("forward pass loss 2", loss)
  return loss

def train_step(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch)
  print("loss before mean", loss)

  grads = jax.lax.pmean(grads, "device")
  # loss = jax.lax.pmean(loss, "device")

  print("loss after mean", loss)
  state = state.apply_gradients(grads=grads)
  return state, loss

In [None]:
opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
states = jax.device_put_replicated(state, jax.local_devices())

In [None]:
with jax.disable_jit():
  model_apply_batch = jax.vmap(model_apply, in_axes=(None, 0), out_axes=(0))

  opt = optax.adam(learning_rate=0.0001)
  state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)
  states = jax.device_put_replicated(state, jax.local_devices())
  train_step_pmap = jax.pmap(train_step, axis_name="device")

  for epoch in range(1):
    inputs, targets = get_batch()
    inputs = jnp.reshape(inputs, [DEVICE_COUNT, -1, inputs.shape[1]])
    targets = jnp.reshape(targets, [DEVICE_COUNT, -1, targets.shape[1]])
    batch = inputs, targets


    states, loss = train_step_pmap(states, batch)
    print("loss", loss, "epoch", epoch) if epoch%100==0 else None

forward pass loss 1 Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[])>with<MapTrace(level=0/1)> with
    val = ShardedDeviceArray([4.4624095, 4.098372 , 4.7049522, 4.719946 , 4.3354793,
                    4.6712008, 4.5151587, 4.525096 ], dtype=float32)
    shard_axes = {'device': 0}
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7842309a7cb0>, in_tracers=(Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x784230494a90; to 'JaxprTracer' at 0x784230496930>], out_avals=[ShapedArray(float32[])], primitive=div, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7842306662f0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
forward pass loss 2 Traced<ShapedArray(float32[

In [None]:
def train_step(state, batch):
  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(random_key, data=state.step)

  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  loss, grads = grad_fn(state.params, state, batch, dropout_key)

  return state, loss