# Experiment 15: M-GRPO with Entropy Control (Standalone)

**Status:** Ready to Run  
**Date:** 2024-12-21  
**Runtime:** Google Colab (T4/A100 GPU)  

This is a **fully self-contained notebook** - no external dependencies on the axiom-rl package.
All code needed for M-GRPO training is embedded directly in this notebook.

---

## What is M-GRPO?

**M-GRPO (Momentum-Anchored GRPO)** is a stabilized reinforcement learning technique for training language models.

### The Problem with Standard GRPO
Standard GRPO (Group Relative Policy Optimization) often fails due to **policy collapse**:
- Model becomes overconfident in one solution pattern
- Entropy drops → diversity collapses
- Performance crashes

### M-GRPO Solution: Two Models
1. **Policy Model** - Trainable, learns and improves
2. **Momentum Model** - Slow EMA copy, provides stable reference

**EMA Update:** `θ_momentum = 0.99 * θ_momentum + 0.01 * θ_policy`

Combined sampling from BOTH models prevents collapse.

---

## What We're Training

Teaching a 0.5B model to write Python functions for:
- **RPN Evaluator** - Evaluate reverse polish notation expressions
- **Parentheses Validator** - Check if brackets are balanced
- **Fibonacci** - Compute nth Fibonacci number
- **Binary Search** - Find element in sorted array
- **Edit Distance** - Levenshtein distance between strings
- **Coin Change** - Minimum coins for amount (DP)

Each problem has 5 test cases. Model gets **partial reward** for passing some tests.

---
## PART 1: SETUP AND INSTALLATION
---

In [None]:
# Install dependencies
!pip install -q torch transformers accelerate peft bitsandbytes matplotlib

In [None]:
# GPU Check and Configuration
import torch
import gc
import warnings
warnings.filterwarnings('ignore')

