<a href="https://colab.research.google.com/github/himanshu-warulkar/JAX-and-Flax-projects/blob/main/MiniGPT_Jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tested with free Google Compute Engine Backend. No GPU required.

# Imports

In [None]:
import jax
import flax.linen as nn
import jax.numpy as jnp
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as pp
import tqdm
import unittest
import time
import functools
import math

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
#@title Helper functions
dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

def get_batch(random_key, data, batch_size, block_size):
  """Generate a batch of data of inputs x and targets y.

  Args:
    random_key (jax.random.PRNGKey): Random number generator key.
    data (array-like): 1d JAX array of integer tokens
    batch_size (int): Batch size.
    block_size (int): The maximum input context length.

  Returns:
    x (array-like): 2d JAX array of shape (batch_size, block_size).
    y (array-like): 2d JAX array of shape (batch_size, block_size).
        x[i, j] == y[i, j-1] where j > 0.
  """
  # generate a small batch of data of inputs x and targets y
  ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
  x = dynamic_slice_vmap(data, ix, (block_size,))
  y = dynamic_slice_vmap(data, ix+1, (block_size,))
  return x, y

def load_shakespeare_dataset():
  with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
  data = jnp.array(encode(text))
  n = int(0.9*len(data)) # first 90% will be train, rest val
  train_data = data[:n]
  eval_data = data[n:]
  return train_data, eval_data

