<h1>Install dependencies for CPU</h1>

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-02-08 03:39:42--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-02-08 03:39:42 (49.8 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
# !pip install -q --upgrade pip # To support manylinux2010 wheels.
# !pip install -q --upgrade jax jaxlib # CPU-only
# !pip install -q --upgrade jaxtyping
# !pip install -q --upgrade flax

In [2]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb

def println(*args):
  for arg in args:
    print(arg)


# Dataset pipeline

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Below would result in a minibatch size of 32.
BATCH_SIZE = 4 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 8 # what is the maximum context length for predictions?

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


### Test dataset pipeline

In [4]:
xb, yb = get_batch()
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47]
 [47 64 43 52 10  0 14 43]
 [53 56 43  1 61 43  1 54]
 [53 41 43 43 42  1 39 52]]
inputs shape
(4, 8)
targets
[[47 56 57 58  1 15 47 58]
 [64 43 52 10  0 14 43 44]
 [56 43  1 61 43  1 54 56]
 [41 43 43 42  1 39 52 63]]
targets shape
(4, 8)
when input is [18] the target: 47
when input is [18, 47] the target: 56
when input is [18, 47, 56] the target: 57
when input is [18, 47, 56, 57] the target: 58
when input is [18, 47, 56, 57, 58] the target: 1
when input is [18, 47, 56, 57, 58, 1] the target: 15
when input is [18, 47, 56, 57, 58, 1, 15] the target: 47
when input is [18, 47, 56, 57, 58, 1, 15, 47] the target: 58
when input is [47] the target: 64
when input is [47, 64] the target: 43
when input is [47, 64, 43] the target: 52
when input is [47, 64, 43, 52] the target: 10
when input is [47, 64, 43, 52, 10] the target: 0
when input is [47, 64, 43, 52, 10, 0] the target: 14
when input is [47, 64, 43, 52, 10, 0, 14] the target: 43
when input is [47, 64, 43, 

# Implement Bigram Model

In [None]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

## I'll make the flax model accept
`[block size worth tokens, token]`

## I'll then use vmap to make the model accept batches of data.
`[batch dim, block size worth of tokens, token]`

In [None]:
sample_input_row = jnp.ones(shape=[1, 1], dtype=jnp.int32)
sample_input_row

Array([[1]], dtype=int32)

In [None]:
model = BigramLangModel(vocab_size=65)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_input_row)
params = params["params"]

## Make the model accepts batch of data
`[batch, block of tokens, token_ids]`

## To make it accept a batch, you need to use vmap.

In [None]:
model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

In [None]:
# model_apply_batch will accept
# [batch, block of tokens, token ids]

## Sample forward pass, loss and backward pass.

In [None]:
batch = get_batch()
inputs, targets = batch
println("inputs", inputs, inputs.shape, "targets", targets.shape)

inputs
[[ 1 44 59 56 58 46 43 56]
 [ 1 46 43 39 56  1 51 43]
 [57 54 43 39 49  8  0  0]
 [50 50 10  0 31 54 43 39]]
(4, 8)
targets
(4, 8)


In [None]:
output = model_apply_batch({"params": params}, inputs)

In [None]:
# output should shape [4, 8, 65]
# batch size = 4
# block of tokens = 8
# token_id to embedding = 65

output.shape

(4, 8, 65)

