# Qwen 1.5B GRPO Training on Colab TPU

This notebook runs GRPO on **Qwen2.5-Coder-1.5B-Instruct** using **TPU v2-8**.

## TPU vs GPU

TPUs have more memory (8GB per core x 8 cores = 64GB total) but require special setup:
- Use `torch_xla` for TPU support
- Models need to be moved to TPU device explicitly
- Some operations work differently

## Goal

Fix Trapping Rain Water (only Qwen 1.5B failure: 4/5 = 80%)

---

## Step 1: Check Runtime Type

Make sure you selected **TPU** runtime:
- Runtime -> Change runtime type -> TPU

In [None]:
import os
import sys

print("=" * 60)
print("STEP 1: INSTALL DEPENDENCIES")
print("=" * 60)

# Check if we might be on TPU (install torch_xla)
might_be_tpu = 'COLAB_TPU_ADDR' in os.environ or 'TPU_NAME' in os.environ

if might_be_tpu:
    print("TPU environment variable detected, installing torch_xla...")
    os.system("pip install -q torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html")
else:
    print("No TPU env var found, skipping torch_xla")

os.system("pip install -q transformers accelerate sentencepiece")
print("Core dependencies installed!")

## Step 2: Install Dependencies

In [None]:
print("=" * 60)
print("STEP 2: DETECT RUNTIME")
print("=" * 60)

USE_TPU = False

# Method 1: Check COLAB_TPU_ADDR (TPU v2/v3)
if 'COLAB_TPU_ADDR' in os.environ:
    print(f"TPU Address: {os.environ['COLAB_TPU_ADDR']}")
    USE_TPU = True

# Method 2: Check for TPU_NAME (TPU v4/newer)
elif 'TPU_NAME' in os.environ:
    print(f"TPU Name: {os.environ['TPU_NAME']}")
    USE_TPU = True

# Method 3: Try to detect TPU via torch_xla (if installed)
else:
    try:
        import torch_xla
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        print(f"TPU detected via torch_xla: {device}")
        USE_TPU = True
    except ImportError:
        print("torch_xla not installed (not on TPU)")
    except Exception as e:
        print(f"torch_xla error: {e}")

if USE_TPU:
    print("\n✅ TPU detected! Using TPU runtime.")
else:
    # Check for GPU as fallback
    import torch
    if torch.cuda.is_available():
        print(f"\n✅ GPU detected: {torch.cuda.get_device_name(0)}")
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"   GPU Memory: {gpu_mem:.1f} GB")
    else:
        print("\n⚠️  No TPU or GPU found!")
        print("   Go to Runtime -> Change runtime type -> TPU or GPU")
        print("   Continuing anyway (will use CPU, very slow)...")

print(f"\nUSE_TPU = {USE_TPU}")

## Step 3: Setup TPU Device

In [None]:
import torch
import gc

if USE_TPU:
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    # Get TPU device
    device = xm.xla_device()
    print(f"TPU Device: {device}")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

def clear_memory():
    gc.collect()
    if USE_TPU:
        xm.mark_step()  # Sync TPU
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"\nDevice ready: {device}")

## Step 4: Configuration

In [None]:
CONFIG = {
    "model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
    "num_steps": 10,
    "num_generations": 4,      # TPU has more memory
    "learning_rate": 5e-5,
    "beta": 0.04,
    "max_new_tokens": 512,
    "difficulty": 5,
    "num_test_cases": 5,
    "seed": 42,
}

print("=" * 60)
print("CONFIGURATION")
print("=" * 60)
for k, v in CONFIG.items():
    print(f"  {k:20} = {v}")

## Step 5: Problem Generator

In [None]:
import random
from dataclasses import dataclass
from typing import List, Any

@dataclass
class TestCase:
    input_args: List[Any]
    expected_output: Any

@dataclass 
class Problem:
    problem_id: str
    title: str
    description: str
    function_signature: str
    function_name: str
    test_cases: List[TestCase]
    
    def to_prompt(self) -> str:
        examples = "\n".join(
            f"  {self.function_name}({repr(tc.input_args[0])}) -> {repr(tc.expected_output)}"
            for tc in self.test_cases[:3]
        )
        return f"""## {self.title}

{self.description}

### Function Signature
```python
{self.function_signature}
```

### Examples
```python
{examples}
```

Write ONLY the function implementation."""