def init_train_state(
    model,
    params,
    learning_rate=1e-4,
):
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, x, y):
  """Run one step of training.
  Args:
    state (jax.training.TrainState): Jax TrainState containing weights and
      optimizer states.
    x (array-like): 2d JAX int array of shape (batch_size, block_size).
    y (array-like): 2d JAX int array of shape (batch_size, block_size).

  Returns:
    state (jax.training.TrainState): The new train state after applying
      gradient descent on weights and updating optimizer states.
    loss (float): Loss for this training step.
  """
  def _loss(params):
    predictions = state.apply_fn(params, x) # B, T, vocab_size
    loss = optax.softmax_cross_entropy_with_integer_labels(predictions, y)
    return loss.mean()
  loss, grads = jax.value_and_grad(_loss)(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss

@jax.jit
def eval_step(state, x, y):
  predictions = state.apply_fn(state.params, x)
  return optax.softmax_cross_entropy_with_integer_labels(predictions, y).mean()

def run_training_loop(
    num_iterations,
    batch_size,
    block_size,
    learning_rate,
    eval_data,
    train_data,
    model,
):
  """
  Runs the training loop for the specified model.

  Args:
      num_iterations (int): The number of training iterations.
      batch_size (int): The number of samples in each batch.
      block_size (int): The size of each block (sequence length).
      learning_rate (float): The learning rate for the optimizer.
      eval_data (array-like): 1d JAX array of integer tokens, consisting of evaluation data.
      train_data (array-like): 1d JAX array of integer tokens, consisting of training data.
      model (nn.Module, optional): A Jax Model object.

  Returns:
      state: The training state with the best eval metrics.

  Example:
      >>> final_state = run_training_loop(
      >>>     num_iterations=1000,
      >>>     batch_size=16,
      >>>     block_size=32,
      >>>     learning_rate=0.001,
      >>>     eval_data=eval_data,
      >>>     train_data=train_data,
      >>>     model=mini_gpt
      >>> )
  """
  random_key = jax.random.PRNGKey(0)
  x = jnp.ones((batch_size, block_size), dtype=jnp.int16)
  random_key, random_subkey = jax.random.split(random_key)
  params = model.init(random_subkey, x)
  state = init_train_state(
      model, params, learning_rate=learning_rate)
  predictions = state.apply_fn(state.params, x)
  best_state = state
  best_eval_loss = math.inf
  for i in range(num_iterations):
    random_key, random_subkey = jax.random.split(random_key)
    x, y = get_batch(random_subkey, train_data, batch_size=batch_size, block_size=block_size)
    state, loss = train_step(state, x, y)

    if i % 100 == 0:
      random_key, random_subkey = jax.random.split(random_key)
      eval_loss = eval_step(state, *get_batch(random_subkey, eval_data, batch_size=batch_size, block_size=block_size))
      print(f"Step: {i}\t train loss: {loss}\t eval loss: {eval_loss}")
      if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        best_state = state
  return best_state

## Load and tokenize dataset

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print("length of dataset in characters: ", len(text))

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Let's now split up the data into train and validation sets
data = jnp.array(encode(text))
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
eval_data = data[n:]

# Checking the performance of a simple text decoder model

The SimpleDecoder below will predict the next token given a single token.

In [None]:
class SimpleDecoder(nn.Module):
  vocab_size: int

  def setup(self):
    self.token_embedding = nn.Embed(
        num_embeddings=self.vocab_size,
        features=self.vocab_size)

  def __call__(self, x):
    B, T = x.shape
    return self.token_embedding(x) # B, T, vocab_size

  def generate(self, start_token, max_length=20, end_token=None):
    # Initialize the generated sequence with the start token
    generated_sequence = [start_token]
    current_token = start_token

    for _ in range(max_length - 1):  # We already have the start token
      # Convert the current token to a tensor
      current_token_tensor = jnp.array([[current_token]])

      # Get the token embeddings
      token_logits = self.__call__(current_token_tensor)

      # Get the token with the highest probability
      next_token = jnp.argmax(token_logits, axis=-1)[0]

      # Append the next token to the generated sequence
      generated_sequence.append(int(next_token[0]))

      # If the end token is generated, stop the generation
      if end_token is not None and next_token[0] == end_token:
          break

      # Update the current token
      current_token = int(next_token[0])

    return generated_sequence

decoder = SimpleDecoder(vocab_size=vocab_size)
start_token = 23
dummy = jnp.ones((4, 8), dtype=jnp.int16)
params = decoder.init(jax.random.PRNGKey(0), dummy)

# Generate a sequence
generated_sequence = decoder.apply(params, start_token, method=decoder.generate, max_length=20)
print("Generated sequence:", decode(generated_sequence))


The Generated sequence is gibberish. Let's see if it gets better when we train it.

In [None]:
# You can play around the parameters here to see how that affects loss.
num_iterations = 7000
learning_rate = 1e-3
num_layers = 4
batch_size = 16
block_size = 32
num_heads = 4
hidden_dim = 64

decoder = SimpleDecoder(vocab_size=vocab_size)

simple_decoder_state = run_training_loop(
    num_iterations = num_iterations,
    learning_rate = learning_rate,
    batch_size = batch_size,
    block_size = block_size,
    eval_data = eval_data,
    train_data = train_data,
    model = decoder
)

In [None]:
generated_sequence = decoder.apply(simple_decoder_state.params, start_token, method=decoder.generate, max_length=20)
print("Generated sequence:", decode(generated_sequence))


# Task 1 - Implement MiniGPT.

* You can use off-the-shelf Flax modules like Dense, LayerNorm. You may not use Flax's SelfAttention. Instead, use AttentionTask1 provided below.
* Note that block_size, T, input context window length are different ways to refer to the same thing.

In [None]:
# B == batch_size.
# T == number of tokens in sequence.
# C == hidden_dim == hidden dimension of transformer.
# head_dim == Head dimension for each Attention head. head_dim * num_heads == C.

# You can use this class for solving Task 1. We will revisit this class in Task 2.
class AttentionTask1(nn.Module):
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)
    self.attention_impl = nn.MultiHeadDotProductAttention(
        num_heads=1, qkv_features=self.head_dim, dropout_rate=0.)

  def __call__(self, x):
    B, T, C = x.shape
    q = self.query(x)  # B, T, head_dim
    k = self.key(x)  # B, T, head_dim
    v = self.value(x)  # B, T, head_dim
    mask = jnp.tril(jnp.ones((B, 1, T, T)))
    return self.attention_impl(inputs_q=q, inputs_k=k, inputs_v=v, mask=mask)  # B, T, head_dim

