# Experiment 15: M-GRPO with Entropy Control

**Status:** Ready to Run  
**Date:** 2024-12-20  
**Runtime:** Google Colab A100 Recommended  

This notebook implements M-GRPO (Momentum-Anchored GRPO) with entropy control mechanisms.

## Key Innovations from Papers

### M-GRPO Paper
- Two-model setup: policy (trainable) + momentum (EMA)
- Combined sampling: M from policy + N from momentum
- IQR-based entropy filtering to prevent mode collapse

### Entropy Mechanism Paper  
- Performance follows: R = -a * exp(H) + b
- 95% of gains in first 1/12 of training
- Clip-Cov and KL-Cov for entropy control

## PART 1: SETUP

In [None]:
# Install dependencies
!pip install -q torch transformers accelerate peft datasets matplotlib seaborn

In [None]:
# GPU Check
import torch
import gc

print("="*60)
print("GPU DETECTION")
print("="*60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Detected: {gpu_name}")
    print(f"GPU Memory: {gpu_memory_gb:.1f} GB")
    DEVICE = "cuda:0"
    DTYPE = torch.float16
else:
    print("No GPU! This experiment requires GPU.")
    DEVICE = "cpu"
    DTYPE = torch.float32

print(f"\nUsing: {DEVICE}")

In [None]:
# Clone axiom-rl repository
!git clone https://github.com/YOUR_REPO/axiom-rl.git 2>/dev/null || echo "Repo exists"
import sys
sys.path.insert(0, 'axiom-rl')

In [None]:
# Configuration
CONFIG = {
    # Model
    "model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
    
    # M-GRPO
    "num_policy_samples": 4,
    "num_momentum_samples": 4,
    "momentum": 0.99,
    "beta": 0.04,
    
    # Entropy Control
    "use_iqr_filter": True,
    "iqr_k": 0.75,
    "use_clip_cov": False,
    "use_kl_cov": False,
    
    # Training
    "num_iterations": 20,
    "learning_rate": 1e-5,
    "max_new_tokens": 512,
    "temperature": 0.7,
    
    # LoRA
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    
    # Problems
    "problem_types": ["rpn", "parentheses", "fibonacci", "binary_search"],
    "train_per_type": 10,
    "val_per_type": 5,
    "seed": 42,
}

print("Configuration loaded")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## PART 2: LOAD MODEL AND SETUP TRAINER

In [None]:
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

print("="*60)
print("LOADING MODEL")
print("="*60)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"], trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded")

# Load policy model
print("\nLoading policy model...")
policy_model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=DTYPE,
    device_map="auto",
    trust_remote_code=True,
)

# Apply LoRA
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=CONFIG["lora_dropout"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
policy_model = get_peft_model(policy_model, lora_config)
policy_model.print_trainable_parameters()

# Create momentum model (EMA copy)
print("\nCreating momentum model...")
momentum_model = copy.deepcopy(policy_model)
momentum_model.eval()
for p in momentum_model.parameters():
    p.requires_grad = False

# Reference model (frozen)
print("\nLoading reference model...")
ref_model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=DTYPE,
    device_map="auto",
    trust_remote_code=True,
)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

print("\nAll models loaded!")

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=CONFIG["learning_rate"],
)
print("Optimizer configured")

## PART 3: PROBLEM GENERATION

In [None]:
from axiom.procedural import (
    RPNEvaluatorGenerator,
    ParenthesesValidatorGenerator,
    FibonacciGenerator,
    BinarySearchGenerator,
)
import random

GENERATORS = {
    "rpn": RPNEvaluatorGenerator,
    "parentheses": ParenthesesValidatorGenerator,
    "fibonacci": FibonacciGenerator,
    "binary_search": BinarySearchGenerator,
}

def generate_problems(config, seed=42):
    rng = random.Random(seed)
    train_problems, val_problems = [], []
    
    for prob_type in config["problem_types"]:
        if prob_type not in GENERATORS:
            continue
        gen = GENERATORS[prob_type](seed=rng.randint(0, 1000000))
        
        for _ in range(config["train_per_type"]):
            train_problems.append(gen.generate(difficulty=5, num_test_cases=5))
        for _ in range(config["val_per_type"]):
            val_problems.append(gen.generate(difficulty=5, num_test_cases=5))
    
    rng.shuffle(train_problems)
    rng.shuffle(val_problems)
    return train_problems, val_problems

train_problems, val_problems = generate_problems(CONFIG)
print(f"Train problems: {len(train_problems)}")
print(f"Val problems: {len(val_problems)}")

