# Evaluate Reasoning Specialist (500M)

This notebook evaluates the **Reasoning Specialist** model for its role in the Svend ensemble:

**Role in Ensemble:**
- Step-by-step reasoning chains
- Mathematical problem solving
- Tool calling (SymPy, Z3, Python sandbox)
- Long context reasoning (8K tokens)
- Gets verified by the Verifier model

**What We Test:**
1. Math reasoning (GSM8K-style)
2. Multi-step logic
3. Tool call formatting
4. Chain-of-thought quality
5. Answer extraction reliability

## 1. Setup

In [None]:
# Mount Drive to load checkpoint
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo
!rm -rf /content/svend
!git clone https://github.com/ewolters/svend.git /content/svend
%cd /content/svend
!git log -1 --oneline

In [None]:
!pip install -q torch transformers datasets

In [None]:
import sys
modules_to_remove = [key for key in sys.modules.keys() if key.startswith('src')]
for mod in modules_to_remove:
    del sys.modules[mod]
if '/content/svend' not in sys.path:
    sys.path.insert(0, '/content/svend')

import torch
import json
import re
from transformers import AutoTokenizer
from src.models.config import get_config
from src.models.transformer import ReasoningTransformer

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

## 2. Load Model

In [None]:
# Path to your trained checkpoint
CHECKPOINT_PATH = "/content/drive/MyDrive/svend-checkpoints/base-reasoner/final.pt"
# Or use a step checkpoint:
# CHECKPOINT_PATH = "/content/drive/MyDrive/svend-checkpoints/base-reasoner/step_001000.pt"

MODEL_SIZE = "500m"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load model
model_config = get_config(MODEL_SIZE)
model_config.vocab_size = tokenizer.vocab_size

model = ReasoningTransformer(model_config)

# Load weights
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.cuda()
model.eval()

params = sum(p.numel() for p in model.parameters())
print(f"Loaded model: {params/1e6:.1f}M parameters")

In [None]:
# Generation helper
@torch.no_grad()
def generate(prompt, max_tokens=256, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=temperature > 0,
        top_p=0.9,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

## 3. Evaluation Benchmarks

Testing capabilities critical for the Reasoning Specialist role.

In [None]:
# Evaluation problems designed for reasoning specialist

EVAL_PROBLEMS = {
    "math_basic": [
        {"q": "What is 15% of 200?", "a": "30"},
        {"q": "If a train travels at 60 mph for 2.5 hours, how far does it go?", "a": "150"},
        {"q": "What is 7 * 8 + 12 / 4?", "a": "59"},
        {"q": "A shirt costs $25. With a 20% discount, what is the final price?", "a": "20"},
        {"q": "If 3x + 7 = 22, what is x?", "a": "5"},
    ],
    
    "math_multistep": [
        {
            "q": "A snail climbs 3 meters up a wall during the day but slides down 2 meters at night. If the wall is 10 meters tall, how many days does it take to reach the top?",
            "a": "8"
        },
        {
            "q": "If 5 machines can produce 5 widgets in 5 minutes, how many minutes would it take 100 machines to produce 100 widgets?",
            "a": "5"
        },
        {
            "q": "A store has a 'buy 2 get 1 free' deal on $10 items. How much do you pay for 7 items?",
            "a": "50"
        },
        {
            "q": "John is twice as old as Mary. In 5 years, the sum of their ages will be 40. How old is Mary now?",
            "a": "10"
        },
    ],
    
    "logic": [
        {
            "q": "If all cats are mammals, and some mammals are pets, can we conclude that some cats are pets? Answer yes or no.",
            "a": "no"
        },
        {
            "q": "Alice is taller than Bob. Bob is taller than Charlie. Is Alice taller than Charlie? Answer yes or no.",
            "a": "yes"
        },
        {
            "q": "If it rains, the ground is wet. The ground is wet. Did it rain? Answer yes, no, or cannot determine.",
            "a": "cannot determine"
        },
    ],
    
    "tool_format": [
        {
            "q": "Use the calculator tool to compute the derivative of x^3 + 2x. Format: <tool_call>calculator: [expression]</tool_call>",
            "check": "tool_call",
            "expected_contains": ["tool_call", "derivative", "3*x"]
        },
        {
            "q": "Use the python tool to calculate 17 factorial. Format: <tool_call>python: [code]</tool_call>",
            "check": "tool_call",
            "expected_contains": ["tool_call", "factorial", "17"]
        },
    ],
    
    "chain_of_thought": [
        {
            "q": "Think step by step: A farmer has 17 sheep. All but 9 die. How many are left?",
            "a": "9",
            "check_cot": True  # Check that reasoning appears before answer
        },
        {
            "q": "Think step by step: I have a 3-gallon jug and a 5-gallon jug. How can I measure exactly 4 gallons?",
            "check_cot": True,
            "expected_steps": 3  # Should have multiple reasoning steps
        },
    ]
}

print(f"Total problems: {sum(len(v) for v in EVAL_PROBLEMS.values())}")
for cat, probs in EVAL_PROBLEMS.items():
    print(f"  {cat}: {len(probs)}")

## 4. Run Evaluation

In [None]:
def extract_number(text):
    """Extract the last number from text."""
    # Look for explicit answer patterns first
    patterns = [
        r"(?:answer|result)\s*(?:is|=|:)?\s*(-?\d+(?:\.\d+)?)",
        r"=\s*(-?\d+(?:\.\d+)?)\s*$",
        r"####\s*(-?\d+(?:\.\d+)?)",
    ]
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1)
    
    # Fall back to last number
    numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
    return numbers[-1] if numbers else None

