## Use CPU runtime

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

!pip install -q --upgrade pip # To support manylinux2010 wheels.
!pip install -q --upgrade jax jaxlib # CPU-only
!pip install -q --upgrade jaxtyping
!pip install -q --upgrade flax

--2023-11-11 01:34:41--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-11-11 01:34:41 (14.5 MB/s) - ‘input.txt’ saved [1115394/1115394]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[0m

In [2]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import jaxtyping
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb

def println(*args):
  for arg in args:
    print(arg)


## Dataset pipeline

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Below would result in a minibatch size of 32.
BATCH_SIZE = 4 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 8 # what is the maximum context length for predictions?

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


In [4]:
xb, yb = get_batch()
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47]
 [47 64 43 52 10  0 14 43]
 [53 56 43  1 61 43  1 54]
 [53 41 43 43 42  1 39 52]]
inputs shape
(4, 8)
targets
[[47 56 57 58  1 15 47 58]
 [64 43 52 10  0 14 43 44]
 [56 43  1 61 43  1 54 56]
 [41 43 43 42  1 39 52 63]]
targets shape
(4, 8)
when input is [18] the target: 47
when input is [18, 47] the target: 56
when input is [18, 47, 56] the target: 57
when input is [18, 47, 56, 57] the target: 58
when input is [18, 47, 56, 57, 58] the target: 1
when input is [18, 47, 56, 57, 58, 1] the target: 15
when input is [18, 47, 56, 57, 58, 1, 15] the target: 47
when input is [18, 47, 56, 57, 58, 1, 15, 47] the target: 58
when input is [47] the target: 64
when input is [47, 64] the target: 43
when input is [47, 64, 43] the target: 52
when input is [47, 64, 43, 52] the target: 10
when input is [47, 64, 43, 52, 10] the target: 0
when input is [47, 64, 43, 52, 10, 0] the target: 14
when input is [47, 64, 43, 52, 10, 0, 14] the target: 43
when input is [47, 64, 43, 

In [5]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits

In [8]:
# In flax, you need to init the model with sample input that you would pass during each forward pass.
sample_input_row = jrand.randint(key=jrand.PRNGKey(99), minval=0, maxval=65, dtype=jnp.int32, shape=[BLOCK_SIZE])
sample_input_row

Array([21, 23, 50, 28, 53, 55, 40, 29], dtype=int32)

In [10]:
sample_input_batch = jrand.randint(key=jrand.PRNGKey(99), minval=0, maxval=65, dtype=jnp.int32, shape=[BATCH_SIZE, BLOCK_SIZE])
sample_input_batch

Array([[52,  4, 61, 62, 37, 37, 33, 18],
       [64, 41, 20,  2, 41, 35, 29, 21],
       [40, 45, 52, 10, 12, 55, 49, 56],
       [ 3, 60, 15,  5, 14, 12,  1, 13]], dtype=int32)

In [12]:
model = BigramLangModel(vocab_size=65)
params = model.init(jrand.PRNGKey(99), sample_input_batch)["params"]
model, params

(BigramLangModel(
     # attributes
     vocab_size = 65
 ),
 {'token_embedding_table': {'embedding': Array([[ 0.0752212 ,  0.01071652, -0.02585994, ..., -0.06997449,
            0.10274917, -0.0226865 ],
          [ 0.09400459,  0.12404279,  0.06972364, ...,  0.0593865 ,
            0.1517611 ,  0.11131446],
          [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.20769818,
            0.01281604,  0.03134193],
          ...,
          [-0.1394756 , -0.00640967, -0.07666602, ..., -0.2944119 ,
            0.11875169, -0.08573762],
          [ 0.05703759, -0.11280773,  0.2570641 , ..., -0.02059634,
           -0.02818088,  0.13305528],
          [-0.12428083, -0.13785616, -0.12170235, ..., -0.07394623,
            0.19811267, -0.06473607]], dtype=float32)}})

In [20]:
sample_logits = model.apply({"params": params}, sample_input_batch)
"sample batch shape", sample_input_batch.shape, "sample logits", sample_logits

('sample batch shape',
 (4, 8),
 'sample logits',
 Array([[[ 8.75270739e-02, -2.27933563e-02,  2.42999336e-03, ...,
           1.76198840e-01, -2.76548654e-01,  4.57528904e-02],
         [ 1.04733538e-02,  3.02993655e-01, -6.49068654e-02, ...,
          -3.16001356e-01,  6.11841679e-02,  4.50519659e-02],
         [-3.75676826e-02,  2.13820398e-01,  5.01289777e-02, ...,
          -4.57006246e-02, -4.29858230e-02, -2.51459748e-01],
         ...,
         [ 3.42681557e-02, -5.03289811e-02,  6.59156889e-02, ...,
          -8.72746203e-03,  4.07012515e-02,  3.04256212e-02],
         [-2.42578298e-01, -9.27777663e-02,  3.04300897e-02, ...,
          -7.17728958e-02,  4.83715236e-02, -5.26186675e-02],
         [-5.26097491e-02, -6.47737905e-02,  2.42933154e-01, ...,
           1.00368716e-01,  1.61088228e-01, -2.64439564e-02]],
 
        [[-1.24280833e-01, -1.37856156e-01, -1.21702351e-01, ...,
          -7.39462301e-02,  1.98112667e-01, -6.47360682e-02],
         [-1.87322591e-02, -8.3400882

## Sample forward pass, loss and backward pass.

In [22]:
batch = get_batch()
inputs, targets = batch
println("inputs", inputs, inputs.shape, "targets", targets.shape)

inputs
[[ 6  1 57 54 43 39 49  8]
 [ 0 18 47 56 57 58  1 15]
 [58 47 64 43 52 10  0 37]
 [59  1 39 56 43  1 39 50]]
(4, 8)
targets
(4, 8)


In [23]:
# Decode a batch containing several input rows.
println(",".join([decode(input_row, itos) for input_row in inputs.tolist()]))

, speak.,
First C,tizen:
Y,u are al


In [24]:
logits = model.apply({"params": params}, inputs)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
loss = loss.mean()
println(loss)

4.188369


In [25]:
# To do backward pass, you first need to compute grads.
# In JAX, you use jax.grad to do a function transformation on the forward
# function to get the gradient of the original function.
# The grad is calculate wrt to the first param in the function.
def forward_pass(params, batch):
  inputs, targets = batch
  logits = model.apply({"params": params}, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [27]:
grad_fn = jax.grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.

In [28]:
grads = grad_fn(params, batch)
# These are the grads for the params.
println(grads)

{'token_embedding_table': {'embedding': Array([[0.00103511, 0.00097044, 0.00093559, ..., 0.00089521, 0.001064  ,
        0.00093856],
       [0.00203619, 0.00209828, 0.00198735, ..., 0.00196691, 0.00215726,
        0.00207174],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00040068, 0.00039527, 0.00040171, ..., 0.00042136, 0.00055311,
        0.00042526]], dtype=float32)}}


In [29]:
# Apply grads to params to get new params.
lr = 0.001
println("params before:", params)
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
println("params after:", params)

params before:
{'token_embedding_table': {'embedding': Array([[ 0.0752212 ,  0.01071652, -0.02585994, ..., -0.06997449,
         0.10274917, -0.0226865 ],
       [ 0.09400459,  0.12404279,  0.06972364, ...,  0.0593865 ,
         0.1517611 ,  0.11131446],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.20769818,
         0.01281604,  0.03134193],
       ...,
       [-0.1394756 , -0.00640967, -0.07666602, ..., -0.2944119 ,
         0.11875169, -0.08573762],
       [ 0.05703759, -0.11280773,  0.2570641 , ..., -0.02059634,
        -0.02818088,  0.13305528],
       [-0.12428083, -0.13785616, -0.12170235, ..., -0.07394623,
         0.19811267, -0.06473607]], dtype=float32)}}
params after:
{'token_embedding_table': {'embedding': Array([[ 0.07522016,  0.01071555, -0.02586087, ..., -0.06997538,
         0.1027481 , -0.02268744],
       [ 0.09400256,  0.12404069,  0.06972165, ...,  0.05938453,
         0.15175894,  0.11131239],
       [-0.0302137 , -0.07326671, -0.2515272 , ...,  0.207698

## Writing train step in flax

In [31]:
def compute_loss(params, state, batch):
  inputs, targets = batch
  logits = state.apply_fn({"params": params}, inputs)
  # println(logits)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
  loss = loss.mean()
  return loss

In [32]:
grad_fn = jax.grad(compute_loss, argnums=(0))