# FeedForward is given to you for free.
class FeedForward(nn.Module):
  hidden_dim: int

  def setup(self):
    self.f1 = nn.Dense(features=4 * self.hidden_dim)
    self.f2 = nn.Dense(features=self.hidden_dim)

  def __call__(self, x):
    return self.f2(nn.relu(self.f1(x)))  # B, T, hidden_dim

class MultiHeadAttention(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.heads = [AttentionTask1(self.head_dim) for _ in range(self.num_heads)]
    self.dense = nn.Dense(self.num_heads*self.head_dim)

  def __call__(self, x):
    # Process input through each attention head
    head_outputs = [head(x) for head in self.heads]  # List of [B, T, head_dim]

    # Concatenate all head outputs along the feature dimension
    concatenated = jnp.concatenate(head_outputs, axis=-1)  # [B, T, num_heads*head_dim]

    # Project back to hidden_dim
    return self.dense(concatenated)  # [B, T, hidden_dim]

class DecoderBlock(nn.Module):
  hidden_dim: int
  num_heads: int

  def setup(self):
    head_dim = self.hidden_dim // self.num_heads
    self.mha = MultiHeadAttention(num_heads=self.num_heads, head_dim=head_dim)
    self.ffn = FeedForward(hidden_dim=self.hidden_dim)
    self.ln1 = nn.LayerNorm()
    self.ln2 = nn.LayerNorm()

  def __call__(self, x):
    # First residual connection with layer norm
    attn_output = self.mha(self.ln1(x))
    x = x + attn_output  # Residual connection

    # Second residual connection with layer norm
    ffn_output = self.ffn(self.ln2(x))
    x = x + ffn_output  # Residual connection

    return x  # [B, T, hidden_dim]

class MiniGPT(nn.Module):
  vocab_size: int
  hidden_dim: int
  block_size: int
  num_layers: int
  num_heads: int

  def setup(self):
    self.token_embedding = nn.Embed(
        num_embeddings=self.vocab_size,
        features=self.hidden_dim)
    self.position_encoding = nn.Embed(
        num_embeddings=self.block_size,
        features=self.hidden_dim
    )
    self.final_dense = nn.Dense(features=self.vocab_size)
    self.decoder_blocks = [
        DecoderBlock(hidden_dim=self.hidden_dim, num_heads=self.num_heads)
        for _ in range(self.num_layers)
    ]
    self.ln_final = nn.LayerNorm()

  def __call__(self, x):
    B, T = x.shape
    token_embeddings = self.token_embedding(x)  # [B, T, hidden_dim]
    position_embeddings = self.position_encoding(jnp.arange(T))[None, :, :]  # [1, T, hidden_dim]

    # Combine token and position embeddings
    x = token_embeddings + position_embeddings  # [B, T, hidden_dim]

    # Process through decoder blocks
    for block in self.decoder_blocks:
      x = block(x)

    # Final layer norm
    x = self.ln_final(x)

    return self.final_dense(x)  # [B, T, vocab_size]


  def generate(self, random_key, params, x, max_new_tokens=50):
    for _ in range(max_new_tokens):
      logits = self.apply(params, x[:, -self.block_size:])
      random_key, random_subkey = jax.random.split(random_key)
      new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=None)
      x = jnp.concatenate([x, new_token[:, None]], axis=1)
    return x

In [None]:
# You can play around the parameters here to see how that affects loss.
num_iterations = 4000
learning_rate = 1e-3
num_layers = 4
batch_size = 16
block_size = 32
num_heads = 4
hidden_dim = 128

mini_gpt = MiniGPT(
    vocab_size=vocab_size,
    hidden_dim=hidden_dim,
    block_size=block_size,
    num_layers=num_layers,
    num_heads=num_heads
)

mini_gpt_state = run_training_loop(
    num_iterations=num_iterations,
    learning_rate=learning_rate,
    batch_size=batch_size,
    block_size=block_size,
    eval_data=eval_data,
    train_data=train_data,
    model=mini_gpt
)

