# üß† Gemma2 Reasoning GRPO - Full Production

**Google Tunix Hackathon 2026**

### Core Architecture:
- **Hardware:** TPU v5e-8 (2 Data x 4 Tensor Mesh)
- **Model:** Gemma 2 2B-IT (CPU-Offloaded bf16 Init)
- **Fine-Tuning:** LoRA on Attention + MLP layers
- **Strategy:** GRPO with 16 Parallel Generations per prompt
- **Rewards:** Format (25%) + Logic (30%) + Accuracy (45%) + Self-Correction & Length Regularization

## üì¶ Cell 1: Environment Setup
*Run once, then restart kernel*

In [None]:
import os
SETUP_MARKER = "/kaggle/working/.setup_complete_prod_v1"

if os.path.exists(SETUP_MARKER):
    print("‚úÖ Packages ready. Continuing...")
else:
    print("="*60)
    print("SETTING UP PRODUCTION ENVIRONMENT...")
    print("="*60)
    
    # Use %pip for notebooks (recommended by Tunix docs)
    %pip install --upgrade pip -q
    %pip uninstall -y jax jaxlib flax optax -q 2>/dev/null
    
    print("Installing JAX/TPU stack...")
    %pip install -q "jax[tpu]>=0.8.0" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    %pip install -q "numpy==2.0.0" "pyarrow==17.0.0"
    
    print("Installing Tunix & Qwix from GitHub (latest)...")
    %pip install -q git+https://github.com/google/tunix
    %pip install -q git+https://github.com/google/qwix
    %pip uninstall -q flax -y 2>/dev/null
    %pip install -q git+https://github.com/google/flax
    
    print("Installing supporting packages...")
    %pip install -q kagglehub transformers grain huggingface_hub tensorflow tensorflow_datasets datasets
    
    with open(SETUP_MARKER, "w") as f: 
        f.write("done")
    print("\n" + "="*60)
    print("‚úÖ SETUP COMPLETE. RESTART KERNEL NOW!")
    print("="*60)

## üîå Cell 2: Imports & Memory Monitor

In [None]:
import os, re, gc, time, shutil
from pathlib import Path
from typing import List, Dict, Any

import jax
import jax.numpy as jnp
import numpy as np
import optax
import grain
import qwix
import kagglehub

from flax import nnx
from orbax import checkpoint as ocp
from datasets import load_dataset
from transformers import AutoTokenizer

# Tunix imports
try:
    from tunix.models.gemma2 import model as gemma_model_lib
    from tunix.models.gemma2 import params as gemma_params_lib
    print("‚úì Using tunix.models.gemma2")
except ImportError:
    from tunix.models.gemma import model as gemma_model_lib
    from tunix.models.gemma import params as gemma_params_lib
    print("‚úì Using tunix.models.gemma (fallback)")

from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.generate import sampler as sampler_lib

class MemoryMonitor:
    @staticmethod
    def get_usage():
        try:
            stats = [d.memory_stats() for d in jax.devices() if d.memory_stats()]
            if stats:
                used = sum(s['bytes_in_use'] for s in stats)
                limit = sum(s['bytes_limit'] for s in stats)
                return used, limit
        except:
            pass
        return 0, 0
    
    @staticmethod
    def print_summary():
        used, limit = MemoryMonitor.get_usage()
        if limit > 0:
            print(f"  TPU Memory: {used/1e9:.2f}GB / {limit/1e9:.2f}GB ({100*used/limit:.1f}%)")
        else:
            print("  TPU Memory: stats unavailable")
    
    @staticmethod
    def check_available(required_gb=20):
        used, limit = MemoryMonitor.get_usage()
        available = (limit - used) / 1e9
        return available >= required_gb

monitor = MemoryMonitor()
print(f"JAX: {jax.__version__} | TPU Cores: {len(jax.devices())}")
monitor.print_summary()

## ‚öôÔ∏è Cell 3: Production Configuration
*16 Generations | 600 Steps | Optimized for TPU v5e-8*

In [None]:
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_VERSION = "gemma2-2b-it"
MODEL_PATH = "google/gemma-2/flax/gemma2-2b-it"
MODEL_HF_NAME = "google/gemma-2-2b-it"

