# üß† Gemma2 Reasoning GRPO - Full Production

**Google Tunix Hackathon 2026**

### Core Architecture:
- **Hardware:** TPU v5e-8 (2 Data x 4 Tensor Mesh)
- **Model:** Gemma 2 2B-IT (CPU-Offloaded bf16 Init)
- **Fine-Tuning:** LoRA on Attention + MLP layers
- **Strategy:** GRPO with 16 Parallel Generations per prompt
- **Rewards:** Format (25%) + Logic (30%) + Accuracy (45%) + Self-Correction & Length Regularization

## üì¶ Cell 1: Environment Setup
*Run once, then restart kernel*

In [1]:
import os
import shutil
import glob
import time

# =============================================================================
# FORCE RESET: Smart Cleanup of ALL old markers
# =============================================================================
# This defines the NEW marker we want to create after success
CURRENT_MARKER = "/kaggle/working/.setup_complete_ff"

print("üßπ Cleaning up old installation markers...")

# Use glob to find ANY file starting with ".setup_complete"
old_markers = glob.glob("/kaggle/working/.setup_complete*")

for marker in old_markers:
    # Don't delete the current one if it essentially already exists (though we usually want to overwrite)
    try:
        os.remove(marker)
        print(f"   üóëÔ∏è Deleted old marker: {marker}")
    except OSError as e:
        print(f"   ‚ö†Ô∏è Could not delete {marker}: {e}")

# =============================================================================
# INSTALLATION (Aggressive Cleanup Mode)
# =============================================================================
print("="*60)
print("üîÑ STARTING AGGRESSIVE CLEAN INSTALLATION...")
print("="*60)

# 1. Update pip
%pip install --upgrade pip -q

# 2. AGGRESSIVELY UNINSTALL CONFLICTING PACKAGES
# We must remove these to prevent the "PyExtensionType" error
print("üóëÔ∏è Uninstalling conflicting libraries...")
%pip uninstall -y datasets pyarrow pandas huggingface_hub -q 2>/dev/null

# 3. Install JAX for TPU
print("‚¨áÔ∏è Installing JAX/TPU stack...")
%pip install -q "jax[tpu]>=0.8.0" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# 4. Install Tunix & Qwix
print("‚¨áÔ∏è Installing Tunix & Qwix...")
%pip install -q git+https://github.com/google/tunix
%pip install -q git+https://github.com/google/qwix
%pip uninstall -q flax -y 2>/dev/null
%pip install -q git+https://github.com/google/flax

# 5. CRITICAL FIX: Install Compatible Data Libraries
# We force specific versions that are known to work together
print("‚¨áÔ∏è Installing Data Libraries (The Fix)...")
%pip install -q "numpy==2.0.0" "pyarrow==17.0.0" "datasets>=2.21.0" "pandas>=2.2.0"
%pip install -q kagglehub transformers grain huggingface_hub tensorflow tensorflow_datasets

# 6. Create the NEW marker
with open(CURRENT_MARKER, "w") as f: 
    f.write("done")
    
print("\n" + "="*60)
print("‚úÖ INSTALLATION COMPLETE.")
print("‚ö†Ô∏è PLEASE RESTART KERNEL NOW (‚ü≥ Button)!")
print("="*60)

üßπ Cleaning up old installation markers...
   üóëÔ∏è Deleted old marker: /kaggle/working/.setup_complete_ff
