# üß† Gemma2 Reasoning GRPO - Full Production

**Google Tunix Hackathon 2026**

### Core Architecture:
- **Hardware:** TPU v5e-8 (2 Data x 4 Tensor Mesh)
- **Model:** Gemma 2 2B-IT (CPU-Offloaded bf16 Init)
- **Fine-Tuning:** LoRA on Attention + MLP layers
- **Strategy:** GRPO with 16 Parallel Generations per prompt
- **Rewards:** Format (25%) + Logic (30%) + Accuracy (45%) + Self-Correction & Length Regularization

## Cell 1: Environment Setup

In [None]:
import os
import shutil
import glob
import time

# =============================================================================
# FORCE RESET: Smart Cleanup of ALL old markers
# =============================================================================
# This defines the NEW marker we want to create after success
CURRENT_MARKER = "/kaggle/working/.setup_complete_gtg"

print("üßπ Cleaning up old installation markers...")

# Use glob to find ANY file starting with ".setup_complete"
old_markers = glob.glob("/kaggle/working/.setup_complete*")

for marker in old_markers:
    # Don't delete the current one if it essentially already exists (though we usually want to overwrite)
    try:
        os.remove(marker)
        print(f"   üóëÔ∏è Deleted old marker: {marker}")
    except OSError as e:
        print(f"   ‚ö†Ô∏è Could not delete {marker}: {e}")

print("="*60)
print("üîÑ INSTALLING DEPENDENCIES")
print("="*60)

%pip install --upgrade pip -q

# 1. Clean previous installations
    print("\nüßπ Cleaning previous installations...")
    %pip uninstall -y tunix google-tunix flax qwix 2>/dev/null
    
    # 2. Install JAX for TPU
    print("\n‚¨áÔ∏è Installing JAX/TPU stack...")
    %pip install -q "jax[tpu]>=0.8.0" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
    # 3. Install Tunix from PyPI (STABLE - recommended)
    print("\n‚¨áÔ∏è Installing Tunix (stable from PyPI)...")
    %pip install -q "google-tunix[prod]"
    
    # 4. Install Qwix and Flax
    print("\n‚¨áÔ∏è Installing Qwix and Flax...")
    %pip install -q git+https://github.com/google/qwix
    %pip install -q git+https://github.com/google/flax
    
    # 5. Install other dependencies
    print("\n‚¨áÔ∏è Installing other dependencies...")
    %pip install -q kagglehub transformers grain huggingface_hub orbax-checkpoint
    %pip install -q "numpy>=2.0.0" "pyarrow>=17.0.0" "datasets>=2.21.0"
    
    # Create marker
    with open(MARKER, "w") as f:
        f.write("done")
    
    print("\n" + "="*60)
    print("‚úÖ INSTALLATION COMPLETE")
    print("‚ö†Ô∏è  PLEASE RESTART KERNEL NOW (‚ü≥ Button)!")
    print("="*60)
else:
    print("‚úÖ Dependencies already installed. Proceeding...")

## Cell 2: Verify Installation

In [None]:
import sys

print("üîç Verifying Tunix Installation...\n")

try:
    # Check tunix imports
    from tunix.rl import rl_cluster as rl_cluster_lib
    from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
    from tunix.rl.rollout import base_rollout
    from tunix.sft import metrics_logger
    
    print("‚úÖ tunix.rl.rl_cluster imported")
    print("‚úÖ GRPOConfig, GRPOLearner imported")
    print("‚úÖ base_rollout imported")
    print("‚úÖ metrics_logger imported")
    
    # Verify correct classes exist
    assert hasattr(rl_cluster_lib, 'ClusterConfig'), "ClusterConfig not found!"
    assert hasattr(rl_cluster_lib, 'RLTrainingConfig'), "RLTrainingConfig not found!"
    assert hasattr(rl_cluster_lib, 'Role'), "Role enum not found!"
    assert hasattr(rl_cluster_lib, 'RLCluster'), "RLCluster not found!"
    
    print("\n‚úÖ All required classes found:")
    print(f"   - ClusterConfig: {rl_cluster_lib.ClusterConfig}")
    print(f"   - RLTrainingConfig: {rl_cluster_lib.RLTrainingConfig}")
    print(f"   - Role: {list(rl_cluster_lib.Role)}")
    
    # Check RLTrainingConfig has the fields we need
    import inspect
    sig = inspect.signature(rl_cluster_lib.RLTrainingConfig)
    params = list(sig.parameters.keys())
    print(f"\n‚úÖ RLTrainingConfig parameters: {params[:10]}...")
    
    if 'actor_optimizer' in params:
        print("‚úÖ 'actor_optimizer' field exists (required)")
    if 'gradient_accumulation_steps' in params:
        print("‚úÖ 'gradient_accumulation_steps' field exists (optional)")
        
    print("\n" + "="*60)
    print("üéâ TUNIX INSTALLATION VERIFIED SUCCESSFULLY!")
    print("="*60)
    
