<a href="https://colab.research.google.com/github/himanshu-warulkar/JAX-and-Flax-projects/blob/main/Mini_Language_Model_with_JAX_and_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [4]:
# Implementing a Mini Language Model with JAX and Flax

'''
This notebook demonstrates how to implement and train a small language model using JAX and Flax. We'll use the **WikiText-2 dataset** for training and include educational tasks to reinforce key concepts.


'''
## Imports


import jax
import flax.linen as nn
import jax.numpy as jnp
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
import tqdm
import unittest
import time

## Load and Tokenize Datasets
### Download and Preprocess WikiText-2

In [5]:
# Load WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-v1")
text = "\n".join(dataset["train"]["text"][:1000])  # Use a subset for faster training

# Create character-level tokenizer
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

# Split into train/validation
data = jnp.array(encode(text))
n = int(0.9 * len(data))
train_data, eval_data = data[:n], data[n:]
print(f"Train size: {len(train_data)}, Val size: {len(eval_data)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/685k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.07M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/618k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Train size: 256523, Val size: 28503


## Helper Functions
### Batch Generator

In [6]:
def get_batch(rng, data, batch_size, block_size):
    ix = jax.random.randint(rng, (batch_size,), 0, len(data)-block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

### Training Loop

In [7]:
def train_step(state, x, y):
    def loss_fn(params):
        logits = state.apply_fn(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss

## Task 1: Implement Positional Encoding

### Objective: Add positional embeddings to the token embeddings in the model.
### Hint: Use nn.Embed to create a positional encoding table.

In [8]:
class Task1Model(nn.Module):
    vocab_size: int
    hidden_dim: int
    block_size: int

    def setup(self):
        self.token_embed = nn.Embed(self.vocab_size, self.hidden_dim)
        # TODO: Add positional encoding here

    def __call__(self, x):
        B, T = x.shape
        tok_emb = self.token_embed(x)  # (B, T, hidden_dim)
        # TODO: Add positional embeddings
        return tok_emb  # Modify this

## Task 2: Implement Multi-Head Attention

### Objective: Create a multi-head attention layer without using Flax's built-in modules.
#### Hint: Split queries, keys, and values into multiple heads.

In [9]:
class MultiHeadAttention(nn.Module):
    num_heads: int
    head_dim: int

    def setup(self):
        # TODO: Initialize query, key, value projections
        pass

    def __call__(self, x):
        B, T, C = x.shape
        # TODO: Split into heads, compute attention, concatenate
        return x  # Modify this

## Task 3: Optimize Training with JIT
### Objective: Use jax.jit to compile the training step for faster execution.
### Hint: Decorate train_step with @jax.jit.

In [11]:
@jax.jit  # <-- Add this decorator
def train_step(state, x, y):
    def loss_fn(params):
        #TODO
    #return state.apply_gradients(grads=grads), loss

## Training and Evaluation
### Model Initialization

In [None]:
model = Task1Model(vocab_size=vocab_size, hidden_dim=128, block_size=32)
rng = jax.random.PRNGKey(0)
x = jnp.ones((4, 32), dtype=jnp.int32)
params = model.init(rng, x)
state = train_state.TrainState.create(
    apply_fn=model.apply, params=params, tx=optax.adam(1e-3)
)

# Training loop
for step in range(1000):
    rng, subkey = jax.random.split(rng)
    x, y = get_batch(subkey, train_data, batch_size=32, block_size=32)
    state, loss = train_step(state, x, y)
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss:.4f}")

# Appendix: **Solutions**
## Task 1 Solution

In [None]:
class Task1Solution(nn.Module):
    vocab_size: int
    hidden_dim: int
    block_size: int

    def setup(self):
        self.token_embed = nn.Embed(self.vocab_size, self.hidden_dim)
        self.pos_embed = nn.Embed(self.block_size, self.hidden_dim)

    def __call__(self, x):
        B, T = x.shape
        tok_emb = self.token_embed(x)
        pos = self.pos_embed(jnp.arange(T))
        return tok_emb + pos

## Task 2 Solution

In [None]:
class MultiHeadAttentionSolution(nn.Module):
    num_heads: int
    head_dim: int

    def setup(self):
        self.proj_q = nn.Dense(self.num_heads * self.head_dim)
        self.proj_k = nn.Dense(self.num_heads * self.head_dim)
        self.proj_v = nn.Dense(self.num_heads * self.head_dim)
        self.proj_out = nn.Dense(self.num_heads * self.head_dim)

    def __call__(self, x):
        B, T, C = x.shape
        q = self.proj_q(x).reshape(B, T, self.num_heads, self.head_dim)
        k = self.proj_k(x).reshape(B, T, self.num_heads, self.head_dim)
        v = self.proj_v(x).reshape(B, T, self.num_heads, self.head_dim)
        attn = jnp.einsum("bqhd,bkhd->bhqk", q, k) / jnp.sqrt(self.head_dim)
        attn = jax.nn.softmax(attn, axis=-1)
        out = jnp.einsum("bhqk,bkhd->bqhd", attn, v).reshape(B, T, -1)
        return self.proj_out(out)

## Task 3 Solution

In [None]:
@jax.jit  # <-- Add this decorator
def train_step(state, x, y):
    def loss_fn(params):
        logits = state.apply_fn(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss

## Tests

In [None]:
class TestTasks(unittest.TestCase):
    def test_positional_encoding(self):
        model = Task1Solution(vocab_size=10, hidden_dim=8, block_size=16)
        x = jnp.ones((2, 16), dtype=jnp.int32)
        params = model.init(jax.random.PRNGKey(0), x)
        output = model.apply(params, x)
        self.assertEqual(output.shape, (2, 16, 8))

# Run tests
unittest.main(argv=[''], exit=False)