In [None]:
# To do backward pass, you first need to compute grads.
# In JAX, you use jax.grad to do a function transformation on the forward
# function to get the gradient of the original function.
# The grad is calculate wrt to the first param in the function.
def forward_pass(params, batch):
  inputs, targets = batch
  logits = model_apply_batch({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [None]:
# Test forward pass.
batch = get_batch()
forward_pass(params=params, batch=batch)

Array(4.188369, dtype=float32)

In [None]:
grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

In [None]:
# Test forward pass and grads.
# Grads would be the gradients for params.
loss, grads = grad_fn(params, batch)
println("loss", loss, "grads", grads)

loss
4.188369
grads
{'token_embedding_table': {'embedding': Array([[0.00103511, 0.00097044, 0.00093559, ..., 0.00089521, 0.001064  ,
        0.00093856],
       [0.00203619, 0.00209828, 0.00198735, ..., 0.00196691, 0.00215726,
        0.00207174],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00040068, 0.00039527, 0.00040171, ..., 0.00042136, 0.00055311,
        0.00042526]], dtype=float32)}}


In [None]:
# Apply grads to params to get new params.
lr = 0.001
println("params before:", params)
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
println("params after:", params)

params before:
{'token_embedding_table': {'embedding': Array([[ 0.0752212 ,  0.01071652, -0.02585994, ..., -0.06997449,
         0.10274917, -0.0226865 ],
       [ 0.09400459,  0.12404279,  0.06972364, ...,  0.0593865 ,
         0.1517611 ,  0.11131445],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.20769818,
         0.01281604,  0.03134193],
       ...,
       [-0.1394756 , -0.00640967, -0.07666602, ..., -0.2944119 ,
         0.1187517 , -0.08573762],
       [ 0.05703759, -0.11280773,  0.2570641 , ..., -0.02059634,
        -0.02818088,  0.13305528],
       [-0.12428083, -0.13785616, -0.12170236, ..., -0.07394623,
         0.19811267, -0.06473607]], dtype=float32)}}
params after:
{'token_embedding_table': {'embedding': Array([[ 0.07522016,  0.01071555, -0.02586087, ..., -0.06997538,
         0.1027481 , -0.02268744],
       [ 0.09400256,  0.12404069,  0.06972165, ...,  0.05938453,
         0.15175894,  0.11131237],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.207698

# Writing train step in flax
## copy-pasting everything at one place and running a train step.

In [None]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

model = BigramLangModel(vocab_size=65)

sample_input_row = jnp.ones(shape=[1, 1], dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_input_row)
params = params["params"]

model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

4.1769814
4.093553
4.0517178
3.9815784
3.955411
3.8293285
3.7160275
3.7963347
3.7700095
3.5433385


# Implement code for generating tokens

In [None]:
input_token = "n"
encode(input_token, stoi)


[52]

In [None]:
input_token = jnp.array([[52]], dtype=jnp.int32)
input_token.shape

(1, 1)

In [None]:
next_token_logit = state.apply_fn({"params": state.params}, input_token)

In [None]:
next_token_logit

Array([[[-0.03682718,  0.3873823 , -0.6297037 , -0.49402413,
         -0.5302313 , -0.06883954,  0.02130109, -0.38699383,
          0.05293135, -0.67822653,  0.04108687, -0.17033677,
         -0.21989107, -0.664077  , -0.5178356 , -0.8944765 ,
         -0.47879392, -0.6437434 , -0.8325122 , -0.8247406 ,
         -0.7639986 , -0.6349938 , -0.58913904, -0.71123016,
         -0.6316706 , -0.8322299 , -0.61255544, -0.8065904 ,
         -0.6916736 , -0.7578409 , -0.77268964, -0.66586775,
         -0.48679668, -0.6974697 , -0.5410919 , -0.86016047,
         -0.7141595 , -0.87438166, -0.67868704,  0.03718114,
         -0.9563432 , -0.07171286,  0.3103663 ,  0.27640587,
         -0.15164396,  0.25430372, -0.8608676 , -0.11117504,
         -0.5252868 , -0.07216536,  0.02316949, -0.9069602 ,
         -0.26114473,  0.19493026, -0.9803428 , -0.70663923,
         -0.25509965,  0.25409326,  0.1060916 , -0.11378517,
         -0.35723025, -0.7799293 , -0.53064156, -0.42475328,
         -0.66108644]]],

In [None]:
init_key = jrand.PRNGKey(99)
key, split_key = jrand.split(init_key)

In [None]:
next_to_next_token = jrand.categorical(split_key, next_token_logit)
next_to_next_token

Array([[25]], dtype=int32)

In [None]:
decode(next_to_next_token.tolist()[0], itos)

'M'

In [None]:
next_to_next_to_next_logit = state.apply_fn({"params": state.params}, next_to_next_token)

In [None]:
key, split_key = jrand.split(key)

In [None]:
next_to_next_to_next_token = jrand.categorical(split_key, next_token_logit)

next_to_next_to_next_token

Array([[59]], dtype=int32)

In [None]:
decode(next_to_next_to_next_token.tolist()[0], itos)

'u'

In [None]:
# Putting together the generate code

input_token = jnp.array([[52]], dtype=jnp.int32)
key = jrand.PRNGKey(99)

result = ""
for i in range(100):
  key, split_key = jrand.split(key)
  next_token_logit = state.apply_fn({"params": state.params}, input_token)
  next_token = jrand.categorical(split_key, next_token_logit)
  next_token_decode = decode(next_token.tolist()[0], itos)
  result = result + next_token_decode

print(result)



MuTqF
d$oMJ
 CiSOzjIftBqertiG,3gdghx,,.V,d .zbfa'$fXoeyu'l!m:oaBBQRcrEttkQm3u-r.v3LgdMVfxsx-;ga!kcPW


# Mathematical trick in attention

## doing bag of words
basically, in B, T, C
at t-th token in a row of batch, just sum all the values upto t.

In [None]:
# Generate T, C and write code which works for T, C.
# Then, vmap it for batch

# Using tokens = 4
# Using each token with channel = 2 to make it easy to visualize
T, C = 4, 2

key, split_key = jrand.split(jrand.PRNGKey(99))

x = jrand.normal(split_key, (T, C))
x

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.38331428, -0.16180514],
       [ 1.4986674 ,  1.10728   ],
       [ 1.1535788 ,  0.9676542 ]], dtype=float32)

### version 1: using for loop

In [None]:
def bow_attention(x: jnp.array, T: int, C: int):
  """Operates on a single row within batchs

    It calculates bow attention by summing all
    token channels prev + current token channels.
  """
  xbow = jnp.zeros(shape=(T, C))
  for token in range(T):
    xprev = x[:token]
    xcurrent = x[token:token+1]

    current_bow = jnp.mean(jnp.concatenate([xprev, xcurrent], axis=0), axis=0)
    xbow = xbow.at[token].set(current_bow)
  return xbow

In [None]:
bow_attention(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

In [None]:
# Test bow_attention using non-random tensor.
test_numbers = jnp.arange(1, 5).reshape(-1, 1)
test_arr = jnp.tile(test_numbers, (1, C))
bow_attention(test_arr, T, C)

Array([[1. , 1. ],
       [1.5, 1.5],
       [2. , 2. ],
       [2.5, 2.5]], dtype=float32)

### version 2: Starting to write it using matmul

In [None]:
x

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.38331428, -0.16180514],
       [ 1.4986674 ,  1.10728   ],
       [ 1.1535788 ,  0.9676542 ]], dtype=float32)