In [None]:
# Uncomment below to print predictions:
#x = jnp.zeros((1, 1), dtype=jnp.int32)
#random_key = jax.random.PRNGKey(0)
#tokens = mini_gpt.generate(random_key, params=mini_gpt_state.params, x=x)
#print(decode(tokens[0].tolist()))

In [None]:
# Pass this test before moving on to Task 2.
class TestTask1(unittest.TestCase):

  def test_minigpt(self):
    # Do not change these parameters.
    num_iterations = 4000
    learning_rate = 1e-3
    num_layers = 4
    batch_size = 16
    block_size = 32
    num_heads = 4
    hidden_dim = 128
    target_loss = 1.9
    random_key = jax.random.PRNGKey(42)

    mini_gpt = MiniGPT(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        block_size=block_size,
        num_layers=num_layers,
        num_heads=num_heads
    )

    train_data, eval_data = load_shakespeare_dataset()
    mini_gpt_state = run_training_loop(
        num_iterations = num_iterations,
        learning_rate = learning_rate,
        batch_size = batch_size,
        block_size = block_size,
        eval_data = eval_data,
        train_data = train_data,
        model = mini_gpt
    )
    eval_losses = []
    for _ in tqdm.tqdm(range(100)):
      random_key, random_subkey = jax.random.split(random_key)
      x, y = get_batch(
          random_subkey, eval_data, batch_size=batch_size, block_size=block_size)
      batch_eval_loss = eval_step(mini_gpt_state, x, y)
      eval_losses.append(batch_eval_loss)
    print(f"Average eval loss: {np.mean(eval_losses)}")
    self.assertTrue(np.mean(eval_losses) < target_loss)

# Uncomment the test below.
#TestTask1().test_minigpt()

# Implement the Self-Attention Jax Module


* We are implementing a decoder-only transformer. This means that each token can only attend to previous tokens, but not future tokens.

In [None]:
class AttentionTask2Solution(nn.Module):
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)

  def __call__(self, x):
    B, T, C = x.shape
    q = self.query(x) # B, T, head_dim
    k = self.key(x) # B, T, head_dim
    wei = q @ jax.numpy.transpose(k, axes=(0, 2, 1)) # B, T, T
    mask = jnp.tril(jnp.ones((T, T)))
    wei = jnp.where(mask, wei, -jnp.inf)
    wei = nn.softmax(wei / jnp.sqrt(self.head_dim), axis=-1) # B, T, T
    return wei @ self.value(x) # B, T, C

In [None]:
class TestAttention(unittest.TestCase):

  EXPECTED_ATTENTION_ARRAY = np.array([
    [[-0.3368626, 0.1565489, 0.96250117, 0.7116083, 0.48668504,
      0.3070267, -0.49149823, 0.7827484, 0.4131582, 0.7505922,
      0.90185213, -0.34802976, 1.2631372, 0.8314824, 0.45534268,
      0.11072167],
     [0.355573, 0.36409345, 0.19864899, 0.58222437, -0.01833684,
      0.8821246, 0.26334122, 0.10999514, 0.69409794, 0.3437622,
      -0.71399987, 0.6530971, 0.00235165, -0.5397035, 0.55874693,
      -0.4885986],
     [0.6003635, 0.34785143, -0.25671193, 0.3002994, -0.31720588,
      1.2125036, 0.6570689, -0.22460055, 0.9200514, -0.01703957,
      -1.5395278, 1.1767541, -0.7460983, -1.3350787, 0.61231965,
      -1.0458561],
     [-0.7845163, -0.5571454, 0.39112994, -0.63247937, -0.2971205,
      0.19273886, -0.25068092, 0.5804176, 0.3952121, 0.24023446,
      1.1744585, -1.0228857, 1.0987606, 0.90741533, 0.19215004,
      -0.98253024]]
    ]
  )

  def test_attention(self):
    attention = AttentionTask2Solution(head_dim=16)
    params = attention.init(jax.random.key(0), jnp.ones((1, 4, 8)))
    x = jax.random.normal(key=jax.random.key(0), shape=(1, 4, 8), dtype=jnp.float32)
    y = attention.apply(params, x)
    self.assertTrue(np.allclose(y, self.EXPECTED_ATTENTION_ARRAY))

