# 05 - Evaluation: Cross-Model x Cross-Stage Comparison

Evaluate all 9 checkpoints (3 models x 3 stages) on 200 held-out test prompts.

In [None]:
!pip install -q torch transformers peft bitsandbytes accelerate manim rich

In [None]:
import os
import re
import json
import subprocess
import tempfile
import shutil
from collections import Counter

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from tqdm.notebook import tqdm

In [None]:
# Inline helpers (same as GRPO notebook)
def extract_python_code(text):
    match = re.findall(r'```python\s*\n(.*?)```', text, re.DOTALL)
    if match: return match[0].strip()
    match = re.findall(r'```\s*\n(.*?)```', text, re.DOTALL)
    if match: return match[0].strip()
    return text.strip()

def clean_code(text):
    code = extract_python_code(text)
    match = re.search(r'class\s+(\w+)\s*\(.*Scene.*\)', code)
    if match and match.group(1) != 'GenScene':
        code = code.replace(match.group(1), 'GenScene')
    if 'from manim import' not in code:
        code = 'from manim import *\n' + code
    return code

def verify_code(code, timeout=60):
    """Returns (success, error_type, animation_count)."""
    code = clean_code(code)
    match = re.search(r'class\s+(\w+)\s*\(.*Scene.*\)', code)
    if not match:
        return False, 'class_not_found', 0
    
    tmp_dir = tempfile.mkdtemp(prefix='eval_')
    try:
        path = os.path.join(tmp_dir, 'scene.py')
        with open(path, 'w') as f:
            f.write(code)
        result = subprocess.run(
            ['manim', path, match.group(1), '--format=mp4', '-ql',
             '--media_dir', tmp_dir, '--custom_folders'],
            capture_output=True, text=True, timeout=timeout, cwd=tmp_dir
        )
        stderr = result.stderr or ''
        anims = re.findall(r'Animation\s+(\d+):', stderr)
        anim_count = max(int(a) for a in anims) + 1 if anims else 0
        
        if result.returncode == 0:
            return True, 'none', anim_count
        
        if 'SyntaxError' in stderr: return False, 'syntax_error', 0
        if 'ImportError' in stderr: return False, 'import_error', 0
        if 'is not in the script' in stderr: return False, 'class_not_found', 0
        if 'Traceback' in stderr: return False, 'runtime_error', 0
        return False, 'unknown', 0
    except subprocess.TimeoutExpired:
        return False, 'timeout', 0
    except Exception as e:
        return False, 'unknown', 0
    finally:
        shutil.rmtree(tmp_dir, ignore_errors=True)

In [None]:
# Load test prompts
TEST_PATH = "/kaggle/input/gm-training-data/test_prompts.jsonl"
prompts = []
with open(TEST_PATH) as f:
    for line in f:
        prompts.append(json.loads(line)['prompt'])
print(f'Test prompts: {len(prompts)}')

In [None]:
SYSTEM_PROMPT = (
    "Write Manim scripts for animations in Python. "
    "Generate code, not text. Always use GenScene as the class name."
)

def evaluate_checkpoint(model_id, checkpoint_path, prompts, max_prompts=200):
    """Generate + evaluate for a single checkpoint."""
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
    )
    base = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True
    )
    model = PeftModel.from_pretrained(base, checkpoint_path)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    results = []
    for prompt in tqdm(prompts[:max_prompts]):
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output = model.generate(
                **inputs, max_new_tokens=2048, temperature=0.2,
                do_sample=True, pad_token_id=tokenizer.pad_token_id
            )
        response = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        success, error_type, anims = verify_code(response)
        results.append({'success': success, 'error': error_type, 'animations': anims})
    
    # Cleanup GPU
    del model, base
    torch.cuda.empty_cache()
    
    successes = sum(1 for r in results if r['success'])
    return {
        'rate': successes / len(results),
        'successes': successes,
        'total': len(results),
        'errors': Counter(r['error'] for r in results if not r['success']),
    }

In [None]:
# Define checkpoints to evaluate
# Adjust paths to your Kaggle datasets
CHECKPOINTS = {
    ('qwen2.5-coder-7b', 'sft'): ('/kaggle/input/gm-sft/sft-qwen2.5-coder-7b', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
    ('qwen2.5-coder-7b', 'dpo'): ('/kaggle/input/gm-dpo/dpo-qwen2.5-coder-7b', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
    ('qwen2.5-coder-7b', 'grpo'): ('/kaggle/input/gm-grpo/grpo-qwen2.5-coder-7b', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
    # Add other models as they become available
}

all_results = {}
for (model_name, stage), (ckpt, model_id) in CHECKPOINTS.items():
    if os.path.exists(ckpt):
        print(f'\nEvaluating {model_name}/{stage}...')
        all_results[(model_name, stage)] = evaluate_checkpoint(model_id, ckpt, prompts)
    else:
        print(f'Skipping {model_name}/{stage} (checkpoint not found)')

In [None]:
# Comparison table
import pandas as pd

models = ['qwen2.5-coder-7b', 'deepseek-coder-v2-lite', 'codellama-7b']
stages = ['sft', 'dpo', 'grpo']

data = []
for model in models:
    row = {'Model': model}
    for stage in stages:
        key = (model, stage)
        if key in all_results:
            r = all_results[key]
            row[stage.upper()] = f"{r['rate']:.0%} ({r['successes']}/{r['total']})"
        else:
            row[stage.upper()] = '-'
    data.append(row)

df = pd.DataFrame(data)
print(df.to_string(index=False))