# =============================================================================
# MESH CONFIGURATION
# TPU v5e-8: 8 cores, (2,4) mesh = 2 FSDP √ó 4 TP
# tp=4 splits the 4 KV-heads across 4 cores (1 head per core)
# fsdp=2 splits data/weights across 2 groups
# =============================================================================
MESH_SHAPE = (2, 4)
MESH_AXES = ("fsdp", "tp")

# =============================================================================
# LORA CONFIGURATION
# Targeting Attention (q, k, v, o) + MLP (gate, up, down) layers
# =============================================================================
LORA_RANK = 64
LORA_ALPHA = 64.0
# Regex pattern for Gemma2 layer names
LORA_TARGET_PATTERN = ".*q_einsum|.*kv_einsum|.*o_proj|.*gate_proj|.*up_proj|.*down_proj"

# =============================================================================
# GRPO CONFIGURATION (Production Settings)
# G=16: Generate 16 responses per prompt for statistical power
# Œº=3: 3 iterations per batch for stable updates
# Œ≤=0.04: KL penalty coefficient
# Œµ=0.2: PPO-style clipping
# =============================================================================
NUM_GENERATIONS = 16    # G in GRPO paper - 16 parallel generations
NUM_ITERATIONS = 3      # Œº in GRPO paper - iterations per batch
BETA = 0.04             # KL divergence penalty
EPSILON = 0.2           # Clipping parameter

# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================
MAX_STEPS = 600
LEARNING_RATE = 2e-6
WARMUP_STEPS = 40
WEIGHT_DECAY = 0.01

# Batch sizes - carefully tuned for 16 generations
# full_batch = num_samples √ó num_generations must be divisible by mini_batch
MINI_BATCH_SIZE = 8
MICRO_BATCH_SIZE = 1    # Gradient accumulation for memory efficiency

# =============================================================================
# SEQUENCE CONFIGURATION
# =============================================================================
MAX_PROMPT_LENGTH = 256
MAX_GENERATION_LENGTH = 768  # Allow long reasoning chains

# =============================================================================
# REWARD WEIGHTS (Total = 100%)
# =============================================================================
REWARD_WEIGHT_FORMAT = 0.25      # 25% - Correct XML structure
REWARD_WEIGHT_LOGIC = 0.30       # 30% - Quality reasoning with transitions
REWARD_WEIGHT_ACCURACY = 0.45    # 45% - Correct final answer
# Self-correction and length are additive bonuses (not weighted)

# =============================================================================
# OUTPUT TAGS
# =============================================================================
TAG_REASONING_START = "<reasoning>"
TAG_REASONING_END = "</reasoning>"
TAG_ANSWER_START = "<answer>"
TAG_ANSWER_END = "</answer>"

# =============================================================================
# PATHS
# =============================================================================
CKPT_DIR = "/kaggle/working/checkpoints"
OUTPUT_DIR = "/kaggle/working/gemma2-grpo-16gen"

# =============================================================================
# DATA CONFIGURATION
# num_samples must create full_batch divisible by MINI_BATCH_SIZE
# With G=16, if we use 32 samples: 32√ó16=512, 512/8=64 ‚úì
# =============================================================================
NUM_TRAIN_SAMPLES = 448  # 448 √ó 16 = 7168, divisible by 8
NUM_TEST_SAMPLES = 64    # 64 √ó 16 = 1024, divisible by 8
RANDOM_SEED = 42

os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("="*60)
print("PRODUCTION CONFIGURATION")
print("="*60)
print(f"Model: {MODEL_HF_NAME}")
print(f"Mesh: {MESH_SHAPE} ({MESH_AXES})")
print(f"LoRA: rank={LORA_RANK}, alpha={LORA_ALPHA}")
print(f"GRPO: G={NUM_GENERATIONS}, Œº={NUM_ITERATIONS}, Œ≤={BETA}, Œµ={EPSILON}")
print(f"Training: {MAX_STEPS} steps, LR={LEARNING_RATE}")
print(f"Batch: mini={MINI_BATCH_SIZE}, micro={MICRO_BATCH_SIZE}")
print(f"Data: {NUM_TRAIN_SAMPLES} train, {NUM_TEST_SAMPLES} test")
print(f"Rewards: Format={REWARD_WEIGHT_FORMAT*100}%, Logic={REWARD_WEIGHT_LOGIC*100}%, Accuracy={REWARD_WEIGHT_ACCURACY*100}%")
print("="*60)