#TestAttention().test_attention()

#  Speed up MultiheadAttention with Einsum.



In [None]:
class MultiHeadAttentionTask3(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.dense = nn.Dense(features=self.num_heads * self.head_dim)

  def __call__(self, x):
    B, T, C = x.shape

    # Project inputs to queries, keys, values
    q = self.query(x)  # [B, T, num_heads * head_dim]
    k = self.key(x)    # [B, T, num_heads * head_dim]
    v = self.value(x)  # [B, T, num_heads * head_dim]

    # Reshape to separate heads
    q = jnp.reshape(q, (B, T, self.num_heads, self.head_dim))  # [B, T, num_heads, head_dim]
    k = jnp.reshape(k, (B, T, self.num_heads, self.head_dim))  # [B, T, num_heads, head_dim]
    v = jnp.reshape(v, (B, T, self.num_heads, self.head_dim))  # [B, T, num_heads, head_dim]

    # Compute attention scores using einsum
    attn_scores = jnp.einsum('bqhd,bkhd->bhqk', q, k)  # [B, num_heads, T, T]

    # Scale attention scores
    attn_scores = attn_scores / jnp.sqrt(self.head_dim)

    # Create causal mask
    mask = jnp.tril(jnp.ones((T, T)))  # [T, T]
    mask = mask[None, None, :, :]  # [1, 1, T, T]
    attn_scores = jnp.where(mask == 0, -1e9, attn_scores)

    # Compute attention weights
    attn_weights = nn.softmax(attn_scores, axis=-1)  # [B, num_heads, T, T]

    # Apply attention to values
    out = jnp.einsum('bhqk,bkhd->bqhd', attn_weights, v)  # [B, T, num_heads, head_dim]

    # Concatenate heads and project
    out = jnp.reshape(out, (B, T, self.num_heads * self.head_dim))  # [B, T, num_heads * head_dim]
    out = self.dense(out)  # [B, T, num_heads * head_dim]

    return out

In [None]:
# Hyperparameters
batch_size = 32
block_size = 128
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cpu'  # or 'gpu' if available
eval_iters = 200

# Model parameters
vocab_size = 50257  # GPT-2 vocab size
hidden_dim = 256
num_heads = 8
num_layers = 6
head_dim = hidden_dim // num_heads

# Initialize model
model = MiniGPTWithTask3(
    vocab_size=vocab_size,
    hidden_dim=hidden_dim,
    block_size=block_size,
    num_layers=num_layers,
    num_heads=num_heads
)

# Initialize parameters
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
params = model.init(subkey, jnp.ones((batch_size, block_size), dtype=jnp.int32))

# Optimizer
optimizer = optax.adamw(learning_rate)
opt_state = optimizer.init(params)

# Training step
@jax.jit
def train_step(params, opt_state, xb, yb):
    def loss_fn(params):
        logits = model.apply(params, xb)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits.reshape(-1, logits.shape[-1]),
            labels=yb.reshape(-1)
        ).mean()
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Evaluation function
def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = []
        for _ in range(eval_iters):
            xb, yb = get_batch(split)
            logits = model.apply(params, xb)
            loss = optax.softmax_cross_entropy_with_integer_labels(
                logits=logits.reshape(-1, logits.shape[-1]),
                labels=yb.reshape(-1)
            ).mean()
            losses.append(loss)
        out[split] = np.mean(losses)
    return out

# Training loop
train_losses = []
eval_losses = []

for iter in range(max_iters):
    # Get batch
    xb, yb = get_batch('train')

    # Train step
    params, opt_state, loss = train_step(params, opt_state, xb, yb)
    train_losses.append(loss)

    # Evaluation
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        eval_losses.append(losses['val'])
        print(f"Step {iter}: train loss {loss:.4f}, val loss {losses['val']:.4f}")