üîÑ STARTING AGGRESSIVE CLEAN INSTALLATION...
[0mNote: you may need to restart the kernel to use updated packages.
üóëÔ∏è Uninstalling conflicting libraries...
Note: you may need to restart the kernel to use updated packages.
‚¨áÔ∏è Installing JAX/TPU stack...
[0mNote: you may need to restart the kernel to use updated packages.
‚¨áÔ∏è Installing Tunix & Qwix...
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
‚¨áÔ∏è Installing Data Libraries (The Fix)...
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-metrax 0.2.4

## üîå Cell 2: Imports & Memory Monitor

In [2]:
import os, re, gc, time, shutil
from pathlib import Path
from typing import List, Dict, Any

import jax
import jax.numpy as jnp
import numpy as np
import optax
import grain
import qwix
import kagglehub

from flax import nnx
from orbax import checkpoint as ocp
from datasets import load_dataset
from transformers import AutoTokenizer

# Tunix imports
try:
    from tunix.models.gemma2 import model as gemma_model_lib
    from tunix.models.gemma2 import params as gemma_params_lib
    print("‚úì Using tunix.models.gemma2")
except ImportError:
    from tunix.models.gemma import model as gemma_model_lib
    from tunix.models.gemma import params as gemma_params_lib
    print("‚úì Using tunix.models.gemma (fallback)")

from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.generate import sampler as sampler_lib

class MemoryMonitor:
    @staticmethod
    def get_usage():
        try:
            stats = [d.memory_stats() for d in jax.devices() if d.memory_stats()]
            if stats:
                used = sum(s['bytes_in_use'] for s in stats)
                limit = sum(s['bytes_limit'] for s in stats)
                return used, limit
        except:
            pass
        return 0, 0
    
    @staticmethod
    def print_summary():
        used, limit = MemoryMonitor.get_usage()
        if limit > 0:
            print(f"  TPU Memory: {used/1e9:.2f}GB / {limit/1e9:.2f}GB ({100*used/limit:.1f}%)")
        else:
            print("  TPU Memory: stats unavailable")
    
    @staticmethod
    def check_available(required_gb=20):
        used, limit = MemoryMonitor.get_usage()
        available = (limit - used) / 1e9
        return available >= required_gb

monitor = MemoryMonitor()
print(f"JAX: {jax.__version__} | TPU Cores: {len(jax.devices())}")
monitor.print_summary()



‚úì Using tunix.models.gemma (fallback)


E0000 00:00:1768971420.405157    4258 common_lib.cc:650] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:238


JAX: 0.8.2 | TPU Cores: 8
  TPU Memory: 0.00GB / 135.27GB (0.0%)


## ‚öôÔ∏è Cell 3: Production Configuration
*16 Generations | 600 Steps | Optimized for TPU v5e-8*

In [3]:
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_VERSION = "gemma2-2b-it"
MODEL_PATH = "google/gemma-2/flax/gemma2-2b-it"
MODEL_HF_NAME = "google/gemma-2-2b-it"

# =============================================================================
# MESH CONFIGURATION
# =============================================================================
MESH_SHAPE = (2, 4)
MESH_AXES = ("fsdp", "tp")

# =============================================================================
# LORA CONFIGURATION
# =============================================================================
LORA_RANK = 64
LORA_ALPHA = 64.0
LORA_TARGET_PATTERN = ".*q_einsum|.*kv_einsum|.*o_proj|.*gate_proj|.*up_proj|.*down_proj"

# =============================================================================
# GRPO CONFIGURATION (SAFETY MODE)
# =============================================================================
NUM_GENERATIONS = 8     # Keep at 8
NUM_ITERATIONS = 2      # Reduced to 2 for speed
BETA = 0.04             
EPSILON = 0.2           

# =============================================================================
# TRAINING CONFIGURATION (SAFETY MODE)
# =============================================================================
MAX_STEPS = 600
LEARNING_RATE = 2e-6
WARMUP_STEPS = 40
WEIGHT_DECAY = 0.01

# CRITICAL CHANGE: Reduced to 1 to prevent OOM
MINI_BATCH_SIZE = 1    
MICRO_BATCH_SIZE = 1    

# =============================================================================
# SEQUENCE CONFIGURATION (SAFETY MODE)
# =============================================================================
MAX_PROMPT_LENGTH = 256
# CRITICAL CHANGE: Reduced to 300 (Saves ~2GB VRAM)
MAX_GENERATION_LENGTH = 300

# =============================================================================
# REWARD WEIGHTS
# =============================================================================
REWARD_WEIGHT_FORMAT = 0.25      
REWARD_WEIGHT_LOGIC = 0.30       
REWARD_WEIGHT_ACCURACY = 0.45    

# =============================================================================
# OUTPUT TAGS
# =============================================================================
TAG_REASONING_START = "<reasoning>"
TAG_REASONING_END = "</reasoning>"
TAG_ANSWER_START = "<answer>"
TAG_ANSWER_END = "</answer>"

# =============================================================================
# PATHS
# =============================================================================
CKPT_DIR = "/kaggle/working/checkpoints"
OUTPUT_DIR = "/kaggle/working/gemma2-grpo-safe"

# =============================================================================
# DATA CONFIGURATION
# =============================================================================
NUM_TRAIN_SAMPLES = 448
NUM_TEST_SAMPLES = 64 
RANDOM_SEED = 42

import os
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("="*60)
print("‚úÖ SAFETY CONFIGURATION LOADED")
print("="*60)
print(f"   Batch Size: {MINI_BATCH_SIZE} (Minimum)")
print(f"   Gen Length: {MAX_GENERATION_LENGTH} (Safe)")

‚úÖ SAFETY CONFIGURATION LOADED
   Batch Size: 1 (Minimum)
   Gen Length: 300 (Safe)


## üéØ Cell 4: Weighted Reward Functions
*Format (25%) + Logic (30%) + Accuracy (45%) + Bonuses*

In [4]:
def format_reward(prompts, completions, **kwargs):
    """
    Reward for correct XML tag structure (25% weight).
    Checks: <reasoning>...</reasoning> followed by <answer>...</answer>
    """
    rewards = []
    for c in completions:
        score = 0.0
        has_r_start = TAG_REASONING_START in c
        has_r_end = TAG_REASONING_END in c
        has_a_start = TAG_ANSWER_START in c
        has_a_end = TAG_ANSWER_END in c
        
        # Full format with correct ordering
        if has_r_start and has_r_end and has_a_start and has_a_end:
            r_end_pos = c.find(TAG_REASONING_END)
            a_start_pos = c.find(TAG_ANSWER_START)
            if r_end_pos < a_start_pos:
                score = 1.0  # Perfect format
            else:
                score = 0.5  # Wrong order
        # Partial format
        elif (has_r_start and has_r_end) or (has_a_start and has_a_end):
            score = 0.3
        
        rewards.append(score * REWARD_WEIGHT_FORMAT)
    return rewards


def logic_reward(prompts, completions, **kwargs):
    """
    Reward for quality reasoning with logical transitions (30% weight).
    Checks: step-by-step reasoning, transition words, mathematical operations.
    """
    TRANSITIONS = [
        'therefore', 'because', 'since', 'so', 'thus', 'hence',
        'first', 'second', 'third', 'then', 'next', 'finally',
        'step', 'calculate', 'compute', 'multiply', 'divide', 'add', 'subtract'
    ]
    
    rewards = []
    for c in completions:
        score = 0.0
        
        # Extract reasoning section
        match = re.search(f'{TAG_REASONING_START}(.*?){TAG_REASONING_END}', c, re.DOTALL | re.IGNORECASE)
        if match:
            content = match.group(1).lower()
            
            # Count transition words (up to 0.4)
            trans_count = sum(1 for t in TRANSITIONS if t in content)
            score += min(0.4, trans_count * 0.05)
            
            # Count explicit steps like "Step 1:", "Step 2:" (up to 0.3)
            step_count = len(re.findall(r'step\s*\d', content))
            score += min(0.3, step_count * 0.1)
            
            # Check for mathematical expressions (up to 0.2)
            math_ops = len(re.findall(r'\d+\s*[+\-√ó√∑*/]\s*\d+', content))
            score += min(0.2, math_ops * 0.05)
            
            # Bonus for "equals" or "=" showing computation (up to 0.1)
            equals_count = content.count('=') + content.count('equals')
            score += min(0.1, equals_count * 0.02)
        
        rewards.append(min(1.0, score) * REWARD_WEIGHT_LOGIC)
    return rewards


def accuracy_reward(prompts, completions, **kwargs):
    """
    Reward for correct numerical answer (45% weight).
    Compares extracted answer with ground truth.
    """
    answers = kwargs.get('answers', [])
    rewards = []
    
    for i, c in enumerate(completions):
        score = 0.0
        
        if i >= len(answers):
            rewards.append(0.0)
            continue
        
        # Extract answer from completion
        match = re.search(f'{TAG_ANSWER_START}(.*?){TAG_ANSWER_END}', c, re.DOTALL)
        if not match:
            rewards.append(0.0)
            continue
        
        pred_text = match.group(1).strip()
        truth_text = str(answers[i]).strip()
        
        # Extract numbers
        pred_nums = re.findall(r'-?\d+\.?\d*', pred_text)
        truth_nums = re.findall(r'-?\d+\.?\d*', truth_text)
        
        if pred_nums and truth_nums:
            try:
                pred = float(pred_nums[-1])  # Use last number as final answer
                truth = float(truth_nums[-1])
                
                if pred == truth:
                    score = 1.0  # Exact match
                elif abs(truth) > 0.001:
                    ratio = pred / truth
                    if 0.99 <= ratio <= 1.01:
                        score = 0.9  # Within 1%
                    elif 0.9 <= ratio <= 1.1:
                        score = 0.5  # Within 10%
            except (ValueError, ZeroDivisionError):
                pass
        
        rewards.append(score * REWARD_WEIGHT_ACCURACY)
    return rewards


def self_correction_reward(prompts, completions, **kwargs):
    """
    Bonus reward for self-correction behavior (additive, max 0.1).
    Encourages the model to catch and fix its own mistakes.
    """
    CORRECTION_PHRASES = [
        "wait", "actually", "let me recalculate", "i made an error",
        "correction", "re-checking", "that's not right", "let me redo"
    ]
    
    rewards = []
    for c in completions:
        c_lower = c.lower()
        count = sum(1 for phrase in CORRECTION_PHRASES if phrase in c_lower)
        # Cap at 0.1 bonus
        rewards.append(min(0.1, count * 0.03))
    return rewards


def length_reward(prompts, completions, **kwargs):
    """
    Length regularization (additive, max 0.05).
    Rewards appropriate response length, penalizes too short/long.
    """
    rewards = []
    for c in completions:
        words = len(c.split())
        if 100 <= words <= 400:
            rewards.append(0.05)  # Ideal length
        elif 50 <= words <= 600:
            rewards.append(0.02)  # Acceptable
        else:
            rewards.append(0.0)   # Too short or too long
    return rewards


# Combine all reward functions
REWARD_FUNCTIONS = [
    format_reward,      # 25% weight
    logic_reward,       # 30% weight  
    accuracy_reward,    # 45% weight
    self_correction_reward,  # Bonus up to 10%
    length_reward,      # Bonus up to 5%
]

print("‚úÖ Reward Functions Defined")
print(f"   ‚Ä¢ format_reward: {REWARD_WEIGHT_FORMAT*100}% weight")
print(f"   ‚Ä¢ logic_reward: {REWARD_WEIGHT_LOGIC*100}% weight")
print(f"   ‚Ä¢ accuracy_reward: {REWARD_WEIGHT_ACCURACY*100}% weight")
print(f"   ‚Ä¢ self_correction_reward: bonus (max 10%)")
print(f"   ‚Ä¢ length_reward: bonus (max 5%)")
print(f"   Total potential: {(REWARD_WEIGHT_FORMAT + REWARD_WEIGHT_LOGIC + REWARD_WEIGHT_ACCURACY)*100}% + 15% bonus")

‚úÖ Reward Functions Defined
   ‚Ä¢ format_reward: 25.0% weight
   ‚Ä¢ logic_reward: 30.0% weight
   ‚Ä¢ accuracy_reward: 45.0% weight
   ‚Ä¢ self_correction_reward: bonus (max 10%)
   ‚Ä¢ length_reward: bonus (max 5%)
   Total potential: 100.0% + 15% bonus


## üìä Cell 5: Data Loading (GSM8K)

In [5]:
# System prompt teaching the expected output format
SYSTEM_PROMPT = f"""You are a mathematical problem solver who shows their work.

For each problem:
1. Think through it step-by-step inside {TAG_REASONING_START} and {TAG_REASONING_END} tags
2. Show your calculations clearly
3. Give your final numerical answer inside {TAG_ANSWER_START} and {TAG_ANSWER_END} tags

Example:
{TAG_REASONING_START}
Step 1: Identify what we need to find.
Step 2: Set up the calculation.
Step 3: 5 √ó 10 = 50
Therefore, the answer is 50.
{TAG_REASONING_END}
{TAG_ANSWER_START}
50
{TAG_ANSWER_END}"""

def format_prompt(question: str) -> str:
    """Format a question with the system prompt."""
    return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nSolution:"

def extract_answer(answer_text: str) -> str:
    """Extract the final answer from GSM8K format (after ####)."""
    if '####' in answer_text:
        return answer_text.split('####')[-1].strip()
    return answer_text.strip()

# Load GSM8K dataset
print("Loading GSM8K dataset...")
# FIX: Use "gsm8k" instead of "openai/gsm8k" to prevent path resolution errors
dataset = load_dataset("gsm8k", "main", split="train")
dataset = dataset.shuffle(seed=RANDOM_SEED)

# Prepare data
all_data = []
for i, item in enumerate(dataset):
    # Fetch enough data to cover both train and test requests
    if i >= NUM_TRAIN_SAMPLES + NUM_TEST_SAMPLES + 16: # Buffer for filtering
        break
    all_data.append({
        'prompt': format_prompt(item['question']),
        'answer': extract_answer(item['answer'])
    })

# Ensure sizes are multiples of 8 (Required for TPU mesh distribution)
# We calculate the largest multiple of 8 that fits within your requested sample size
train_size = (min(len(all_data), NUM_TRAIN_SAMPLES) // 8) * 8
remaining_after_train = len(all_data) - train_size
test_size = (min(remaining_after_train, NUM_TEST_SAMPLES) // 8) * 8

train_data = all_data[:train_size]
test_data = all_data[train_size : train_size + test_size]

# Verify batch alignment
full_batch_train = len(train_data) * NUM_GENERATIONS
full_batch_test = len(test_data) * NUM_GENERATIONS

print(f"‚úÖ Data Loaded")
print(f"   Train: {len(train_data)} samples √ó {NUM_GENERATIONS} gen = {full_batch_train} (√∑{MINI_BATCH_SIZE}={full_batch_train//MINI_BATCH_SIZE})")
print(f"   Test: {len(test_data)} samples √ó {NUM_GENERATIONS} gen = {full_batch_test} (√∑{MINI_BATCH_SIZE}={full_batch_test//MINI_BATCH_SIZE})")

Loading GSM8K dataset...
‚úÖ Data Loaded
   Train: 448 samples √ó 8 gen = 3584 (√∑1=3584)
   Test: 64 samples √ó 8 gen = 512 (√∑1=512)


## üîß Cell 6: Mesh & Tokenizer

In [6]:
# Create mesh for distributed training
mesh = jax.make_mesh(
    MESH_SHAPE, 
    MESH_AXES, 
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH_SHAPE)
)
print(f"‚úì Mesh created: {mesh}")

# Get HuggingFace token
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
    try:
        from kaggle_secrets import UserSecretsClient
        hf_token = UserSecretsClient().get_secret("HF_TOKEN")
        print("‚úì HF Token from Kaggle Secrets")
    except Exception as e:
        print(f"‚ö†Ô∏è No HF_TOKEN found: {e}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_HF_NAME, token=hf_token)

# Get EOS tokens for generation stopping
EOS_TOKENS = [tokenizer.eos_token_id]
try:
    end_turn_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
    if end_turn_id != tokenizer.unk_token_id:
        EOS_TOKENS.append(end_turn_id)
except:
    pass

print(f"‚úÖ Tokenizer Ready | EOS tokens: {EOS_TOKENS}")

‚úì Mesh created: Mesh('fsdp': 2, 'tp': 4, axis_types=(Auto, Auto))
‚úì HF Token from Kaggle Secrets
‚úÖ Tokenizer Ready | EOS tokens: [1, 107]


## ü§ñ Cell 7: CPU-Offload Model Loading + LoRA
*Build on CPU first to prevent TPU OOM during initialization*

In [7]:
# =============================================================================
# STEP 1: Clean memory
# =============================================================================
gc.collect()
jax.clear_caches()
print("1. Memory cleared")
monitor.print_summary()

# =============================================================================
# STEP 2: Download model
# =============================================================================
print("\n2. Downloading model from Kaggle...")
k_path = kagglehub.model_download(MODEL_PATH)
ckpt_path = os.path.join(k_path, MODEL_VERSION)
print(f"   Path: {ckpt_path}")

# =============================================================================
# STEP 3: Get CPU device for offloading
# =============================================================================
try:
    cpu_device = jax.devices('cpu')[0]
    print(f"\n3. CPU device for offloading: {cpu_device}")
except:
    cpu_device = jax.devices()[0] # Fallback if no CPU (unlikely)
    print("\n3. No separate CPU device found")

# =============================================================================
# STEP 4: Build model & Apply LoRA (FORCED ON CPU)
# =============================================================================
# CRITICAL FIX: We use 'with jax.default_device(cpu_device)' to ensure 
# initial random weights are allocated on RAM, not TPU VRAM.
print("\n4. Building model and applying LoRA on CPU...")

with jax.default_device(cpu_device):
    # A. Build the model structure (allocates on CPU)
    print("   Initializing model structure...")
    config = gemma_model_lib.ModelConfig.gemma2_2b_it()
    base_model = gemma_model_lib.Gemma(config, rngs=nnx.Rngs(RANDOM_SEED))
    
    # B. Load weights (allocates on CPU)
    print("   Loading checkpoint weights...")
    raw_params = gemma_params_lib.load_and_format_params(ckpt_path)
    print("   Converting to bf16...")
    bf16_params = jax.tree.map(lambda x: x.astype(jnp.bfloat16), raw_params)
    nnx.update(base_model, bf16_params)
    
    # Clean up raw params to free RAM
    del raw_params, bf16_params
    gc.collect()
    
    # C. Split into Reference and Actor
    print("   Creating Actor/Reference copies...")
    graph, state = nnx.split(base_model)
    ref_model = nnx.merge(graph, state) # Frozen reference
    actor_base = nnx.merge(graph, state) # Base for actor
    
    # D. Apply LoRA to Actor (happens on CPU)
    print(f"   Applying LoRA (rank={LORA_RANK})...")
    lora_provider = qwix.LoraProvider(
        module_path=LORA_TARGET_PATTERN,
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    
    # Get shape and apply
    model_input = actor_base.get_model_input()
    actor = qwix.apply_lora_to_model(
        actor_base,
        lora_provider,
        rngs=nnx.Rngs(RANDOM_SEED),
        **model_input
    )
    print("   ‚úì LoRA applied on CPU")

# Clean up base model artifacts
del base_model, actor_base
gc.collect()

# =============================================================================
# STEP 5: Shard EVERYTHING to TPU mesh
# =============================================================================
print("\n5. Sharding models to TPU mesh (2√ó4)...")

def shard_model_to_tpu(model, name):
    print(f"   Moving {name} to TPU...")
    with mesh:
        # 1. Get current state (on CPU)
        state = nnx.state(model)
        # 2. Calculate where each slice belongs
        pspecs = nnx.get_partition_spec(state)
        # 3. Move it (This is the heavy data transfer)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        # 4. Update the model wrapper
        nnx.update(model, sharded_state)

shard_model_to_tpu(ref_model, "Reference Model")
shard_model_to_tpu(actor, "Actor Model")

print("   ‚úì All models sharded successfully")
monitor.print_summary()

# =============================================================================
# FINAL STATUS
# =============================================================================
print("\n" + "="*60)
print("‚úÖ MODEL READY")
print("="*60)
print(f"   Actor: Gemma2-2B-IT + LoRA (rank={LORA_RANK})")
print(f"   Reference: Gemma2-2B-IT (frozen)")
monitor.print_summary()

1. Memory cleared
  TPU Memory: 0.00GB / 135.27GB (0.0%)

2. Downloading model from Kaggle...
   Path: /kaggle/input/gemma-2/flax/gemma2-2b-it/1/gemma2-2b-it

3. CPU device for offloading: TFRT_CPU_0

4. Building model and applying LoRA on CPU...
   Initializing model structure...


ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x78248676dd40> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x78248676dd40> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x78248676dd

   Loading checkpoint weights...
   Converting to bf16...


ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-43' coro=<_async_in_context.<locals>.run_in_context() running at /usr/local/lib/python3.12/site-packages/ipykernel/utils.py:60> wait_for=<Task pending name='Task-2' coro=<Kernel.shell_main() running at /usr/local/lib/python3.12/site-packages/ipykernel/kernelbase.py:597> cb=[Task.task_wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /usr/local/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
  gc.collect()
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-2' coro=<Kernel.shell_main() running at /usr/local/lib/python3.12/site-packages/ipykernel/kernelbase.py:597> cb=[Task.task_wakeup()]>


   Creating Actor/Reference copies...
   Applying LoRA (rank=64)...
   ‚úì LoRA applied on CPU

5. Sharding models to TPU mesh (2√ó4)...
   Moving Reference Model to TPU...
   Moving Actor Model to TPU...
   ‚úì All models sharded successfully
  TPU Memory: 105.33GB / 135.27GB (77.9%)

‚úÖ MODEL READY
   Actor: Gemma2-2B-IT + LoRA (rank=64)
   Reference: Gemma2-2B-IT (frozen)
  TPU Memory: 105.33GB / 135.27GB (77.9%)


## üìö Cell 8: Data Loaders (Grain MapDataset)

In [8]:
class GSM8KDataSource:
    """Data source compatible with Grain MapDataset."""
    def __init__(self, data):
        self._data = data
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, idx):
        item = self._data[idx]
        # Tunix GRPO expects 'prompts' and 'answers' keys
        return {
            'prompts': item['prompt'],
            'answers': item['answer']
        }

# Create Grain MapDataset pipelines
# Using the same pattern as official Tunix GRPO demo
# FIX: Added .batch() to bundle examples into groups of 8
train_dataset = (
    grain.MapDataset.source(GSM8KDataSource(train_data))
    .shuffle(seed=RANDOM_SEED)
    .batch(MINI_BATCH_SIZE, drop_remainder=True)
)

val_dataset = (
    grain.MapDataset.source(GSM8KDataSource(test_data))
    .batch(MINI_BATCH_SIZE, drop_remainder=True)
)

print("‚úÖ Grain Datasets Created (Batched)")
print(f"   Train: {len(train_data)} prompts -> {len(train_data)//MINI_BATCH_SIZE} batches")
print(f"   Val: {len(test_data)} prompts -> {len(test_data)//MINI_BATCH_SIZE} batches")

‚úÖ Grain Datasets Created (Batched)
   Train: 448 prompts -> 448 batches
   Val: 64 prompts -> 64 batches


## ‚ö° Cell 9: Training Configuration

In [9]:
# Learning rate schedule with warmup and cosine decay
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    decay_steps=MAX_STEPS,
    end_value=LEARNING_RATE * 0.1
)

# Optimizer with gradient clipping
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(lr_schedule, weight_decay=WEIGHT_DECAY)
)

# Cluster configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        max_steps=MAX_STEPS,
        eval_every_n_steps=25,
        mini_batch_size=MINI_BATCH_SIZE,
        train_micro_batch_size=MICRO_BATCH_SIZE,
        checkpoint_root_directory=CKPT_DIR,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=MAX_GENERATION_LENGTH,
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=0.9,
        top_p=0.95,
        eos_tokens=EOS_TOKENS,
    )
)

