In [1]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

8

In [7]:
import dataclasses

@dataclasses.dataclass
class Config:
    vocab_size: int = 66
    batch_size: int = 512
    block_size: int = 64
    n_embed: int = 256
    num_heads: int = 8
    num_layers: int = 6

config = Config()

# Dataset

In [29]:
from dataclasses import dataclass
from typing import Dict, List, Mapping, Tuple

import jax
import jax.numpy as jnp
import tensorflow as tf
import requests

# Below would result in a minibatch size of 32.
BATCH_SIZE = config.batch_size # how many independent sequences will we process in parallel?
BLOCK_SIZE = config.block_size # what is the maximum context length for predictions?

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]


def create_dataset(training: bool = True):
  data = train_data if training else val_data
  dataset = (tf.data.Dataset.from_tensor_slices(data)
                .repeat()
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .as_numpy_iterator())
  return dataset

def get_batch(dataset):
  batch = next(dataset)
  return jnp.array(batch)

train_dataset = create_dataset(training=True)
val_dataset = create_dataset(training=False)

--2024-04-26 04:11:07--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.4’


2024-04-26 04:11:07 (20.0 MB/s) - ‘input.txt.4’ saved [1115394/1115394]

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


### test dataset

In [9]:
xb, yb = get_batch(train_dataset)
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
# for b in range(BATCH_SIZE): # batch dimension
#     for t in range(BLOCK_SIZE): # time dimension
#         context = xb[b, :t+1]
#         target = yb[b,t]
#         print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 ...  0 13 50]
 [10  0 31 ... 53  1 42]
 [43  1 58 ... 47 56 57]
 ...
 [43 11  1 ...  1 58 46]
 [52  1 13 ...  1 21  1]
 [43 55 59 ... 43 45 45]]
inputs shape
(512, 64)
targets
[[47 56 57 ... 13 50 50]
 [ 0 31 54 ...  1 42 47]
 [ 1 58 46 ... 56 57 58]
 ...
 [11  1 46 ... 58 46 43]
 [ 1 13 59 ... 21  1 56]
 [55 59 43 ... 45 45  5]]
targets shape
(512, 64)


# Model

In [10]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

class SingleHeadAttention(nn.Module):
  """Implements a simple SingleHeadAttention layer for a single example from batch."""

  head_size: int
  num_tokens: int
  dropout_rate: float = 0.2
  use_causal_mask: bool = True  # if True, only attend to past tokens.

  @nn.compact
  def __call__(self, tokens: jnp.array, training: bool):
    """Tokens, each with some channel dim."""
    # Use separate single dense layers for calculating keys, query, values
    keys = nn.Dense(self.head_size, use_bias=False)(tokens)
    queries = nn.Dense(self.head_size, use_bias=False)(tokens)
    values = nn.Dense(self.head_size, use_bias=False)(tokens)

    mask = (
        jnp.tril(jnp.ones(shape=(self.num_tokens, self.num_tokens)))
        if self.use_causal_mask
        else jnp.ones(shape=(self.num_tokens, self.num_tokens))
    )

    # compute attention score.
    # import pdb; pdb.set_trace()
    wei = jnp.dot(queries, keys.T) / jnp.sqrt(self.head_size)
    wei = jnp.where(mask == 0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)

    attention_values = jnp.dot(
        wei, values
    )  # (num_tokens, num_tokens) * (num_tokens, head_size)
    attention_values = nn.Dropout(rate=self.dropout_rate, deterministic=True)(
        attention_values
    )
    return attention_values  # (num_tokens, head_size)