In [None]:
wei = jnp.array([[1.0, 0., 0., 0.],
 [0.5, 0.5, 0., 0.],
 [0.333, 0.333, 0.333, 0.],
 [0.25, 0.25, 0.25, 0.25]])

In [None]:
tril = jnp.tril(jnp.ones(shape=(T, T)))
tril

Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)

In [None]:
jnp.sum(tril, axis=1, keepdims=True)

Array([[1.],
       [2.],
       [3.],
       [4.]], dtype=float32)

In [None]:
def get_wei(T: int):
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  return tril/jnp.sum(tril, axis=1, keepdims=True)

In [None]:
get_wei(T)

Array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)

In [None]:
# Putting together the bow attention calculation using matmul
def bow_attention_matmul(x: jnp.array, T: int, C: int):
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = tril/jnp.sum(tril, axis=1, keepdims=True)

  return jnp.dot(wei, x)


In [None]:
bow_attention_matmul(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

### version 3: use softmax to generate wei matrix
so that it can be learnable?

In [None]:
tril = jnp.tril(jnp.ones(shape=(T, T)))
tril

Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)

In [None]:
# you start wei as all zeros
wei = jnp.zeros(shape=(T, T))
wei

Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

In [None]:
# but now, we modify wei such that whenever tril==0, we put -inf into wei
wei = jnp.where(tril==0, -jnp.inf, wei)
wei

Array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]], dtype=float32)