# GRPO configuration - 16 generations!
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,  # G=16
    num_iterations=NUM_ITERATIONS,    # Œº=3
    beta=BETA,                        # Œ≤=0.04
    epsilon=EPSILON,                  # Œµ=0.2
)

print("‚úÖ Training Configuration")
print(f"   LR Schedule: warmup {WARMUP_STEPS} ‚Üí {LEARNING_RATE} ‚Üí cosine decay")
print(f"   Optimizer: AdamW (clip=1.0, decay={WEIGHT_DECAY})")
print(f"   GRPO: G={NUM_GENERATIONS}, Œº={NUM_ITERATIONS}, Œ≤={BETA}, Œµ={EPSILON}")
print(f"   Eval every {25} steps")

‚úÖ Training Configuration
   LR Schedule: warmup 40 ‚Üí 2e-06 ‚Üí cosine decay
   Optimizer: AdamW (clip=1.0, decay=0.01)
   GRPO: G=8, Œº=2, Œ≤=0.04, Œµ=0.2
   Eval every 25 steps


## üéì Cell 10: Create GRPO Trainer

In [12]:
# =============================================================================
# STEP 1: Smart Imports & Component Recovery
# =============================================================================
import tunix
import inspect
import gc
import jax
import optax  # Needed for the Real Optimizer Fix
from tunix.rl.grpo import grpo_learner
from tunix.rl import rl_cluster
from tunix.rl import rl_learner
from tunix.rl.rollout import base_rollout