## üéØ Cell 4: Weighted Reward Functions
*Format (25%) + Logic (30%) + Accuracy (45%) + Bonuses*

In [None]:
def format_reward(prompts, completions, **kwargs):
    """
    Reward for correct XML tag structure (25% weight).
    Checks: <reasoning>...</reasoning> followed by <answer>...</answer>
    """
    rewards = []
    for c in completions:
        score = 0.0
        has_r_start = TAG_REASONING_START in c
        has_r_end = TAG_REASONING_END in c
        has_a_start = TAG_ANSWER_START in c
        has_a_end = TAG_ANSWER_END in c
        
        # Full format with correct ordering
        if has_r_start and has_r_end and has_a_start and has_a_end:
            r_end_pos = c.find(TAG_REASONING_END)
            a_start_pos = c.find(TAG_ANSWER_START)
            if r_end_pos < a_start_pos:
                score = 1.0  # Perfect format
            else:
                score = 0.5  # Wrong order
        # Partial format
        elif (has_r_start and has_r_end) or (has_a_start and has_a_end):
            score = 0.3
        
        rewards.append(score * REWARD_WEIGHT_FORMAT)
    return rewards


def logic_reward(prompts, completions, **kwargs):
    """
    Reward for quality reasoning with logical transitions (30% weight).
    Checks: step-by-step reasoning, transition words, mathematical operations.
    """
    TRANSITIONS = [
        'therefore', 'because', 'since', 'so', 'thus', 'hence',
        'first', 'second', 'third', 'then', 'next', 'finally',
        'step', 'calculate', 'compute', 'multiply', 'divide', 'add', 'subtract'
    ]
    
    rewards = []
    for c in completions:
        score = 0.0
        
        # Extract reasoning section
        match = re.search(f'{TAG_REASONING_START}(.*?){TAG_REASONING_END}', c, re.DOTALL | re.IGNORECASE)
        if match:
            content = match.group(1).lower()
            
            # Count transition words (up to 0.4)
            trans_count = sum(1 for t in TRANSITIONS if t in content)
            score += min(0.4, trans_count * 0.05)
            
            # Count explicit steps like "Step 1:", "Step 2:" (up to 0.3)
            step_count = len(re.findall(r'step\s*\d', content))
            score += min(0.3, step_count * 0.1)
            
            # Check for mathematical expressions (up to 0.2)
            math_ops = len(re.findall(r'\d+\s*[+\-√ó√∑*/]\s*\d+', content))
            score += min(0.2, math_ops * 0.05)
            
            # Bonus for "equals" or "=" showing computation (up to 0.1)
            equals_count = content.count('=') + content.count('equals')
            score += min(0.1, equals_count * 0.02)
        
        rewards.append(min(1.0, score) * REWARD_WEIGHT_LOGIC)
    return rewards


def accuracy_reward(prompts, completions, **kwargs):
    """
    Reward for correct numerical answer (45% weight).
    Compares extracted answer with ground truth.
    """
    answers = kwargs.get('answers', [])
    rewards = []
    
    for i, c in enumerate(completions):
        score = 0.0
        
        if i >= len(answers):
            rewards.append(0.0)
            continue
        
        # Extract answer from completion
        match = re.search(f'{TAG_ANSWER_START}(.*?){TAG_ANSWER_END}', c, re.DOTALL)
        if not match:
            rewards.append(0.0)
            continue
        
        pred_text = match.group(1).strip()
        truth_text = str(answers[i]).strip()
        
        # Extract numbers
        pred_nums = re.findall(r'-?\d+\.?\d*', pred_text)
        truth_nums = re.findall(r'-?\d+\.?\d*', truth_text)
        
        if pred_nums and truth_nums:
            try:
                pred = float(pred_nums[-1])  # Use last number as final answer
                truth = float(truth_nums[-1])
                
                if pred == truth:
                    score = 1.0  # Exact match
                elif abs(truth) > 0.001:
                    ratio = pred / truth
                    if 0.99 <= ratio <= 1.01:
                        score = 0.9  # Within 1%
                    elif 0.9 <= ratio <= 1.1:
                        score = 0.5  # Within 10%
            except (ValueError, ZeroDivisionError):
                pass
        
        rewards.append(score * REWARD_WEIGHT_ACCURACY)
    return rewards