class MultiHeadAttention(nn.Module):
  """Implements a MultiHeadAttention layer for a single example from batch."""

  head_size: int
  num_heads: int
  num_tokens: int
  dropout_rate: float = 0.2

  def setup(self):
    self.heads = [
        SingleHeadAttention(
            head_size=self.head_size, num_tokens=self.num_tokens
        )
        for _ in range(self.num_heads)
    ]

    # Project concatenated output from all attention heads to final output
    # dimension, which is head_size * num_heads.
    self.projection = nn.Dense(features=self.num_heads * self.head_size)
    self.dropout = nn.Dropout(rate=self.dropout_rate, deterministic=True)

  def __call__(self, tokens: jnp.array, training: bool):
    output_from_each_head = []
    for h in self.heads:
      head_output = h(tokens, training)
      output_from_each_head.append(head_output)

    # Run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(output_from_each_head, axis=-1)

    projection = self.projection(out_from_all_heads)
    return self.dropout(projection)

class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    # Attention paper uses 4 times token_info_size when doing linear transformation
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x, training: bool):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x

class TransformerEncoderBlock(nn.Module):
  num_heads: int
  # output_size = head_size * num_heads, is the final embedding dimension you get after concatenating from all heads.
  output_size: int
  num_tokens: int

  def setup(self):
    # communication.
    # each single head will produce head_size worth of info for key, value, querie. You concatenate all of them to get the final output_size.
    self.head_size = self.output_size // self.num_heads
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   head_size = self.head_size,
                                                   num_tokens=self.num_tokens)

    # computation.
    self.computation_layer = FeedForward(output_size=self.output_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

    self.dropout = nn.Dropout(rate=0.2, deterministic=True)

  def __call__(self, x, training: bool):
    # transformer encoder forward pass
    x = x + self.self_attention_heads(self.ln1(x), training)

    x = x + self.computation_layer(self.ln2(x), training)

    x = self.dropout(x)
    return x

class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  num_tokens: int # block size, i.e., number of tokens attention block is looking at once
  num_heads: int
  num_layers: int

  def setup(self):
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.num_tokens, features=self.n_embed)

    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    self.blocks = [
        TransformerEncoderBlock(num_heads=self.num_heads,
                                output_size=self.n_embed,
                                num_tokens=self.num_tokens) for _ in range(self.num_layers)
    ]
    self.ln = nn.LayerNorm()
    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array, training: bool):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""
    # generate emb for each token. output: (num_tokens, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    # language model, forward pass, block_of_tokens
    for i in range(self.num_layers):
      x = self.blocks[i](x, training)

    x = self.ln(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


class LanguageModelBatch(nn.Module):
  """Extends MultiHeadAttention to work on a batch of data."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  num_tokens: int # block size, i.e., number of tokens attention block is looking at once
  num_heads: int
  num_layers: int

  @nn.compact
  def __call__(self, tokens: jnp.array, training: bool):
    return jax.vmap(LanguageModel(
        vocab_size=self.vocab_size,
        n_embed=self.n_embed,
        num_tokens=self.num_tokens,
        num_heads=self.num_heads,
        num_layers=self.num_layers),
                    in_axes=(0, None),  # tokens, training
                    out_axes=(0),)(tokens, training)

In [11]:
def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs, False)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [12]:
def backward_pass(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))
  loss, grads = grad_fn(state.params, state, batch)

  state = state.apply_gradients(grads=grads)
  return state, loss

In [13]:
def backward_pass_pmap(state, batch):
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))
  loss, grads = grad_fn(state.params, state, batch)

  loss = jax.lax.pmean(loss, axis_name="devices")
  grads = jax.lax.pmean(grads, axis_name="devices")

  state = state.apply_gradients(grads=grads)
  return state, loss

In [14]:
def train_step(state, batch):
  state, loss = backward_pass(state, batch)
  return state, loss

train_step_pmap = jax.pmap(
    jax.jit(train_step), in_axes=(0, 0), out_axes=(0), axis_name="devices")

In [30]:
def get_batch_pmap(dataset):
  inputs, targets = get_batch(dataset)
  inputs = inputs.reshape((jax.device_count(), -1, inputs.shape[-1]))
  targets = targets.reshape((jax.device_count(), -1, targets.shape[-1]))
  return inputs, targets


In [16]:
model = LanguageModelBatch(vocab_size=config.vocab_size,
                      n_embed=config.n_embed,
                      num_tokens=config.block_size,
                      num_heads=config.num_heads,
                      num_layers=config.num_layers)

In [17]:
inputs, targets = get_batch(train_dataset)
inputs.shape, targets.shape

((512, 64), (512, 64))

In [31]:
inputsp, targetsp = get_batch_pmap(train_dataset)
inputsp.shape, targetsp.shape

((8, 64, 64), (8, 64, 64))

In [19]:
output, params = model.init_with_output(jax.random.PRNGKey(99), inputs, training=False)
params = params["params"]

output.shape

(512, 64, 66)

In [32]:
opt = optax.adam(learning_rate=0.0001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt)

In [33]:
states = jax.device_put_replicated(state, jax.local_devices())

In [34]:
backward_pass(state, get_batch(train_dataset))

(TrainState(step=1, apply_fn=<bound method Module.apply of LanguageModelBatch(
     # attributes
     vocab_size = 66
     n_embed = 256
     num_tokens = 64
     num_heads = 8
     num_layers = 6
 )>, params=FrozenDict({
     LanguageModel_0: {
         blocks_0: {
             computation_layer: {
                 ffwd: {
                     bias: DeviceArray([-0.00010007,  0.00010007,  0.00010007, ...,  0.00010007,
                                   0.00010007,  0.00010007], dtype=float32),
                     kernel: DeviceArray([[ 0.0627896 ,  0.07046933, -0.08207601, ..., -0.06679668,
                                   -0.06939802, -0.08279363],
                                  [-0.10408718, -0.06428235, -0.02808709, ...,  0.03370257,
                                    0.13579439,  0.06203614],
                                  [ 0.02165288, -0.06087535, -0.04173132, ...,  0.04389834,
                                   -0.06378376,  0.02416425],
                              

In [35]:
states, loss = train_step_pmap(states, get_batch_pmap(train_dataset))

In [38]:
for step in range(1000):
  train_batch = get_batch_pmap(train_dataset)
  states, loss = train_step_pmap(states, train_batch)

  print("loss", loss[0], "step", step) if step%100==0 else None

loss 2.694466 step 0
loss 2.5500324 step 100
loss 2.4421945 step 200
loss 2.356482 step 300
loss 2.2772663 step 400
loss 2.2399716 step 500
loss 2.0661721 step 600
loss 2.0726125 step 700
loss 2.0671456 step 800
loss 1.9112525 step 900


# Generating

In [40]:
T = config.block_size

state = jax.tree_map(lambda x: x[0], states)

state_apply_jit = jax.jit(state.apply_fn)

context = jnp.tile(jnp.array([52], dtype=jnp.int32), T)
context = context[None, -T:]
key = jrand.PRNGKey(99)

for _ in range(100):
  next_token_logits = state_apply_jit({"params": state.params}, context[:, -T:], False)

  key, split_key = jrand.split(key)
  new_token = jax.random.categorical(key, next_token_logits[:, -1, :], axis=-1, shape=(1, 1))

  context = jnp.concatenate([context, new_token], axis=1)


print(context.tolist()[0])

[52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 53, 61, 43, 31, 57, 1, 39, 58, 58, 0, 0, 21, 52, 27, 6, 1, 15, 53, 61, 5, 1, 58, 39, 39, 63, 1, 58, 61, 47, 58, 46, 53, 6, 1, 58, 47, 59, 57, 49, 1, 58, 46, 53, 52, 43, 6, 1, 49, 39, 50, 50, 1, 39, 47, 42, 14, 43, 57, 47, 43, 8, 0, 35, 53, 1, 24, 43, 58, 1, 53, 56, 1, 14, 59, 52, 58, 53, 1, 5, 58, 1, 52, 53, 1, 54, 39, 47, 52, 12, 0, 32, 46, 53, 59, 6, 1, 39, 52, 42, 1]


In [41]:
decode(context.tolist()[0], itos)

"nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnoweSs att\n\nInO, Cow' taay twitho, tiusk thone, kall aidBesie.\nWo Let or Bunto 't no pain?\nThou, and "