print("1. Initializing Tunix components...")

# --- FIX 0: CLEANUP MEMORY ---
gc.collect()
jax.clear_caches()

# --- FIX 1: Find ClusterConfig ---
ClusterConfigClass = None
for name, obj in inspect.getmembers(rl_cluster):
    if inspect.isclass(obj) and "ClusterConfig" in name:
        ClusterConfigClass = obj
        break
if not ClusterConfigClass:
    raise ImportError("CRITICAL: No ClusterConfig class found.")

# --- FIX 2: Find TrainingConfig ---
TrainingConfigClass = None
search_modules = [rl_learner, grpo_learner]
try:
    import tunix.common.config as common_config
    search_modules.append(common_config)
except ImportError: pass

for mod in search_modules:
    for name, obj in inspect.getmembers(mod):
        if inspect.isclass(obj) and "TrainingConfig" in name:
            TrainingConfigClass = obj
            break
    if TrainingConfigClass: break

# --- FIX 3: Roles ---
def get_valid_role(options):
    for name in options:
        if hasattr(tunix.Role, name):
            return getattr(tunix.Role, name)
    return options[0]

ACTOR_ROLE = get_valid_role(["ACTOR_TRAIN", "ACTOR", "POLICY"])
CRITIC_ROLE = get_valid_role(["CRITIC_TRAIN", "CRITIC", "VALUE"])
ROLLOUT_ROLE = get_valid_role(["ROLLOUT", "INFERENCE"]) 
REFERENCE_ROLE = get_valid_role(["REFERENCE", "REF"])