def self_correction_reward(prompts, completions, **kwargs):
    """
    Bonus reward for self-correction behavior (additive, max 0.1).
    Encourages the model to catch and fix its own mistakes.
    """
    CORRECTION_PHRASES = [
        "wait", "actually", "let me recalculate", "i made an error",
        "correction", "re-checking", "that's not right", "let me redo"
    ]
    
    rewards = []
    for c in completions:
        c_lower = c.lower()
        count = sum(1 for phrase in CORRECTION_PHRASES if phrase in c_lower)
        # Cap at 0.1 bonus
        rewards.append(min(0.1, count * 0.03))
    return rewards


def length_reward(prompts, completions, **kwargs):
    """
    Length regularization (additive, max 0.05).
    Rewards appropriate response length, penalizes too short/long.
    """
    rewards = []
    for c in completions:
        words = len(c.split())
        if 100 <= words <= 400:
            rewards.append(0.05)  # Ideal length
        elif 50 <= words <= 600:
            rewards.append(0.02)  # Acceptable
        else:
            rewards.append(0.0)   # Too short or too long
    return rewards


# Combine all reward functions
REWARD_FUNCTIONS = [
    format_reward,      # 25% weight
    logic_reward,       # 30% weight  
    accuracy_reward,    # 45% weight
    self_correction_reward,  # Bonus up to 10%
    length_reward,      # Bonus up to 5%
]

print("‚úÖ Reward Functions Defined")
print(f"   ‚Ä¢ format_reward: {REWARD_WEIGHT_FORMAT*100}% weight")
print(f"   ‚Ä¢ logic_reward: {REWARD_WEIGHT_LOGIC*100}% weight")
print(f"   ‚Ä¢ accuracy_reward: {REWARD_WEIGHT_ACCURACY*100}% weight")
print(f"   ‚Ä¢ self_correction_reward: bonus (max 10%)")
print(f"   ‚Ä¢ length_reward: bonus (max 5%)")
print(f"   Total potential: {(REWARD_WEIGHT_FORMAT + REWARD_WEIGHT_LOGIC + REWARD_WEIGHT_ACCURACY)*100}% + 15% bonus")

## üìä Cell 5: Data Loading (GSM8K)

In [None]:
# System prompt teaching the expected output format
SYSTEM_PROMPT = f"""You are a mathematical problem solver who shows their work.

For each problem:
1. Think through it step-by-step inside {TAG_REASONING_START} and {TAG_REASONING_END} tags
2. Show your calculations clearly
3. Give your final numerical answer inside {TAG_ANSWER_START} and {TAG_ANSWER_END} tags

Example:
{TAG_REASONING_START}
Step 1: Identify what we need to find.
Step 2: Set up the calculation.
Step 3: 5 √ó 10 = 50
Therefore, the answer is 50.
{TAG_REASONING_END}
{TAG_ANSWER_START}
50
{TAG_ANSWER_END}"""

def format_prompt(question: str) -> str:
    """Format a question with the system prompt."""
    return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nSolution:"

def extract_answer(answer_text: str) -> str:
    """Extract the final answer from GSM8K format (after ####)."""
    if '####' in answer_text:
        return answer_text.split('####')[-1].strip()
    return answer_text.strip()

# Load GSM8K dataset
print("Loading GSM8K dataset...")
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.shuffle(seed=RANDOM_SEED)

# Prepare data
all_data = []
for i, item in enumerate(dataset):
    if i >= NUM_TRAIN_SAMPLES + NUM_TEST_SAMPLES:
        break
    all_data.append({
        'prompt': format_prompt(item['question']),
        'answer': extract_answer(item['answer'])
    })

train_data = all_data[:NUM_TRAIN_SAMPLES]
test_data = all_data[NUM_TRAIN_SAMPLES:NUM_TRAIN_SAMPLES + NUM_TEST_SAMPLES]

# Verify batch alignment
full_batch_train = len(train_data) * NUM_GENERATIONS
full_batch_test = len(test_data) * NUM_GENERATIONS

print(f"‚úÖ Data Loaded")
print(f"   Train: {len(train_data)} samples √ó {NUM_GENERATIONS} gen = {full_batch_train} (√∑{MINI_BATCH_SIZE}={full_batch_train//MINI_BATCH_SIZE})")
print(f"   Test: {len(test_data)} samples √ó {NUM_GENERATIONS} gen = {full_batch_test} (√∑{MINI_BATCH_SIZE}={full_batch_test//MINI_BATCH_SIZE})")

