<h1>Install dependencies for CPU</h1>

In [24]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-02-03 22:55:38--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2024-02-03 22:55:38 (33.6 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [25]:
# !pip install -q --upgrade pip # To support manylinux2010 wheels.
# !pip install -q --upgrade jax jaxlib # CPU-only
# !pip install -q --upgrade jaxtyping
# !pip install -q --upgrade flax

In [26]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import jaxtyping
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb

def println(*args):
  for arg in args:
    print(arg)


# Dataset pipeline

In [27]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Below would result in a minibatch size of 32.
BATCH_SIZE = 4 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 8 # what is the maximum context length for predictions?

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


### Test dataset pipeline

In [28]:
xb, yb = get_batch()
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47]
 [47 64 43 52 10  0 14 43]
 [53 56 43  1 61 43  1 54]
 [53 41 43 43 42  1 39 52]]
inputs shape
(4, 8)
targets
[[47 56 57 58  1 15 47 58]
 [64 43 52 10  0 14 43 44]
 [56 43  1 61 43  1 54 56]
 [41 43 43 42  1 39 52 63]]
targets shape
(4, 8)
when input is [18] the target: 47
when input is [18, 47] the target: 56
when input is [18, 47, 56] the target: 57
when input is [18, 47, 56, 57] the target: 58
when input is [18, 47, 56, 57, 58] the target: 1
when input is [18, 47, 56, 57, 58, 1] the target: 15
when input is [18, 47, 56, 57, 58, 1, 15] the target: 47
when input is [18, 47, 56, 57, 58, 1, 15, 47] the target: 58
when input is [47] the target: 64
when input is [47, 64] the target: 43
when input is [47, 64, 43] the target: 52
when input is [47, 64, 43, 52] the target: 10
when input is [47, 64, 43, 52, 10] the target: 0
when input is [47, 64, 43, 52, 10, 0] the target: 14
when input is [47, 64, 43, 52, 10, 0, 14] the target: 43
when input is [47, 64, 43, 

# Implement Bigram Model

In [29]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

## I'll make the flax model accept
`[block size worth tokens, token]`

## I'll then use vmap to make the model accept batches of data.
`[batch dim, block size worth of tokens, token]`

In [30]:
sample_input_row = jnp.ones(shape=[1, 1], dtype=jnp.int32)
sample_input_row

Array([[1]], dtype=int32)

In [31]:
model = BigramLangModel(vocab_size=65)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_input_row)
params = params["params"]

## Make the model accepts batch of data
`[batch, block of tokens, token_ids]`

## To make it accept a batch, you need to use vmap.

In [32]:
model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

In [33]:
# model_apply_batch will accept
# [batch, block of tokens, token ids]

## Sample forward pass, loss and backward pass.

In [34]:
batch = get_batch()
inputs, targets = batch
println("inputs", inputs, inputs.shape, "targets", targets.shape)

inputs
[[ 1 44 59 56 58 46 43 56]
 [ 1 46 43 39 56  1 51 43]
 [57 54 43 39 49  8  0  0]
 [50 50 10  0 31 54 43 39]]
(4, 8)
targets
(4, 8)


In [35]:
output = model_apply_batch({"params": params}, inputs)

In [36]:
# output should shape [4, 8, 65]
# batch size = 4
# block of tokens = 8
# token_id to embedding = 65

output.shape

(4, 8, 65)