# =============================================================================
# STEP 2: Configure Training (The Deep Universal Mock)
# =============================================================================
print("2. Building Training Configuration...")

train_args = {
    "total_steps": MAX_STEPS,
    "save_every_steps": 100,
    "checkpoint_dir": CKPT_DIR,
    "rollout_micro_batch_size": MINI_BATCH_SIZE,
    "compute_logps_micro_batch_size": MINI_BATCH_SIZE,
    "train_micro_batch_size": MINI_BATCH_SIZE,
    # Pass LR here so the mock can use it
    "learning_rate": LEARNING_RATE,
    "weight_decay": WEIGHT_DECAY
}

if TrainingConfigClass:
    # --- PATH A: Use Real Class ---
    # (We assume Real Class handles its own optimizer creation if found)
    training_config = TrainingConfigClass(**train_args)
else:
    # --- PATH B: Deep Universal Mock (The Fix) ---
    print("   ‚ö†Ô∏è TrainingConfig not found. Using Deep Universal Mock.")
    
    class DeepUniversalTrainingConfig:
        def __init__(self, **kwargs):
            # 1. Absorb provided args
            self.__dict__.update(kwargs)
            
            # FIX 1: Add Missing Scalar Attributes (The Error You Saw)
            if not hasattr(self, 'gradient_accumulation_steps'):
                self.gradient_accumulation_steps = 1
            if not hasattr(self, 'max_grad_norm'):
                self.max_grad_norm = 1.0
            if not hasattr(self, 'checkpoint_root_directory'):
                self.checkpoint_root_directory = kwargs.get('checkpoint_dir', CKPT_DIR)
            if not hasattr(self, 'metrics_logging_options'):
                self.metrics_logging_options = None

            # FIX 2: Create REAL Optimizers (The Hidden Trap)
            # The library expects an actual Optax object, not a config dict.
            lr = kwargs.get('learning_rate', 2e-6)
            wd = kwargs.get('weight_decay', 0.01)
            
            # Create a standard AdamW optimizer
            # We wrap it in a chain to simulate a real setup
            real_optimizer = optax.chain(
                optax.clip_by_global_norm(1.0),
                optax.adamw(learning_rate=lr, weight_decay=wd)
            )
            
            self.actor_optimizer = real_optimizer
            self.critic_optimizer = real_optimizer
            
            # Keep the config object too, just in case
            class MockOptConfig: pass
            self.optimizer_config = MockOptConfig()
            self.optimizer_config.learning_rate = lr
            self.optimizer_config.weight_decay = wd

    training_config = DeepUniversalTrainingConfig(**train_args)