In [None]:
# next we take softmax along row, that is dim==-1
wei = nn.softmax(wei, axis=-1)
wei

Array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)

In [None]:
def calc_attention(x: jnp.array, T:int, C:int):
  """Calculates attention for a row of tokens."""
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)

  return jnp.dot(wei, x)

In [None]:
calc_attention(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

In [None]:
calc_attention(test_arr, T, C)

Array([[1. , 1. ],
       [1.5, 1.5],
       [2. , 2. ],
       [2.5, 2.5]], dtype=float32)

In [None]:
calc_attention_batch = jax.vmap(calc_attention, in_axes=(0, None, None), out_axes=(0))

In [None]:
T, C = 8, 65
test_numbers = jnp.arange(1, T+1).reshape(-1, 1)
test_arr = jnp.tile(test_numbers, (1, C))

# add batch dimension to test_arr
test_arr_batch = test_arr[None, :]

In [None]:
# Test calc_attention_batch using get_batch
calc_attention_batch(test_arr_batch, T, C)

Array([[[1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ],
        [1.5      , 1.5      , 1.5      , 1.5      , 1.5      ,
         1.5      , 1.5      , 1.5      , 1.5      , 1.5      ,
         1.5      , 1.5      , 1.5     

# Putting together new Bigram model

In [None]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(token_embs)

## Add positional embeddings to the above.

In [None]:
block_of_tokens_example = jnp.ones(shape=(1, 8))
block_of_tokens_example, block_of_tokens_example.shape, block_of_tokens_example.shape[1]

(Array([[1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32), (1, 8), 8)

In [None]:
num_pos = block_of_tokens_example.shape[1]
num_pos

8

In [None]:
jnp.arange(0, num_pos)

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [None]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  block_size: int # T, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embed)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)


## Coding single head attention.

In [None]:
# (copy-pasting from top)
# Here, wei has uniform attention scores to previous tokens.
# That is token in Tth position, is assuming that
# each previous token has same amount of info.

# But we want each Tth token to learn what to pay attention to.

# So, we have each token emit "keys" -- info I have
# Each token will emit "query" -- what I'm looking for
# wei becomes the dot prodct of "keys" and "query" -- higher the dot product higher the match between
# what I'm look for and what some previous token has.
def calc_attention(x: jnp.array, T:int, C:int):
  """Calculates attention for a row of tokens."""
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)

  return jnp.dot(wei, x)

In [None]:
token_info_size = 16 # head_size, each token produces vector of this size for key, query

# key, query will take vector of size C.
# i.e., channels containing info of token and will output token_info_size
key_layer = nn.Dense(token_info_size, use_bias=False)

query_layer = nn.Dense(token_info_size, use_bias=False)

In [None]:
# (tokens, channel info for each)
# (T, C)

# for easy visualization, T=4, C=2
T=4; C=2
x = jrand.normal(jrand.PRNGKey(999), shape=(T, C))
x

Array([[ 0.27297866, -0.6993713 ],
       [ 0.428855  , -1.5621939 ],
       [-0.05503325,  0.18392533],
       [-0.18410844,  0.53945136]], dtype=float32)

In [None]:
prng = jrand.PRNGKey(9999)
key, split_key = jrand.split(prng)

In [None]:
# keys emitted by each token.
kparams = key_layer.init(split_key, x)["params"]
keys = key_layer.apply({"params": kparams}, x)

# queries emitted by each token
# NOTE: each token parallely and indpendently emits its "keys" and "queries"
qparams = query_layer.init(split_key, x)["params"]
queries = query_layer.apply({"params": qparams}, x)

keys.shape, queries.shape # each are (T, 16)

((4, 16), (4, 16))

In [None]:
# NOW, wei becomes this dot product between keys and querys
wei = jnp.dot(queries, keys.T)
wei

Array([[ 6.0750933, 12.761856 , -1.5228109, -4.5677843],
       [12.761856 , 27.052605 , -3.221544 , -9.631147 ],
       [-1.5228109, -3.221544 ,  0.3838081,  1.1482861],
       [-4.5677843, -9.631147 ,  1.1482861,  3.4396741]], dtype=float32)

In [None]:
  tril = jnp.tril(jnp.ones(shape=(T, T)))

  # Don't initialize wei as zeros
  # wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)
  wei

Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [6.2173666e-07, 9.9999940e-01, 0.0000000e+00, 0.0000000e+00],
       [1.2637094e-01, 2.3115156e-02, 8.5051382e-01, 0.0000000e+00],
       [3.0229829e-04, 1.9118115e-06, 9.1810778e-02, 9.0788496e-01]],      dtype=float32)

In [None]:
# each token also produces "value", which is what we would multiply with wei.
# so, wei is attention score.
# whenever the attention score is high, we want to take its value.

### Combing things from above and adding value layer as well.

In [None]:
# (tokens, channel info for each)
# (T, C)

# for easy visualization, T=4, C=2
T=4; C=2
x = jrand.normal(jrand.PRNGKey(999), shape=(T, C))

key, split_key = jrand.split(jrand.PRNGKey(9999))

token_info_size = 16 # head_size

key_layer = nn.Dense(token_info_size, use_bias=False)
query_layer = nn.Dense(token_info_size, use_bias=False)
value_layer = nn.Dense(token_info_size, use_bias=False)


keys = key_layer.apply(key_layer.init(split_key, x), x) # (T, token_info_size)
queries = query_layer.apply(query_layer.init(split_key, x), x)
values = value_layer.apply(value_layer.init(split_key, x), x) # (T, token_info_size)

tril = jnp.tril(jnp.ones(shape=(T, T)))

wei = jnp.dot(queries, keys.T) # (T, T)
wei = jnp.where(tril==0, -jnp.inf, wei)
wei = nn.softmax(wei, axis=-1)


out = jnp.dot(wei, values) # (T, T) * (T, token_info_size)

# shape should be (T, token_info_size)
# i.e., (4, 16)
out.shape

(4, 16)

## self-attention vs cross-attention: https://youtu.be/kCc8FmEb1nY?t=4542

### Scaled attention -- dividing wei*value by squared root of head_size https://youtu.be/kCc8FmEb1nY?t=4638

In [None]:
wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
wei = jnp.where(tril==0, -jnp.inf, wei)
wei = nn.softmax(wei, axis=-1)


out = jnp.dot(wei, values) # (T, T) * (T, token_info_size)
out

Array([[ 0.15091053,  1.1747676 , -0.27543876,  0.6276668 , -0.6659185 ,
        -0.7521053 ,  0.15832554,  0.61729145,  0.5173192 , -1.0826657 ,
         0.4595548 ,  0.38114175,  0.36432883, -0.12410479, -0.28500274,
         0.8726827 ],
       [ 0.15494807,  2.420155  , -0.75046575,  1.4276459 , -1.2494795 ,
        -1.4867611 ,  0.18499118,  1.4277841 ,  1.0585895 , -2.3402715 ,
         0.8965928 ,  0.6221145 ,  1.0077578 , -0.24458581, -0.6518733 ,
         1.8538882 ],
       [-0.0108901 , -0.18261386,  0.05756752, -0.1084154 ,  0.09365092,
         0.11186212, -0.01323292, -0.10853641, -0.07983959,  0.1771509 ,
        -0.06739505, -0.04610366, -0.07736284,  0.01839836,  0.04952013,
        -0.14017412],
       [-0.08724618, -0.85421604,  0.22667341, -0.47580495,  0.46656573,
         0.5378475 , -0.09476726, -0.47136265, -0.37513095,  0.8030861 ,
        -0.32692865, -0.2536266 , -0.30200323,  0.08864281,  0.21657023,
        -0.642643  ]], dtype=float32)

# Implement self-attention head.
(copy-pasting from above code mostly into Flax module.)

In [None]:
class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.

  T: int # block size; number of tokens in a block
  C: int # channel info size: size of info channel of each token.


  def setup(self):
    super().setup()

    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    keys = self.key_layer(block_of_tokens_with_info_channels) # (T, token_info_size)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)


    out = jnp.dot(wei, values) # (T, T) * (T, token_info_size))
    return out # (T, token_info_size)


https://youtu.be/kCc8FmEb1nY?t=4819

In [None]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    self.self_attention_head = Head(token_info_size=self.n_embed, T=self.T, C=self.C)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    x = self.self_attention_head(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


## Traing the network.

In [None]:
BLOCK_SIZE

8

In [None]:
T = 8

In [None]:
model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens)
params = params["params"]

model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

4.204059
3.9591136
3.0337052
2.8353043
3.6666312
2.9169798
2.6496902
2.8693814
2.4801302
2.4153466


# Multi-head attention https://youtu.be/kCc8FmEb1nY?t=4925

In [None]:
# (copy-pasting single head attention from above)
class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.

  T: int # block size; number of tokens in a block
  C: int # channel info size: size of info channel of each token.


  def setup(self):
    super().setup()

    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    keys = self.key_layer(block_of_tokens_with_info_channels) # (T, token_info_size)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)


    out = jnp.dot(wei, values) # (T, T) * (T, token_info_size))
    return out # (T, token_info_size)


In [None]:
# You just run multiple attention heads in parallel and concatenate their output along channel dimension, i.e., dim==-1

In [None]:
class MultiHeadAttention(nn.Module):
  num_heads: int
  token_info_size: int

  T: int
  C: int

  def setup(self):
    super().setup()

    self.heads = [Head(token_info_size=self.token_info_size, T=self.T, C=self.C) for _ in range(self.num_heads)]

  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    out_from_each_head = jnp.array([h(block_of_tokens_with_info_channels) for h in self.heads])
    return jnp.concatenate(out_from_each_head, axis=-1)



In [None]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    # *** new ***
    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    self.self_attention_heads = MultiHeadAttention(num_heads=4, token_info_size=int(self.n_embed/4), T=self.T, C=self.C)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    x = self.self_attention_heads(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


In [None]:
model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens)
params = params["params"]

# model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

# *** new ***: Fuck, jax.jit makes it so much faster even on GPU.
model_apply_batch = jax.jit(jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0)))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

4.16981
3.848041
3.2327182
2.9303553
3.073019
2.7059855
2.8082283
2.8490913
2.7219577
2.7902007


# NanoGPT

In [None]:
class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    super().setup()
    # **new**: attention paper uses 4 times token_info_size when doing linear transformation.
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)

    # **new**: projection layer, which goes back into residual pathway.
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x

In [None]:
# (copy-pasting single head attention from above)
class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.
  T: int # block size; number of tokens in a block

  def setup(self):
    super().setup()
    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    keys = self.key_layer(block_of_tokens_with_info_channels) # (T, token_info_size)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # chanel info size
    C = int(block_of_tokens_with_info_channels.shape[-1])
    print("[ntn99] channel_info_size: ", C)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)


    out = jnp.dot(wei, values) # (T, T) * (T, token_info_size))
    return out # (T, token_info_size)