In [37]:
# To do backward pass, you first need to compute grads.
# In JAX, you use jax.grad to do a function transformation on the forward
# function to get the gradient of the original function.
# The grad is calculate wrt to the first param in the function.
def forward_pass(params, batch):
  inputs, targets = batch
  logits = model_apply_batch({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [38]:
# Test forward pass.
batch = get_batch()
forward_pass(params=params, batch=batch)

Array(4.188369, dtype=float32)

In [39]:
grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

In [40]:
# Test forward pass and grads.
# Grads would be the gradients for params.
loss, grads = grad_fn(params, batch)
println("loss", loss, "grads", grads)

loss
4.188369
grads
{'token_embedding_table': {'embedding': Array([[0.00103511, 0.00097044, 0.00093559, ..., 0.00089521, 0.001064  ,
        0.00093856],
       [0.00203619, 0.00209828, 0.00198735, ..., 0.00196691, 0.00215726,
        0.00207174],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00040068, 0.00039527, 0.00040171, ..., 0.00042136, 0.00055311,
        0.00042526]], dtype=float32)}}


In [41]:
# Apply grads to params to get new params.
lr = 0.001
println("params before:", params)
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
println("params after:", params)

params before:
{'token_embedding_table': {'embedding': Array([[ 0.0752212 ,  0.01071652, -0.02585994, ..., -0.06997449,
         0.10274917, -0.0226865 ],
       [ 0.09400459,  0.12404279,  0.06972364, ...,  0.0593865 ,
         0.1517611 ,  0.11131445],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.20769818,
         0.01281604,  0.03134193],
       ...,
       [-0.1394756 , -0.00640967, -0.07666602, ..., -0.2944119 ,
         0.1187517 , -0.08573762],
       [ 0.05703759, -0.11280773,  0.2570641 , ..., -0.02059634,
        -0.02818088,  0.13305528],
       [-0.12428083, -0.13785616, -0.12170236, ..., -0.07394623,
         0.19811267, -0.06473607]], dtype=float32)}}
params after:
{'token_embedding_table': {'embedding': Array([[ 0.07522016,  0.01071555, -0.02586087, ..., -0.06997538,
         0.1027481 , -0.02268744],
       [ 0.09400256,  0.12404069,  0.06972165, ...,  0.05938453,
         0.15175894,  0.11131237],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.207698

# Writing train step in flax
## copy-pasting everything at one place and running a train step.

In [42]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

model = BigramLangModel(vocab_size=65)

sample_input_row = jnp.ones(shape=[1, 1], dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_input_row)
params = params["params"]

model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

4.1769814
4.093553
4.0517178
3.9815784
3.955411
3.8293285
3.7160275
3.7963347
3.7700095
3.5433385


# Implement code for generating tokens

In [43]:
input_token = "n"
encode(input_token, stoi)


[52]

In [44]:
input_token = jnp.array([[52]], dtype=jnp.int32)
input_token.shape

(1, 1)

In [45]:
next_token_logit = state.apply_fn({"params": state.params}, input_token)

In [46]:
next_token_logit

Array([[[-0.03682718,  0.38738227, -0.6297037 , -0.4940241 ,
         -0.5302313 , -0.06883954,  0.02130109, -0.38699383,
          0.05293135, -0.67822653,  0.04108687, -0.17033677,
         -0.21989109, -0.6640769 , -0.5178356 , -0.8944765 ,
         -0.47879392, -0.64374334, -0.83251226, -0.8247406 ,
         -0.7639986 , -0.6349938 , -0.58913904, -0.71123016,
         -0.6316706 , -0.8322299 , -0.61255544, -0.8065904 ,
         -0.6916736 , -0.7578409 , -0.77268964, -0.66586775,
         -0.48679668, -0.6974697 , -0.5410919 , -0.86016047,
         -0.7141595 , -0.87438166, -0.67868704,  0.03718114,
         -0.9563431 , -0.07171286,  0.3103663 ,  0.27640587,
         -0.15164396,  0.25430372, -0.8608676 , -0.11117504,
         -0.5252868 , -0.07216536,  0.02316949, -0.9069602 ,
         -0.26114473,  0.19493027, -0.98034286, -0.70663923,
         -0.25509965,  0.25409326,  0.1060916 , -0.11378516,
         -0.35723025, -0.7799293 , -0.53064156, -0.42475328,
         -0.6610865 ]]],

In [47]:
init_key = jrand.PRNGKey(99)
key, split_key = jrand.split(init_key)

In [48]:
next_to_next_token = jrand.categorical(split_key, next_token_logit)
next_to_next_token

Array([[25]], dtype=int32)

In [49]:
decode(next_to_next_token.tolist()[0], itos)

'M'

In [50]:
next_to_next_to_next_logit = state.apply_fn({"params": state.params}, next_to_next_token)

In [51]:
key, split_key = jrand.split(key)

In [52]:
next_to_next_to_next_token = jrand.categorical(split_key, next_token_logit)

next_to_next_to_next_token

Array([[59]], dtype=int32)

In [53]:
decode(next_to_next_to_next_token.tolist()[0], itos)

'u'

In [54]:
# Putting together the generate code

input_token = jnp.array([[52]], dtype=jnp.int32)
key = jrand.PRNGKey(99)

result = ""
for i in range(100):
  key, split_key = jrand.split(key)
  next_token_logit = state.apply_fn({"params": state.params}, input_token)
  next_token = jrand.categorical(split_key, next_token_logit)
  next_token_decode = decode(next_token.tolist()[0], itos)
  result = result + next_token_decode

print(result)



MuTqF
d$oMJ
 CiSOzjIftBqertiG,3gdghx,,.V,d .zbfa'$fXoeyu'l!m:oaBBQRcrEttkQm3u-r.v3LgdMVfxsx-;ga!kcPW


# Mathematical trick in attention

## doing bag of words
basically, in B, T, C
at t-th token in a row of batch, just sum all the values upto t.

In [55]:
# Generate T, C and write code which works for T, C.
# Then, vmap it for batch

# Using tokens = 4
# Using each token with channel = 2 to make it easy to visualize
T, C = 4, 2

key, split_key = jrand.split(jrand.PRNGKey(99))

x = jrand.normal(split_key, (T, C))
x

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.38331428, -0.16180514],
       [ 1.4986674 ,  1.10728   ],
       [ 1.1535788 ,  0.9676542 ]], dtype=float32)

### version 1: using for loop

In [56]:
def bow_attention(x: jnp.array, T: int, C: int):
  """Operates on a single row within batchs

    It calculates bow attention by summing all
    token channels prev + current token channels.
  """
  xbow = jnp.zeros(shape=(T, C))
  for token in range(T):
    xprev = x[:token]
    xcurrent = x[token:token+1]

    current_bow = jnp.mean(jnp.concatenate([xprev, xcurrent], axis=0), axis=0)
    xbow = xbow.at[token].set(current_bow)
  return xbow

In [57]:
bow_attention(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

In [58]:
# Test bow_attention using non-random tensor.
test_numbers = jnp.arange(1, 5).reshape(-1, 1)
test_arr = jnp.tile(test_numbers, (1, C))
bow_attention(test_arr, T, C)

Array([[1. , 1. ],
       [1.5, 1.5],
       [2. , 2. ],
       [2.5, 2.5]], dtype=float32)

### version 2: Starting to write it using matmul

In [59]:
x

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.38331428, -0.16180514],
       [ 1.4986674 ,  1.10728   ],
       [ 1.1535788 ,  0.9676542 ]], dtype=float32)