# =============================================================================
# STEP 3: Configure Cluster (Static Memory)
# =============================================================================
print("3. Building Cluster Configuration...")

role_to_mesh_dict = {
    ACTOR_ROLE: mesh,
    CRITIC_ROLE: mesh,
    ROLLOUT_ROLE: mesh,
    REFERENCE_ROLE: mesh
}

cluster_config = ClusterConfigClass(
    role_to_mesh=role_to_mesh_dict,
    training_config=training_config,
    
    # STATIC MEMORY MODE (Crucial for Stability)
    offload_to_cpu=False,
    
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=MAX_GENERATION_LENGTH,
        max_prompt_length=MAX_PROMPT_LENGTH,
        temperature=1.0,
        top_k=50,
        top_p=0.95,
        eos_tokens=[tokenizer.eos_token_id, 107],
        seed=RANDOM_SEED,
    ),
)

# =============================================================================
# STEP 4: Instantiate Components
# =============================================================================
print("4. Instantiating Cluster...")

cluster = rl_cluster.RLCluster(
    cluster_config=cluster_config,
    tokenizer=tokenizer,
    actor=actor,
    reference=ref_model,
    critic=None,
    reward=None
)

print("5. Building GRPO Trainer...")
algorithm_config = grpo_learner.GRPOConfig(
    num_generations=NUM_GENERATIONS,
    epsilon=EPSILON,
    beta=BETA,
    num_iterations=NUM_ITERATIONS,
)