except ImportError as e:
    print(f"‚ùå Import Error: {e}")
    print("\nPlease run Cell 1 and restart the kernel.")
    sys.exit(1)
except AssertionError as e:
    print(f"‚ùå Assertion Error: {e}")
    print("\nYour tunix version may be incompatible. Try reinstalling.")
    sys.exit(1)

## Cell 3: TPU Setup & Memory Monitor

In [None]:
# =============================================================================
# TPU SETUP & MEMORY MONITORING
# =============================================================================
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
import time

print("="*60)
print("‚ö° TPU INITIALIZATION")
print("="*60)

# Check devices
devices = jax.devices()
NUM_TPUS = len(devices)
print(f"\n‚úÖ Found {NUM_TPUS} TPU devices:")
for i, d in enumerate(devices):
    print(f"   Device {i}: {d}")

# Determine mesh shape based on TPU count
if NUM_TPUS == 8:
    MESH_SHAPE = (2, 4)  # 2 FSDP x 4 TP
elif NUM_TPUS == 4:
    MESH_SHAPE = (2, 2)
elif NUM_TPUS == 1:
    MESH_SHAPE = (1, 1)
else:
    raise ValueError(f"Unsupported TPU count: {NUM_TPUS}")

MESH_AXIS_NAMES = ("fsdp", "tp")

print(f"\n‚úÖ Mesh configuration: {MESH_SHAPE} ({MESH_AXIS_NAMES})")

# Memory monitoring class
class TPUMemoryMonitor:
    def __init__(self):
        self.devices = jax.devices()
    
    def get_memory_stats(self):
        stats = []
        for i, device in enumerate(self.devices):
            try:
                mem = device.memory_stats()
                if mem:
                    used_gb = mem.get('bytes_in_use', 0) / 1e9
                    limit_gb = mem.get('bytes_limit', 0) / 1e9
                    pct = 100 * used_gb / limit_gb if limit_gb > 0 else 0
                    stats.append((i, used_gb, limit_gb, pct))
            except:
                pass
        return stats
    
    def print_summary(self):
        print("\nüìä TPU Memory Usage:")
        stats = self.get_memory_stats()
        if stats:
            for i, used, limit, pct in stats:
                bar = "‚ñà" * int(pct/5) + "‚ñë" * (20 - int(pct/5))
                status = "‚ö†Ô∏è" if pct > 80 else "‚úÖ"
                print(f"   TPU {i}: [{bar}] {used:.1f}/{limit:.1f} GB ({pct:.1f}%) {status}")
        else:
            print("   Could not retrieve memory stats")

monitor = TPUMemoryMonitor()
monitor.print_summary()

print("\n" + "="*60)
print("‚úÖ TPU READY")
print("="*60)

## Cell 4: Hyperparameters.

In [None]:
# =============================================================================
# HYPERPARAMETERS - Memory Safe for TPU v5e (128GB)
# =============================================================================
from dataclasses import dataclass