In [60]:
wei = jnp.array([[1.0, 0., 0., 0.],
 [0.5, 0.5, 0., 0.],
 [0.333, 0.333, 0.333, 0.],
 [0.25, 0.25, 0.25, 0.25]])

In [61]:
tril = jnp.tril(jnp.ones(shape=(T, T)))
tril

Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)

In [62]:
jnp.sum(tril, axis=1, keepdims=True)

Array([[1.],
       [2.],
       [3.],
       [4.]], dtype=float32)

In [63]:
def get_wei(T: int):
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  return tril/jnp.sum(tril, axis=1, keepdims=True)

In [64]:
get_wei(T)

Array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)

In [65]:
# Putting together the bow attention calculation using matmul
def bow_attention_matmul(x: jnp.array, T: int, C: int):
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = tril/jnp.sum(tril, axis=1, keepdims=True)

  return jnp.dot(wei, x)


In [66]:
bow_attention_matmul(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

### version 3: use softmax to generate wei matrix
so that it can be learnable?

In [67]:
tril = jnp.tril(jnp.ones(shape=(T, T)))
tril

Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)

In [68]:
# you start wei as all zeros
wei = jnp.zeros(shape=(T, T))
wei

Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

In [69]:
# but now, we modify wei such that whenever tril==0, we put -inf into wei
wei = jnp.where(tril==0, -jnp.inf, wei)
wei

Array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]], dtype=float32)

In [70]:
# next we take softmax along row, that is dim==-1
wei = nn.softmax(wei, axis=-1)
wei

Array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)

In [71]:
def calc_attention(x: jnp.array, T:int, C:int):
  """Calculates attention for a row of tokens."""
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)

  return jnp.dot(wei, x)

In [72]:
calc_attention(x, T, C)

Array([[ 0.2628779 , -0.1837252 ],
       [ 0.3230961 , -0.17276517],
       [ 0.7149532 ,  0.25391656],
       [ 0.8246096 ,  0.432351  ]], dtype=float32)