class TrappingRainWaterGenerator:
    def __init__(self, seed=None):
        self.rng = random.Random(seed)
        self._counter = 0
    
    def _solve(self, height):
        if not height: return 0
        left, right = 0, len(height) - 1
        left_max = right_max = water = 0
        while left < right:
            if height[left] < height[right]:
                if height[left] >= left_max: left_max = height[left]
                else: water += left_max - height[left]
                left += 1
            else:
                if height[right] >= right_max: right_max = height[right]
                else: water += right_max - height[right]
                right -= 1
        return water
    
    def generate(self, difficulty=5, num_test_cases=5):
        self._counter += 1
        test_cases = []
        for _ in range(num_test_cases):
            length = 5 + difficulty * 2
            heights = [self.rng.randint(0, 3 + difficulty) for _ in range(length)]
            test_cases.append(TestCase([heights], self._solve(heights)))
        
        return Problem(
            problem_id=f"trap_{self._counter}",
            title="Trapping Rain Water",
            description="Given n non-negative integers representing an elevation map, compute how much water it can trap after raining.",
            function_signature="def trap(height: list) -> int:",
            function_name="trap",
            test_cases=test_cases,
        )

# Test
gen = TrappingRainWaterGenerator(seed=42)
p = gen.generate()
print(f"Generated problem with {len(p.test_cases)} test cases")

## Step 6: Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

print("=" * 60)
print("LOADING MODEL")
print("=" * 60)
print(f"Model: {CONFIG['model_name']}")
print("Loading... (this takes 2-3 minutes on TPU)")

tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded!")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float32 if USE_TPU else torch.float16,  # TPU prefers float32
)