@dataclass
class Config:
    # === Model ===
    MODEL_HF_NAME: str = "google/gemma-2-2b-it"
    
    # === LoRA ===
    LORA_RANK: int = 32
    LORA_ALPHA: float = 32.0
    
    # === GRPO Algorithm ===
    NUM_GENERATIONS: int = 2      # G in paper (keep low for memory)
    NUM_ITERATIONS: int = 1       # Œº in paper
    BETA: float = 0.04            # KL penalty coefficient
    EPSILON: float = 0.2          # Clipping parameter
    
    # === Generation ===
    MAX_PROMPT_LENGTH: int = 256
    MAX_GENERATION_LENGTH: int = 300  # Your requirement
    TEMPERATURE: float = 0.9      # High for diverse responses
    TOP_P: float = 1.0
    TOP_K: int = 50
    
    # === Training (MEMORY SAFE) ===
    TRAIN_MICRO_BATCH_SIZE: int = 1   # Minimum for safety
    MINI_BATCH_SIZE: int = 1
    
    # === Optimizer ===
    LEARNING_RATE: float = 3e-6
    WARMUP_RATIO: float = 0.1
    WEIGHT_DECAY: float = 0.1
    ADAM_B1: float = 0.9
    ADAM_B2: float = 0.99
    MAX_GRAD_NORM: float = 0.1
    
    # === Schedule ===
    MAX_STEPS: int = 100          # Adjust based on dataset
    EVAL_EVERY_N_STEPS: int = 20
    
    # === Dataset ===
    NUM_TRAIN_SAMPLES: int = 500
    NUM_TEST_SAMPLES: int = 50
    TRAIN_FRACTION: float = 1.0
    
    # === Reward Weights ===
    REWARD_WEIGHT_FORMAT: float = 0.25
    REWARD_WEIGHT_LOGIC: float = 0.30
    REWARD_WEIGHT_ACCURACY: float = 0.45
    
    # === Paths ===
    OUTPUT_DIR: str = "/kaggle/working/grpo_output"
    CKPT_DIR: str = "/kaggle/working/checkpoints"

cfg = Config()

print("="*60)
print("üìù CONFIGURATION")
print("="*60)
print(f"\nModel: {cfg.MODEL_HF_NAME}")
print(f"LoRA: rank={cfg.LORA_RANK}, alpha={cfg.LORA_ALPHA}")
print(f"\nGRPO: G={cfg.NUM_GENERATIONS}, Œº={cfg.NUM_ITERATIONS}, Œ≤={cfg.BETA}, Œµ={cfg.EPSILON}")
print(f"Generation: max_len={cfg.MAX_GENERATION_LENGTH}, temp={cfg.TEMPERATURE}")
print(f"\nTraining: batch={cfg.TRAIN_MICRO_BATCH_SIZE}, steps={cfg.MAX_STEPS}")
print(f"Optimizer: lr={cfg.LEARNING_RATE}, warmup={cfg.WARMUP_RATIO*100}%")
print(f"\nRewards: format={cfg.REWARD_WEIGHT_FORMAT}, logic={cfg.REWARD_WEIGHT_LOGIC}, acc={cfg.REWARD_WEIGHT_ACCURACY}")

## Cell 5: Data Loading (GSM8K)

In [None]:
# =============================================================================
# LOAD DATASET
# =============================================================================
import re
from datasets import load_dataset

print("üìö Loading GSM8K dataset...")

# Load dataset
gsm8k = load_dataset("openai/gsm8k", "main")

def extract_answer(solution: str) -> str:
    """Extract the final numerical answer from GSM8K solution."""
    # GSM8K answers are formatted as: #### <number>
    match = re.search(r'####\s*([\d,\.\-]+)', solution)
    if match:
        return match.group(1).replace(',', '')
    return ""

def format_prompt(question: str) -> str:
    """Format question with reasoning instructions."""
    return f"""You are a helpful math tutor. Solve this problem step by step.

Problem: {question}

Think through this carefully. Show your reasoning inside <reasoning></reasoning> tags, then give your final numerical answer inside <answer></answer> tags."""

# Prepare datasets
def prepare_dataset(split, max_samples):
    data = gsm8k[split].select(range(min(max_samples, len(gsm8k[split]))))
    prompts = []
    answers = []
    for item in data:
        prompts.append(format_prompt(item['question']))
        answers.append(extract_answer(item['answer']))
    return prompts, answers

train_prompts, train_answers = prepare_dataset('train', cfg.NUM_TRAIN_SAMPLES)
test_prompts, test_answers = prepare_dataset('test', cfg.NUM_TEST_SAMPLES)

print(f"\n‚úÖ Loaded {len(train_prompts)} training samples")
print(f"‚úÖ Loaded {len(test_prompts)} test samples")

# Show example
print("\n" + "="*60)
print("Example prompt:")
print("="*60)
print(train_prompts[0][:500] + "...")
print(f"\nExpected answer: {train_answers[0]}")

## Cell 6: Load Model & Tokenizer

In [None]:
# =============================================================================
# LOAD MODEL & TOKENIZER
# =============================================================================
import os
import kagglehub
from flax import nnx

