In [4]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [5]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [6]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

# Dataset

In [7]:
from dataclasses import dataclass
from typing import Dict, List, Mapping, Tuple

import jax
import jax.numpy as jnp
import tensorflow as tf
import requests

# Below would result in a minibatch size of 32.
BATCH_SIZE = 32 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 16 # what is the maximum context length for predictions?

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]


def create_dataset(training: bool = True):
  data = train_data if training else val_data
  dataset = (tf.data.Dataset.from_tensor_slices(data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
  return dataset

def get_batch(dataset):
  batch = next(dataset)
  return jnp.array(batch)

train_dataset = create_dataset(training=True)
val_dataset = create_dataset(training=False)

--2024-04-25 03:35:25--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-25 03:35:25 (19.9 MB/s) - ‘input.txt’ saved [1115394/1115394]

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


### test dataset

In [8]:
xb, yb = get_batch(train_dataset)
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14]
 [44 53 56 43  1 61 43  1 54 56 53 41 43 43 42  1]
 [52 63  1 44 59 56 58 46 43 56  6  1 46 43 39 56]
 [51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0]
 [54 43 39 49  6  1 57 54 43 39 49  8  0  0 18 47]
 [57 58  1 15 47 58 47 64 43 52 10  0 37 53 59  1]
 [56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1]
 [39 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39]
 [ 1 58 53  1 44 39 51 47 57 46 12  0  0 13 50 50]
 [ 0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50]
 [43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64]
 [52 10  0 18 47 56 57 58  6  1 63 53 59  1 49 52]
 [61  1 15 39 47 59 57  1 25 39 56 41 47 59 57  1]
 [57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53]
 [58 46 43  1 54 43 53 54 50 43  8  0  0 13 50 50]
 [ 0 35 43  1 49 52 53 61  5 58  6  1 61 43  1 49]
 [53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58]
 [64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50]
 [46 47 51  6  1 39 52 42  1 61 43  5 50 50  1 46]
 [60 43  1 41 53 56 52  

# Model

In [70]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

class SingleHeadAttention(nn.Module):
  """Implements a simple SingleHeadAttention layer for a single example from batch."""

  head_size: int
  num_tokens: int
  dropout_rate: float = 0.2
  use_causal_mask: bool = True  # if True, only attend to past tokens.

  @nn.compact
  def __call__(self, tokens: jnp.array, training: bool):
    """Tokens, each with some channel dim."""
    # Use separate single dense layers for calculating keys, query, values
    keys = nn.Dense(self.head_size, use_bias=False)(tokens)
    queries = nn.Dense(self.head_size, use_bias=False)(tokens)
    values = nn.Dense(self.head_size, use_bias=False)(tokens)

    mask = (
        jnp.tril(jnp.ones(shape=(self.num_tokens, self.num_tokens)))
        if self.use_causal_mask
        else jnp.ones(shape=(self.num_tokens, self.num_tokens))
    )

    # compute attention score.
    # import pdb; pdb.set_trace()
    wei = jnp.dot(queries, keys.T) / jnp.sqrt(self.head_size)
    wei = jnp.where(mask == 0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)

    attention_values = jnp.dot(
        wei, values
    )  # (num_tokens, num_tokens) * (num_tokens, head_size)
    attention_values = nn.Dropout(rate=self.dropout_rate, deterministic=True)(
        attention_values
    )
    return attention_values  # (num_tokens, head_size)


class MultiHeadAttention(nn.Module):
  """Implements a MultiHeadAttention layer for a single example from batch."""

  head_size: int
  num_heads: int
  num_tokens: int
  dropout_rate: float = 0.2

  def setup(self):
    self.heads = [
        SingleHeadAttention(
            head_size=self.head_size, num_tokens=self.num_tokens
        )
        for _ in range(self.num_heads)
    ]

    # Project concatenated output from all attention heads to final output
    # dimension, which is head_size * num_heads.
    self.projection = nn.Dense(features=self.num_heads * self.head_size)
    self.dropout = nn.Dropout(rate=self.dropout_rate, deterministic=True)

  def __call__(self, tokens: jnp.array, training: bool):
    output_from_each_head = []
    for h in self.heads:
      head_output = h(tokens, training)
      output_from_each_head.append(head_output)

    # Run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(output_from_each_head, axis=-1)

    projection = self.projection(out_from_all_heads)
    return self.dropout(projection)

class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    # Attention paper uses 4 times token_info_size when doing linear transformation
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x, training: bool):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x

class TransformerEncoderBlock(nn.Module):
  num_heads: int
  # output_size = head_size * num_heads, is the final embedding dimension you get after concatenating from all heads.
  output_size: int
  num_tokens: int

  def setup(self):
    # communication.
    # each single head will produce head_size worth of info for key, value, querie. You concatenate all of them to get the final output_size.
    self.head_size = self.output_size // self.num_heads
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   head_size = self.head_size,
                                                   num_tokens=self.num_tokens)

    # computation.
    self.computation_layer = FeedForward(output_size=self.output_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

    self.dropout = nn.Dropout(rate=0.2, deterministic=True)

  def __call__(self, x, training: bool):
    # transformer encoder forward pass
    x = x + self.self_attention_heads(self.ln1(x), training)

    x = x + self.computation_layer(self.ln2(x), training)

    x = self.dropout(x)
    return x

class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  num_tokens: int # block size, i.e., number of tokens attention block is looking at once
  num_heads: int
  num_layers: int

  def setup(self):
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.num_tokens, features=self.n_embed)

    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    self.blocks = [
        TransformerEncoderBlock(num_heads=self.num_heads,
                                output_size=self.n_embed,
                                num_tokens=self.num_tokens) for _ in range(self.num_layers)
    ]
    self.ln = nn.LayerNorm()
    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array, training: bool):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""
    # generate emb for each token. output: (num_tokens, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    # language model, forward pass, block_of_tokens
    for i in range(self.num_layers):
      x = self.blocks[i](x, training)

    x = self.ln(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


class LanguageModelBatch(nn.Module):
  """Extends MultiHeadAttention to work on a batch of data."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  num_tokens: int # block size, i.e., number of tokens attention block is looking at once
  num_heads: int
  num_layers: int

  def setup(self):
    self.lm_single_example = LanguageModel(
        vocab_size=self.vocab_size,
        n_embed=self.n_embed,
        num_tokens=self.num_tokens,
        num_heads=self.num_heads,
        num_layers=self.num_layers)

    self.lm_batch = jax.vmap(
        self.lm_single_example,
        in_axes=(0, None),  # tokens, training
        out_axes=(0),
    )

  def __call__(self, tokens: jnp.array, training: bool):
    return self.lm_batch(tokens, training)

In [71]:
model = LanguageModelBatch(vocab_size=65,
                      n_embed=256,
                      num_tokens=BLOCK_SIZE,
                      num_heads=2,
                      num_layers=1)

In [84]:
inputs, targets = get_batch(train_dataset)
inputs.shape, targets.shape

((32, 16), (32, 16))

In [76]:
output, params = model.init_with_output(jax.random.PRNGKey(99), inputs, training=False)
params = params["params"]

output.shape

(32, 16, 65)

In [85]:
inputs.shape

(32, 16)

In [89]:
# class TrainState(train_state.TrainState):


T = BLOCK_SIZE

model = LanguageModelBatch(vocab_size=65,
                      n_embed=256,
                      num_tokens=BLOCK_SIZE,
                      num_heads=2,
                      num_layers=1)


output, params = model.init_with_output(jrand.PRNGKey(99), inputs, training=False)
params = params["params"]


def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn(params, inputs, False)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss
grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.0001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt)

for epoch in range(1):
  batch = get_batch(train_dataset)

  loss, grads = grad_fn(state.params, state, batch)
  print("loss", loss, "epoch", epoch) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)


ScopeCollectionNotFound: Tried to access "embedding" from collection "params" in "/lm_single_example/token_embedding_table" but the collection is empty. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeCollectionNotFound)

In [82]:
from typing import Tuple

import chex
from chex._src import fake
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from dataclasses import dataclass
import importlib
import pdb

# import dataset
# import model

class TrainState(train_state.TrainState):
    key: jax.random.KeyArray

@dataclass
class Config:
    BATCH_SIZE: int = 256
    BLOCK_SIZE: int = 64
    T: int = 64
    n_embed: int = 256
    num_heads: int = 8
    num_layers: int = 6

config = Config()

random_key = jax.random.PRNGKey(99)

# Initialize model
lm_model = LanguageModelBatch(vocab_size=65,
                      n_embed=config.n_embed,
                      num_tokens=config.BLOCK_SIZE,
                      num_heads=config.num_heads,
                      num_layers=config.num_layers)

output, params = lm_model.init_with_output(jax.random.PRNGKey(99), inputs, training=True)
params = params["params"]

PER_HOST_BATCH_SIZE = config.BATCH_SIZE // jax.device_count()

# Define forward pass
def forward_pass(params, state, batch): #, dropout_key):
    inputs, targets = batch
    logits = state.apply_fn(params, inputs, True)# , dropout_key)

    chex.assert_shape(inputs, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))
    chex.assert_shape(targets, (PER_HOST_BATCH_SIZE, config.BLOCK_SIZE))

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    loss = loss.mean()
    return loss

# Define training step
def train_step(state, inputs, targets): #, dropout_key):
    # dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)
    batch = inputs, targets

    grad_fn = jax.value_and_grad(forward_pass, argnums=(0))
    loss, grads = grad_fn(state.params, state, batch) # , dropout_key)

    loss = jax.lax.pmean(loss, axis_name="devices")
    grads = jax.lax.pmean(grads, axis_name="devices")

    state = state.apply_gradients(grads=grads)
    return state, loss

# Initialize optimizer and training state
opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=lm_model.apply, params=params, tx=opt, key=random_key)

# pmap the train_step.
train_step_pmap = jax.jit(jax.pmap(train_step, in_axes=(0, 0, 0), out_axes=(0), axis_name="devices"))
states = jax.device_put_replicated(state, jax.local_devices())

ValueError: Incompatible shapes for broadcasting: shapes=[(4, 16, 256), (4, 256)]

In [81]:
# fake_pmap = chex.fake_pmap_and_jit(enable_jit_patching=fake_jit, enable_pmap_patching=fake_pmap)
# fake_pmap.start()
num_epochs = 1 # 20
steps_per_epoch = 1 # len(data.train_data) // config.BATCH_SIZE
for epoch in range(num_epochs):
  print("epoch: ", epoch)

  for step in range(steps_per_epoch):
    random_key, random_subkey = jax.random.split(random_key)

    inputs, targets = get_batch(train_dataset)

    # create device dimension for minibatch
    inputs = inputs.reshape((jax.device_count(), -1, inputs.shape[-1]))
    targets = targets.reshape((jax.device_count(), -1, targets.shape[-1]))

    states, loss = train_step_pmap(states, inputs, targets)
    print("loss", loss[0], "epoch", epoch) if epoch % 1 == 0 else None

# fake_pmap.stop()

epoch:  0


NameError: name 'train_step_pmap' is not defined

In [None]:
inputs, targets = get_batch(train_dataset)

In [None]:
inputs.shape, targets.shape

((32, 16), (32, 16))

In [None]:
# create device dimension for minibatch
inputs = inputs.reshape((jax.device_count(), -1, inputs.shape[-1]))
targets = targets.reshape((jax.device_count(), -1, targets.shape[-1]))# create device dimension for minibatch

In [None]:
inputs.shape, targets.shape

((8, 4, 16), (8, 4, 16))