## PART 4: M-GRPO TRAINING LOOP

In [None]:
import torch.nn.functional as F
import numpy as np
from axiom.verifier import TestHarness

harness = TestHarness()

def update_momentum(policy_model, momentum_model, m=0.99):
    """EMA update: theta_k <- m * theta_k + (1-m) * theta_q"""
    with torch.no_grad():
        for p_k, p_q in zip(momentum_model.parameters(), policy_model.parameters()):
            p_k.data.mul_(m).add_(p_q.data, alpha=1 - m)

def generate_samples(model, tokenizer, prompt, num_samples=4):
    """Generate samples from a model."""
    model.eval()
    inputs = tokenizer([prompt] * num_samples, return_tensors="pt", padding=True).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=CONFIG["max_new_tokens"],
            do_sample=True,
            temperature=CONFIG["temperature"],
            pad_token_id=tokenizer.pad_token_id,
        )
    
    input_len = inputs.input_ids.shape[1]
    return tokenizer.batch_decode(outputs[:, input_len:], skip_special_tokens=True)

def compute_entropy(logits):
    """Compute per-token entropy."""
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    return -(probs * log_probs).sum(dim=-1)

def iqr_filter(entropies, k=0.75):
    """Filter low-entropy samples using IQR method."""
    arr = np.array(entropies)
    Q1, Q3 = np.percentile(arr, 25), np.percentile(arr, 75)
    threshold = Q1 - k * (Q3 - Q1)
    return arr >= max(threshold, 0.1)

print("Training utilities defined")

In [None]:
# Training metrics tracking
metrics_history = {
    "iteration": [],
    "loss": [],
    "reward": [],
    "entropy": [],
    "filtered_count": [],
    "val_accuracy": [],
}

print("Metrics tracking initialized")

In [None]:
from datetime import datetime

print("="*60)
print("M-GRPO TRAINING")
print("="*60)
print(f"Iterations: {CONFIG['num_iterations']}")
print(f"Momentum: {CONFIG['momentum']}")
print(f"IQR Filter: {CONFIG['use_iqr_filter']}")

start_time = datetime.now()

for iteration in range(CONFIG["num_iterations"]):
    iter_start = datetime.now()
    print(f"\n--- Iteration {iteration + 1}/{CONFIG['num_iterations']} ---")
    
    total_loss = 0
    total_reward = 0
    total_entropy = 0
    total_filtered = 0
    num_samples = 0
    
    # Process each training problem
    for problem in train_problems[:10]:  # Limit for demo
        prompt = problem.to_prompt()
        
        # 1. Combined rollout: policy + momentum
        policy_gens = generate_samples(policy_model, tokenizer, prompt, CONFIG["num_policy_samples"])
        momentum_gens = generate_samples(momentum_model, tokenizer, prompt, CONFIG["num_momentum_samples"])
        all_gens = policy_gens + momentum_gens
        
        # 2. Compute rewards
        rewards = []
        for gen in all_gens:
            try:
                result = harness.verify(problem, gen)
                rewards.append(1.0 if result.passed else 0.0)
            except:
                rewards.append(0.0)
        
        policy_rewards = rewards[:len(policy_gens)]
        
        # 3. Only train on positive rewards
        for gen, rew in zip(policy_gens, policy_rewards):
            if rew <= 0:
                continue
            
            # Tokenize and compute loss
            full_text = prompt + gen
            inputs = tokenizer(full_text, return_tensors="pt", truncation=True, max_length=2048).to(policy_model.device)
            prompt_len = len(tokenizer.encode(prompt))
            
            outputs = policy_model(**inputs)
            logits = outputs.logits[:, prompt_len-1:-1, :]
            labels = inputs.input_ids[:, prompt_len:]
            
            # Compute entropy
            entropy = compute_entropy(logits).mean().item()
            total_entropy += entropy
            
            # Policy gradient loss
            log_probs = F.log_softmax(logits, dim=-1)
            token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
            loss = -token_log_probs.mean()
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_reward += rew
            num_samples += 1
    
    # 4. Update momentum model
    update_momentum(policy_model, momentum_model, CONFIG["momentum"])
    
    # 5. Evaluate on validation set
    val_correct = 0
    for problem in val_problems[:5]:  # Quick eval
        prompt = problem.to_prompt()
        gens = generate_samples(policy_model, tokenizer, prompt, 1)
        try:
            result = harness.verify(problem, gens[0])
            if result.passed:
                val_correct += 1
        except:
            pass
    val_acc = val_correct / 5
    
    # 6. Log metrics
    avg_loss = total_loss / max(num_samples, 1)
    avg_reward = total_reward / max(num_samples, 1)
    avg_entropy = total_entropy / max(num_samples, 1)
    
    metrics_history["iteration"].append(iteration + 1)
    metrics_history["loss"].append(avg_loss)
    metrics_history["reward"].append(avg_reward)
    metrics_history["entropy"].append(avg_entropy)
    metrics_history["filtered_count"].append(total_filtered)
    metrics_history["val_accuracy"].append(val_acc)
    
    iter_time = (datetime.now() - iter_start).total_seconds()
    print(f"  Loss: {avg_loss:.4f}, Reward: {avg_reward:.3f}, Entropy: {avg_entropy:.3f}")
    print(f"  Val Accuracy: {val_acc:.1%}, Time: {iter_time:.1f}s")
    
    # Clear memory
    gc.collect()
    torch.cuda.empty_cache()