grpo_trainer = grpo_learner.GRPOLearner(
    rl_cluster=cluster,
    algorithm_config=algorithm_config,
    training_config=training_config, 
    reward_functions=[
        format_reward_func,
        logic_reward_func,
        accuracy_reward_func
    ],
    reward_weights=[
        REWARD_WEIGHT_FORMAT,
        REWARD_WEIGHT_LOGIC,
        REWARD_WEIGHT_ACCURACY
    ]
)

print("\n" + "="*60)
print("‚úÖ GRPO TRAINER READY")
print("="*60)
monitor.print_summary()



1. Initializing Tunix components...
2. Building Training Configuration...
   ‚ö†Ô∏è TrainingConfig not found. Using Universal Mock.
3. Building Cluster Configuration...
4. Instantiating Cluster...


AttributeError: 'UniversalTrainingConfig' object has no attribute 'gradient_accumulation_steps'

## üöÄ Cell 11: Training Loop
*First steps take 10-15 minutes for JIT compilation*

In [None]:
print("="*60)
print("üöÄ STARTING GRPO TRAINING")
print("="*60)
print(f"")
print(f"Configuration:")
print(f"  ‚Ä¢ Steps: {MAX_STEPS}")
print(f"  ‚Ä¢ Generations per prompt: {NUM_GENERATIONS}")
print(f"  ‚Ä¢ Iterations per batch: {NUM_ITERATIONS}")
print(f"  ‚Ä¢ Train samples: {NUM_TRAIN_SAMPLES}")
print(f"")
print(f"Reward Weights:")
print(f"  ‚Ä¢ Format: {REWARD_WEIGHT_FORMAT*100}%")
print(f"  ‚Ä¢ Logic: {REWARD_WEIGHT_LOGIC*100}%")
print(f"  ‚Ä¢ Accuracy: {REWARD_WEIGHT_ACCURACY*100}%")
print(f"")
print("="*60)
print("\n‚è≥ First steps will take 10-15 minutes for JIT compilation...")
print("   Subsequent steps will be much faster.\n")

