## Use CPU runtime

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

!pip install -q --upgrade pip # To support manylinux2010 wheels.
!pip install -q --upgrade jax jaxlib # CPU-only
!pip install -q --upgrade jaxtyping
!pip install -q --upgrade flax

--2023-09-26 03:21:45--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-09-26 03:21:45 (15.7 MB/s) - ‘input.txt’ saved [1115394/1115394]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.5/84.5 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behavi

In [2]:
from typing import List, Dict, Mapping, Tuple

import jax
import jax.numpy as jnp
import jax.random as jrand
import jaxtyping
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb

def println(*args):
  for arg in args:
    print(arg)


## Dataset pipeline

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Create chars vocubulary using all the unique characters in the text.
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Create mapping from characters to integers.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create reverse mapping from integers to characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Create encode, decode function.
def encode(s: str, stoi: Mapping[str, int]) -> List[int]:
  return [stoi[c] for c in s]

def decode(tokens: List[int], itos: Mapping[int, str]) -> str:
  return ''.join([itos[i] for i in tokens])

println(encode("hii there", stoi), decode(encode("hii there", stoi), itos))

# Let's now split up the data into train and validation sets.
data = jnp.array(encode(text, stoi), dtype=jnp.int64)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Below would result in a minibatch size of 32.
BATCH_SIZE = 4 # how many independent sequences will we process in parallel?
BLOCK_SIZE = 8 # what is the maximum context length for predictions?

train_dataset = (tf.data.Dataset.from_tensor_slices(train_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())
val_dataset = (tf.data.Dataset.from_tensor_slices(val_data)
                .batch(BLOCK_SIZE+1)
                .map(lambda input: (input[:BLOCK_SIZE], input[1:BLOCK_SIZE+1]),
                     num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .repeat()
                .as_numpy_iterator())

def get_batch(training: bool = True):
  if not training:
    val_batch = next(val_dataset)
    return jnp.array(val_batch)

  train_batch = next(train_dataset)
  return jnp.array(train_batch)

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


  data = jnp.array(encode(text, stoi), dtype=jnp.int64)


In [4]:
xb, yb = get_batch()
println("inputs", xb, "inputs shape", xb.shape)
println("targets", yb, "targets shape", yb.shape)
for b in range(BATCH_SIZE): # batch dimension
    for t in range(BLOCK_SIZE): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs
[[18 47 56 57 58  1 15 47]
 [47 64 43 52 10  0 14 43]
 [53 56 43  1 61 43  1 54]
 [53 41 43 43 42  1 39 52]]
inputs shape
(4, 8)
targets
[[47 56 57 58  1 15 47 58]
 [64 43 52 10  0 14 43 44]
 [56 43  1 61 43  1 54 56]
 [41 43 43 42  1 39 52 63]]
targets shape
(4, 8)
when input is [18] the target: 47
when input is [18, 47] the target: 56
when input is [18, 47, 56] the target: 57
when input is [18, 47, 56, 57] the target: 58
when input is [18, 47, 56, 57, 58] the target: 1
when input is [18, 47, 56, 57, 58, 1] the target: 15
when input is [18, 47, 56, 57, 58, 1, 15] the target: 47
when input is [18, 47, 56, 57, 58, 1, 15, 47] the target: 58
when input is [47] the target: 64
when input is [47, 64] the target: 43
when input is [47, 64, 43] the target: 52
when input is [47, 64, 43, 52] the target: 10
when input is [47, 64, 43, 52, 10] the target: 0
when input is [47, 64, 43, 52, 10, 0] the target: 14
when input is [47, 64, 43, 52, 10, 0, 14] the target: 43
when input is [47, 64, 43, 

In [None]:
class BigramLangModel(nn.Module):
  """Reads one char and predicits the next char."""
  vocab_size: int

  def setup(self):
    super().setup()
    self.token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)

  def __call__(self, inputs):
    # Run block size inputs through embedding lookup.
    # For each char, you get the logit predicted for that char.
    # Then, you use the target token for that input and do a cross_entropy_loss.
    logits = self.token_embedding_table(inputs)
    return logits