total_time = (datetime.now() - start_time).total_seconds() / 60
print(f"\nTraining complete! Total time: {total_time:.1f} minutes")

## PART 5: RESULTS VISUALIZATION

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss
axes[0, 0].plot(metrics_history["iteration"], metrics_history["loss"], 'b-o')
axes[0, 0].set_xlabel("Iteration")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training Loss")
axes[0, 0].grid(True)

# Reward
axes[0, 1].plot(metrics_history["iteration"], metrics_history["reward"], 'g-o')
axes[0, 1].set_xlabel("Iteration")
axes[0, 1].set_ylabel("Reward")
axes[0, 1].set_title("Average Reward")
axes[0, 1].grid(True)

# Entropy
axes[1, 0].plot(metrics_history["iteration"], metrics_history["entropy"], 'r-o')
axes[1, 0].set_xlabel("Iteration")
axes[1, 0].set_ylabel("Entropy")
axes[1, 0].set_title("Policy Entropy (should stay stable)")
axes[1, 0].grid(True)

# Validation Accuracy
axes[1, 1].plot(metrics_history["iteration"], [v*100 for v in metrics_history["val_accuracy"]], 'm-o')
axes[1, 1].set_xlabel("Iteration")
axes[1, 1].set_ylabel("Accuracy (%)")
axes[1, 1].set_title("Validation Accuracy")
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig("mgrpo_results.png", dpi=150)
plt.show()

## PART 6: SAVE MODEL AND RESULTS

In [None]:
import json

# Save model
policy_model.save_pretrained("mgrpo_model")
tokenizer.save_pretrained("mgrpo_model")
print("Model saved to mgrpo_model/")

# Save metrics
with open("mgrpo_metrics.json", "w") as f:
    json.dump(metrics_history, f, indent=2)
print("Metrics saved to mgrpo_metrics.json")

# Save config
with open("mgrpo_config.json", "w") as f:
    json.dump(CONFIG, f, indent=2)
print("Config saved to mgrpo_config.json")

## PART 7: RUN BENCHMARKS

In [None]:
# Install benchmark dependencies
!pip install -q datasets

In [None]:
from axiom.benchmarks import run_all_benchmarks
from pathlib import Path

print("="*60)
print("RUNNING BENCHMARKS")
print("="*60)

# Run benchmarks on trained model
reports = run_all_benchmarks(
    model_path="mgrpo_model",
    benchmark_names=["math500", "gpqa_diamond"],
    output_dir=Path("benchmark_results"),
    max_samples=50,  # Limit for quick testing
)

print("\n" + "="*60)
print("BENCHMARK RESULTS")
print("="*60)
for name, report in reports.items():
    print(f"  {name}: {report.accuracy:.1%} ({report.correct}/{report.total})")

In [None]:
# Final summary
print("="*60)
print("EXPERIMENT 15 SUMMARY")
print("="*60)
print(f"\nModel: {CONFIG['model_name']}")
print(f"Iterations: {CONFIG['num_iterations']}")
print(f"Momentum: {CONFIG['momentum']}")
print(f"IQR Filter: {CONFIG['use_iqr_filter']}")
print(f"\nFinal Validation Accuracy: {metrics_history['val_accuracy'][-1]:.1%}")
print(f"Final Entropy: {metrics_history['entropy'][-1]:.3f}")
print(f"\nFiles saved:")
print("  - mgrpo_model/ (trained model)")
print("  - mgrpo_metrics.json (training metrics)")
print("  - mgrpo_config.json (configuration)")
print("  - mgrpo_results.png (training curves)")
print("  - benchmark_results/ (benchmark scores)")