In [None]:
class MultiHeadAttention(nn.Module):
  num_heads: int
  final_token_info_size: int # After concatenating from all heads, how much info (values -- emb size) you have on each token.
  T: int

  def setup(self):
    super().setup()
    self.token_info_size_per_head = int(self.final_token_info_size/self.num_heads)
    self.heads = [
        Head(token_info_size=self.token_info_size_per_head, T=self.T) for _ in range(self.num_heads)
    ]

    self.projection = nn.Dense(features=self.final_token_info_size)

  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    out_from_each_head = jnp.array([h(block_of_tokens_with_info_channels) for h in self.heads])

    # You just run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(out_from_each_head, axis=-1)
    print("[ntn99] out_from_all_heads concatenated shape: ", out_from_all_heads.shape)

    return self.projection(out_from_all_heads)

In [None]:
class Block(nn.Module):
  num_heads: int
  final_token_info_size: int
  T: int

  def setup(self):
    super().setup()

    # communication.
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   final_token_info_size=self.final_token_info_size,
                                                   T=self.T)

    # computation.
    self.computation_layer = FeedForward(output_size=self.final_token_info_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

  def __call__(self, x):
    x = x + self.self_attention_heads(self.ln1(x))
    print("[ntn99] input size after attention_head: ", x.shape)

    x = x + self.computation_layer(self.ln2(x))
    print("[ntn99] input size after computation (end of block): ", x.shape)
    return x

In [None]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    # *** new ***
    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    # self.self_attention_heads = MultiHeadAttention(num_heads=4, final_token_info_size=self.n_embed, T=self.T)
    self.blocks = nn.Sequential([
        Block(num_heads=4, final_token_info_size=self.n_embed, T=self.T),
        Block(num_heads=4, final_token_info_size=self.n_embed, T=self.T),
        Block(num_heads=4, final_token_info_size=self.n_embed, T=self.T),
        Block(num_heads=4, final_token_info_size=self.n_embed, T=self.T),
        nn.LayerNorm(), # TODO: I think my reduction_axis should be 0.
    ])

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    # x = self.self_attention_heads(x)
    x = self.blocks(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


In [None]:
model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens)
params = params["params"]

# model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

# *** new ***: Fuck, jax.jit makes it so much faster even on GPU.
model_apply_batch = jax.jit(jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0)))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.0001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(5000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print("loss", loss, "epoch", epoch) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

# Generating
for generating, I think we give block of tokens and then get next token prediction at each token. But you only uses the next token prediction of the last token because that has attended to everything before it.

You just take the logit at last Tth position and you should pass everything but the 0th from previous tokens and then next predicted token.

In [None]:
BATCH_SIZE, BLOCK_SIZE

(4, 8)

In [None]:
T = BLOCK_SIZE
T

8

In [None]:
state_apply_jit = jax.jit(state.apply_fn)
state_apply_jit({"params": state.params}, context[:, -T:])

In [None]:
next_token_logit.shape

(1, 8, 65)

In [None]:
next_token_logit[:, -1, :].shape

(1, 65)

In [None]:
jax.random.categorical(split_key, next_token_logit[:, -1, :], axis=-1, shape=(1, 1))

Array([[46]], dtype=int32)

## putting it together

In [None]:
T = BLOCK_SIZE

state_apply_jit = jax.jit(state.apply_fn)

context = jnp.tile(jnp.array([52], dtype=jnp.int32), T)
context = context[None, -T:]
key = jrand.PRNGKey(99)

for _ in range(100):
  next_token_logits = state_apply_jit({"params": state.params}, context[:, -T:])

  key, split_key = jrand.split(key)
  new_token = jax.random.categorical(key, next_token_logits[:, -1, :], axis=-1, shape=(1, 1))

  context = jnp.concatenate([context, new_token], axis=1)


In [None]:
decode(context.tolist()[0], itos)

'nnnnnnnnp eSs ,orr\nine, Co,R eaaotewl\nNshM tuhk e \nos,sitdN ci,Bf iC Oot d:y oWtBthrh l sre \naonsoa\noK, aAth'

# Combining all code in one cell
*********

In [8]:
T = BLOCK_SIZE

In [58]:
class FeedForward(nn.Module):
  output_size: int

  def setup(self):
    # **new**: attention paper uses 4 times token_info_size when doing linear transformation.
    # and then projects it back to token_info_size in linear transformation layer.
    self.ffwd = nn.Dense(features=4 * self.output_size)

    # **new**: projection layer, which goes back into residual pathway.
    self.projection = nn.Dense(self.output_size)

  def __call__(self, x, training: bool):
    x = nn.relu(self.ffwd(x))
    x = self.projection(x)
    return x


class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.
  T: int # block size; number of tokens in a block

  def setup(self):
    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)

    self.dropout = nn.Dropout(rate=0.2)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    # input: (T, info channels )
    # output: (T, token_info_size)
    keys = self.key_layer(block_of_tokens_with_info_channels)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # chanel info size
    C = int(block_of_tokens_with_info_channels.shape[-1])
    # print("[ntn99] channel_info_size: ", C)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, token_info_size) * (token_info_size, T) == (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)

    attention_values = jnp.dot(wei, values) # (T, T) * (T, token_info_size))

    attention_values = self.dropout(attention_values, deterministic=not training)

    return attention_values # (T, token_info_size)