monitor.print_summary()
start_time = time.time()

# Run training
with mesh:
    grpo_trainer.train(train_dataset, val_dataset)

elapsed = time.time() - start_time
print("\n" + "="*60)
print("üéâ TRAINING COMPLETE")
print("="*60)
print(f"   Total time: {elapsed/60:.1f} minutes")
print(f"   Steps completed: {MAX_STEPS}")
monitor.print_summary()

## üíæ Cell 12: Save Model

In [None]:
print("Saving trained model...")

os.makedirs(OUTPUT_DIR, exist_ok=True)

with mesh:
    _, actor_state = nnx.split(actor)
    checkpointer = ocp.StandardCheckpointer()
    save_path = os.path.join(OUTPUT_DIR, "actor_state")
    checkpointer.save(save_path, actor_state)

print(f"‚úÖ Model saved to {OUTPUT_DIR}")
print(f"   Checkpoint: {save_path}")

## üß™ Cell 13: Test Inference

In [None]:
test_questions = [
    ("If a store sells 150 apples at $4 each, what is the total revenue?", "600"),
    ("A train travels 120 miles in 2 hours. What is its speed in miles per hour?", "60"),
    ("Sarah has 24 cookies. She gives 1/3 to her brother and 1/4 to her sister. How many cookies does she have left?", "10"),
]

print("="*60)
print("üß™ TESTING TRAINED MODEL")
print("="*60 + "\n")

with mesh:
    sampler = sampler_lib.Sampler(
        model=actor,
        tokenizer=tokenizer,
        max_tokens=MAX_GENERATION_LENGTH
    )
    
    for i, (question, expected) in enumerate(test_questions, 1):
        prompt = format_prompt(question)
        output = sampler.generate(prompt)
        
        print(f"Question {i}: {question}")
        print(f"Expected: {expected}")
        print(f"Output:")
        print(output[:800])
        print("-" * 60 + "\n")

## üìã Cell 14: Final Summary

In [None]:
print("="*60)
print("üèÜ PRODUCTION REASONING MODEL COMPLETE")
print("="*60)
print(f"")
print(f"Architecture:")
print(f"  ‚Ä¢ Hardware: TPU v5e-8 (2√ó4 mesh)")
print(f"  ‚Ä¢ Model: {MODEL_HF_NAME}")
print(f"  ‚Ä¢ Fine-Tuning: LoRA (rank={LORA_RANK}, alpha={LORA_ALPHA})")
print(f"  ‚Ä¢ Targets: Attention + MLP layers")
print(f"")
print(f"GRPO Configuration:")
print(f"  ‚Ä¢ Generations (G): {NUM_GENERATIONS}")
print(f"  ‚Ä¢ Iterations (Œº): {NUM_ITERATIONS}")
print(f"  ‚Ä¢ KL Penalty (Œ≤): {BETA}")
print(f"  ‚Ä¢ Clipping (Œµ): {EPSILON}")
print(f"")
print(f"Reward Weights:")
print(f"  ‚Ä¢ Format: {REWARD_WEIGHT_FORMAT*100}%")
print(f"  ‚Ä¢ Logic: {REWARD_WEIGHT_LOGIC*100}%")
print(f"  ‚Ä¢ Accuracy: {REWARD_WEIGHT_ACCURACY*100}%")
print(f"  ‚Ä¢ Self-Correction: bonus")
print(f"  ‚Ä¢ Length: bonus")
print(f"")
print(f"Training:")
print(f"  ‚Ä¢ Steps: {MAX_STEPS}")
print(f"  ‚Ä¢ Samples: {NUM_TRAIN_SAMPLES} train, {NUM_TEST_SAMPLES} test")
print(f"  ‚Ä¢ Output: {OUTPUT_DIR}")
print(f"")
monitor.print_summary()
print("="*60)