print("="*60)
print("ü§ñ LOADING MODEL & TOKENIZER")
print("="*60)

# Download model from Kaggle
print("\n‚¨áÔ∏è Downloading model from Kaggle...")
model_path = kagglehub.model_download("google/gemma-2/flax/gemma2-2b-it")
print(f"‚úÖ Model downloaded to: {model_path}")

# Import tunix model components
from tunix.models.gemma2 import model as gemma_lib
from tunix.models.gemma2 import params_safetensors as params_safetensors_lib
from tunix.models.gemma2 import transformer_config
from tunix.generate import tokenizer_adapter as tokenizer_lib

# Load tokenizer
print("\n‚¨áÔ∏è Loading tokenizer...")
tokenizer = tokenizer_lib.HuggingFaceTokenizerAdapter.from_pretrained(cfg.MODEL_HF_NAME)
print(f"‚úÖ Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")

# Get EOS tokens
EOS_TOKENS = [tokenizer.eos_token_id] if tokenizer.eos_token_id else [1]
print(f"‚úÖ EOS tokens: {EOS_TOKENS}")

# Get model config
model_config = transformer_config.get_model_config_by_name("gemma2-2b")
print(f"‚úÖ Model config: {model_config.num_layers} layers, {model_config.embed_dim} embed dim")

monitor.print_summary()

## Cell 7: Create Mesh & Load Weights

In [None]:
# =============================================================================
# CREATE MESH & LOAD MODEL WEIGHTS
# =============================================================================
import os

print("="*60)
print("üîß CREATING MESH & LOADING WEIGHTS")
print("="*60)

# Create JAX mesh
mesh = jax.make_mesh(
    MESH_SHAPE,
    MESH_AXIS_NAMES,
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH_SHAPE)
)
print(f"\n‚úÖ Created mesh: {mesh}")

# Load model weights within mesh context
print("\n‚¨áÔ∏è Loading model weights (this may take a few minutes)...")

with mesh:
    gemma = params_safetensors_lib.create_model_from_safe_tensors(
        os.path.abspath(model_path),
        model_config,
        mesh,
        dtype=jnp.bfloat16
    )

print("‚úÖ Model weights loaded")
monitor.print_summary()

## Cell 8: Apply LORA

In [None]:
# =============================================================================
# APPLY LoRA
# =============================================================================
import qwix

print("="*60)
print("üéØ APPLYING LoRA")
print("="*60)

# LoRA target layers - attention and MLP
LORA_TARGETS = [
    ".*attn.q_proj.*",
    ".*attn.k_proj.*",
    ".*attn.v_proj.*",
    ".*attn.o_proj.*",
    ".*mlp.gate_proj.*",
    ".*mlp.up_proj.*",
    ".*mlp.down_proj.*",
]

print(f"\nLoRA config:")
print(f"  Rank: {cfg.LORA_RANK}")
print(f"  Alpha: {cfg.LORA_ALPHA}")
print(f"  Targets: {len(LORA_TARGETS)} layer patterns")

with mesh:
    # Create LoRA actor (trainable)
    actor = qwix.apply_lora(
        gemma,
        rank=cfg.LORA_RANK,
        alpha=cfg.LORA_ALPHA,
        target_modules=LORA_TARGETS,
    )
    
    # Keep reference model frozen (no LoRA)
    reference = gemma

print("\n‚úÖ LoRA applied to actor model")
print("‚úÖ Reference model kept frozen")

# Count parameters
def count_params(model):
    params = nnx.state(model, nnx.Param)
    return sum(p.size for p in jax.tree_util.tree_leaves(params))

print(f"\nParameter counts:")
try:
    actor_params = count_params(actor)
    ref_params = count_params(reference)
    print(f"  Actor (with LoRA): ~{actor_params/1e9:.2f}B")
    print(f"  Reference: ~{ref_params/1e9:.2f}B")
except:
    print("  (Could not count parameters)")

monitor.print_summary()

## Cell 9: Reward Functions

In [None]:
# =============================================================================
# REWARD FUNCTIONS
# =============================================================================
import re
from typing import List, Optional

print("="*60)
print("üèÜ DEFINING REWARD FUNCTIONS")
print("="*60)