## üîß Cell 6: Mesh & Tokenizer

In [None]:
# Create mesh for distributed training
mesh = jax.make_mesh(
    MESH_SHAPE, 
    MESH_AXES, 
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH_SHAPE)
)
print(f"‚úì Mesh created: {mesh}")

# Get HuggingFace token
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
    try:
        from kaggle_secrets import UserSecretsClient
        hf_token = UserSecretsClient().get_secret("HF_TOKEN")
        print("‚úì HF Token from Kaggle Secrets")
    except Exception as e:
        print(f"‚ö†Ô∏è No HF_TOKEN found: {e}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_HF_NAME, token=hf_token)

# Get EOS tokens for generation stopping
EOS_TOKENS = [tokenizer.eos_token_id]
try:
    end_turn_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
    if end_turn_id != tokenizer.unk_token_id:
        EOS_TOKENS.append(end_turn_id)
except:
    pass

print(f"‚úÖ Tokenizer Ready | EOS tokens: {EOS_TOKENS}")

## ü§ñ Cell 7: CPU-Offload Model Loading + LoRA
*Build on CPU first to prevent TPU OOM during initialization*

In [None]:
# =============================================================================
# STEP 1: Clean memory
# =============================================================================
gc.collect()
jax.clear_caches()
print("1. Memory cleared")
monitor.print_summary()

# =============================================================================
# STEP 2: Download model
# =============================================================================
print("\n2. Downloading model from Kaggle...")
k_path = kagglehub.model_download(MODEL_PATH)
ckpt_path = os.path.join(k_path, MODEL_VERSION)
print(f"   Path: {ckpt_path}")

# =============================================================================
# STEP 3: Get CPU device for offloading
# =============================================================================
try:
    cpu_device = jax.devices('cpu')[0]
    print(f"\n3. CPU device for offloading: {cpu_device}")
except:
    cpu_device = None
    print("\n3. No separate CPU device, using default")

# =============================================================================
# STEP 4: Build model blueprint on CPU (prevents TPU init spike)
# =============================================================================
print("\n4. Building model on CPU (bf16)...")
config = gemma_model_lib.ModelConfig.gemma2_2b_it()

# Build model structure
base_model = gemma_model_lib.Gemma(config, rngs=nnx.Rngs(RANDOM_SEED))

# Load and convert weights to bf16
print("   Loading checkpoint...")
raw_params = gemma_params_lib.load_and_format_params(ckpt_path)
print("   Converting to bf16...")
bf16_params = jax.tree.map(lambda x: x.astype(jnp.bfloat16), raw_params)
nnx.update(base_model, bf16_params)

# Clean up
del raw_params, bf16_params
gc.collect()
print("   ‚úì Model built on CPU")

# =============================================================================
# STEP 5: Shard to TPU mesh
# =============================================================================
print("\n5. Sharding to TPU mesh (2√ó4)...")
with mesh:
    graph, state = nnx.split(base_model)
    # Create reference model (frozen, for KL penalty)
    ref_model = nnx.merge(graph, state)
    # Create actor model (will have LoRA applied)
    actor = nnx.merge(graph, state)

del base_model
gc.collect()
print("   ‚úì Models sharded to TPU")
monitor.print_summary()

# =============================================================================
# STEP 6: Apply LoRA to actor model
# =============================================================================
print("\n6. Applying LoRA (rank=64)...")
print(f"   Target pattern: {LORA_TARGET_PATTERN}")

# Create LoRA provider
lora_provider = qwix.LoraProvider(
    module_path=LORA_TARGET_PATTERN,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)

# Get model input shape
model_input = actor.get_model_input()

# Apply LoRA
actor = qwix.apply_lora_to_model(
    actor,
    lora_provider,
    rngs=nnx.Rngs(RANDOM_SEED),
    **model_input
)

# Re-shard after LoRA application
with mesh:
    state = nnx.state(actor)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(actor, sharded_state)

print("   ‚úì LoRA applied to Attention + MLP layers")

# =============================================================================
# FINAL STATUS
# =============================================================================
print("\n" + "="*60)
print("‚úÖ MODEL READY")
print("="*60)
print(f"   Actor: Gemma2-2B-IT + LoRA (rank={LORA_RANK})")
print(f"   Reference: Gemma2-2B-IT (frozen)")
monitor.print_summary()

## üìö Cell 8: Data Loaders (Grain MapDataset)

In [None]:
class GSM8KDataSource:
    """Data source compatible with Grain MapDataset."""
    def __init__(self, data):
        self._data = data
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, idx):
        item = self._data[idx]
        # Tunix GRPO expects 'prompts' and 'answers' keys
        return {
            'prompts': item['prompt'],
            'answers': item['answer']
        }