# Move to device
model = model.to(device)
print(f"Model loaded on: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

## Step 7: Helper Functions

In [None]:
import re

def extract_code(response):
    for pattern in [r"```python\n(.*?)```", r"```\n(.*?)```"]:
        matches = re.findall(pattern, response, re.DOTALL)
        if matches: return matches[-1].strip()
    if "def " in response:
        lines, code, in_func = response.split("\n"), [], False
        for line in lines:
            if line.strip().startswith("def "): in_func = True
            if in_func: code.append(line)
        if code: return "\n".join(code).strip()
    return response

def verify(code, problem):
    namespace = {}
    try:
        exec(code, namespace)
    except: return False, 0.0
    
    func = namespace.get(problem.function_name)
    if not func:
        funcs = [v for k, v in namespace.items() if callable(v) and not k.startswith("_")]
        func = funcs[0] if funcs else None
    if not func: return False, 0.0
    
    passed = sum(1 for tc in problem.test_cases 
                 if _safe_call(func, tc.input_args) == tc.expected_output)
    return passed == len(problem.test_cases), passed / len(problem.test_cases)

def _safe_call(func, args):
    try: return func(*args)
    except: return None

def generate_solution(model, tokenizer, problem, temp=0.2):
    messages = [
        {"role": "system", "content": "You are an expert Python programmer."},
        {"role": "user", "content": problem.to_prompt()},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    model.eval()
    with torch.no_grad():
        out = model.generate(
            **inputs, 
            max_new_tokens=CONFIG["max_new_tokens"],
            do_sample=True, 
            temperature=temp, 
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    if USE_TPU:
        xm.mark_step()  # Sync TPU
    
    return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

print("Helper functions defined!")

## Step 8: Baseline Evaluation

In [None]:
print("=" * 60)
print("BASELINE EVALUATION")
print("=" * 60)

gen = TrappingRainWaterGenerator(seed=CONFIG["seed"])
baseline_passed = 0

for i in range(5):
    p = gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    response = generate_solution(model, tokenizer, p)
    code = extract_code(response)
    success, partial = verify(code, p)
    print(f"  [{i+1}] {'PASS' if success else 'FAIL'} ({partial*100:.0f}%)")
    baseline_passed += success
    clear_memory()

baseline_acc = baseline_passed / 5
print(f"\nBaseline: {baseline_passed}/5 ({baseline_acc*100:.0f}%)")

## Step 9: Create Reference Model

In [None]:
import copy

print("=" * 60)
print("CREATING REFERENCE MODEL")
print("=" * 60)

# Create reference model (frozen copy)
ref_model = copy.deepcopy(model)
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

print("Reference model created!")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
print(f"Optimizer: AdamW (lr={CONFIG['learning_rate']})")

## Step 10: GRPO Training

In [None]:
def grpo_step(model, ref_model, optimizer, problem):
    """Single GRPO training step."""
    messages = [
        {"role": "system", "content": "You are an expert Python programmer."},
        {"role": "user", "content": problem.to_prompt()},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(device)
    prompt_len = inputs.input_ids.shape[1]
    
    # Generate completions
    completions = []
    completion_ids = []
    
    model.eval()
    with torch.no_grad():
        for _ in range(CONFIG["num_generations"]):
            out = model.generate(
                **inputs,
                max_new_tokens=CONFIG["max_new_tokens"],
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
            )
            comp_text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
            completions.append(comp_text)
            completion_ids.append(out[0])
            if USE_TPU:
                xm.mark_step()
    
    # Compute rewards
    rewards = []
    for comp in completions:
        code = extract_code(comp)
        success, partial = verify(code, problem)
        rewards.append(1.0 if success else partial * 0.5)
    
    rewards = torch.tensor(rewards, device=device)
    advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    
    # Training
    model.train()
    optimizer.zero_grad()
    
    total_loss = 0.0
    for comp_ids, adv in zip(completion_ids, advantages):
        full_ids = comp_ids.unsqueeze(0)
        attn_mask = torch.ones_like(full_ids)
        
        # Policy forward
        outputs = model(input_ids=full_ids, attention_mask=attn_mask)
        logits = outputs.logits[:, prompt_len-1:-1, :]
        target = full_ids[:, prompt_len:]
        
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        token_log_probs = torch.gather(log_probs, -1, target.unsqueeze(-1)).squeeze(-1)
        seq_log_prob = token_log_probs.mean()
        
        # Reference forward
        with torch.no_grad():
            ref_out = ref_model(input_ids=full_ids, attention_mask=attn_mask)
            ref_logits = ref_out.logits[:, prompt_len-1:-1, :]
            ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1)
            ref_token_log_probs = torch.gather(ref_log_probs, -1, target.unsqueeze(-1)).squeeze(-1)
            ref_seq_log_prob = ref_token_log_probs.mean()
        
        # Loss
        kl = seq_log_prob - ref_seq_log_prob
        loss = -adv * seq_log_prob + CONFIG["beta"] * kl
        loss.backward()
        total_loss += loss.item()
        
        if USE_TPU:
            xm.mark_step()
    
    # Update
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    if USE_TPU:
        xm.optimizer_step(optimizer)
        xm.mark_step()
    else:
        optimizer.step()
    
    return {
        "loss": total_loss / CONFIG["num_generations"],
        "mean_reward": rewards.mean().item(),
        "max_reward": rewards.max().item(),
    }

print("GRPO function defined!")

In [None]:
# Training loop
print("=" * 60)
print("GRPO TRAINING")
print("=" * 60)
print(f"Steps: {CONFIG['num_steps']}, Generations: {CONFIG['num_generations']}")
print()

train_gen = TrappingRainWaterGenerator(seed=CONFIG["seed"] + 100)
metrics = []

for step in range(CONFIG["num_steps"]):
    problem = train_gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    
    try:
        m = grpo_step(model, ref_model, optimizer, problem)
        metrics.append(m)
        print(f"Step {step+1:2d}: Loss={m['loss']:.4f}, Avg={m['mean_reward']:.3f}, Max={m['max_reward']:.3f}")
    except Exception as e:
        print(f"Step {step+1}: Error - {e}")
    
    clear_memory()

print("\nTraining complete!")

## Step 11: Post-Training Evaluation

In [None]:
# Free reference model
del ref_model
clear_memory()

print("=" * 60)
print("POST-TRAINING EVALUATION")
print("=" * 60)

eval_gen = TrappingRainWaterGenerator(seed=CONFIG["seed"] + 999)
final_passed = 0
model.eval()

for i in range(5):
    p = eval_gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    response = generate_solution(model, tokenizer, p, temp=0.2)
    code = extract_code(response)
    success, partial = verify(code, p)
    print(f"  [{i+1}] {'PASS' if success else 'FAIL'} ({partial*100:.0f}%)")
    final_passed += success
    clear_memory()

final_acc = final_passed / 5
print(f"\nFinal: {final_passed}/5 ({final_acc*100:.0f}%)")

## Step 12: Results

In [None]:
print("=" * 60)
print("RESULTS")
print("=" * 60)

print(f"\nBaseline: {baseline_acc*100:.0f}%")
print(f"Final:    {final_acc*100:.0f}%")
print(f"Change:   {(final_acc - baseline_acc)*100:+.0f}%")

if metrics:
    print(f"\nTraining curve:")
    for i, m in enumerate(metrics):
        print(f"  Step {i+1}: reward={m['mean_reward']:.3f}, max={m['max_reward']:.3f}")

if final_acc > baseline_acc:
    print("\n" + "="*60)
    print("SUCCESS! Model improved!")
    print("="*60)

## Step 13: Save Model (Optional)

In [None]:
# Uncomment to save to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# 
# # Move model to CPU for saving
# model_cpu = model.to('cpu')
# save_path = "/content/drive/MyDrive/axiom-rl/qwen-1.5b-grpo-tpu"
# model_cpu.save_pretrained(save_path)
# tokenizer.save_pretrained(save_path)
# print(f"Saved to {save_path}")

---

## Notes

**TPU-specific considerations:**
- Use `xm.mark_step()` to sync TPU operations
- Use `xm.optimizer_step()` for optimizer updates
- TPU prefers float32 over float16
- First compilation is slow, subsequent runs are faster

**If training is slow:**
- This is normal for the first few steps (JIT compilation)
- TPU shines on larger batch sizes
- Consider using GPU (T4/A100) for this small-scale experiment