def check_answer(response, expected):
    """Check if response contains the expected answer."""
    response_lower = response.lower()
    expected_lower = expected.lower()
    
    # Direct containment
    if expected_lower in response_lower:
        return True
    
    # Numeric comparison
    try:
        extracted = extract_number(response)
        if extracted:
            return abs(float(extracted) - float(expected)) < 0.01
    except:
        pass
    
    return False

def check_tool_format(response, expected_contains):
    """Check if response has proper tool call format."""
    response_lower = response.lower()
    return all(term.lower() in response_lower for term in expected_contains)

def count_reasoning_steps(response):
    """Count reasoning steps in response."""
    # Look for numbered steps, bullet points, or 'step' mentions
    step_patterns = [
        r"step\s*\d",
        r"^\d+\.",
        r"^-\s",
        r"first|second|third|then|next|finally",
    ]
    count = 0
    for pattern in step_patterns:
        count += len(re.findall(pattern, response, re.IGNORECASE | re.MULTILINE))
    return count

In [None]:
# Run evaluation
results = {}

for category, problems in EVAL_PROBLEMS.items():
    print(f"\n{'='*50}")
    print(f"Category: {category}")
    print('='*50)
    
    category_results = []
    
    for i, prob in enumerate(problems):
        prompt = f"Question: {prob['q']}\n\nAnswer:"
        response = generate(prompt, max_tokens=300, temperature=0.3)
        
        # Extract just the generated part
        generated = response[len(prompt):].strip()
        
        # Evaluate based on problem type
        result = {"question": prob['q'], "response": generated}
        
        if prob.get('check') == 'tool_call':
            result['correct'] = check_tool_format(generated, prob['expected_contains'])
        elif 'a' in prob:
            result['correct'] = check_answer(generated, prob['a'])
            result['expected'] = prob['a']
        
        if prob.get('check_cot'):
            result['reasoning_steps'] = count_reasoning_steps(generated)
            result['has_reasoning'] = result['reasoning_steps'] >= 2
        
        category_results.append(result)
        
        # Print result
        status = "PASS" if result.get('correct', result.get('has_reasoning', False)) else "FAIL"
        print(f"\n[{status}] Q{i+1}: {prob['q'][:60]}...")
        print(f"  Response: {generated[:100]}..." if len(generated) > 100 else f"  Response: {generated}")
    
    results[category] = category_results
    
    # Category summary
    if category_results:
        if 'correct' in category_results[0]:
            acc = sum(r['correct'] for r in category_results) / len(category_results)
            print(f"\n{category} Accuracy: {acc:.1%}")

## 5. Summary & Analysis

In [None]:
print("\n" + "="*60)
print("EVALUATION SUMMARY - Reasoning Specialist (500M)")
print("="*60)

total_correct = 0
total_problems = 0

for category, cat_results in results.items():
    if cat_results and 'correct' in cat_results[0]:
        correct = sum(r['correct'] for r in cat_results)
        total = len(cat_results)
        total_correct += correct
        total_problems += total
        print(f"{category:20s}: {correct}/{total} ({correct/total:.1%})")
    elif cat_results and 'has_reasoning' in cat_results[0]:
        has_cot = sum(r.get('has_reasoning', False) for r in cat_results)
        avg_steps = sum(r.get('reasoning_steps', 0) for r in cat_results) / len(cat_results)
        print(f"{category:20s}: {has_cot}/{len(cat_results)} with CoT, avg {avg_steps:.1f} steps")

if total_problems > 0:
    print(f"\n{'Overall':20s}: {total_correct}/{total_problems} ({total_correct/total_problems:.1%})")