class MultiHeadAttention(nn.Module):
  num_heads: int
  final_token_info_size: int # After concatenating from all heads, how much info (values -- emb size) you have on each token.
  T: int

  def setup(self):
    self.token_info_size_per_head = int(self.final_token_info_size/self.num_heads)
    self.heads = [
        Head(token_info_size=self.token_info_size_per_head, T=self.T) for _ in range(self.num_heads)
    ]

    self.projection = nn.Dense(features=self.final_token_info_size)

  def __call__(self, block_of_tokens_with_info_channels: jnp.array, training: bool):
    out_from_each_head = jnp.array([h(block_of_tokens_with_info_channels, training) for h in self.heads])

    # You just run multiple attention heads in parallel and concatenate
    # their output along channel dimension, i.e., dim==-1
    out_from_all_heads = jnp.concatenate(out_from_each_head, axis=-1)
    # print("[ntn99] out_from_all_heads concatenated shape: ", out_from_all_heads.shape)

    return self.projection(out_from_all_heads)


class Block(nn.Module):
  num_heads: int
  final_token_info_size: int
  T: int

  def setup(self):
    # communication.
    self.self_attention_heads = MultiHeadAttention(num_heads=self.num_heads,
                                                   final_token_info_size=self.final_token_info_size,
                                                   T=self.T)

    # computation.
    self.computation_layer = FeedForward(output_size=self.final_token_info_size)

    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

  def __call__(self, x, training: bool):
    x = x + self.self_attention_heads(self.ln1(x), training)
    # print("[ntn99] input size after attention_head: ", x.shape)

    x = x + self.computation_layer(self.ln2(x), training)
    # print("[ntn99] input size after computation (end of block): ", x.shape)
    return x