print("=" * 60)
print("GPU DETECTION")
print("=" * 60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Detected: {gpu_name}")
    print(f"GPU Memory: {gpu_memory_gb:.1f} GB")
    DEVICE = "cuda:0"
    DTYPE = torch.float16
    
    # Select model based on GPU memory
    if gpu_memory_gb >= 15:
        MODEL_NAME = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
        print(f"Using 1.5B model (enough VRAM)")
    else:
        MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
        print(f"Using 0.5B model (limited VRAM)")
else:
    print("WARNING: No GPU detected! Training will be very slow.")
    DEVICE = "cpu"
    DTYPE = torch.float32
    MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"

print(f"\nDevice: {DEVICE}")
print(f"Model: {MODEL_NAME}")

In [None]:
# Configuration
CONFIG = {
    # Model
    "model_name": MODEL_NAME,
    "torch_dtype": "float16",
    
    # M-GRPO Core
    "num_policy_samples": 4,      # Samples from policy model
    "num_momentum_samples": 4,    # Samples from momentum model  
    "momentum": 0.99,             # EMA coefficient
    "beta": 0.04,                 # KL penalty (if used)
    
    # Entropy Control
    "use_iqr_filter": True,       # IQR-based low entropy filtering
    "iqr_k": 0.75,                # IQR multiplier
    "min_entropy_threshold": 0.1, # Absolute minimum entropy
    
    # Training
    "num_steps": 20,              # Training steps
    "batch_size": 4,              # Problems per step
    "learning_rate": 1e-5,
    "max_new_tokens": 512,
    "temperature": 0.7,
    
    # LoRA
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
    
    # Problems
    "problem_types": ["rpn", "parentheses", "fibonacci", "binary_search", "edit_distance", "coin_change"],
    "train_per_type": 10,
    "val_per_type": 5,
    "test_cases_per_problem": 5,
    "difficulty_range": [4, 7],
    "seed": 42,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

---
## PART 2: PROBLEM GENERATORS (Self-Contained)

These generators create coding problems with test cases.
Each problem requires writing a Python function.

---

In [None]:
import random
from dataclasses import dataclass, field
from typing import List, Any, Optional
from abc import ABC, abstractmethod

@dataclass
class TestCase:
    """A single test case for a problem."""
    input_args: Any
    expected_output: Any

@dataclass 
class AlgorithmicProblem:
    """A coding problem with test cases."""
    problem_id: str
    title: str
    description: str
    function_name: str
    function_signature: str
    test_cases: List[TestCase]
    difficulty: int = 5
    
    def to_prompt(self) -> str:
        """Convert to LLM prompt."""
        return f"""Write a Python function to solve the following problem.

## Problem: {self.title}

{self.description}

## Function Signature
```python
{self.function_signature}
```

## Requirements
- Implement the function exactly as specified
- Handle edge cases appropriately
- Return the correct type

## Your Solution
"""

class AlgorithmicGenerator(ABC):
    """Base class for problem generators."""
    
    def __init__(self, seed: int = None):
        self.rng = random.Random(seed)
    
    @abstractmethod
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        pass

print("Base classes defined")

In [None]:
# RPN Evaluator Generator
class RPNEvaluatorGenerator(AlgorithmicGenerator):
    """Generate RPN (Reverse Polish Notation) evaluation problems."""
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        for _ in range(num_test_cases):
            # Generate RPN expression
            num_ops = min(difficulty, 5)
            tokens = []
            stack_size = 0
            
            for i in range(num_ops * 2 + 1):
                if stack_size < 2 or (self.rng.random() < 0.6 and i < num_ops * 2 - 1):
                    tokens.append(str(self.rng.randint(1, 20)))
                    stack_size += 1
                else:
                    op = self.rng.choice(['+', '-', '*'])
                    tokens.append(op)
                    stack_size -= 1
            
            # Add remaining operators
            while stack_size > 1:
                tokens.append(self.rng.choice(['+', '-', '*']))
                stack_size -= 1
            
            # Evaluate
            stack = []
            for t in tokens:
                if t in ['+', '-', '*']:
                    b, a = stack.pop(), stack.pop()
                    if t == '+': stack.append(a + b)
                    elif t == '-': stack.append(a - b)
                    else: stack.append(a * b)
                else:
                    stack.append(int(t))
            
            test_cases.append(TestCase(input_args=[tokens], expected_output=stack[0]))
        
        return AlgorithmicProblem(
            problem_id=f"rpn_{self.rng.randint(1000, 9999)}",
            title="RPN Expression Evaluator",
            description="""Evaluate a Reverse Polish Notation (RPN) expression.

RPN is a mathematical notation where operators follow their operands.
For example: ["2", "3", "+"] = 5, ["4", "2", "*", "3", "+"] = 11

Supported operators: +, -, *
All operands are integers.""",
            function_name="evaluate_rpn",
            function_signature="def evaluate_rpn(tokens: List[str]) -> int:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("RPNEvaluatorGenerator defined")

In [None]:
# Parentheses Validator Generator
class ParenthesesValidatorGenerator(AlgorithmicGenerator):
    """Generate parentheses validation problems."""
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        brackets = {'(': ')', '[': ']', '{': '}'}
        
        for i in range(num_test_cases):
            length = difficulty * 2
            
            if i < num_test_cases // 2:  # Valid cases
                s = ""
                stack = []
                for _ in range(length // 2):
                    if not stack or self.rng.random() < 0.6:
                        open_b = self.rng.choice(list(brackets.keys()))
                        s += open_b
                        stack.append(open_b)
                    else:
                        s += brackets[stack.pop()]
                while stack:
                    s += brackets[stack.pop()]
                expected = True
            else:  # Invalid cases
                all_brackets = list(brackets.keys()) + list(brackets.values())
                s = ''.join(self.rng.choices(all_brackets, k=length))
                # Verify it's actually invalid
                stack = []
                valid = True
                for c in s:
                    if c in brackets:
                        stack.append(c)
                    elif c in brackets.values():
                        if not stack:
                            valid = False
                            break
                        if brackets.get(stack.pop()) != c:
                            valid = False
                            break
                if stack:
                    valid = False
                expected = valid
            
            test_cases.append(TestCase(input_args=[s], expected_output=expected))
        
        return AlgorithmicProblem(
            problem_id=f"paren_{self.rng.randint(1000, 9999)}",
            title="Valid Parentheses",
            description="""Check if a string of brackets is valid.

A string is valid if:
- Open brackets are closed by the same type of brackets
- Open brackets are closed in the correct order
- Every close bracket has a corresponding open bracket

Bracket types: (), [], {}""",
            function_name="is_valid",
            function_signature="def is_valid(s: str) -> bool:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("ParenthesesValidatorGenerator defined")

In [None]:
# Fibonacci Generator
class FibonacciGenerator(AlgorithmicGenerator):
    """Generate Fibonacci sequence problems."""
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        max_n = difficulty * 5
        
        # Precompute fibonacci
        fib = [0, 1]
        for i in range(2, max_n + 1):
            fib.append(fib[-1] + fib[-2])
        
        # Generate test cases
        ns = [0, 1] + [self.rng.randint(2, max_n) for _ in range(num_test_cases - 2)]
        self.rng.shuffle(ns)
        
        for n in ns[:num_test_cases]:
            test_cases.append(TestCase(input_args=[n], expected_output=fib[n]))
        
        return AlgorithmicProblem(
            problem_id=f"fib_{self.rng.randint(1000, 9999)}",
            title="Fibonacci Number",
            description="""Return the nth Fibonacci number.

The Fibonacci sequence is defined as:
- F(0) = 0
- F(1) = 1  
- F(n) = F(n-1) + F(n-2) for n > 1

Examples: F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5, F(10)=55""",
            function_name="fibonacci",
            function_signature="def fibonacci(n: int) -> int:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("FibonacciGenerator defined")

In [None]:
# Binary Search Generator
class BinarySearchGenerator(AlgorithmicGenerator):
    """Generate binary search problems."""
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        
        for i in range(num_test_cases):
            size = difficulty * 3
            arr = sorted(self.rng.sample(range(1, size * 3), size))
            
            if i < num_test_cases // 2:  # Target exists
                idx = self.rng.randint(0, len(arr) - 1)
                target = arr[idx]
                expected = idx
            else:  # Target doesn't exist
                target = self.rng.choice([0, arr[-1] + 1, arr[0] - 1])
                if target in arr:
                    target = arr[-1] + self.rng.randint(1, 10)
                expected = -1
            
            test_cases.append(TestCase(input_args=[arr, target], expected_output=expected))
        
        return AlgorithmicProblem(
            problem_id=f"bsearch_{self.rng.randint(1000, 9999)}",
            title="Binary Search",
            description="""Find the index of target in a sorted array.

Given a sorted array of integers and a target value, return the index
of the target if found, otherwise return -1.

You must implement binary search with O(log n) time complexity.""",
            function_name="binary_search",
            function_signature="def binary_search(arr: List[int], target: int) -> int:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("BinarySearchGenerator defined")

In [None]:
# Edit Distance Generator
class EditDistanceGenerator(AlgorithmicGenerator):
    """Generate edit distance (Levenshtein) problems."""
    
    def _edit_distance(self, s1: str, s2: str) -> int:
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        for i in range(m + 1):
            dp[i][0] = i
        for j in range(n + 1):
            dp[0][j] = j
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i-1] == s2[j-1]:
                    dp[i][j] = dp[i-1][j-1]
                else:
                    dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
        
        return dp[m][n]
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        chars = 'abcdefghij'
        max_len = difficulty * 2
        
        for _ in range(num_test_cases):
            len1 = self.rng.randint(1, max_len)
            len2 = self.rng.randint(1, max_len)
            s1 = ''.join(self.rng.choices(chars, k=len1))
            s2 = ''.join(self.rng.choices(chars, k=len2))
            
            expected = self._edit_distance(s1, s2)
            test_cases.append(TestCase(input_args=[s1, s2], expected_output=expected))
        
        return AlgorithmicProblem(
            problem_id=f"edit_{self.rng.randint(1000, 9999)}",
            title="Edit Distance",
            description="""Calculate the minimum edit distance between two strings.

The edit distance (Levenshtein distance) is the minimum number of
single-character operations needed to transform one string into another.

Allowed operations:
- Insert a character
- Delete a character
- Replace a character

Example: edit_distance("kitten", "sitting") = 3""",
            function_name="edit_distance",
            function_signature="def edit_distance(s1: str, s2: str) -> int:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("EditDistanceGenerator defined")

In [None]:
# Coin Change Generator
class CoinChangeGenerator(AlgorithmicGenerator):
    """Generate coin change (minimum coins) problems."""
    
    def _coin_change(self, coins: List[int], amount: int) -> int:
        if amount == 0:
            return 0
        dp = [float('inf')] * (amount + 1)
        dp[0] = 0
        
        for i in range(1, amount + 1):
            for coin in coins:
                if coin <= i and dp[i - coin] != float('inf'):
                    dp[i] = min(dp[i], dp[i - coin] + 1)
        
        return dp[amount] if dp[amount] != float('inf') else -1
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        test_cases = []
        
        for _ in range(num_test_cases):
            # Generate coin denominations
            num_coins = self.rng.randint(2, min(difficulty, 5))
            coins = sorted(list(set([1] + [self.rng.randint(2, difficulty * 3) for _ in range(num_coins - 1)])))
            
            # Generate amount
            amount = self.rng.randint(1, difficulty * 10)
            expected = self._coin_change(coins, amount)
            
            test_cases.append(TestCase(input_args=[coins, amount], expected_output=expected))
        
        return AlgorithmicProblem(
            problem_id=f"coins_{self.rng.randint(1000, 9999)}",
            title="Coin Change",
            description="""Find the minimum number of coins to make up an amount.

Given an array of coin denominations and a target amount, return the
fewest number of coins needed to make up that amount.

If the amount cannot be made up, return -1.

You have an infinite supply of each coin denomination.

Example: coins=[1,2,5], amount=11 -> 3 (5+5+1)""",
            function_name="coin_change",
            function_signature="def coin_change(coins: List[int], amount: int) -> int:",
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("CoinChangeGenerator defined")

In [None]:
# Generator Registry
GENERATORS = {
    "rpn": RPNEvaluatorGenerator,
    "parentheses": ParenthesesValidatorGenerator,
    "fibonacci": FibonacciGenerator,
    "binary_search": BinarySearchGenerator,
    "edit_distance": EditDistanceGenerator,
    "coin_change": CoinChangeGenerator,
}

def generate_problems(config, seed=42):
    """Generate train and validation problem sets."""
    rng = random.Random(seed)
    train_problems, val_problems = [], []
    diff_min, diff_max = config["difficulty_range"]
    
    for prob_type in config["problem_types"]:
        if prob_type not in GENERATORS:
            print(f"Warning: Unknown problem type '{prob_type}'")
            continue
        
        gen = GENERATORS[prob_type](seed=rng.randint(0, 1000000))
        
        # Generate training problems
        for _ in range(config["train_per_type"]):
            diff = rng.randint(diff_min, diff_max)
            train_problems.append(gen.generate(
                difficulty=diff,
                num_test_cases=config["test_cases_per_problem"]
            ))
        
        # Generate validation problems
        for _ in range(config["val_per_type"]):
            diff = rng.randint(diff_min, diff_max)
            val_problems.append(gen.generate(
                difficulty=diff,
                num_test_cases=config["test_cases_per_problem"]
            ))
    
    rng.shuffle(train_problems)
    rng.shuffle(val_problems)
    return train_problems, val_problems

# Test problem generation
print("\nTesting problem generation...")
train_problems, val_problems = generate_problems(CONFIG)
print(f"Generated {len(train_problems)} training problems")
print(f"Generated {len(val_problems)} validation problems")
print(f"\nSample problem:")
print(train_problems[0].to_prompt()[:500] + "...")

---
## PART 3: CODE VERIFICATION (Self-Contained)

The verifier executes generated code against test cases.

---

In [None]:
import subprocess
import tempfile
import json
import os
import re

def extract_code(completion: str) -> str:
    """
    Extract Python code from model completion.
    
    Handles:
    1. ```python ... ``` blocks
    2. ``` ... ``` blocks  
    3. Raw code with 'def' statements
    """
    # Try to find ```python ... ``` blocks
    python_blocks = re.findall(r'```python\s*(.*?)```', completion, re.DOTALL)
    if python_blocks:
        for block in sorted(python_blocks, key=len, reverse=True):
            if 'def ' in block:
                return block.strip()
        return python_blocks[0].strip()
    
    # Try to find ``` ... ``` blocks
    code_blocks = re.findall(r'```\s*(.*?)```', completion, re.DOTALL)
    if code_blocks:
        for block in sorted(code_blocks, key=len, reverse=True):
            if 'def ' in block:
                return block.strip()
        return code_blocks[0].strip()
    
    # Extract function definition directly
    if 'def ' in completion:
        lines = completion.split('\n')
        code_lines = []
        in_function = False
        indent_level = None
        
        for line in lines:
            stripped = line.lstrip()
            if stripped.startswith('def '):
                in_function = True
                indent_level = len(line) - len(stripped)
                code_lines = [line]
            elif in_function:
                if stripped and not stripped.startswith('#'):
                    current_indent = len(line) - len(stripped)
                    if current_indent <= indent_level and stripped:
                        break
                code_lines.append(line)
        
        if code_lines:
            return '\n'.join(code_lines).strip()
    
    return completion.strip()


def verify_solution(code: str, problem: AlgorithmicProblem, timeout: float = 5.0) -> dict:
    """
    Verify a solution against test cases.
    
    Returns:
        dict with 'passed', 'passed_count', 'total_count', 'error'
    """
    # Build test script
    test_cases_json = json.dumps([
        {"input": tc.input_args, "expected": tc.expected_output}
        for tc in problem.test_cases
    ])
    
    test_script = f'''# -*- coding: utf-8 -*-
import json
from typing import List, Optional, Tuple, Dict, Any

# === SOLUTION CODE ===
{code}
# === END SOLUTION ===

def run_tests():
    test_cases = json.loads(\'{test_cases_json}\')
    results = []
    
    for i, tc in enumerate(test_cases):
        inp = tc["input"]
        expected = tc["expected"]
        
        try:
            if isinstance(inp, list):
                actual = {problem.function_name}(*inp)
            else:
                actual = {problem.function_name}(inp)
            passed = actual == expected
            results.append({{"passed": passed, "error": None}})
        except Exception as e:
            results.append({{"passed": False, "error": str(e)}})
    
    print(json.dumps({{"results": results, "success": True}}))

if __name__ == "__main__":
    try:
        run_tests()
    except Exception as e:
        print(json.dumps({{"results": [], "success": False, "error": str(e)}}))
'''
    
    # Execute in subprocess
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(test_script)
            script_path = f.name
        
        result = subprocess.run(
            ['python', script_path],
            capture_output=True,
            text=True,
            timeout=timeout,
        )
        
        os.unlink(script_path)
        
        if result.stdout.strip():
            data = json.loads(result.stdout.strip())
            if data.get("success"):
                passed_count = sum(1 for r in data["results"] if r["passed"])
                total_count = len(data["results"])
                return {
                    "passed": passed_count == total_count,
                    "passed_count": passed_count,
                    "total_count": total_count,
                    "error": None,
                }
            else:
                return {
                    "passed": False,
                    "passed_count": 0,
                    "total_count": len(problem.test_cases),
                    "error": data.get("error", "Unknown error"),
                }
        else:
            return {
                "passed": False,
                "passed_count": 0,
                "total_count": len(problem.test_cases),
                "error": result.stderr or "No output",
            }
    
    except subprocess.TimeoutExpired:
        return {
            "passed": False,
            "passed_count": 0,
            "total_count": len(problem.test_cases),
            "error": "Timeout",
        }
    except Exception as e:
        return {
            "passed": False,
            "passed_count": 0,
            "total_count": len(problem.test_cases),
            "error": str(e),
        }


# Test the verifier
print("Testing verifier...")
test_problem = train_problems[0]
print(f"Problem: {test_problem.title}")

# Test with a correct solution
if test_problem.function_name == "fibonacci":
    test_code = """def fibonacci(n):
    if n <= 1:
        return n
    a, b = 0, 1
    for _ in range(2, n + 1):
        a, b = b, a + b
    return b"""
    result = verify_solution(test_code, test_problem)
    print(f"Fibonacci test: {result}")
else:
    print(f"First problem is {test_problem.function_name}, skipping built-in test")

---
## PART 4: LOAD MODEL AND CREATE TRAINER
---

In [None]:
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

print("=" * 60)
print("LOADING MODEL")
print("=" * 60)

# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    CONFIG["model_name"],
    trust_remote_code=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer loaded: vocab_size={tokenizer.vocab_size}")

# Load policy model with LoRA
print("\nLoading policy model...")
policy_model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=DTYPE,
    device_map="auto",
    trust_remote_code=True,
)

# Apply LoRA
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    target_modules=CONFIG["target_modules"],
    lora_dropout=CONFIG["lora_dropout"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
policy_model = get_peft_model(policy_model, lora_config)
policy_model.print_trainable_parameters()

# Create momentum model (EMA copy)
print("\nCreating momentum model (EMA copy)...")
momentum_model = copy.deepcopy(policy_model)
momentum_model.eval()
for p in momentum_model.parameters():
    p.requires_grad = False
print("Momentum model created (frozen)")

# Optimizer
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=CONFIG["learning_rate"],
)

print("\nSetup complete!")
print(f"  Policy model: trainable")
print(f"  Momentum model: frozen (EMA)")
print(f"  Optimizer: AdamW, lr={CONFIG['learning_rate']}")

---
## PART 5: M-GRPO TRAINING UTILITIES
---

In [None]:
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple

def update_momentum(policy_model, momentum_model, m: float = 0.99):
    """
    EMA update for momentum model.
    θ_momentum = m * θ_momentum + (1 - m) * θ_policy
    """
    with torch.no_grad():
        for p_m, p_p in zip(momentum_model.parameters(), policy_model.parameters()):
            p_m.data.mul_(m).add_(p_p.data, alpha=1 - m)


def generate_samples(
    model,
    tokenizer,
    prompt: str,
    num_samples: int = 4,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
) -> List[str]:
    """Generate multiple samples from a model."""
    model.eval()
    
    inputs = tokenizer(
        [prompt] * num_samples,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024,
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    input_len = inputs.input_ids.shape[1]
    completions = tokenizer.batch_decode(
        outputs[:, input_len:],
        skip_special_tokens=True,
    )
    
    return completions


def compute_entropy(logits: torch.Tensor) -> torch.Tensor:
    """Compute per-token entropy from logits."""
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -(probs * log_probs).sum(dim=-1)
    return entropy


def iqr_filter(entropies: List[float], k: float = 0.75) -> List[bool]:
    """
    IQR-based filtering for low-entropy samples.
    Returns mask where True = keep, False = filter out.
    """
    arr = np.array(entropies)
    Q1 = np.percentile(arr, 25)
    Q3 = np.percentile(arr, 75)
    threshold = Q1 - k * (Q3 - Q1)
    threshold = max(threshold, 0.1)  # Minimum threshold
    return (arr >= threshold).tolist()


def compute_advantages(rewards: torch.Tensor) -> torch.Tensor:
    """Compute standardized advantages."""
    mean = rewards.mean()
    std = rewards.std() + 1e-8
    return (rewards - mean) / std


print("Training utilities defined:")
print("  - update_momentum(): EMA update for momentum model")
print("  - generate_samples(): Generate completions from model")
print("  - compute_entropy(): Per-token entropy calculation")
print("  - iqr_filter(): IQR-based low-entropy filtering")
print("  - compute_advantages(): Advantage normalization")

In [None]:
# Test generation
print("Testing generation...")
test_prompt = train_problems[0].to_prompt()
test_completions = generate_samples(
    policy_model,
    tokenizer,
    test_prompt,
    num_samples=2,
    max_new_tokens=256,
)
print(f"Generated {len(test_completions)} completions")
print(f"\nFirst completion (first 300 chars):")
print(test_completions[0][:300])

---
## PART 6: M-GRPO TRAINING LOOP

The main training loop implementing:
1. Combined rollout (policy + momentum samples)
2. Reward computation with partial credit
3. Policy gradient updates
4. Momentum model EMA update

---

In [None]:
from datetime import datetime
import time

# Metrics tracking
metrics_history = {
    "step": [],
    "loss": [],
    "mean_reward": [],
    "mean_entropy": [],
    "success_rate": [],
    "val_accuracy": [],
    "filtered_count": [],
}

# Create prompt -> problem mapping for quick lookup
prompt_to_problem = {p.to_prompt(): p for p in train_problems}
train_prompts = list(prompt_to_problem.keys())

print("=" * 60)
print("M-GRPO TRAINING")
print("=" * 60)
print(f"Steps: {CONFIG['num_steps']}")
print(f"Batch size: {CONFIG['batch_size']} problems")
print(f"Samples per problem: {CONFIG['num_policy_samples']} policy + {CONFIG['num_momentum_samples']} momentum")
print(f"Momentum coefficient: {CONFIG['momentum']}")
print(f"IQR filter: {CONFIG['use_iqr_filter']}")
print(f"\nTraining problems: {len(train_problems)}")
print(f"Validation problems: {len(val_problems)}")

In [None]:
# Main training loop
train_start_time = time.time()

for step in range(CONFIG["num_steps"]):
    step_start = time.time()
    print(f"\n{'='*60}")
    print(f"Step {step}/{CONFIG['num_steps']}")
    print(f"{'='*60}")
    
    # Sample batch of prompts
    batch_prompts = random.sample(train_prompts, min(CONFIG["batch_size"], len(train_prompts)))
    
    step_loss = 0.0
    step_reward = 0.0
    step_entropy = 0.0
    step_successes = 0
    step_updates = 0
    step_filtered = 0
    
    for prompt_idx, prompt in enumerate(batch_prompts):
        problem = prompt_to_problem[prompt]
        print(f"  [{prompt_idx+1}/{len(batch_prompts)}] {problem.title}...", end=" ")
        
        # 1. Combined rollout: generate from both models
        policy_gens = generate_samples(
            policy_model, tokenizer, prompt,
            num_samples=CONFIG["num_policy_samples"],
            max_new_tokens=CONFIG["max_new_tokens"],
            temperature=CONFIG["temperature"],
        )
        
        momentum_gens = generate_samples(
            momentum_model, tokenizer, prompt,
            num_samples=CONFIG["num_momentum_samples"],
            max_new_tokens=CONFIG["max_new_tokens"],
            temperature=CONFIG["temperature"],
        )
        
        # 2. Compute rewards for all generations (partial credit)
        all_gens = policy_gens + momentum_gens
        all_rewards = []
        
        for gen in all_gens:
            code = extract_code(gen)
            result = verify_solution(code, problem)
            # Partial reward: proportion of tests passed
            reward = result["passed_count"] / max(result["total_count"], 1)
            all_rewards.append(reward)
        
        policy_rewards = torch.tensor(all_rewards[:len(policy_gens)])
        max_reward = max(all_rewards)
        
        if max_reward > 0:
            step_successes += 1
        
        # 3. Train on policy samples with positive advantage
        advantages = compute_advantages(policy_rewards)
        
        for gen_idx, (gen, adv, rew) in enumerate(zip(policy_gens, advantages, policy_rewards)):
            # Only update on positive advantage
            if adv <= 0:
                continue
            
            # Tokenize
            full_text = prompt + gen
            inputs = tokenizer(
                full_text,
                return_tensors="pt",
                truncation=True,
                max_length=2048,
            ).to(policy_model.device)
            
            prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
            prompt_len = len(prompt_tokens)
            
            # Forward pass
            policy_model.train()
            outputs = policy_model(**inputs)
            
            # Get logits for completion tokens only
            logits = outputs.logits[:, prompt_len-1:-1, :]
            labels = inputs.input_ids[:, prompt_len:]
            
            if labels.shape[1] == 0:
                continue
            
            # Compute entropy for this generation
            entropy = compute_entropy(logits).mean().item()
            step_entropy += entropy
            
            # IQR filter check (optional)
            if CONFIG["use_iqr_filter"] and entropy < CONFIG["min_entropy_threshold"]:
                step_filtered += 1
                continue
            
            # Policy gradient loss
            log_probs = F.log_softmax(logits, dim=-1)
            token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
            pg_loss = -(token_log_probs.mean() * adv)
            
            # Backward
            optimizer.zero_grad()
            pg_loss.backward()
            optimizer.step()
            
            step_loss += pg_loss.item()
            step_reward += rew.item()
            step_updates += 1
        
        print(f"reward={max_reward:.2f}")
    
    # 4. Update momentum model via EMA
    update_momentum(policy_model, momentum_model, CONFIG["momentum"])
    
    # 5. Compute step metrics
    num_samples = len(batch_prompts) * CONFIG["num_policy_samples"]
    avg_loss = step_loss / max(step_updates, 1)
    avg_reward = step_reward / max(step_updates, 1)
    avg_entropy = step_entropy / max(step_updates, 1)
    success_rate = step_successes / len(batch_prompts)
    
    # 6. Quick validation
    val_correct = 0
    val_subset = random.sample(val_problems, min(5, len(val_problems)))
    for vp in val_subset:
        gens = generate_samples(policy_model, tokenizer, vp.to_prompt(), num_samples=1, max_new_tokens=512)
        code = extract_code(gens[0])
        result = verify_solution(code, vp)
        if result["passed"]:
            val_correct += 1
    val_accuracy = val_correct / len(val_subset)
    
    # Log metrics
    metrics_history["step"].append(step)
    metrics_history["loss"].append(avg_loss)
    metrics_history["mean_reward"].append(avg_reward)
    metrics_history["mean_entropy"].append(avg_entropy)
    metrics_history["success_rate"].append(success_rate)
    metrics_history["val_accuracy"].append(val_accuracy)
    metrics_history["filtered_count"].append(step_filtered)
    
    # Print summary
    step_time = time.time() - step_start
    elapsed = time.time() - train_start_time
    eta = (elapsed / (step + 1)) * (CONFIG["num_steps"] - step - 1)
    
    print(f"\nStep {step} Summary:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Reward: {avg_reward:.3f}")
    print(f"  Entropy: {avg_entropy:.3f}")
    print(f"  Success Rate: {success_rate:.1%}")
    print(f"  Val Accuracy: {val_accuracy:.1%}")
    print(f"  Updates: {step_updates}, Filtered: {step_filtered}")
    print(f"  Time: {step_time:.1f}s, ETA: {eta/60:.1f}m")
    
    # Check for entropy collapse
    if avg_entropy < 0.1:
        print(f"  ⚠️ WARNING: Low entropy ({avg_entropy:.3f}) - possible collapse!")
    
    # Clear memory
    gc.collect()
    torch.cuda.empty_cache()

total_time = time.time() - train_start_time
print(f"\n{'='*60}")
print(f"TRAINING COMPLETE")
print(f"{'='*60}")
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Final metrics:")
print(f"  Loss: {metrics_history['loss'][-1]:.4f}")
print(f"  Reward: {metrics_history['mean_reward'][-1]:.3f}")
print(f"  Entropy: {metrics_history['mean_entropy'][-1]:.3f}")
print(f"  Val Accuracy: {metrics_history['val_accuracy'][-1]:.1%}")

---
## PART 7: RESULTS VISUALIZATION
---

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Loss
axes[0, 0].plot(metrics_history["step"], metrics_history["loss"], 'b-o', markersize=4)
axes[0, 0].set_xlabel("Step")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training Loss")
axes[0, 0].grid(True, alpha=0.3)

# Reward
axes[0, 1].plot(metrics_history["step"], metrics_history["mean_reward"], 'g-o', markersize=4)
axes[0, 1].set_xlabel("Step")
axes[0, 1].set_ylabel("Mean Reward")
axes[0, 1].set_title("Average Reward")
axes[0, 1].grid(True, alpha=0.3)

# Entropy (critical for M-GRPO)
axes[0, 2].plot(metrics_history["step"], metrics_history["mean_entropy"], 'r-o', markersize=4)
axes[0, 2].axhline(y=0.1, color='orange', linestyle='--', label='Collapse threshold')
axes[0, 2].set_xlabel("Step")
axes[0, 2].set_ylabel("Mean Entropy")
axes[0, 2].set_title("Policy Entropy (should stay above 0.1)")
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Success Rate
axes[1, 0].plot(metrics_history["step"], [s*100 for s in metrics_history["success_rate"]], 'm-o', markersize=4)
axes[1, 0].set_xlabel("Step")
axes[1, 0].set_ylabel("Success Rate (%)")
axes[1, 0].set_title("Training Success Rate")
axes[1, 0].grid(True, alpha=0.3)

# Validation Accuracy
axes[1, 1].plot(metrics_history["step"], [v*100 for v in metrics_history["val_accuracy"]], 'c-o', markersize=4)
axes[1, 1].set_xlabel("Step")
axes[1, 1].set_ylabel("Accuracy (%)")
axes[1, 1].set_title("Validation Accuracy")
axes[1, 1].grid(True, alpha=0.3)

# Filtered samples
axes[1, 2].bar(metrics_history["step"], metrics_history["filtered_count"], color='orange', alpha=0.7)
axes[1, 2].set_xlabel("Step")
axes[1, 2].set_ylabel("Count")
axes[1, 2].set_title("IQR Filtered Samples (low entropy)")
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("mgrpo_training_results.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nResults saved to mgrpo_training_results.png")

---
## PART 8: SAVE MODEL AND RESULTS
---

In [None]:
import json
from pathlib import Path

# Create output directory
output_dir = Path("mgrpo_output")
output_dir.mkdir(exist_ok=True)

# Save policy model
print("Saving policy model...")
policy_model.save_pretrained(output_dir / "policy")
tokenizer.save_pretrained(output_dir / "policy")

# Save momentum model
print("Saving momentum model...")
momentum_model.save_pretrained(output_dir / "momentum")

# Save metrics
print("Saving metrics...")
with open(output_dir / "metrics.json", "w") as f:
    json.dump(metrics_history, f, indent=2)

# Save config
print("Saving config...")
with open(output_dir / "config.json", "w") as f:
    # Convert non-serializable types
    config_save = {k: str(v) if not isinstance(v, (int, float, str, list, dict, bool)) else v 
                   for k, v in CONFIG.items()}
    json.dump(config_save, f, indent=2)

# Save summary
summary = {
    "model": CONFIG["model_name"],
    "num_steps": CONFIG["num_steps"],
    "final_loss": metrics_history["loss"][-1],
    "final_reward": metrics_history["mean_reward"][-1],
    "final_entropy": metrics_history["mean_entropy"][-1],
    "final_val_accuracy": metrics_history["val_accuracy"][-1],
    "total_time_minutes": total_time / 60,
}
with open(output_dir / "summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(f"\nAll files saved to {output_dir}/")
print(f"  - policy/: Trained policy model")
print(f"  - momentum/: Momentum model (EMA)")
print(f"  - metrics.json: Training metrics")
print(f"  - config.json: Configuration")
print(f"  - summary.json: Final summary")

---
## PART 9: FINAL EVALUATION
---

In [None]:
print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Evaluate on all validation problems
val_results = {}
total_correct = 0
total_problems = len(val_problems)

for prob_type in CONFIG["problem_types"]:
    val_results[prob_type] = {"correct": 0, "total": 0}

print("\nEvaluating on validation set...")
for vp in val_problems:
    # Generate solution
    gens = generate_samples(
        policy_model, tokenizer, vp.to_prompt(),
        num_samples=1,
        max_new_tokens=512,
        temperature=0.3,  # Lower temp for eval
    )
    
    code = extract_code(gens[0])
    result = verify_solution(code, vp)
    
    # Find problem type
    for pt in CONFIG["problem_types"]:
        if pt in vp.problem_id:
            val_results[pt]["total"] += 1
            if result["passed"]:
                val_results[pt]["correct"] += 1
                total_correct += 1
            break

# Print results
print("\nResults by Problem Type:")
print("-" * 40)
for prob_type, stats in val_results.items():
    if stats["total"] > 0:
        acc = stats["correct"] / stats["total"] * 100
        print(f"  {prob_type:20s}: {stats['correct']}/{stats['total']} = {acc:.1f}%")

print("-" * 40)
overall_acc = total_correct / total_problems * 100
print(f"  {'OVERALL':20s}: {total_correct}/{total_problems} = {overall_acc:.1f}%")

In [None]:
# Show a sample generation
print("\n" + "=" * 60)
print("SAMPLE GENERATION")
print("=" * 60)

sample_problem = val_problems[0]
print(f"\nProblem: {sample_problem.title}")
print(f"\nPrompt:")
print(sample_problem.to_prompt()[:500] + "...")

# Generate
sample_gen = generate_samples(
    policy_model, tokenizer, sample_problem.to_prompt(),
    num_samples=1,
    max_new_tokens=512,
    temperature=0.3,
)[0]

print(f"\nGenerated Solution:")
print(sample_gen[:800])

# Verify
code = extract_code(sample_gen)
result = verify_solution(code, sample_problem)
print(f"\nVerification: {result['passed_count']}/{result['total_count']} tests passed")
if result["error"]:
    print(f"Error: {result['error']}")

---
## EXPERIMENT COMPLETE

### Summary

This notebook implemented **M-GRPO (Momentum-Anchored GRPO)** training with:

1. **Two-model architecture**: Policy (trainable) + Momentum (EMA)
2. **Combined rollout**: Samples from both models for diversity
3. **Partial rewards**: Proportional credit for passing tests
4. **IQR filtering**: Removes low-entropy samples
5. **Entropy monitoring**: Detects and warns about collapse

### Files Created

- `mgrpo_output/policy/` - Trained policy model (can load with HuggingFace)
- `mgrpo_output/momentum/` - Momentum model
- `mgrpo_output/metrics.json` - Training metrics
- `mgrpo_output/config.json` - Configuration
- `mgrpo_output/summary.json` - Final summary
- `mgrpo_training_results.png` - Training curves

### Next Steps

1. Download the trained model from `mgrpo_output/policy/`
2. Run on harder problems or more iterations
3. Compare against vanilla GRPO baseline
4. Tune hyperparameters (momentum, learning rate, samples)

---