# Qwen 1.5B GRPO Training on Colab (Memory-Optimized)

This notebook runs GRPO on **Qwen2.5-Coder-1.5B-Instruct** to fix the Trapping Rain Water problem.

## Memory Optimization

GRPO requires both a **policy model** and a **reference model**, which is tight on T4 (16GB).

This notebook uses several optimizations:
1. **Gradient checkpointing** - trades compute for memory
2. **Sequential processing** - process one completion at a time
3. **Aggressive cache clearing** - free memory between steps
4. **Shorter sequences** - limit max tokens

## Background

From our experiments:
- Qwen 1.5B solves 9/10 hard problems natively (90%)
- Only failure: **Trapping Rain Water** (4/5 = 80%)
- Goal: Use GRPO to fix this single failure

---

## Step 1: Check Runtime

**IMPORTANT**: Use GPU runtime (T4 minimum, A100 preferred for comfort)

In [None]:
import torch
import gc

def clear_memory():
    """Aggressively clear GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

print("=" * 60)
print("GPU CHECK")
print("=" * 60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory < 14:
        print("\nERROR: Need at least T4 (16GB) for this notebook.")
        print("Go to Runtime -> Change runtime type -> T4 GPU")
    elif gpu_memory < 20:
        print("\nT4 detected - using memory-optimized settings.")
        USE_LOW_MEMORY = True
    else:
        print("\nA100/V100 detected - can use standard settings.")
        USE_LOW_MEMORY = False
else:
    raise RuntimeError("GPU required!")

print(f"\nLow memory mode: {USE_LOW_MEMORY}")

## Step 2: Install Dependencies

In [None]:
%%time
!pip install -q transformers accelerate
print("\nDependencies installed!")

## Step 3: Configuration

In [None]:
# Configuration based on GPU memory
if USE_LOW_MEMORY:
    CONFIG = {
        "model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
        "num_steps": 10,
        "num_generations": 2,      # Reduced from 4
        "learning_rate": 5e-5,
        "beta": 0.04,
        "max_new_tokens": 384,     # Reduced from 512
        "difficulty": 5,
        "num_test_cases": 5,
        "seed": 42,
        "gradient_checkpointing": True,
    }
else:
    CONFIG = {
        "model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
        "num_steps": 10,
        "num_generations": 4,
        "learning_rate": 5e-5,
        "beta": 0.04,
        "max_new_tokens": 512,
        "difficulty": 5,
        "num_test_cases": 5,
        "seed": 42,
        "gradient_checkpointing": False,
    }

print("=" * 60)
print("CONFIGURATION")
print("=" * 60)
for k, v in CONFIG.items():
    print(f"  {k:25} = {v}")

## Step 4: Problem Generator

In [None]:
import random
from dataclasses import dataclass
from typing import List, Any

@dataclass
class TestCase:
    input_args: List[Any]
    expected_output: Any

@dataclass 
class Problem:
    problem_id: str
    title: str
    description: str
    function_signature: str
    function_name: str
    test_cases: List[TestCase]
    
    def to_prompt(self) -> str:
        examples = "\n".join(
            f"  {self.function_name}({repr(tc.input_args[0])}) -> {repr(tc.expected_output)}"
            for tc in self.test_cases[:3]
        )
        return f"""## {self.title}

{self.description}

### Function Signature
```python
{self.function_signature}
```

### Examples
```python
{examples}
```