class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup
  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    # *** new ***
    # Since, there are 4 heads, each head only needs to output token_info of size 8.
    # Concantenate token_info from all 4 heards, gives us 32
    # self.self_attention_heads = MultiHeadAttention(num_heads=4, final_token_info_size=self.n_embed, T=self.T)
    self.num_blocks = 4
    self.blocks = [
        Block(num_heads=4, final_token_info_size=self.n_embed, T=self.T) for _ in range(self.num_blocks)
    ]
    self.ln = nn.LayerNorm()
    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array, training: bool):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.

    # TODO(ntnsonti): setting num_pos to T always
    # num_pos = block_of_tokens.shape[0]
    num_pos = T
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    # x = self.self_attention_heads(x)
    # x = self.blocks(x)(training)
    for i in range(self.num_blocks):
      x = self.blocks[i](x, training)

    x = self.ln(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


## training loop

In [57]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

random_key = jax.random.PRNGKey(99)
random_key, random_subkey = jax.random.split(random_key)

model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens, training=False)
params = params["params"]


For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  key: jax.random.KeyArray


Array([1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

In [68]:
model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0, None), out_axes=(0))

# *** new ***: Fuck, jax.jit makes it so much faster even on GPU.
# model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch, dropout_key):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs, jnp.tile(jnp.array([True], dtype=jnp.int32), T), rngs={'dropout': dropout_key})
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.0001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key)

for epoch in range(1):
  batch = get_batch()

  random_key, random_subkey = jax.random.split(random_key)
  dropout_key = jax.random.fold_in(key=random_key, data=state.step)

  loss, grads = grad_fn(state.params, state, batch, dropout_key)
  print("loss", loss, "epoch", epoch) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 4: axis 0 of argument args[0] of type int32[4,8];
  * one axis had size 2: axis 0 of argument rngs['dropout'] of type uint32[2]

## generating

In [None]:
T = BLOCK_SIZE

state_apply_jit = jax.jit(state.apply_fn)

context = jnp.tile(jnp.array([52], dtype=jnp.int32), T)
context = context[None, -T:]
key = jrand.PRNGKey(99)

for _ in range(100):
  next_token_logits = state_apply_jit({"params": state.params}, context[:, -T:])

  key, split_key = jrand.split(key)
  new_token = jax.random.categorical(key, next_token_logits[:, -1, :], axis=-1, shape=(1, 1))

  context = jnp.concatenate([context, new_token], axis=1)


decode(context.tolist()[0], itos)