# Create Grain MapDataset pipelines
# Using the same pattern as official Tunix GRPO demo
train_dataset = (
    grain.MapDataset.source(GSM8KDataSource(train_data))
    .shuffle(seed=RANDOM_SEED)
)

val_dataset = (
    grain.MapDataset.source(GSM8KDataSource(test_data))
)

print("‚úÖ Grain Datasets Created")
print(f"   Train: {len(train_data)} prompts")
print(f"   Val: {len(test_data)} prompts")

## ‚ö° Cell 9: Training Configuration

In [None]:
# Learning rate schedule with warmup and cosine decay
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    decay_steps=MAX_STEPS,
    end_value=LEARNING_RATE * 0.1
)

# Optimizer with gradient clipping
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(lr_schedule, weight_decay=WEIGHT_DECAY)
)

# Cluster configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        max_steps=MAX_STEPS,
        eval_every_n_steps=25,
        mini_batch_size=MINI_BATCH_SIZE,
        train_micro_batch_size=MICRO_BATCH_SIZE,
        checkpoint_root_directory=CKPT_DIR,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=MAX_GENERATION_LENGTH,
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=0.9,
        top_p=0.95,
        eos_tokens=EOS_TOKENS,
    )
)

# GRPO configuration - 16 generations!
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,  # G=16
    num_iterations=NUM_ITERATIONS,    # Œº=3
    beta=BETA,                        # Œ≤=0.04
    epsilon=EPSILON,                  # Œµ=0.2
)

print("‚úÖ Training Configuration")
print(f"   LR Schedule: warmup {WARMUP_STEPS} ‚Üí {LEARNING_RATE} ‚Üí cosine decay")
print(f"   Optimizer: AdamW (clip=1.0, decay={WEIGHT_DECAY})")
print(f"   GRPO: G={NUM_GENERATIONS}, Œº={NUM_ITERATIONS}, Œ≤={BETA}, Œµ={EPSILON}")
print(f"   Eval every {25} steps")

## üéì Cell 10: Create GRPO Trainer

In [None]:
gc.collect()
jax.clear_caches()

print("Creating RLCluster and GRPOLearner...")

with mesh:
    # Create RL cluster
    rl_cluster = rl_cluster_lib.RLCluster(
        actor=actor,
        reference=ref_model,
        tokenizer=tokenizer,
        cluster_config=cluster_config,
    )
    
    # Create GRPO trainer with weighted reward functions
    # NOTE: Use 'algo_config' (not 'grpo_config') per latest Tunix API
    grpo_trainer = GRPOLearner(
        rl_cluster=rl_cluster,
        reward_fns=REWARD_FUNCTIONS,
        algo_config=grpo_config,
    )

print("\n" + "="*60)
print("‚úÖ GRPO TRAINER READY")
print("="*60)
print(f"   Generations per prompt: {NUM_GENERATIONS}")
print(f"   Reward functions: {len(REWARD_FUNCTIONS)}")
print(f"     - format_reward (25%)")
print(f"     - logic_reward (30%)")
print(f"     - accuracy_reward (45%)")
print(f"     - self_correction_reward (bonus)")
print(f"     - length_reward (bonus)")
monitor.print_summary()

## üöÄ Cell 11: Training Loop
*First steps take 10-15 minutes for JIT compilation*