Write ONLY the function implementation."""


class TrappingRainWaterGenerator:
    def __init__(self, seed=None):
        self.rng = random.Random(seed)
        self._counter = 0
    
    def _solve(self, height):
        if not height: return 0
        left, right = 0, len(height) - 1
        left_max = right_max = water = 0
        while left < right:
            if height[left] < height[right]:
                if height[left] >= left_max: left_max = height[left]
                else: water += left_max - height[left]
                left += 1
            else:
                if height[right] >= right_max: right_max = height[right]
                else: water += right_max - height[right]
                right -= 1
        return water
    
    def generate(self, difficulty=5, num_test_cases=5):
        self._counter += 1
        test_cases = []
        for _ in range(num_test_cases):
            length = 5 + difficulty * 2
            heights = [self.rng.randint(0, 3 + difficulty) for _ in range(length)]
            test_cases.append(TestCase([heights], self._solve(heights)))
        
        return Problem(
            problem_id=f"trap_{self._counter}",
            title="Trapping Rain Water",
            description="Given n non-negative integers representing an elevation map, compute how much water it can trap after raining.",
            function_signature="def trap(height: list) -> int:",
            function_name="trap",
            test_cases=test_cases,
        )

# Test
gen = TrappingRainWaterGenerator(seed=42)
p = gen.generate()
print(f"Generated problem with {len(p.test_cases)} test cases")
print(f"Example: trap({p.test_cases[0].input_args[0][:6]}...) -> {p.test_cases[0].expected_output}")

## Step 5: Load Model with Memory Optimization

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

print("=" * 60)
print("LOADING MODEL")
print("=" * 60)

clear_memory()

tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float16,
    device_map="auto",
)

# Enable gradient checkpointing for memory efficiency
if CONFIG["gradient_checkpointing"]:
    model.gradient_checkpointing_enable()
    print("Gradient checkpointing: ENABLED")

print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params")
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

## Step 6: Baseline Evaluation

In [None]:
import re

def extract_code(response):
    for pattern in [r"```python\n(.*?)```", r"```\n(.*?)```"]:
        matches = re.findall(pattern, response, re.DOTALL)
        if matches: return matches[-1].strip()
    if "def " in response:
        lines, code, in_func = response.split("\n"), [], False
        for line in lines:
            if line.strip().startswith("def "): in_func = True
            if in_func: code.append(line)
        if code: return "\n".join(code).strip()
    return response

def verify(code, problem):
    namespace = {}
    try:
        exec(code, namespace)
    except: return False, 0.0
    
    func = namespace.get(problem.function_name)
    if not func:
        funcs = [v for k, v in namespace.items() if callable(v) and not k.startswith("_")]
        func = funcs[0] if funcs else None
    if not func: return False, 0.0
    
    passed = sum(1 for tc in problem.test_cases 
                 if _safe_call(func, tc.input_args) == tc.expected_output)
    return passed == len(problem.test_cases), passed / len(problem.test_cases)

def _safe_call(func, args):
    try: return func(*args)
    except: return None

def generate(model, tokenizer, problem, temp=0.2):
    messages = [
        {"role": "system", "content": "You are an expert Python programmer."},
        {"role": "user", "content": problem.to_prompt()},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=CONFIG["max_new_tokens"],
                            do_sample=True, temperature=temp, top_p=0.9,
                            pad_token_id=tokenizer.pad_token_id)
    return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

# Baseline
print("=" * 60)
print("BASELINE EVALUATION")
print("=" * 60)

gen = TrappingRainWaterGenerator(seed=CONFIG["seed"])
baseline_passed = 0
for i in range(5):
    p = gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    code = extract_code(generate(model, tokenizer, p))
    success, partial = verify(code, p)
    print(f"  [{i+1}] {'PASS' if success else 'FAIL'} ({partial*100:.0f}%)")
    baseline_passed += success
    clear_memory()

baseline_acc = baseline_passed / 5
print(f"\nBaseline: {baseline_passed}/5 ({baseline_acc*100:.0f}%)")

## Step 7: GRPO Training (Memory-Optimized)

Key optimizations:
1. Process ONE completion at a time (not batched)
2. Immediately delete tensors after use
3. Clear cache after each step

In [None]:
import copy

print("=" * 60)
print("CREATING REFERENCE MODEL")
print("=" * 60)

clear_memory()
print(f"Before ref model: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Create reference model
ref_model = copy.deepcopy(model)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

print(f"After ref model: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
print("Optimizer created")

In [None]:
def grpo_step_memory_efficient(model, ref_model, optimizer, problem):
    """
    Memory-efficient GRPO step:
    1. Generate all completions first (no grad)
    2. Compute rewards
    3. Process ONE completion at a time for gradient
    """
    messages = [
        {"role": "system", "content": "You are an expert Python programmer."},
        {"role": "user", "content": problem.to_prompt()},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    prompt_len = inputs.input_ids.shape[1]
    
    # Phase 1: Generate completions (no grad)
    completions = []
    completion_texts = []
    
    model.eval()
    with torch.no_grad():
        for _ in range(CONFIG["num_generations"]):
            out = model.generate(
                **inputs,
                max_new_tokens=CONFIG["max_new_tokens"],
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
            )
            # Store only the completion tokens (not full sequence)
            comp_ids = out[0][prompt_len:].cpu()  # Move to CPU!
            completions.append(comp_ids)
            completion_texts.append(tokenizer.decode(comp_ids, skip_special_tokens=True))
            del out
    
    clear_memory()
    
    # Phase 2: Compute rewards
    rewards = []
    for text in completion_texts:
        code = extract_code(text)
        success, partial = verify(code, problem)
        rewards.append(1.0 if success else partial * 0.5)
    
    rewards = torch.tensor(rewards)
    advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    
    # Phase 3: Gradient update (one at a time)
    model.train()
    optimizer.zero_grad()
    
    total_loss = 0.0
    for comp_ids, adv in zip(completions, advantages):
        # Reconstruct full sequence
        full_ids = torch.cat([inputs.input_ids[0].cpu(), comp_ids]).unsqueeze(0).to(model.device)
        attn_mask = torch.ones_like(full_ids)
        
        # Forward pass (policy)
        outputs = model(input_ids=full_ids, attention_mask=attn_mask)
        logits = outputs.logits[:, prompt_len-1:-1, :]  # Only completion tokens
        target = full_ids[:, prompt_len:]
        
        # Log probs
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        token_log_probs = torch.gather(log_probs, -1, target.unsqueeze(-1)).squeeze(-1)
        seq_log_prob = token_log_probs.mean()
        
        # Reference log prob
        with torch.no_grad():
            ref_out = ref_model(input_ids=full_ids, attention_mask=attn_mask)
            ref_logits = ref_out.logits[:, prompt_len-1:-1, :]
            ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1)
            ref_token_log_probs = torch.gather(ref_log_probs, -1, target.unsqueeze(-1)).squeeze(-1)
            ref_seq_log_prob = ref_token_log_probs.mean()
        
        # Loss
        kl = seq_log_prob - ref_seq_log_prob
        loss = -adv.to(model.device) * seq_log_prob + CONFIG["beta"] * kl
        loss.backward()
        total_loss += loss.item()
        
        # Immediately free memory
        del full_ids, attn_mask, outputs, logits, log_probs, token_log_probs
        del ref_out, ref_logits, ref_log_probs, ref_token_log_probs
        clear_memory()
    
    # Update
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    del inputs
    clear_memory()
    
    return {
        "loss": total_loss / CONFIG["num_generations"],
        "mean_reward": rewards.mean().item(),
        "max_reward": rewards.max().item(),
    }

print("GRPO function defined (memory-optimized)")

In [None]:
# Training loop
print("=" * 60)
print("GRPO TRAINING")
print("=" * 60)
print(f"Steps: {CONFIG['num_steps']}, Generations: {CONFIG['num_generations']}")
print()

gen = TrappingRainWaterGenerator(seed=CONFIG["seed"] + 100)
metrics = []

for step in range(CONFIG["num_steps"]):
    problem = gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    
    try:
        m = grpo_step_memory_efficient(model, ref_model, optimizer, problem)
        metrics.append(m)
        print(f"Step {step+1:2d}: Loss={m['loss']:.4f}, Avg={m['mean_reward']:.3f}, Max={m['max_reward']:.3f}")
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"Step {step+1}: OOM! Clearing memory and continuing...")
            clear_memory()
        else:
            raise

print("\nTraining complete!")

## Step 8: Post-Training Evaluation

In [None]:
# Free reference model
del ref_model
clear_memory()

print("=" * 60)
print("POST-TRAINING EVALUATION")
print("=" * 60)

eval_gen = TrappingRainWaterGenerator(seed=CONFIG["seed"] + 999)
final_passed = 0
model.eval()

for i in range(5):
    p = eval_gen.generate(CONFIG["difficulty"], CONFIG["num_test_cases"])
    code = extract_code(generate(model, tokenizer, p, temp=0.2))
    success, partial = verify(code, p)
    print(f"  [{i+1}] {'PASS' if success else 'FAIL'} ({partial*100:.0f}%)")
    final_passed += success
    clear_memory()

final_acc = final_passed / 5
print(f"\nFinal: {final_passed}/5 ({final_acc*100:.0f}%)")

## Step 9: Results

In [None]:
print("=" * 60)
print("RESULTS")
print("=" * 60)

print(f"\nBaseline: {baseline_acc*100:.0f}%")
print(f"Final:    {final_acc*100:.0f}%")
print(f"Change:   {(final_acc - baseline_acc)*100:+.0f}%")

if metrics:
    print(f"\nTraining curve:")
    for i, m in enumerate(metrics):
        print(f"  Step {i+1}: reward={m['mean_reward']:.3f}")

if final_acc > baseline_acc:
    print("\n" + "="*60)
    print("SUCCESS! Model improved!")
    print("="*60)

## Step 10: Save (Optional)

In [None]:
# Uncomment to save to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# 
# save_path = "/content/drive/MyDrive/axiom-rl/qwen-1.5b-grpo"
# model.save_pretrained(save_path)
# tokenizer.save_pretrained(save_path)
# print(f"Saved to {save_path}")

---

## Troubleshooting

**Still getting OOM?**
1. Reduce `num_generations` to 1
2. Reduce `max_new_tokens` to 256
3. Use A100 runtime (Runtime -> Change runtime type)

**Training not improving?**
1. Increase `num_steps` to 20-30
2. Try lower learning rate (1e-5)
3. Consider teacher distillation instead