# 04 - GRPO Training on Kaggle (Generative Manim)

Self-contained GRPO with Manim verifier as reward function.

**Strategy**: Use `-ql` rendering (2-5s/scene). 8 rollouts x 5s = 40s/prompt.
1,000 prompts ~ 11 hours (fits in Kaggle 12h session).

**Prerequisites**: DPO checkpoint + prompts dataset.

In [None]:
!pip install -q torch transformers trl peft bitsandbytes accelerate datasets wandb
!pip install -q manim

In [None]:
import os
import re
import json
import subprocess
import tempfile
import shutil
import time

import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
import wandb

In [None]:
# Configuration
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"
MODEL_NAME = "qwen2.5-coder-7b"

DPO_CHECKPOINT = "/kaggle/input/gm-dpo-checkpoint/dpo-qwen2.5-coder-7b"
PROMPTS_PATH = "/kaggle/input/gm-training-data/sft_train.jsonl"
OUTPUT_DIR = f"/kaggle/working/grpo-{MODEL_NAME}"

GROUP_SIZE = 8
LEARNING_RATE = 1e-5
KL_COEFF = 0.05
TEMPERATURE = 0.8
RENDER_TIMEOUT = 60  # Shorter timeout for Kaggle

In [None]:
# W&B
from kaggle_secrets import UserSecretsClient
try:
    secrets = UserSecretsClient()
    wandb.login(key=secrets.get_secret("WANDB_API_KEY"))
    USE_WANDB = True
except:
    USE_WANDB = False

In [None]:
# === Inline Manim Verifier (no imports from training package) ===

def extract_python_code(text):
    """Extract code from markdown fences."""
    match = re.findall(r'```python\s*\n(.*?)```', text, re.DOTALL)
    if match:
        return match[0].strip()
    match = re.findall(r'```\s*\n(.*?)```', text, re.DOTALL)
    if match:
        return match[0].strip()
    return text.strip()

def clean_code(text):
    code = extract_python_code(text)
    # Normalize class name
    match = re.search(r'class\s+(\w+)\s*\(.*Scene.*\)', code)
    if match and match.group(1) != 'GenScene':
        code = code.replace(match.group(1), 'GenScene')
    if 'from manim import' not in code:
        code = 'from manim import *\n' + code
    return code

def verify_and_reward(code, timeout=60):
    """Render code, return reward (0.0-1.5)."""
    code = clean_code(code)
    
    # Check for Scene class
    match = re.search(r'class\s+(\w+)\s*\(.*Scene.*\)', code)
    if not match:
        return 0.0
    class_name = match.group(1)
    
    tmp_dir = tempfile.mkdtemp(prefix='manim_')
    try:
        scene_path = os.path.join(tmp_dir, 'scene.py')
        with open(scene_path, 'w') as f:
            f.write(code)
        
        result = subprocess.run(
            ['manim', scene_path, class_name, '--format=mp4', '-ql',
             '--media_dir', tmp_dir, '--custom_folders'],
            capture_output=True, text=True, timeout=timeout, cwd=tmp_dir
        )
        
        if result.returncode == 0:
            # Success: 1.0 + animation bonus
            play_count = len(re.findall(r'self\.play\(', code))
            bonus = min(play_count * 0.1, 0.5)
            return 1.0 + bonus
        return 0.0
    except subprocess.TimeoutExpired:
        return 0.0
    except Exception:
        return 0.0
    finally:
        shutil.rmtree(tmp_dir, ignore_errors=True)

In [None]:
# Load model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=bnb_config, device_map="auto", trust_remote_code=True
)
base_model = prepare_model_for_kbit_training(base_model)
model = PeftModel.from_pretrained(base_model, DPO_CHECKPOINT, is_trainable=True)
print("DPO model loaded")

In [None]:
# Load prompts
dataset = load_dataset("json", data_files=PROMPTS_PATH, split="train")

def extract_prompt(example):
    if "messages" in example:
        for msg in example["messages"]:
            if msg["role"] == "user":
                return {"prompt": msg["content"]}
    return {"prompt": example.get("prompt", "")}

dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
print(f"Prompts: {len(dataset)}")

In [None]:
# Reward function for GRPOTrainer
def reward_fn(completions, **kwargs):
    return [verify_and_reward(c, timeout=RENDER_TIMEOUT) for c in completions]

In [None]:
# LoRA + GRPO config
peft_config = LoraConfig(
    r=32, lora_alpha=64, lora_dropout=0.05,
    target_modules="all-linear", bias="none", task_type="CAUSAL_LM",
)

grpo_config = GRPOConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=1,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    fp16=True,
    logging_steps=5,
    save_steps=50,
    save_total_limit=2,
    optim="paged_adamw_32bit",
    report_to="wandb" if USE_WANDB else "none",
    seed=42,
    num_generations=GROUP_SIZE,
    max_completion_length=2048,
    temperature=TEMPERATURE,
)

trainer = GRPOTrainer(
    model=model,
    args=grpo_config,
    train_dataset=dataset,
    reward_funcs=reward_fn,
    peft_config=peft_config,
    tokenizer=tokenizer,
)

print("Starting GRPO training (with live Manim rendering)...")
trainer.train()

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"GRPO model saved to {OUTPUT_DIR}")

if USE_WANDB:
    wandb.finish()