# Building Gemma: A JAX/Flax Implementation Tutorial

This notebook demonstrates how to implement a smaller version of the Gemma language model using JAX and Flax. We'll build the model step by step, explaining key concepts along the way.

## Table of Contents
1. [Setup and Dependencies](#setup)
2. [Model Configuration](#config)
3. [Core Model Components](#components)
4. [Training Infrastructure](#training)
5. [Training Loop and Generation](#generation)

Let's get started!

## 1. Setup and Dependencies <a name="setup"></a>

First, let's import all necessary libraries. We'll be using:
- JAX for numerical computations and automatic differentiation
- Flax for neural network layers
- Optax for optimization
- Other utilities for data handling and tokenization

In [None]:
import os
import math
import pickle
import jax
import jax.numpy as jnp
from jax import random, lax, vmap
import flax.linen as nn
import optax
from functools import partial
from tqdm import tqdm
from datasets import load_dataset
from tokenizers import SentencePieceUnigramTokenizer
from typing import Any, Callable, Dict, List, Optional, Tuple
from flax.training import train_state, orbax_utils
import orbax.checkpoint as ocp

# Check for TPU and set environment
os.environ['JAX_PLATFORM_NAME'] = 'tpu'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
print("JAX devices:", jax.devices())

## 2. Model Configuration <a name="config"></a>

The `GemmaConfig` class defines all hyperparameters for our model. We're using a smaller version of the original Gemma model to make it more manageable for educational purposes.

In [None]:
class GemmaConfig:
    """Configuration for Gemma model"""
    vocab_size: int = 32000
    dim: int = 256  # Reduced from 512 to save memory
    n_layers: int = 4  # Reduced from 8 to save memory
    n_heads: int = 4  # Reduced from 8 to save memory
    n_kv_heads: int = 1  # For multi-query attention, use 1 KV head
    max_seq_len: int = 512  # Reduced from 2048 to save memory
    dropout_rate: float = 0.0  # Dropout rate

    # RoPE settings
    rope_theta: float = 10000.0  # Base for rotary embeddings

    # Training settings
    batch_size: int = 16  # Reduced from 16 to save memory
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    warmup_steps: int = 100
    max_steps: int = 10000

    # Generation settings
    temperature: float = 0.8
    top_k: int = 40
    top_p: float = 0.95

## 3. Core Model Components <a name="components"></a>

### 3.1 RMS Normalization
Root Mean Square Layer Normalization is a simpler alternative to traditional Layer Normalization.

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    dim: int
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        weight = self.param('weight', nn.initializers.ones, (self.dim,))
        variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
        x = x * jnp.reciprocal(jnp.sqrt(variance + self.eps))
        return x * weight

### 3.2 Rotary Position Embeddings (RoPE)
RoPE is a method for encoding positional information into the key and query vectors of the attention mechanism.

In [None]:
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
    """Precompute the frequency tensor for complex exponentials (rotary embeddings)."""
    freqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim))
    t = jnp.arange(max_seq_len, dtype=jnp.float32)
    freqs = jnp.outer(t, freqs)
    return jnp.complex64(jnp.exp(1j * freqs))

def apply_rotary_emb(xq, xk, freqs_cis):
    """Apply rotary embeddings to the query and key tensors."""
    xq_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2))
    xk_r = jnp.reshape(xk, (*xk.shape[:-1], -1, 2))
    
    xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1])
    xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1])
    
    freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1]))
    
    xq_out = xq_complex * freqs_cis
    xk_out = xk_complex * freqs_cis
    
    xq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape)
    xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape)
    
    return xq, xk

### 3.3 Flash Attention
An optimized implementation of the attention mechanism for better memory efficiency.

In [None]:
@partial(jax.jit)
def flash_attention(q, k, v, mask=None, scale=None):
    """Optimized attention implementation"""
    if scale is None:
        scale = 1.0 / jnp.sqrt(q.shape[-1])

    scores = jnp.einsum('bhid,bhjd->bhij', q, k) * scale

    if mask is not None:
        scores = scores + mask

    scores_max = jnp.max(scores, axis=-1, keepdims=True)
    scores = scores - lax.stop_gradient(scores_max)

    attn_weights = jnp.exp(scores)
    attn_weights = attn_weights / jnp.sum(attn_weights, axis=-1, keepdims=True)

    output = jnp.einsum('bhij,bhjd->bhid', attn_weights, v)

    return output