In [None]:
print("="*60)
print("üöÄ STARTING GRPO TRAINING")
print("="*60)
print(f"")
print(f"Configuration:")
print(f"  ‚Ä¢ Steps: {MAX_STEPS}")
print(f"  ‚Ä¢ Generations per prompt: {NUM_GENERATIONS}")
print(f"  ‚Ä¢ Iterations per batch: {NUM_ITERATIONS}")
print(f"  ‚Ä¢ Train samples: {NUM_TRAIN_SAMPLES}")
print(f"")
print(f"Reward Weights:")
print(f"  ‚Ä¢ Format: {REWARD_WEIGHT_FORMAT*100}%")
print(f"  ‚Ä¢ Logic: {REWARD_WEIGHT_LOGIC*100}%")
print(f"  ‚Ä¢ Accuracy: {REWARD_WEIGHT_ACCURACY*100}%")
print(f"")
print("="*60)
print("\n‚è≥ First steps will take 10-15 minutes for JIT compilation...")
print("   Subsequent steps will be much faster.\n")

monitor.print_summary()
start_time = time.time()

# Run training
with mesh:
    grpo_trainer.train(train_dataset, val_dataset)

elapsed = time.time() - start_time
print("\n" + "="*60)
print("üéâ TRAINING COMPLETE")
print("="*60)
print(f"   Total time: {elapsed/60:.1f} minutes")
print(f"   Steps completed: {MAX_STEPS}")
monitor.print_summary()

## üíæ Cell 12: Save Model

In [None]:
print("Saving trained model...")

os.makedirs(OUTPUT_DIR, exist_ok=True)

with mesh:
    _, actor_state = nnx.split(actor)
    checkpointer = ocp.StandardCheckpointer()
    save_path = os.path.join(OUTPUT_DIR, "actor_state")
    checkpointer.save(save_path, actor_state)

print(f"‚úÖ Model saved to {OUTPUT_DIR}")
print(f"   Checkpoint: {save_path}")

## üß™ Cell 13: Test Inference

In [None]:
test_questions = [
    ("If a store sells 150 apples at $4 each, what is the total revenue?", "600"),
    ("A train travels 120 miles in 2 hours. What is its speed in miles per hour?", "60"),
    ("Sarah has 24 cookies. She gives 1/3 to her brother and 1/4 to her sister. How many cookies does she have left?", "10"),
]

print("="*60)
print("üß™ TESTING TRAINED MODEL")
print("="*60 + "\n")

with mesh:
    sampler = sampler_lib.Sampler(
        model=actor,
        tokenizer=tokenizer,
        max_tokens=MAX_GENERATION_LENGTH
    )
    
    for i, (question, expected) in enumerate(test_questions, 1):
        prompt = format_prompt(question)
        output = sampler.generate(prompt)
        
        print(f"Question {i}: {question}")
        print(f"Expected: {expected}")
        print(f"Output:")
        print(output[:800])
        print("-" * 60 + "\n")

## üìã Cell 14: Final Summary

In [None]:
print("="*60)
print("üèÜ PRODUCTION REASONING MODEL COMPLETE")
print("="*60)
print(f"")
print(f"Architecture:")
print(f"  ‚Ä¢ Hardware: TPU v5e-8 (2√ó4 mesh)")
print(f"  ‚Ä¢ Model: {MODEL_HF_NAME}")
print(f"  ‚Ä¢ Fine-Tuning: LoRA (rank={LORA_RANK}, alpha={LORA_ALPHA})")
print(f"  ‚Ä¢ Targets: Attention + MLP layers")
print(f"")
print(f"GRPO Configuration:")
print(f"  ‚Ä¢ Generations (G): {NUM_GENERATIONS}")
print(f"  ‚Ä¢ Iterations (Œº): {NUM_ITERATIONS}")
print(f"  ‚Ä¢ KL Penalty (Œ≤): {BETA}")
print(f"  ‚Ä¢ Clipping (Œµ): {EPSILON}")
print(f"")
print(f"Reward Weights:")
print(f"  ‚Ä¢ Format: {REWARD_WEIGHT_FORMAT*100}%")
print(f"  ‚Ä¢ Logic: {REWARD_WEIGHT_LOGIC*100}%")
print(f"  ‚Ä¢ Accuracy: {REWARD_WEIGHT_ACCURACY*100}%")
print(f"  ‚Ä¢ Self-Correction: bonus")
print(f"  ‚Ä¢ Length: bonus")
print(f"")
print(f"Training:")
print(f"  ‚Ä¢ Steps: {MAX_STEPS}")
print(f"  ‚Ä¢ Samples: {NUM_TRAIN_SAMPLES} train, {NUM_TEST_SAMPLES} test")
print(f"  ‚Ä¢ Output: {OUTPUT_DIR}")
print(f"")
monitor.print_summary()
print("="*60)