## Use CPU runtime

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

!pip install -q --upgrade pip # To support manylinux2010 wheels.
!pip install -q --upgrade jax jaxlib # CPU-only
!pip install -q --upgrade jaxtyping
!pip install -q --upgrade flax

--2024-02-03 04:17:27--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-02-03 04:17:27 (118 MB/s) - ‘input.txt’ saved [1115394/1115394]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[0m

In [2]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import jaxtyping
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb

def println(*args):
  for arg in args:
    print(arg)


## Dataset pipeline

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Below would result in a minibatch size of 32.
BATCH_SIZE = 4 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 8 # what is the maximum context length for predictions?

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


In [4]:
xb, yb = get_batch()
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47]
 [47 64 43 52 10  0 14 43]
 [53 56 43  1 61 43  1 54]
 [53 41 43 43 42  1 39 52]]
inputs shape
(4, 8)
targets
[[47 56 57 58  1 15 47 58]
 [64 43 52 10  0 14 43 44]
 [56 43  1 61 43  1 54 56]
 [41 43 43 42  1 39 52 63]]
targets shape
(4, 8)
when input is [18] the target: 47
when input is [18, 47] the target: 56
when input is [18, 47, 56] the target: 57
when input is [18, 47, 56, 57] the target: 58
when input is [18, 47, 56, 57, 58] the target: 1
when input is [18, 47, 56, 57, 58, 1] the target: 15
when input is [18, 47, 56, 57, 58, 1, 15] the target: 47
when input is [18, 47, 56, 57, 58, 1, 15, 47] the target: 58
when input is [47] the target: 64
when input is [47, 64] the target: 43
when input is [47, 64, 43] the target: 52
when input is [47, 64, 43, 52] the target: 10
when input is [47, 64, 43, 52, 10] the target: 0
when input is [47, 64, 43, 52, 10, 0] the target: 14
when input is [47, 64, 43, 52, 10, 0, 14] the target: 43
when input is [47, 64, 43, 

In [35]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

In [6]:
# In flax, you need to init the model with sample input that you would pass during each forward pass.
# sample_input_row = jrand.randint(key=jrand.PRNGKey(99), minval=0, maxval=65, dtype=jnp.int32, shape=[1, BLOCK_SIZE])
# sample_input_row

Array([[21, 23, 50, 28, 53, 55, 40, 29]], dtype=int32)

# I'll make the flax model accept
`[block size worth tokens, token]`

# I'll then use vmap to make the model accept batches of data.
`[batch dim, block size worth of tokens, token]`

In [59]:
sample_input_row = jnp.ones(shape=[1, 1], dtype=jnp.int32)
sample_input_row

Array([[1]], dtype=int32)

In [61]:
model = BigramLangModel(vocab_size=65)
output, params = model.init_with_output(jrand.PRNGKey(99), sample_input_row)
params = params["params"]

# Make the model accepts batch of data
`[batch, block of tokens, token_ids]`

# To make it accept a batch, you need to use vmap.

In [62]:
model_apply_batch = jax.vmap(model.apply, in_axes=(None, 0), out_axes=(0))

In [63]:
# model_apply_batch will accept
# [batch, block of tokens, token ids]

## Sample forward pass, loss and backward pass.

In [64]:
batch = get_batch()
inputs, targets = batch
println("inputs", inputs, inputs.shape, "targets", targets.shape)

inputs
[[43  1 59 52 39 41 46 47]
 [45  1 57 41 39 56 57  1]
 [46 47 41 46  1 21  1 57]
 [53 59 50 42  1 46 47 42]]
(4, 8)
targets
(4, 8)


In [65]:
output = model_apply_batch({"params": params}, inputs)

In [68]:
# output should shape [4, 8, 65]
# batch size = 4
# block of tokens = 8
# token_id to embedding = 65

output.shape

(4, 8, 65)

In [None]:
# Decode a batch containing several input rows.
println(",".join([decode(input_row, itos) for input_row in input_rows.tolist()]))

In [76]:
# To do backward pass, you first need to compute grads.
# In JAX, you use jax.grad to do a function transformation on the forward
# function to get the gradient of the original function.
# The grad is calculate wrt to the first param in the function.
def forward_pass(params, batch):
  inputs, targets = batch
  logits = model_apply_batch({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [77]:
# Test forward pass.
batch = get_batch()
forward_pass(params=params, batch=batch)

Array(4.179904, dtype=float32)

In [79]:
grad_fn = jax.grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

In [81]:
# Test forward pass and grads.
# Grads would be the gradients for params.
grads = grad_fn(params, batch)
println(grads)

{'token_embedding_table': {'embedding': Array([[0.00103511, 0.00097044, 0.00093559, ..., 0.00089521, 0.001064  ,
        0.00093856],
       [0.00152714, 0.00157371, 0.00149051, ..., 0.00147518, 0.00161794,
        0.00155381],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], dtype=float32)}}


In [82]:
# Apply grads to params to get new params.
lr = 0.001
println("params before:", params)
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
println("params after:", params)

params before:
{'token_embedding_table': {'embedding': Array([[ 0.0752212 ,  0.01071652, -0.02585994, ..., -0.06997449,
         0.10274917, -0.0226865 ],
       [ 0.09400459,  0.12404279,  0.06972364, ...,  0.0593865 ,
         0.1517611 ,  0.11131446],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.20769818,
         0.01281604,  0.03134193],
       ...,
       [-0.1394756 , -0.00640967, -0.07666602, ..., -0.2944119 ,
         0.11875169, -0.08573762],
       [ 0.05703759, -0.11280773,  0.2570641 , ..., -0.02059634,
        -0.02818088,  0.13305528],
       [-0.12428083, -0.13785616, -0.12170235, ..., -0.07394623,
         0.19811267, -0.06473607]], dtype=float32)}}
params after:
{'token_embedding_table': {'embedding': Array([[ 0.07522016,  0.01071555, -0.02586087, ..., -0.06997538,
         0.1027481 , -0.02268744],
       [ 0.09400307,  0.12404122,  0.06972215, ...,  0.05938502,
         0.15175948,  0.1113129 ],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.207698

## Writing train step in flax

In [18]:
def compute_loss(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  # println(logits)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [19]:
grad_fn = jax.value_and_grad(compute_loss, argnums=(0))

In [21]:
params = model.init(jrand.PRNGKey(90), jnp.ones(shape=(1, BLOCK_SIZE), dtype=jnp.int32))["params"]
opt = optax.adam(learning_rate=0.01)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt)

for epoch in range(1000):
  batch = get_batch()
  loss, grads = grad_fn(state.params, state, batch)
  print(loss) if epoch%100==0 else None
  state = state.apply_gradients(grads=grads)



4.1893883
3.5785098
3.2410522
2.8703156
2.7978456
2.7443452
2.7087464
2.4864938
2.5336409
2.3312209


In [23]:
## Code for generating tokens

input_token = "n"
encode(input_token, stoi)

[52]

In [26]:
input_token = jnp.array([[52]], dtype=jnp.int32)
input_token.shape

(1, 1)

In [30]:
next_token_logit = state.apply_fn({"params": state.params}, input_token)

In [32]:
next_token_logit.shape

(1, 1, 65)

## Mathematical trick in attention

## doing bag of words
basically, in B, T, C
at t-th token in a row of batch, just sum all the values upto t.

In [None]:
# Generate T, C and write code which works for T, C.
# Then, vmap it for batch
T, C = BLOCK_SIZE, 65

key, split_key = jrand.split(jrand.PRNGKey(99))

x = jrand.normal(split_key, (T, C))
x.shape

(8, 65)

In [None]:
def bow_attention(x: jnp.array):
  """Operates on a single row within batch, it calculates
    bow attention by summing all token channels prev + current token channe

  """
  T = BLOCK_SIZE # T == 8
  xbow = jnp.zeros(shape=(T, C))
  for token in range(T):
    xprev = x[:token]
    xcurrent = x[token:token+1]

    current_bow = jnp.mean(jnp.concatenate([xprev, xcurrent], axis=0), axis=0)
    xbow = xbow.at[token].set(current_bow)
  return xbow

In [None]:
# Test bow_attention using non-random tensor.
numbers = jnp.arange(1, 9).reshape(-1, 1)
test_arr = jnp.tile(numbers, (1, 65))
bow_attention(test_arr)

In [None]:
jax.make_jaxpr(bow_attention)

<function jax.make_jaxpr(bow_attention)(x: <function array at 0x7bdf6128c1f0>)>

## NanoGPT

In [None]:
from dataclasses import dataclass
from functools import partial
import pickle

import jax
import jax.numpy as jnp

import flax.linen as nn
from flax.training import train_state
from flax import serialization

import optax


@dataclass
class Config():
    seed = 42
    num_iterations = 20000
    batch_size = 512
    block_size = 64
    learning_rate = 1e-4
    embed_size = 256
    num_heads = 8
    head_size = 32
    num_layers = 6
    dropout = 0.2

config = Config()

with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Let's now split up the data into train and validation sets
data = jnp.array(encode(text))
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
eval_data = data[n:]

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(config.batch_size, 1), minval=0, maxval=len(data)-config.block_size)
    x = dynamic_slice_vmap(data, ix, (config.block_size,))
    y = dynamic_slice_vmap(data, ix+1, (config.block_size,))
    return x, y

In [None]:
class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    reduction_axes = -1

    @nn.compact
    def __call__(self, x):
        """Applies layer normalization on the input."""
        # compute statistics
        mean2 = jnp.mean(jax.lax.square(x), self.reduction_axes, keepdims=True)
        mean = jnp.mean(x, self.reduction_axes, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        # compute normalized inputs
        x_norm = (x - mean) * jax.lax.rsqrt(var + self.epsilon)
        return x_norm * self.param("scale", nn.initializers.ones, x.shape[-1]) + self.param("bias", nn.initializers.zeros, x.shape[-1])

In [None]:
class Attention(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        key = nn.Dense(self.head_size, use_bias=False)(x)
        query = nn.Dense(self.head_size, use_bias=False)(x)
        value = nn.Dense(self.head_size, use_bias=False)(x)

        tril = jnp.tril(jnp.ones((x.shape[-2], x.shape[-2])))
        attention_weights = nn.softmax(jnp.where(tril == 0, -jnp.inf, query @ jnp.transpose(key, axes=(0, 2, 1))), axis=-1)
        attention_weights = nn.Dropout(config.dropout)(attention_weights, deterministic=not training)
        return attention_weights @ value

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = jnp.concatenate([Attention(self.head_size)(x, training) for _ in range(self.num_heads)], axis=-1)
        return nn.Dropout(config.dropout)(nn.Dense(self.num_heads*self.head_size)(x), deterministic=not training)

class FeedFoward(nn.Module):

    @nn.compact
    def __call__(self, x, training: bool):
        return nn.Dropout(config.dropout)(nn.Dense(config.embed_size)(nn.relu(nn.Dense(4*config.embed_size)(x))), deterministic=not training)

class Block(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = x + MultiHeadAttention(self.num_heads, self.head_size)(LayerNorm()(x), training)
        return x + FeedFoward()(LayerNorm()(x), training)

In [None]:
class Model(nn.Module):
    num_layers: int
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        B, T = x.shape
        x = nn.Embed(num_embeddings=vocab_size, features=config.embed_size)(x) + \
            nn.Embed(num_embeddings=config.block_size, features=config.embed_size)(jnp.arange(T))
        for _ in range(self.num_layers):
            x = Block(self.num_heads, self.head_size)(x, training)
        x = nn.LayerNorm()(x)
        return nn.Dense(vocab_size)(x)

    def generate(self, random_key, params, context, length=50):
        for _ in range(length):
            logits = self.apply(params, context[:, -config.block_size:], training=False)
            random_key, random_subkey = jax.random.split(random_key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context, new_token], axis=1)
        return context

    @partial(jax.jit, static_argnames=("self", "length"))
    def generate_jit(self, random_key, params, length):
        def scan_generate(carry, x):
            key, context = carry
            logits = self.apply(params, context, training=False)
            random_key, random_subkey = jax.random.split(key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context[:, 1:], new_token], axis=1)
            return (random_key, context), new_token

        _, new_tokens = jax.lax.scan(
            scan_generate,
            (random_key, jnp.zeros((1, config.block_size), dtype=jnp.int32)),
            (),
            length=length,
        )
        return new_tokens

In [None]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

def create_train_state(random_key, config):
    model = Model(num_layers=config.num_layers, num_heads=config.num_heads, head_size=config.head_size)
    params = model.init(random_key, jnp.ones((config.batch_size, config.block_size), dtype=jnp.int32), training=False)
    tx = optax.adamw(config.learning_rate)
    return TrainState.create(
        apply_fn=model.apply, params=params, key=random_key, tx=tx)

@jax.jit
def train_step(state, x, y, dropout_key):
    dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)
    def loss_fn(params):
        logits = state.apply_fn(params, x, training=True, rngs={'dropout': dropout_key})
        one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
        return optax.softmax_cross_entropy(
            logits=logits, labels=one_hot_encoded_labels
        ).mean()

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)

    return state, loss

@jax.jit
def eval_step(state, x, y):
    logits = state.apply_fn(state.params, x, training=False)
    one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

random_key = jax.random.PRNGKey(config.seed)
random_key, random_subkey = jax.random.split(random_key)

state = create_train_state(random_subkey, config)
for i in range(config.num_iterations):
    random_key, random_subkey = jax.random.split(random_key)
    state, loss = train_step(state, *get_batch(random_subkey, train_data), random_subkey)

    if i % 100 == 0:
        random_key, random_subkey = jax.random.split(random_key)
        print(f"Step: {i}\t train loss: {loss}\t eval loss: {eval_step(state, *get_batch(random_subkey, eval_data))}")

params_state_dict = serialization.to_state_dict(state.params)
with open("./outputs/params.pickle", "wb") as params_file:
    pickle.dump(params_state_dict, params_file)

For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  key: jax.random.KeyArray


Step: 0	 train loss: 4.640990257263184	 eval loss: 4.123988628387451
Step: 100	 train loss: 3.151776075363159	 eval loss: 3.0975451469421387
Step: 200	 train loss: 2.798823356628418	 eval loss: 2.74839448928833