def match_format_exactly(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Reward for exact format: <reasoning>...</reasoning><answer>...</answer>
    Weight: 25%
    """
    rewards = []
    pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
    for completion in completions:
        if re.search(pattern, completion, re.DOTALL):
            rewards.append(1.0 * cfg.REWARD_WEIGHT_FORMAT)
        else:
            rewards.append(0.0)
    return rewards

def match_format_approximately(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Partial reward for having reasoning and answer tags.
    """
    rewards = []
    for completion in completions:
        score = 0.0
        if '<reasoning>' in completion and '</reasoning>' in completion:
            score += 0.5
        if '<answer>' in completion and '</answer>' in completion:
            score += 0.5
        rewards.append(score * cfg.REWARD_WEIGHT_FORMAT * 0.5)  # Half weight for partial
    return rewards

def check_answer(prompts: List[str], completions: List[str], 
                 expected_answers: Optional[List[str]] = None, **kwargs) -> List[float]:
    """
    Reward for correct numerical answer.
    Weight: 45%
    """
    if expected_answers is None:
        return [0.0] * len(completions)
    
    rewards = []
    for completion, expected in zip(completions, expected_answers):
        match = re.search(r'<answer>(.*?)</answer>', completion, re.DOTALL)
        if match:
            predicted = match.group(1).strip().replace(',', '')
            expected_clean = expected.strip().replace(',', '')
            
            # Exact match
            if predicted == expected_clean:
                rewards.append(1.0 * cfg.REWARD_WEIGHT_ACCURACY)
            # Try numeric comparison
            else:
                try:
                    pred_num = float(predicted)
                    exp_num = float(expected_clean)
                    if abs(pred_num - exp_num) < 0.01:  # Close enough
                        rewards.append(0.9 * cfg.REWARD_WEIGHT_ACCURACY)
                    elif abs(pred_num - exp_num) / max(abs(exp_num), 1) < 0.1:  # Within 10%
                        rewards.append(0.5 * cfg.REWARD_WEIGHT_ACCURACY)
                    else:
                        rewards.append(0.0)
                except:
                    rewards.append(0.0)
        else:
            rewards.append(0.0)
    return rewards

def check_numbers(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Reward for showing mathematical work (numbers, operations).
    Weight: 30% (logic)
    """
    rewards = []
    for completion in completions:
        score = 0.0
        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', completion, re.DOTALL)
        if reasoning_match:
            reasoning = reasoning_match.group(1)
            
            # Has numbers
            if re.search(r'\d+', reasoning):
                score += 0.3
            
            # Has operations
            if any(op in reasoning for op in ['+', '-', '*', '/', '=']):
                score += 0.3
            
            # Has step indicators
            if any(word in reasoning.lower() for word in ['step', 'first', 'then', 'next', 'so', 'therefore']):
                score += 0.2
            
            # Reasonable length (not too short, not too long)
            if 50 < len(reasoning) < 1500:
                score += 0.2
        
        rewards.append(score * cfg.REWARD_WEIGHT_LOGIC)
    return rewards

print("\n‚úÖ Defined 4 reward functions:")
print(f"   1. match_format_exactly (weight: {cfg.REWARD_WEIGHT_FORMAT*100}%)")
print(f"   2. match_format_approximately")
print(f"   3. check_answer (weight: {cfg.REWARD_WEIGHT_ACCURACY*100}%)")
print(f"   4. check_numbers (weight: {cfg.REWARD_WEIGHT_LOGIC*100}%)")

## Cell 10: GRPO Trainer Setup

In [None]:
# =============================================================================
# GRPO TRAINER SETUP - FIXED TUNIX API
# =============================================================================
import optax
import os

# Correct imports for bleeding-edge tunix
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from orbax import checkpoint as ocp

print("="*60)
print("üöÄ GRPO TRAINER SETUP (FIXED API)")
print("="*60)

# Create output directories
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
os.makedirs(cfg.CKPT_DIR, exist_ok=True)

# ---------------------------------------------------------------------------
# STEP 1: Create REAL Optax Optimizer (NOT a mock!)
# ---------------------------------------------------------------------------
print("\nüìå Step 1: Creating optax optimizer...")

WARMUP_STEPS = int(cfg.WARMUP_RATIO * cfg.MAX_STEPS)

optimizer = optax.chain(
    # Gradient clipping (CRITICAL for stability)
    optax.clip_by_global_norm(max_norm=cfg.MAX_GRAD_NORM),
    # AdamW with warmup + cosine decay
    optax.adamw(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=cfg.LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            decay_steps=cfg.MAX_STEPS,
            end_value=0.0,
        ),
        b1=cfg.ADAM_B1,
        b2=cfg.ADAM_B2,
        weight_decay=cfg.WEIGHT_DECAY,
    ),
)
print(f"   ‚úÖ AdamW optimizer created (lr={cfg.LEARNING_RATE}, warmup={WARMUP_STEPS})")

# ---------------------------------------------------------------------------
# STEP 2: Create Metrics Logging Options
# ---------------------------------------------------------------------------
print("\nüìå Step 2: Creating metrics logger...")

metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_to_stdout=True,
    log_to_tensorboard=True,
    tensorboard_dir=os.path.join(cfg.OUTPUT_DIR, "tensorboard"),
    log_to_wandb=False,  # Disable if causing issues on Kaggle
)
print("   ‚úÖ Metrics logging configured")

# ---------------------------------------------------------------------------
# STEP 3: Create Checkpointing Options
# ---------------------------------------------------------------------------
print("\nüìå Step 3: Creating checkpoint manager...")

checkpointing_options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=cfg.EVAL_EVERY_N_STEPS,
)
print("   ‚úÖ Checkpointing configured")

# ---------------------------------------------------------------------------
# STEP 4: Create RLTrainingConfig (THE CORRECT CLASS!)
# ---------------------------------------------------------------------------
print("\nüìå Step 4: Creating RLTrainingConfig...")

# THIS IS THE KEY FIX - using rl_cluster_lib.RLTrainingConfig
# NOT TrainingConfig or UniversalTrainingConfig
training_config = rl_cluster_lib.RLTrainingConfig(
    # Required: the actual optax optimizer
    actor_optimizer=optimizer,
    
    # Training schedule
    eval_every_n_steps=cfg.EVAL_EVERY_N_STEPS,
    max_steps=cfg.MAX_STEPS,
    
    # Batch sizes (keep small for memory safety)
    mini_batch_size=cfg.MINI_BATCH_SIZE,
    train_micro_batch_size=cfg.TRAIN_MICRO_BATCH_SIZE,
    
    # gradient_accumulation_steps is OPTIONAL - set to None if not needed
    gradient_accumulation_steps=None,
    
    # Metrics and checkpointing
    metrics_logging_options=metrics_logging_options,
    checkpoint_root_directory=cfg.CKPT_DIR,
    checkpointing_options=checkpointing_options,
)
print("   ‚úÖ RLTrainingConfig created")

# ---------------------------------------------------------------------------
# STEP 5: Create RolloutConfig
# ---------------------------------------------------------------------------
print("\nüìå Step 5: Creating RolloutConfig...")

rollout_config = base_rollout.RolloutConfig(
    max_tokens_to_generate=cfg.MAX_GENERATION_LENGTH,
    max_prompt_length=cfg.MAX_PROMPT_LENGTH,
    kv_cache_size=cfg.MAX_PROMPT_LENGTH + cfg.MAX_GENERATION_LENGTH + 256,
    temperature=cfg.TEMPERATURE,
    top_p=cfg.TOP_P,
    top_k=cfg.TOP_K,
    eos_tokens=EOS_TOKENS,
)
print(f"   ‚úÖ RolloutConfig created (max_gen={cfg.MAX_GENERATION_LENGTH})")

# ---------------------------------------------------------------------------
# STEP 6: Create ClusterConfig (NEW API with role_to_mesh)
# ---------------------------------------------------------------------------
print("\nüìå Step 6: Creating ClusterConfig...")

# THIS IS THE KEY FIX - using role_to_mesh dictionary
# NOT the old inference_mesh parameter
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
        # Note: No CRITIC for GRPO (critic-free algorithm)
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,  # Set True if still OOM
    training_config=training_config,
    rollout_config=rollout_config,
)
print("   ‚úÖ ClusterConfig created with role_to_mesh")

# ---------------------------------------------------------------------------
# STEP 7: Create GRPOConfig
# ---------------------------------------------------------------------------
print("\nüìå Step 7: Creating GRPOConfig...")

grpo_config = GRPOConfig(
    num_generations=cfg.NUM_GENERATIONS,
    num_iterations=cfg.NUM_ITERATIONS,
    beta=cfg.BETA,
    epsilon=cfg.EPSILON,
)
print(f"   ‚úÖ GRPOConfig created (G={cfg.NUM_GENERATIONS}, Œ≤={cfg.BETA})")

# ---------------------------------------------------------------------------
# STEP 8: Create RLCluster
# ---------------------------------------------------------------------------
print("\nüìå Step 8: Creating RLCluster...")

rl_cluster = rl_cluster_lib.RLCluster(
    actor=actor,
    reference=reference,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)
print("   ‚úÖ RLCluster created")

# ---------------------------------------------------------------------------
# STEP 9: Create GRPOLearner
# ---------------------------------------------------------------------------
print("\nüìå Step 9: Creating GRPOLearner...")

# Note: parameter is "algo_config", NOT "grpo_config"
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    algo_config=grpo_config,
)
print("   ‚úÖ GRPOLearner created")

print("\n" + "="*60)
print("üéâ GRPO TRAINER SETUP COMPLETE!")
print("="*60)

monitor.print_summary()

## Cell 11: Prepare Training Data

In [None]:
# =============================================================================
# PREPARE TRAINING DATA
# =============================================================================
import grain

print("="*60)
print("üìä PREPARING TRAINING DATA")
print("="*60)

# Create training dataset in the format tunix expects
# Each element should be a dict with 'prompts' key

class GSM8KDataset:
    def __init__(self, prompts, answers):
        self.prompts = prompts
        self.answers = answers
    
    def __len__(self):
        return len(self.prompts)
    
    def __iter__(self):
        for prompt, answer in zip(self.prompts, self.answers):
            yield {
                'prompts': prompt,
                'expected_answers': answer,  # For reward function
            }

train_dataset = GSM8KDataset(train_prompts, train_answers)
val_dataset = GSM8KDataset(test_prompts, test_answers)

print(f"\n‚úÖ Training dataset: {len(train_dataset)} samples")
print(f"‚úÖ Validation dataset: {len(val_dataset)} samples")

# Calculate effective batches
effective_batch = cfg.TRAIN_MICRO_BATCH_SIZE * cfg.NUM_GENERATIONS
num_batches = len(train_dataset) // effective_batch

print(f"\nEffective batch size: {effective_batch}")
print(f"Number of batches per epoch: {num_batches}")
print(f"Total training steps: {cfg.MAX_STEPS}")

## Cell 12: Run Training

In [None]:
# =============================================================================
# RUN TRAINING
# =============================================================================
import time

print("="*60)
print("üèÉ STARTING GRPO TRAINING")
print("="*60)

print(f"\nConfiguration:")
print(f"  ‚Ä¢ Max steps: {cfg.MAX_STEPS}")
print(f"  ‚Ä¢ Eval every: {cfg.EVAL_EVERY_N_STEPS} steps")
print(f"  ‚Ä¢ Generations per prompt (G): {cfg.NUM_GENERATIONS}")
print(f"  ‚Ä¢ Iterations per batch (Œº): {cfg.NUM_ITERATIONS}")
print(f"  ‚Ä¢ Batch size: {cfg.TRAIN_MICRO_BATCH_SIZE}")

print(f"\nReward weights:")
print(f"  ‚Ä¢ Format: {cfg.REWARD_WEIGHT_FORMAT*100}%")
print(f"  ‚Ä¢ Logic: {cfg.REWARD_WEIGHT_LOGIC*100}%")
print(f"  ‚Ä¢ Accuracy: {cfg.REWARD_WEIGHT_ACCURACY*100}%")

print("\n" + "="*60)
print("‚è≥ First steps will take 10-15 minutes for JIT compilation...")
print("   Subsequent steps will be much faster.")
print("="*60 + "\n")

monitor.print_summary()

start_time = time.time()

try:
    with mesh:
        grpo_trainer.train(train_dataset, val_dataset)
    
    elapsed = time.time() - start_time
    
    print("\n" + "="*60)
    print("üéâ TRAINING COMPLETE!")
    print("="*60)
    print(f"   Total time: {elapsed/60:.1f} minutes")
    print(f"   Steps completed: {cfg.MAX_STEPS}")
    
except KeyboardInterrupt:
    elapsed = time.time() - start_time
    print(f"\n‚ö†Ô∏è Training interrupted after {elapsed/60:.1f} minutes")
    
except Exception as e:
    elapsed = time.time() - start_time
    print(f"\n‚ùå Training failed after {elapsed/60:.1f} minutes")
    print(f"   Error: {type(e).__name__}: {e}")
    raise

monitor.print_summary()

## Cell 13: Save Model

In [None]:
# =============================================================================
# SAVE MODEL
# =============================================================================
from orbax import checkpoint as ocp

print("="*60)
print("üíæ SAVING MODEL")
print("="*60)

save_path = os.path.join(cfg.OUTPUT_DIR, "actor_final")
os.makedirs(save_path, exist_ok=True)

with mesh:
    # Get actor state
    _, actor_state = nnx.split(actor)
    
    # Save with Orbax
    checkpointer = ocp.StandardCheckpointer()
    checkpointer.save(save_path, actor_state)

print(f"\n‚úÖ Model saved to: {save_path}")
print(f"\nFiles in output directory:")
for f in os.listdir(cfg.OUTPUT_DIR):
    print(f"   {f}")

## Cell 14: Test Inference

In [None]:
# =============================================================================
# TEST INFERENCE
# =============================================================================
from tunix.generate import sampler as sampler_lib

print("="*60)
print("üß™ TESTING TRAINED MODEL")
print("="*60)

test_questions = [
    ("If a store sells 150 apples at $4 each, what is the total revenue?", "600"),
    ("A train travels 120 miles in 2 hours. What is its speed in miles per hour?", "60"),
    ("Sarah has 24 cookies. She gives 1/3 to her brother and 1/4 to her sister. How many cookies does she have left?", "10"),
]

with mesh:
    sampler = sampler_lib.Sampler(
        model=actor,
        tokenizer=tokenizer,
        max_tokens=cfg.MAX_GENERATION_LENGTH,
    )
    
    for i, (question, expected) in enumerate(test_questions, 1):
        prompt = format_prompt(question)
        
        print(f"\n{'='*60}")
        print(f"Question {i}: {question}")
        print(f"Expected answer: {expected}")
        print("="*60)
        
        output = sampler.generate(prompt)
        print(f"\nModel output:")
        print(output[:1000])
        
        # Extract and check answer
        match = re.search(r'<answer>(.*?)</answer>', output, re.DOTALL)
        if match:
            predicted = match.group(1).strip()
            status = "‚úÖ" if predicted == expected else "‚ùå"
            print(f"\nPredicted: {predicted} {status}")
        else:
            print("\n‚ùå No answer tags found")

## üìã Cell 15: Final Summary

In [None]:
# =============================================================================
# FINAL SUMMARY
# =============================================================================

print("="*60)
print("üèÜ GRPO TRAINING COMPLETE - SUMMARY")
print("="*60)

print(f"""
Architecture:
  ‚Ä¢ Hardware: TPU v5e-8 ({MESH_SHAPE[0]}√ó{MESH_SHAPE[1]} mesh)
  ‚Ä¢ Model: {cfg.MODEL_HF_NAME}
  ‚Ä¢ Fine-Tuning: LoRA (rank={cfg.LORA_RANK}, alpha={cfg.LORA_ALPHA})

GRPO Configuration:
  ‚Ä¢ Generations (G): {cfg.NUM_GENERATIONS}
  ‚Ä¢ Iterations (Œº): {cfg.NUM_ITERATIONS}
  ‚Ä¢ KL Penalty (Œ≤): {cfg.BETA}
  ‚Ä¢ Clipping (Œµ): {cfg.EPSILON}

Reward Weights:
  ‚Ä¢ Format: {cfg.REWARD_WEIGHT_FORMAT*100}%
  ‚Ä¢ Logic: {cfg.REWARD_WEIGHT_LOGIC*100}%
  ‚Ä¢ Accuracy: {cfg.REWARD_WEIGHT_ACCURACY*100}%

Training:
  ‚Ä¢ Steps: {cfg.MAX_STEPS}
  ‚Ä¢ Batch size: {cfg.TRAIN_MICRO_BATCH_SIZE}
  ‚Ä¢ Learning rate: {cfg.LEARNING_RATE}

Output:
  ‚Ä¢ Checkpoint: {cfg.CKPT_DIR}
  ‚Ä¢ Final model: {cfg.OUTPUT_DIR}
""")

monitor.print_summary()

print("\n" + "="*60)
print("‚úÖ Done! Your model is ready.")
print("="*60)