### 3.4 Multi-Query Attention
The attention mechanism with shared key and value heads across query heads for efficiency.

In [None]:
class GemmaCausalSelfAttention(nn.Module):
    """Multi-query attention (single KV head shared across all query heads)"""
    config: GemmaConfig

    def setup(self):
        config = self.config
        dim = config.dim
        n_heads = config.n_heads
        n_kv_heads = config.n_kv_heads
        head_dim = dim // n_heads

        self.wq = nn.Dense(n_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.wk = nn.Dense(n_kv_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.wv = nn.Dense(n_kv_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.wo = nn.Dense(dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))

    def __call__(self, x, freqs_cis, mask=None, deterministic=True):
        B, T, C = x.shape
        config = self.config
        n_heads = config.n_heads
        n_kv_heads = config.n_kv_heads
        head_dim = C // n_heads

        q = self.wq(x).reshape(B, T, n_heads, head_dim)
        k = self.wk(x).reshape(B, T, n_kv_heads, head_dim)
        v = self.wv(x).reshape(B, T, n_kv_heads, head_dim)

        q, k = apply_rotary_emb(q, k, freqs_cis[:T])

        if n_heads > n_kv_heads:
            k = jnp.repeat(k, n_heads // n_kv_heads, axis=2)
            v = jnp.repeat(v, n_heads // n_kv_heads, axis=2)

        q, k, v = map(lambda x: jnp.swapaxes(x, 1, 2), (q, k, v))
        output = flash_attention(q, k, v, mask)
        output = jnp.swapaxes(output, 1, 2).reshape(B, T, -1)
        output = self.wo(output)

        return output

### 3.5 MLP with GeGLU Activation
The feed-forward network using GeGLU activation function.

In [None]:
def geglu(x, w1, w2, w3):
    """GeGLU activation function using Flax modules"""
    return w2(jax.nn.gelu(w3(x), approximate=True) * w1(x))

class GemmaMLP(nn.Module):
    """Feed-forward network with GeGLU activation"""
    config: GemmaConfig

    def setup(self):
        dim = self.config.dim
        hidden_dim = 4 * dim

        self.w1 = nn.Dense(hidden_dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.w2 = nn.Dense(dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.w3 = nn.Dense(hidden_dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))

    def __call__(self, x):
        return geglu(x, self.w1, self.w2, self.w3)

### 3.6 Transformer Block
The complete transformer block combining attention and MLP layers.

In [None]:
class GemmaBlock(nn.Module):
    """Gemma transformer block with pre-normalization"""
    config: GemmaConfig

    def setup(self):
        self.attention_norm = RMSNorm(self.config.dim)
        self.attention = GemmaCausalSelfAttention(self.config)
        self.ffn_norm = RMSNorm(self.config.dim)
        self.ffn = GemmaMLP(self.config)
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(self, x, freqs_cis, mask=None, deterministic=True):
        attn_input = self.attention_norm(x)
        attn_output = self.attention(attn_input, freqs_cis, mask, deterministic)
        h = x + self.dropout(attn_output, deterministic=deterministic)

        ffn_input = self.ffn_norm(h)
        ffn_output = self.ffn(ffn_input)
        out = h + self.dropout(ffn_output, deterministic=deterministic)

        return out

### 3.7 Complete Gemma Model
The full language model combining all components.

In [None]:
class Gemma(nn.Module):
    """Gemma language model"""
    config: GemmaConfig

    def setup(self):
        config = self.config

        self.token_embedding = nn.Embed(
            config.vocab_size,
            config.dim,
            embedding_init=nn.initializers.normal(stddev=0.02)
        )

        self.blocks = [GemmaBlock(config) for _ in range(config.n_layers)]
        self.norm_f = RMSNorm(config.dim)
        self.lm_head = nn.Dense(
            config.vocab_size,
            kernel_init=nn.initializers.normal(stddev=0.02),
            use_bias=False
        )

        self.freqs_cis = precompute_freqs_cis(
            config.dim // config.n_heads,
            config.max_seq_len,
            config.rope_theta
        )

    def __call__(self, input_ids, deterministic=True):
        B, T = input_ids.shape

        mask = jnp.tril(
            jnp.ones((self.config.max_seq_len, self.config.max_seq_len))
        )
        mask = jnp.where(mask == 0, jnp.finfo(jnp.float32).min, 0.0)
        mask = mask[None, None, :T, :T]

        h = self.token_embedding(input_ids)

        for block in self.blocks:
            h = block(h, self.freqs_cis, mask, deterministic)

        h = self.norm_f(h)
        logits = self.lm_head(h)

        return logits

## 4. Training Infrastructure <a name="training"></a>

Let's set up the training infrastructure including data preparation and optimization.

In [None]:
def create_train_state(model, config, rng_key):
    """Create initial training state"""
    init_params = model.init(rng_key, jnp.ones((1, 1), dtype=jnp.int32))

    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=config.learning_rate,
        warmup_steps=config.warmup_steps,
        decay_steps=config.max_steps,
        end_value=config.learning_rate * 0.1
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adamw(
            learning_rate=lr_schedule,
            b1=0.9,
            b2=0.95,
            eps=1e-8,
            weight_decay=config.weight_decay
        )
    )

    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=init_params,
        tx=optimizer
    )

def prepare_datasets(config):
    """Load and prepare datasets"""
    wiki_dataset = load_dataset("karpathy/tiny_shakespeare", split="train")
    
    tokenizer = SentencePieceUnigramTokenizer()
    
    text_column = "text"
    sample_texts = [
        example[text_column] for example in wiki_dataset.select(range(min(10000, len(wiki_dataset))))
        if isinstance(example[text_column], str) and example[text_column].strip()
    ]

    tokenizer.train_from_iterator(
        sample_texts,
        vocab_size=config.vocab_size,
        special_tokens=["<pad>", "<unk>", "<bos>", "<eos>"]
    )

    tokenizer.save("gemma_tokenizer.json")

    def tokenize_function(example):
        text = example[text_column]
        if isinstance(text, str) and text.strip():
            tokens = tokenizer.encode(text).ids
            return {"input_ids": tokens}
        return {"input_ids": []}

    tokenized_dataset = wiki_dataset.map(
        tokenize_function,
        remove_columns=list(wiki_dataset.features.keys()),
        batched=False
    )

    tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) > 0)

    all_tokens = []
    for example in tokenized_dataset:
        all_tokens.extend(example["input_ids"])

    return {"input_ids": all_tokens}, tokenizer

## 5. Training Loop and Generation <a name="generation"></a>

Now let's implement the training loop and text generation functionality.

In [None]:
def train_step(state, batch, dropout_rng):
    """Single training step"""
    inputs, targets = batch

    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params}, 
            inputs, 
            deterministic=False,
            rngs={'dropout': dropout_rng}
        )

        logits = logits.reshape(-1, logits.shape[-1])
        targets_flat = targets.reshape(-1)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, targets_flat
        ).mean()

        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params['params'])
    new_state = state.apply_gradients(grads={'params': grads})

    return new_state, {'loss': loss}

# Example usage
if __name__ == "__main__":
    # Initialize config and model
    config = GemmaConfig()
    model = Gemma(config)
    
    # Create training state
    rng_key = random.PRNGKey(0)
    state = create_train_state(model, config, rng_key)
    
    # Prepare dataset
    train_dataset, tokenizer = prepare_datasets(config)
    
    # Training loop (simplified for demonstration)
    num_steps = 100
    for step in range(num_steps):
        # Get batch
        rng_key, data_key, dropout_key = random.split(rng_key, 3)
        batch = get_batch(data_key, train_dataset, config)
        
        # Training step
        state, metrics = train_step(state, batch, dropout_key)
        
        if step % 10 == 0:
            print(f"Step {step}, Loss: {metrics['loss']:.4f}")
    
    print("Training complete!")

## Conclusion

This notebook has demonstrated how to implement a smaller version of the Gemma language model using JAX and Flax. Key takeaways:

1. The model uses modern architecture components like RMSNorm and Rotary Position Embeddings
2. Multi-query attention helps reduce memory usage while maintaining model quality
3. The implementation is optimized for TPU/GPU acceleration using JAX

To extend this further, you could:
- Scale up the model size
- Implement more sophisticated training techniques
- Add model parallel training support
- Implement better tokenization strategies