In [73]:
calc_attention(test_arr, T, C)

Array([[1. , 1. ],
       [1.5, 1.5],
       [2. , 2. ],
       [2.5, 2.5]], dtype=float32)

In [74]:
calc_attention_batch = jax.vmap(calc_attention, in_axes=(0, None, None), out_axes=(0))

In [75]:
T, C = 8, 65
test_numbers = jnp.arange(1, T+1).reshape(-1, 1)
test_arr = jnp.tile(test_numbers, (1, C))

# add batch dimension to test_arr
test_arr_batch = test_arr[None, :]

In [76]:
# Test calc_attention_batch using get_batch
calc_attention_batch(test_arr_batch, T, C)

Array([[[1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ,
         1.       , 1.       , 1.       , 1.       , 1.       ],
        [1.5      , 1.5      , 1.5      , 1.5      , 1.5      ,
         1.5      , 1.5      , 1.5      , 1.5      , 1.5      ,
         1.5      , 1.5      , 1.5     

# Putting together new Bigram model

In [77]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(token_embs)

## Add positional embeddings to the above.

In [78]:
block_of_tokens_example = jnp.ones(shape=(1, 8))
block_of_tokens_example, block_of_tokens_example.shape, block_of_tokens_example.shape[1]

(Array([[1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32), (1, 8), 8)

In [79]:
num_pos = block_of_tokens_example.shape[1]
num_pos

8

In [80]:
jnp.arange(0, num_pos)

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [81]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  block_size: int # T, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embed)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)


## Coding single head attention.

In [82]:
# (copy-pasting from top)
# Here, wei has uniform attention scores to previous tokens.
# That is token in Tth position, is assuming that
# each previous token has same amount of info.

# But we want each Tth token to learn what to pay attention to.

# So, we have each token emit "keys" -- info I have
# Each token will emit "query" -- what I'm looking for
# wei becomes the dot prodct of "keys" and "query" -- higher the dot product higher the match between
# what I'm look for and what some previous token has.
def calc_attention(x: jnp.array, T:int, C:int):
  """Calculates attention for a row of tokens."""
  tril = jnp.tril(jnp.ones(shape=(T, T)))
  wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)

  return jnp.dot(wei, x)

In [83]:
token_info_size = 16 # head_size, each token produces vector of this size for key, query

# key, query will take vector of size C.
# i.e., channels containing info of token and will output token_info_size
key_layer = nn.Dense(token_info_size, use_bias=False)

query_layer = nn.Dense(token_info_size, use_bias=False)

In [84]:
# (tokens, channel info for each)
# (T, C)

# for easy visualization, T=4, C=2
T=4; C=2
x = jrand.normal(jrand.PRNGKey(999), shape=(T, C))
x

Array([[ 0.27297866, -0.6993713 ],
       [ 0.428855  , -1.5621939 ],
       [-0.05503325,  0.18392533],
       [-0.18410844,  0.53945136]], dtype=float32)

In [85]:
prng = jrand.PRNGKey(9999)
key, split_key = jrand.split(prng)

In [86]:
# keys emitted by each token.
kparams = key_layer.init(split_key, x)["params"]
keys = key_layer.apply({"params": kparams}, x)

# queries emitted by each token
# NOTE: each token parallely and indpendently emits its "keys" and "queries"
qparams = query_layer.init(split_key, x)["params"]
queries = query_layer.apply({"params": qparams}, x)

keys.shape, queries.shape # each are (T, 16)

((4, 16), (4, 16))

In [87]:
# NOW, wei becomes this dot product between keys and querys
wei = jnp.dot(queries, keys.T)
wei

Array([[ 6.0750933, 12.761856 , -1.5228109, -4.5677843],
       [12.761856 , 27.052605 , -3.221544 , -9.631147 ],
       [-1.5228109, -3.221544 ,  0.3838081,  1.1482861],
       [-4.5677843, -9.631147 ,  1.1482861,  3.4396741]], dtype=float32)

In [88]:
  tril = jnp.tril(jnp.ones(shape=(T, T)))

  # Don't initialize wei as zeros
  # wei = jnp.zeros(shape=(T, T))
  wei = jnp.where(tril==0, -jnp.inf, wei)
  wei = nn.softmax(wei, axis=-1)
  wei

Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [6.2173666e-07, 9.9999940e-01, 0.0000000e+00, 0.0000000e+00],
       [1.2637094e-01, 2.3115156e-02, 8.5051382e-01, 0.0000000e+00],
       [3.0229829e-04, 1.9118115e-06, 9.1810778e-02, 9.0788496e-01]],      dtype=float32)

In [89]:
# each token also produces "value", which is what we would multiply with wei.
# so, wei is attention score.
# whenever the attention score is high, we want to take its value.

### Combing things from above and adding value layer as well.

In [90]:
# (tokens, channel info for each)
# (T, C)

# for easy visualization, T=4, C=2
T=4; C=2
x = jrand.normal(jrand.PRNGKey(999), shape=(T, C))

key, split_key = jrand.split(jrand.PRNGKey(9999))

token_info_size = 16 # head_size

key_layer = nn.Dense(token_info_size, use_bias=False)
query_layer = nn.Dense(token_info_size, use_bias=False)
value_layer = nn.Dense(token_info_size, use_bias=False)


keys = key_layer.apply(key_layer.init(split_key, x), x) # (T, token_info_size)
queries = query_layer.apply(query_layer.init(split_key, x), x)
values = value_layer.apply(value_layer.init(split_key, x), x) # (T, token_info_size)

tril = jnp.tril(jnp.ones(shape=(T, T)))

wei = jnp.dot(queries, keys.T) # (T, T)
wei = jnp.where(tril==0, -jnp.inf, wei)
wei = nn.softmax(wei, axis=-1)


out = jnp.dot(wei, values) # (T, T) * (T, token_info_size)

# shape should be (T, token_info_size)
# i.e., (4, 16)
out.shape

(4, 16)

## self-attention vs cross-attention: https://youtu.be/kCc8FmEb1nY?t=4542

### Scaled attention -- dividing wei*value by squared root of head_size https://youtu.be/kCc8FmEb1nY?t=4638

In [91]:
wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
wei = jnp.where(tril==0, -jnp.inf, wei)
wei = nn.softmax(wei, axis=-1)


out = jnp.dot(wei, values) # (T, T) * (T, token_info_size)
out

Array([[ 0.15091053,  1.1747676 , -0.27543876,  0.6276668 , -0.6659185 ,
        -0.7521053 ,  0.15832554,  0.61729145,  0.5173192 , -1.0826657 ,
         0.4595548 ,  0.38114175,  0.36432883, -0.12410479, -0.28500274,
         0.8726827 ],
       [ 0.15494807,  2.420155  , -0.75046575,  1.4276459 , -1.2494795 ,
        -1.4867611 ,  0.18499118,  1.4277841 ,  1.0585895 , -2.3402715 ,
         0.8965928 ,  0.6221145 ,  1.0077578 , -0.24458581, -0.6518733 ,
         1.8538882 ],
       [-0.0108901 , -0.18261386,  0.05756752, -0.1084154 ,  0.09365092,
         0.11186212, -0.01323292, -0.10853641, -0.07983959,  0.1771509 ,
        -0.06739505, -0.04610366, -0.07736284,  0.01839836,  0.04952013,
        -0.14017412],
       [-0.08724618, -0.85421604,  0.22667341, -0.47580495,  0.46656573,
         0.5378475 , -0.09476726, -0.47136265, -0.37513095,  0.8030861 ,
        -0.32692865, -0.2536266 , -0.30200323,  0.08864281,  0.21657023,
        -0.642643  ]], dtype=float32)

# Implement self-attention head.
(copy-pasting from above code mostly into Flax module.)

In [92]:
class Head(nn.Module):
  token_info_size: int # head_size; how much (emb dim) info each token emits for keys, queries, values.

  T: int # block size; number of tokens in a block
  C: int # channel info size: size of info channel of each token.


  def setup(self):
    super().setup()

    # key, query will take vector of size C.
    # i.e., channels containing info of token and will output token_info_size
    self.key_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.query_layer = nn.Dense(self.token_info_size, use_bias=False)
    self.value_layer = nn.Dense(self.token_info_size, use_bias=False)


  def __call__(self, block_of_tokens_with_info_channels: jnp.array):
    """Accepts a block of tokens with info channels, like (8, 65)."""

    # TODO(ntnsonti): Double check; but tril should not be learnable according cGPT.
    tril = jnp.tril(jnp.ones(shape=(self.T, self.T)))

    keys = self.key_layer(block_of_tokens_with_info_channels) # (T, token_info_size)
    queries = self.query_layer(block_of_tokens_with_info_channels)
    values = self.value_layer(block_of_tokens_with_info_channels)

    # compute attention score.
    wei = jnp.dot(queries, keys.T) * C**0.5 # (T, T)
    wei = jnp.where(tril==0, -jnp.inf, wei)
    wei = nn.softmax(wei, axis=-1)


    out = jnp.dot(wei, values) # (T, T) * (T, token_info_size))
    return out # (T, token_info_size)


https://youtu.be/kCc8FmEb1nY?t=4819

In [93]:
class LanguageModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int # number of vocabulary (number of rows of embedding table)
  n_embed: int # embedding dim after lookup

  T: int # block size, i.e., number of tokens attention block is looking at once

  def setup(self):
    super().setup()
    # number of channels you want to use for store info for each token.
    self.C = self.vocab_size

    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed)

    self.pos_embedding_table = nn.Embed(num_embeddings=self.T, features=self.n_embed)

    self.self_attention_head = Head(token_info_size=self.n_embed, T=self.T, C=self.C)

    self.lang_model_head = nn.Dense(features=self.C)

  def __call__(self, block_of_tokens: jnp.array):
    """Accepts a block of tokens, like [0, 1, 2, 3, 4, 5, 6, 7]."""

    # generate em for each token. output: (T, n_embed)
    token_embs = self.token_embedding_table(block_of_tokens)

    # generate position embs for each token.
    ## get token positions.
    num_pos = block_of_tokens.shape[0]
    positions = jnp.arange(0, num_pos)
    pos_embs = self.pos_embedding_table(positions)

    # generate actual input to attention, x, which is sum of token_embs + pos_embs
    x = token_embs + pos_embs

    # feed x into self-attention head.
    x = self.self_attention_head(x)

    # generate logits for each token. output: (T, channels for info -- C)
    token_logits = self.lang_model_head(x)

    return token_logits


## Traing the network.

In [94]:
BLOCK_SIZE

8

In [95]:
T = 8

In [341]:
model = LanguageModel(vocab_size=65, n_embed=32, T=BLOCK_SIZE)

# Now, our language model needs to accept a block of tokens, not one-char at a time.
# We'll then make it accept a batch of blocks of tokens using vmap.
sample_block_of_tokens = jnp.ones(shape=(T), dtype=jnp.int32)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_block_of_tokens)
params = params["params"]

model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

def forward_pass(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

opt = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)

4.1696367
3.4718578
3.3918285
2.907994
3.101877
2.7360342
2.9926176
2.8234916
2.7058547
2.5178351


## NanoGPT

In [3]:
from dataclasses import dataclass
from functools import partial
import pickle

import jax
import jax.numpy as jnp

import flax.linen as nn
from flax.training import train_state
from flax import serialization

import optax


@dataclass
class Config():
    seed = 42
    num_iterations = 20000
    batch_size = 512
    block_size = 64
    learning_rate = 1e-4
    embed_size = 256
    num_heads = 8
    head_size = 32
    num_layers = 6
    dropout = 0.2

config = Config()

with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Let's now split up the data into train and validation sets
data = jnp.array(encode(text))
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
eval_data = data[n:]

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(config.batch_size, 1), minval=0, maxval=len(data)-config.block_size)
    x = dynamic_slice_vmap(data, ix, (config.block_size,))
    y = dynamic_slice_vmap(data, ix+1, (config.block_size,))
    return x, y