print("\n" + "="*60)
print("ROLE FITNESS ASSESSMENT")
print("="*60)

# Assess fitness for ensemble role
assessments = []

# Math capability
math_results = results.get('math_basic', []) + results.get('math_multistep', [])
if math_results:
    math_acc = sum(r['correct'] for r in math_results) / len(math_results)
    if math_acc >= 0.7:
        assessments.append(("Math Reasoning", "GOOD", f"{math_acc:.1%}"))
    elif math_acc >= 0.4:
        assessments.append(("Math Reasoning", "NEEDS WORK", f"{math_acc:.1%}"))
    else:
        assessments.append(("Math Reasoning", "POOR", f"{math_acc:.1%}"))

# Logic capability
logic_results = results.get('logic', [])
if logic_results:
    logic_acc = sum(r['correct'] for r in logic_results) / len(logic_results)
    if logic_acc >= 0.6:
        assessments.append(("Logic", "GOOD", f"{logic_acc:.1%}"))
    else:
        assessments.append(("Logic", "NEEDS WORK", f"{logic_acc:.1%}"))

# Tool format capability
tool_results = results.get('tool_format', [])
if tool_results:
    tool_acc = sum(r['correct'] for r in tool_results) / len(tool_results)
    if tool_acc >= 0.5:
        assessments.append(("Tool Formatting", "GOOD", f"{tool_acc:.1%}"))
    else:
        assessments.append(("Tool Formatting", "NEEDS TRAINING", f"{tool_acc:.1%}"))

# Chain of thought
cot_results = results.get('chain_of_thought', [])
if cot_results:
    has_cot = sum(r.get('has_reasoning', False) for r in cot_results) / len(cot_results)
    if has_cot >= 0.7:
        assessments.append(("Chain-of-Thought", "GOOD", f"{has_cot:.1%}"))
    else:
        assessments.append(("Chain-of-Thought", "NEEDS WORK", f"{has_cot:.1%}"))

for skill, status, score in assessments:
    print(f"{skill:20s}: {status:12s} ({score})")

print("\n" + "="*60)

## 6. Recommendations

In [None]:
print("RECOMMENDATIONS FOR TUNING")
print("="*60)

recommendations = []

# Based on assessments
for skill, status, score in assessments:
    if status == "NEEDS WORK" or status == "POOR":
        if "Math" in skill:
            recommendations.append("- Add more GSM8K/MATH training data")
            recommendations.append("- Include step-by-step solutions in training")
        elif "Logic" in skill:
            recommendations.append("- Add logic puzzle datasets (e.g., LogiQA)")
            recommendations.append("- Train on syllogistic reasoning examples")
        elif "Tool" in skill:
            recommendations.append("- Add synthetic tool-calling examples")
            recommendations.append("- Fine-tune on <tool_call> format specifically")
        elif "Chain" in skill:
            recommendations.append("- Use 'Let's think step by step' prompting in training")
            recommendations.append("- Train on CoT-annotated datasets")
    elif status == "NEEDS TRAINING":
        if "Tool" in skill:
            recommendations.append("- CRITICAL: Model needs tool-call format training")
            recommendations.append("- Generate synthetic tool-calling dataset")

if not recommendations:
    print("Model looks ready for ensemble integration!")
    print("\nNext steps:")
    print("1. Train the Verifier model to validate this model's outputs")
    print("2. Train the Router to direct queries here")
    print("3. Integration testing with full ensemble")
else:
    for rec in recommendations:
        print(rec)

print("\n" + "="*60)

## 7. Save Results

In [None]:
# Save evaluation results
import json
from datetime import datetime

eval_output = {
    "timestamp": datetime.now().isoformat(),
    "model": MODEL_SIZE,
    "checkpoint": CHECKPOINT_PATH,
    "results": results,
    "assessments": assessments,
}

output_path = "/content/drive/MyDrive/svend-checkpoints/base-reasoner/eval_results.json"
with open(output_path, 'w') as f:
    json.dump(eval_output, f, indent=2, default=str)

print(f"Results saved to: {output_path}")

## 8. Interactive Testing

In [None]:
# Test custom prompts
test_prompt = "Question: A bat and a ball cost $1.10. The bat costs $1.00 more than the ball. How much does the ball cost?\n\nAnswer:"

response = generate(test_prompt, max_tokens=200, temperature=0.3)
print(response)

In [None]:
# Test tool calling format
tool_prompt = """You have access to these tools:
- calculator: For math computations
- python: For running code

Question: What is the integral of x^2 dx?

Use <tool_call>tool_name: arguments</tool_call> format if needed.

Answer:"""

response = generate(tool_prompt, max_tokens=200, temperature=0.3)
print(response)