FileNotFoundError: [Errno 2] No such file or directory: 'input.txt'

In [None]:
class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    reduction_axes = -1

    @nn.compact
    def __call__(self, x):
        """Applies layer normalization on the input."""
        # compute statistics
        mean2 = jnp.mean(jax.lax.square(x), self.reduction_axes, keepdims=True)
        mean = jnp.mean(x, self.reduction_axes, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        # compute normalized inputs
        x_norm = (x - mean) * jax.lax.rsqrt(var + self.epsilon)
        return x_norm * self.param("scale", nn.initializers.ones, x.shape[-1]) + self.param("bias", nn.initializers.zeros, x.shape[-1])

In [None]:
class Attention(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        key = nn.Dense(self.head_size, use_bias=False)(x)
        query = nn.Dense(self.head_size, use_bias=False)(x)
        value = nn.Dense(self.head_size, use_bias=False)(x)

        tril = jnp.tril(jnp.ones((x.shape[-2], x.shape[-2])))
        attention_weights = nn.softmax(jnp.where(tril == 0, -jnp.inf, query @ jnp.transpose(key, axes=(0, 2, 1))), axis=-1)
        attention_weights = nn.Dropout(config.dropout)(attention_weights, deterministic=not training)
        return attention_weights @ value

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = jnp.concatenate([Attention(self.head_size)(x, training) for _ in range(self.num_heads)], axis=-1)
        return nn.Dropout(config.dropout)(nn.Dense(self.num_heads*self.head_size)(x), deterministic=not training)

class FeedFoward(nn.Module):

    @nn.compact
    def __call__(self, x, training: bool):
        return nn.Dropout(config.dropout)(nn.Dense(config.embed_size)(nn.relu(nn.Dense(4*config.embed_size)(x))), deterministic=not training)

class Block(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = x + MultiHeadAttention(self.num_heads, self.head_size)(LayerNorm()(x), training)
        return x + FeedFoward()(LayerNorm()(x), training)

In [None]:
class Model(nn.Module):
    num_layers: int
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        B, T = x.shape
        x = nn.Embed(num_embeddings=vocab_size, features=config.embed_size)(x) + \
            nn.Embed(num_embeddings=config.block_size, features=config.embed_size)(jnp.arange(T))
        for _ in range(self.num_layers):
            x = Block(self.num_heads, self.head_size)(x, training)
        x = nn.LayerNorm()(x)
        return nn.Dense(vocab_size)(x)

    def generate(self, random_key, params, context, length=50):
        for _ in range(length):
            logits = self.apply(params, context[:, -config.block_size:], training=False)
            random_key, random_subkey = jax.random.split(random_key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context, new_token], axis=1)
        return context

    @partial(jax.jit, static_argnames=("self", "length"))
    def generate_jit(self, random_key, params, length):
        def scan_generate(carry, x):
            key, context = carry
            logits = self.apply(params, context, training=False)
            random_key, random_subkey = jax.random.split(key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context[:, 1:], new_token], axis=1)
            return (random_key, context), new_token

        _, new_tokens = jax.lax.scan(
            scan_generate,
            (random_key, jnp.zeros((1, config.block_size), dtype=jnp.int32)),
            (),
            length=length,
        )
        return new_tokens

In [None]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

def create_train_state(random_key, config):
    model = Model(num_layers=config.num_layers, num_heads=config.num_heads, head_size=config.head_size)
    params = model.init(random_key, jnp.ones((config.batch_size, config.block_size), dtype=jnp.int32), training=False)
    tx = optax.adamw(config.learning_rate)
    return TrainState.create(
        apply_fn=model.apply, params=params, key=random_key, tx=tx)

@jax.jit
def train_step(state, x, y, dropout_key):
    dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)
    def loss_fn(params):
        logits = state.apply_fn(params, x, training=True, rngs={'dropout': dropout_key})
        one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
        return optax.softmax_cross_entropy(
            logits=logits, labels=one_hot_encoded_labels
        ).mean()

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)

    return state, loss

@jax.jit
def eval_step(state, x, y):
    logits = state.apply_fn(state.params, x, training=False)
    one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

random_key = jax.random.PRNGKey(config.seed)
random_key, random_subkey = jax.random.split(random_key)

state = create_train_state(random_subkey, config)
for i in range(config.num_iterations):
    random_key, random_subkey = jax.random.split(random_key)
    state, loss = train_step(state, *get_batch(random_subkey, train_data), random_subkey)

    if i % 100 == 0:
        random_key, random_subkey = jax.random.split(random_key)
        print(f"Step: {i}\t train loss: {loss}\t eval loss: {eval_step(state, *get_batch(random_subkey, eval_data))}")

params_state_dict = serialization.to_state_dict(state.params)
with open("./outputs/params.pickle", "wb") as params_file:
    pickle.dump(params_state_dict, params_file)