# Axiom-RL V2: Self-Improving Reasoning - Step by Step

This notebook walks through the **V2 Expert Iteration** approach step by step.

## What You'll Learn

1. **Why V2?** - The critical flaw in V1 and how V2 fixes it
2. **Problem Design** - How problems with multiple test cases prevent memorization
3. **Problem Generation** - Creating algorithmic problems programmatically
4. **Model Evaluation** - Testing if a model can solve problems
5. **Solution Verification** - Checking solutions against ALL test cases
6. **Training Loop** - Fine-tuning on verified solutions
7. **Self-Improvement** - The complete iteration cycle

Run each cell in order and observe the outputs carefully.

---
# Part 1: Setup
---

## 1.1 Install Dependencies

First, install the required packages. This takes a few minutes.

In [None]:
# Install dependencies (run once)
!pip install -q torch transformers accelerate peft bitsandbytes datasets
print("Dependencies installed!")

## 1.2 Check GPU Availability

We need a GPU for model inference and training.

In [None]:
import torch

print("="*50)
print("GPU CHECK")
print("="*50)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Available: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
    DEVICE = "cuda:0"
else:
    print("WARNING: No GPU found! Training will be very slow.")
    DEVICE = "cpu"

print(f"\nUsing device: {DEVICE}")

## 1.3 Set Configuration (REQUIRED)

**IMPORTANT: You must run this cell before any other cells that use CONFIG.**

Configure the experiment parameters. Adjust these based on your GPU memory.

In [None]:
# =============================================================
# EXPERIMENT CONFIGURATION - MUST RUN THIS CELL FIRST!
# =============================================================
# This cell defines CONFIG which is used throughout the notebook.
# If you get "CONFIG is not defined" errors, run this cell first.

CONFIG = {
    # Model
    "model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct",
    
    # Problems
    "problem_types": ["rpn", "parentheses"],
    "train_per_type": 5,      # Training problems per type
    "val_per_type": 3,        # Validation problems per type  
    "test_per_type": 3,       # Test problems per type
    "test_cases": 5,          # Test cases per problem (KEY!)
    
    # Training
    "num_iterations": 2,      # Self-improvement iterations
    "learning_rate": 5e-5,
    "lora_r": 16,
    "lora_alpha": 32,
    
    # Random seed for reproducibility
    "seed": 42,
}

print("="*50)
print("EXPERIMENT CONFIGURATION")
print("="*50)
print("\n[!] CONFIG is now defined and ready to use.\n")
for key, value in CONFIG.items():
    print(f"  {key:20} = {value}")
print("\n" + "="*50)

---
# Part 2: Understanding V2 Problem Design
---

## Why V1 Failed

V1 problems had a **critical flaw**: only ONE test case per problem.

```python
# V1 Problem: "Compute 3 + 4 * 2"
def solve():
    return 11  # Just return the answer - no algorithm needed!
```

The model could pass by **memorizing answers** instead of **learning algorithms**.

## How V2 Fixes This

V2 problems have **multiple test cases** with **different inputs**:

```python
# V2 Problem: "Implement RPN Evaluator"
def evaluate_rpn(expression: str) -> int:
    # Must actually implement the algorithm!
    # Can't just return a constant.
```

Test cases:
- `evaluate_rpn("3 4 +")` → 7
- `evaluate_rpn("5 2 *")` → 10  
- `evaluate_rpn("2 3 + 4 *")` → 20

**The model MUST implement the algorithm to pass all test cases.**

## 2.1 Define the TestCase Class

A test case has input arguments and an expected output.

In [None]:
from dataclasses import dataclass, field
from typing import List, Any, Optional

@dataclass
class TestCase:
    """A single test case with input and expected output."""
    input_args: List[Any]    # Arguments to pass to the function
    expected_output: Any      # Expected return value

    def to_assertion(self, func_name: str) -> str:
        """Generate an assertion statement."""
        args_str = ", ".join(repr(arg) for arg in self.input_args)
        return f"assert {func_name}({args_str}) == {self.expected_output!r}"

# Example: Create some test cases
example_cases = [
    TestCase(input_args=["3 4 +"], expected_output=7),
    TestCase(input_args=["5 2 *"], expected_output=10),
    TestCase(input_args=["2 3 + 4 *"], expected_output=20),
]

print("="*50)
print("EXAMPLE TEST CASES")
print("="*50)
for i, tc in enumerate(example_cases, 1):
    print(f"\nTest Case {i}:")
    print(f"  Input:    {tc.input_args}")
    print(f"  Expected: {tc.expected_output}")
    print(f"  Assertion: {tc.to_assertion('evaluate_rpn')}")

## 2.2 Define the AlgorithmicProblem Class

A problem contains multiple test cases and generates prompts for the model.

In [None]:
@dataclass
class AlgorithmicProblem:
    """A problem requiring algorithm implementation."""
    problem_type: str         # e.g., "rpn", "parentheses"
    problem_id: str           # Unique identifier
    title: str                # Human-readable title
    description: str          # Full problem description
    function_signature: str   # e.g., "def evaluate_rpn(expression: str) -> int:"
    test_cases: List[TestCase]  # MULTIPLE test cases
    difficulty: int = 5

    def to_prompt(self) -> str:
        """Convert to a prompt for the model."""
        # Show some example test cases (not all - keep some hidden)
        visible_cases = self.test_cases[:3]
        func_name = self.function_signature.split("(")[0].replace("def ", "")
        
        examples = "\n".join(
            f"  {tc.to_assertion(func_name)}"
            for tc in visible_cases
        )
        
        return f"""## {self.title}

{self.description}

### Function Signature
```python
{self.function_signature}
```

### Examples
```python
{examples}
```

Implement the function. Your solution must pass ALL test cases."""

    def get_func_name(self) -> str:
        """Extract function name from signature."""
        return self.function_signature.split("(")[0].replace("def ", "")

print("AlgorithmicProblem class defined!")
print(f"\nKey attributes:")
print(f"  - problem_type: Type of algorithm (rpn, parentheses, etc.)")
print(f"  - test_cases: List of TestCase objects (typically 5+)")
print(f"  - to_prompt(): Generates the prompt for the model")

---
# Part 3: Problem Generators
---

Generators create problems with randomized test cases.

## 3.1 Base Generator Class

In [None]:
import random
from abc import ABC, abstractmethod

class AlgorithmicGenerator(ABC):
    """Base class for problem generators."""
    
    def __init__(self, seed: Optional[int] = None):
        self.rng = random.Random(seed)
        self._problem_counter = 0
    
    @property
    @abstractmethod
    def problem_type(self) -> str:
        """Return the problem type identifier."""
        pass
    
    @property
    @abstractmethod
    def title(self) -> str:
        """Return the problem title."""
        pass
    
    @property
    @abstractmethod
    def description(self) -> str:
        """Return the problem description."""
        pass
    
    @property
    @abstractmethod
    def function_signature(self) -> str:
        """Return the function signature."""
        pass
    
    @abstractmethod
    def generate_test_cases(self, difficulty: int, count: int) -> List[TestCase]:
        """Generate test cases for the problem."""
        pass
    
    def generate(self, difficulty: int = 5, num_test_cases: int = 5) -> AlgorithmicProblem:
        """Generate a complete problem."""
        self._problem_counter += 1
        problem_id = f"{self.problem_type}_{self._problem_counter}"
        
        test_cases = self.generate_test_cases(difficulty, num_test_cases)
        
        return AlgorithmicProblem(
            problem_type=self.problem_type,
            problem_id=problem_id,
            title=self.title,
            description=self.description,
            function_signature=self.function_signature,
            test_cases=test_cases,
            difficulty=difficulty,
        )

print("AlgorithmicGenerator base class defined!")
print(f"\nSubclasses must implement:")
print(f"  - problem_type: e.g., 'rpn'")
print(f"  - title: e.g., 'RPN Expression Evaluator'")
print(f"  - description: Full problem description")
print(f"  - function_signature: e.g., 'def evaluate_rpn(expr: str) -> int:'")
print(f"  - generate_test_cases(): Create randomized test cases")

## 3.2 RPN Evaluator Generator

Generates Reverse Polish Notation evaluation problems.

In [None]:
class RPNEvaluatorGenerator(AlgorithmicGenerator):
    """
    Generates RPN (Reverse Polish Notation) evaluation problems.
    
    RPN puts operators AFTER operands:
      "3 4 +" means 3 + 4 = 7
      "3 4 + 2 *" means (3 + 4) * 2 = 14
    
    Requires a stack-based algorithm to solve.
    """

    @property
    def problem_type(self) -> str:
        return "rpn"

    @property
    def title(self) -> str:
        return "RPN Expression Evaluator"

    @property
    def description(self) -> str:
        return """Implement a Reverse Polish Notation (RPN) expression evaluator.

In RPN, operators come AFTER their operands:
- "3 4 +" means 3 + 4 = 7
- "3 4 + 2 *" means (3 + 4) * 2 = 14
- "5 1 2 + 4 * + 3 -" means 5 + ((1 + 2) * 4) - 3 = 14

Rules:
- Tokens are separated by spaces
- Valid operators: +, -, *
- All numbers are integers
- Return the final result as an integer"""

    @property
    def function_signature(self) -> str:
        return "def evaluate_rpn(expression: str) -> int:"

    def generate_test_cases(self, difficulty: int, count: int = 5) -> List[TestCase]:
        """Generate RPN expressions with their correct answers."""
        test_cases = []
        for _ in range(count):
            expr, result = self._generate_rpn_expression(difficulty)
            test_cases.append(TestCase(input_args=[expr], expected_output=result))
        return test_cases

    def _generate_rpn_expression(self, difficulty: int) -> tuple:
        """Generate a valid RPN expression and its result."""
        num_ops = min(1 + difficulty // 2, 5)
        ops = ['+', '-', '*']

        # Build expression using stack simulation
        stack = []
        tokens = []

        # Start with two numbers
        n1 = self.rng.randint(1, 9)
        n2 = self.rng.randint(1, 9)
        stack.extend([n1, n2])
        tokens.extend([str(n1), str(n2)])

        # Add operations
        for _ in range(num_ops):
            if len(stack) >= 2:
                op = self.rng.choice(ops)
                b, a = stack.pop(), stack.pop()
                result = a + b if op == '+' else (a - b if op == '-' else a * b)
                stack.append(result)
                tokens.append(op)

            if self.rng.random() < 0.4 and len(tokens) < difficulty * 2:
                n = self.rng.randint(1, 9)
                stack.append(n)
                tokens.append(str(n))

        # Reduce to single result
        while len(stack) > 1:
            op = self.rng.choice(ops)
            b, a = stack.pop(), stack.pop()
            result = a + b if op == '+' else (a - b if op == '-' else a * b)
            stack.append(result)
            tokens.append(op)

        return " ".join(tokens), stack[0]

print("RPNEvaluatorGenerator defined!")

## 3.3 Test the RPN Generator

Let's generate a problem and see what it looks like.

In [None]:
# Create generator with fixed seed for reproducibility
rpn_gen = RPNEvaluatorGenerator(seed=42)

# Generate a problem with 5 test cases
rpn_problem = rpn_gen.generate(difficulty=5, num_test_cases=5)

print("="*60)
print("GENERATED RPN PROBLEM")
print("="*60)
print(f"\nProblem ID: {rpn_problem.problem_id}")
print(f"Problem Type: {rpn_problem.problem_type}")
print(f"Difficulty: {rpn_problem.difficulty}")
print(f"Number of Test Cases: {len(rpn_problem.test_cases)}")

print("\n" + "-"*60)
print("TEST CASES:")
print("-"*60)
for i, tc in enumerate(rpn_problem.test_cases, 1):
    expr = tc.input_args[0]
    result = tc.expected_output
    print(f"  {i}. evaluate_rpn(\"{expr}\") -> {result}")

# Check output diversity
outputs = [tc.expected_output for tc in rpn_problem.test_cases]
unique_outputs = len(set(outputs))
print(f"\nOutput Diversity: {unique_outputs}/{len(outputs)} unique values")
if unique_outputs >= 3:
    print("GOOD: Diverse outputs - can't pass by hardcoding!")
else:
    print("WARNING: Low diversity - might allow hardcoding")

## 3.4 See the Generated Prompt

This is what the model sees.

In [None]:
print("="*60)
print("PROMPT FOR THE MODEL")
print("="*60)
print(rpn_problem.to_prompt())

## 3.5 Parentheses Validator Generator

Another problem type that requires a stack-based algorithm.

In [None]:
class ParenthesesValidatorGenerator(AlgorithmicGenerator):
    """
    Generates parentheses validation problems.
    
    Check if a string of brackets is properly balanced:
      "()[]{}" -> True
      "([)]" -> False
    """

    @property
    def problem_type(self) -> str:
        return "parentheses"

    @property
    def title(self) -> str:
        return "Valid Parentheses Checker"

    @property
    def description(self) -> str:
        return """Implement a function to check if a string of brackets is valid.

A string is valid if:
1. Open brackets are closed by the same type of brackets
2. Open brackets are closed in the correct order
3. Every close bracket has a corresponding open bracket

Valid brackets: (), [], {}"""

    @property
    def function_signature(self) -> str:
        return "def is_valid(s: str) -> bool:"

    def generate_test_cases(self, difficulty: int, count: int = 5) -> List[TestCase]:
        """Generate bracket strings with correct validity."""
        test_cases = []
        num_valid = count // 2 + 1
        num_invalid = count - num_valid

        for _ in range(num_valid):
            s = self._generate_valid(difficulty)
            test_cases.append(TestCase(input_args=[s], expected_output=True))

        for _ in range(num_invalid):
            s = self._generate_invalid(difficulty)
            test_cases.append(TestCase(input_args=[s], expected_output=False))

        self.rng.shuffle(test_cases)
        return test_cases

    def _generate_valid(self, difficulty: int) -> str:
        """Generate a valid bracket string."""
        length = 2 * (1 + difficulty // 2)
        pairs = [("(", ")"), ("[", "]"), ("{", "}")]
        if difficulty <= 3:
            pairs = pairs[:1]
        elif difficulty <= 6:
            pairs = pairs[:2]

        result, stack = [], []
        while len(result) < length:
            remaining = length - len(result)
            if len(stack) >= remaining // 2:
                result.append(stack.pop())
            elif len(stack) == 0:
                pair = self.rng.choice(pairs)
                result.append(pair[0])
                stack.append(pair[1])
            elif self.rng.random() < 0.5:
                pair = self.rng.choice(pairs)
                result.append(pair[0])
                stack.append(pair[1])
            else:
                result.append(stack.pop())
        while stack:
            result.append(stack.pop())
        return "".join(result)

    def _generate_invalid(self, difficulty: int) -> str:
        """Generate an invalid bracket string."""
        pairs = [("(", ")"), ("[", "]"), ("{", "}")]
        if difficulty <= 3:
            pairs = pairs[:1]
        elif difficulty <= 6:
            pairs = pairs[:2]

        pattern = self.rng.choice(["mismatch", "unclosed", "extra_close"])
        
        if pattern == "mismatch" and len(pairs) > 1:
            p1, p2 = self.rng.sample(pairs, 2)
            return p1[0] + p2[1]
        elif pattern == "unclosed":
            pair = self.rng.choice(pairs)
            return pair[0] * 2 + pair[1]
        else:
            pair = self.rng.choice(pairs)
            return pair[1] + pair[0] + pair[1]

print("ParenthesesValidatorGenerator defined!")

## 3.6 Test the Parentheses Generator

In [None]:
paren_gen = ParenthesesValidatorGenerator(seed=42)
paren_problem = paren_gen.generate(difficulty=5, num_test_cases=6)

print("="*60)
print("GENERATED PARENTHESES PROBLEM")
print("="*60)
print(f"\nProblem ID: {paren_problem.problem_id}")
print(f"Number of Test Cases: {len(paren_problem.test_cases)}")

print("\n" + "-"*60)
print("TEST CASES:")
print("-"*60)
true_count = 0
for i, tc in enumerate(paren_problem.test_cases, 1):
    s = tc.input_args[0]
    result = tc.expected_output
    if result:
        true_count += 1
    print(f"  {i}. is_valid(\"{s}\") -> {result}")

print(f"\nBalance: {true_count} True, {len(paren_problem.test_cases) - true_count} False")
print("GOOD: Mix of True/False prevents always returning same value!")

## 3.7 Why Hardcoding Fails

Let's prove that a hardcoded solution cannot pass all test cases.

In [None]:
print("="*60)
print("HARDCODING PREVENTION DEMONSTRATION")
print("="*60)

# RPN Problem
print("\n--- RPN Problem ---")
for i, tc in enumerate(rpn_problem.test_cases, 1):
    print(f"  {i}. evaluate_rpn(\"{tc.input_args[0]}\") -> {tc.expected_output}")

# Try hardcoding the first answer
first_answer = rpn_problem.test_cases[0].expected_output
print(f"\nHardcoded solution: return {first_answer}")
passed = sum(1 for tc in rpn_problem.test_cases if tc.expected_output == first_answer)
print(f"Would pass: {passed}/{len(rpn_problem.test_cases)} test cases")
print(f"Result: {'FAIL' if passed < len(rpn_problem.test_cases) else 'PASS'}")

# Parentheses Problem
print("\n--- Parentheses Problem ---")
for i, tc in enumerate(paren_problem.test_cases, 1):
    print(f"  {i}. is_valid(\"{tc.input_args[0]}\") -> {tc.expected_output}")

print(f"\nHardcoded solution: return True")
passed_true = sum(1 for tc in paren_problem.test_cases if tc.expected_output == True)
print(f"Would pass: {passed_true}/{len(paren_problem.test_cases)} test cases")

print(f"\nHardcoded solution: return False")
passed_false = sum(1 for tc in paren_problem.test_cases if tc.expected_output == False)
print(f"Would pass: {passed_false}/{len(paren_problem.test_cases)} test cases")

print("\n" + "="*60)
print("CONCLUSION: Hardcoding CANNOT pass all test cases!")
print("The model MUST implement the actual algorithm.")
print("="*60)

---
# Part 4: Solution Verification
---

We need to safely execute model-generated code and check if it passes all test cases.

## 4.1 Code Extraction

Extract Python code from model responses.

In [None]:
import re

def extract_code(response: str) -> Optional[str]:
    """Extract Python code from a model response."""
    # Try to find code in ```python blocks
    pattern = r"```python\s*\n(.*?)```"
    matches = re.findall(pattern, response, re.DOTALL)
    if matches:
        return matches[-1].strip()  # Return last code block
    
    # Try plain ``` blocks
    pattern = r"```\s*\n(.*?)```"
    matches = re.findall(pattern, response, re.DOTALL)
    if matches:
        return matches[-1].strip()
    
    return None

# Test with example response
example_response = '''Here's my solution:

```python
def evaluate_rpn(expression: str) -> int:
    stack = []
    for token in expression.split():
        if token.lstrip('-').isdigit():
            stack.append(int(token))
        else:
            b, a = stack.pop(), stack.pop()
            if token == '+': stack.append(a + b)
            elif token == '-': stack.append(a - b)
            elif token == '*': stack.append(a * b)
    return stack[0]
```
'''

extracted = extract_code(example_response)
print("="*60)
print("CODE EXTRACTION TEST")
print("="*60)
print("\nExtracted code:")
print(extracted)

## 4.2 Safe Code Execution

Execute code in a controlled environment with timeout.

In [None]:
import signal
import traceback

class TimeoutError(Exception):
    pass

def execute_with_timeout(code: str, func_name: str, args: list, timeout: float = 5.0):
    """
    Execute code and call the function with given arguments.
    Returns (success, result_or_error)
    """
    try:
        # Create isolated namespace
        namespace = {}
        
        # Execute the code to define the function
        exec(code, namespace)
        
        # Check if function exists
        if func_name not in namespace:
            return False, f"Function '{func_name}' not defined"
        
        # Call the function
        func = namespace[func_name]
        result = func(*args)
        
        return True, result
        
    except Exception as e:
        return False, f"{type(e).__name__}: {str(e)}"

# Test execution
print("="*60)
print("CODE EXECUTION TEST")
print("="*60)

test_code = extracted  # Use the code we extracted above
test_cases_to_try = [
    ("3 4 +", 7),
    ("5 2 *", 10),
    ("2 3 + 4 *", 20),
]

print("\nTesting extracted code against test cases:")
for expr, expected in test_cases_to_try:
    success, result = execute_with_timeout(test_code, "evaluate_rpn", [expr])
    status = "PASS" if success and result == expected else "FAIL"
    print(f"  {status}: evaluate_rpn(\"{expr}\") = {result} (expected {expected})")

## 4.3 Solution Verifier

Verify a solution against ALL test cases in a problem.

In [None]:
def verify_solution(code: str, problem: AlgorithmicProblem, verbose: bool = True) -> tuple:
    """
    Verify a solution against all test cases.
    
    Returns: (passed_all, passed_count, total_count, details)
    """
    func_name = problem.get_func_name()
    passed_count = 0
    details = []
    
    for i, tc in enumerate(problem.test_cases):
        success, result = execute_with_timeout(code, func_name, tc.input_args)
        
        if success and result == tc.expected_output:
            passed_count += 1
            status = "PASS"
        else:
            status = "FAIL"
        
        detail = {
            "test_case": i + 1,
            "input": tc.input_args,
            "expected": tc.expected_output,
            "actual": result if success else "ERROR",
            "status": status,
        }
        details.append(detail)
        
        if verbose:
            print(f"  {status}: {func_name}({tc.input_args[0]!r}) = {result} (expected {tc.expected_output})")
    
    passed_all = (passed_count == len(problem.test_cases))
    return passed_all, passed_count, len(problem.test_cases), details

# Test the verifier
print("="*60)
print("SOLUTION VERIFICATION TEST")
print("="*60)

print(f"\nVerifying against {len(rpn_problem.test_cases)} test cases:")
passed_all, passed, total, _ = verify_solution(extracted, rpn_problem)

print(f"\nResult: {passed}/{total} test cases passed")
print(f"Verdict: {'ACCEPTED' if passed_all else 'REJECTED'}")

---
# Part 5: Model Loading and Inference
---

## 5.1 Load the Base Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Check that CONFIG is defined
if 'CONFIG' not in dir():
    raise RuntimeError("CONFIG is not defined! Please run cell 1.3 first.")

print("="*60)
print("LOADING MODEL")
print("="*60)
print(f"\nModel: {CONFIG['model_name']}")
print("This may take a few minutes...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded!")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float16,
    device_map="auto",
)
print(f"Model loaded on: {model.device}")

# Model info
num_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {num_params / 1e9:.2f}B")

## 5.2 Generate Solution for a Problem

In [None]:
def generate_solution(model, tokenizer, problem: AlgorithmicProblem, 
                      max_new_tokens: int = 512, temperature: float = 0.7) -> str:
    """
    Generate a solution for a problem using the model.
    """
    # Build the prompt
    system_msg = "You are an expert Python programmer. Implement the requested function."
    user_msg = problem.to_prompt()
    
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_msg},
    ]
    
    # Apply chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    # Decode response (excluding prompt)
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response

print("generate_solution function defined!")

## 5.3 Test Solution Generation

In [None]:
print("="*60)
print("GENERATING SOLUTION FOR RPN PROBLEM")
print("="*60)

print("\nProblem:")
print(f"  {rpn_problem.title}")
print(f"  Test cases: {len(rpn_problem.test_cases)}")

print("\nGenerating solution...")
response = generate_solution(model, tokenizer, rpn_problem)

print("\n" + "-"*60)
print("MODEL RESPONSE:")
print("-"*60)
print(response[:1000])  # Show first 1000 chars
if len(response) > 1000:
    print("... (truncated)")

## 5.4 Extract and Verify the Generated Solution

In [None]:
print("="*60)
print("VERIFYING GENERATED SOLUTION")
print("="*60)

# Extract code
generated_code = extract_code(response)

if generated_code:
    print("\nExtracted code:")
    print("-"*40)
    print(generated_code)
    print("-"*40)
    
    print(f"\nVerifying against {len(rpn_problem.test_cases)} test cases:")
    passed_all, passed, total, _ = verify_solution(generated_code, rpn_problem)
    
    print(f"\nResult: {passed}/{total} test cases passed")
    print(f"Verdict: {'ACCEPTED - Can use for training!' if passed_all else 'REJECTED - Not good enough'}")
else:
    print("ERROR: Could not extract code from response")

---
# Part 6: Problem Set Generation
---

Generate train/val/test problem sets.

## 6.1 Generator Registry

In [None]:
# Registry of available generators
GENERATORS_V2 = {
    "rpn": RPNEvaluatorGenerator,
    "parentheses": ParenthesesValidatorGenerator,
}

print("="*60)
print("AVAILABLE PROBLEM GENERATORS")
print("="*60)
for name, cls in GENERATORS_V2.items():
    gen = cls(seed=42)
    print(f"\n{name}:")
    print(f"  Title: {gen.title}")
    print(f"  Signature: {gen.function_signature}")

## 6.2 Generate Problem Sets

In [None]:
# Check that CONFIG is defined
if 'CONFIG' not in dir():
    raise RuntimeError("CONFIG is not defined! Please run cell 1.3 first.")

def generate_problem_sets(config: dict, seed: int = 42) -> tuple:
    """
    Generate train, validation, and test problem sets.
    
    Returns: (train_problems, val_problems, test_problems)
    """
    rng = random.Random(seed)
    
    train_problems = []
    val_problems = []
    test_problems = []
    
    for prob_type in config["problem_types"]:
        if prob_type not in GENERATORS_V2:
            print(f"WARNING: Unknown problem type '{prob_type}', skipping")
            continue
        
        # Create generator with unique seed per type
        gen = GENERATORS_V2[prob_type](seed=rng.randint(0, 1000000))
        
        # Generate problems for each split
        for _ in range(config["train_per_type"]):
            difficulty = rng.randint(3, 7)
            problem = gen.generate(difficulty=difficulty, num_test_cases=config["test_cases"])
            train_problems.append(problem)
        
        for _ in range(config["val_per_type"]):
            difficulty = rng.randint(3, 7)
            problem = gen.generate(difficulty=difficulty, num_test_cases=config["test_cases"])
            val_problems.append(problem)
        
        for _ in range(config["test_per_type"]):
            difficulty = rng.randint(3, 7)
            problem = gen.generate(difficulty=difficulty, num_test_cases=config["test_cases"])
            test_problems.append(problem)
    
    # Shuffle
    rng.shuffle(train_problems)
    rng.shuffle(val_problems)
    rng.shuffle(test_problems)
    
    return train_problems, val_problems, test_problems

# Generate problem sets
print("="*60)
print("GENERATING PROBLEM SETS")
print("="*60)

train_problems, val_problems, test_problems = generate_problem_sets(CONFIG, seed=CONFIG["seed"])

print(f"\nGenerated:")
print(f"  Train: {len(train_problems)} problems")
print(f"  Val:   {len(val_problems)} problems")
print(f"  Test:  {len(test_problems)} problems")
print(f"  Test cases per problem: {CONFIG['test_cases']}")

## 6.3 Inspect Generated Problems

In [None]:
print("="*60)
print("TRAINING PROBLEMS OVERVIEW")
print("="*60)

# Count by type
type_counts = {}
for p in train_problems:
    type_counts[p.problem_type] = type_counts.get(p.problem_type, 0) + 1

print("\nProblems by type:")
for ptype, count in type_counts.items():
    print(f"  {ptype}: {count}")

print("\nFirst 5 training problems:")
for i, p in enumerate(train_problems[:5], 1):
    print(f"  {i}. [{p.problem_type}] {p.problem_id} ({len(p.test_cases)} test cases)")

---
# Part 7: Evaluation Loop
---

Evaluate the model on a set of problems.

## 7.1 Evaluate Model on Problems

In [None]:
def evaluate_model(model, tokenizer, problems: List[AlgorithmicProblem], 
                   verbose: bool = True) -> dict:
    """
    Evaluate model on a list of problems.
    
    Returns dict with:
    - accuracy: % of problems fully solved
    - test_pass_rate: % of individual test cases passed
    - results: detailed results per problem
    """
    results = []
    total_passed_all = 0
    total_test_cases = 0
    total_tests_passed = 0
    
    for i, problem in enumerate(problems):
        if verbose:
            print(f"  [{i+1}/{len(problems)}] {problem.problem_id}...", end=" ")
        
        # Generate solution
        try:
            response = generate_solution(model, tokenizer, problem)
            code = extract_code(response)
            
            if code:
                passed_all, passed, total, details = verify_solution(code, problem, verbose=False)
                total_test_cases += total
                total_tests_passed += passed
                
                if passed_all:
                    total_passed_all += 1
                    if verbose:
                        print(f"PASS ({passed}/{total})")
                else:
                    if verbose:
                        print(f"FAIL ({passed}/{total})")
                
                results.append({
                    "problem_id": problem.problem_id,
                    "passed_all": passed_all,
                    "passed": passed,
                    "total": total,
                    "code": code,
                })
            else:
                if verbose:
                    print("FAIL (no code)")
                results.append({
                    "problem_id": problem.problem_id,
                    "passed_all": False,
                    "passed": 0,
                    "total": len(problem.test_cases),
                    "code": None,
                })
                total_test_cases += len(problem.test_cases)
        except Exception as e:
            if verbose:
                print(f"ERROR: {e}")
            results.append({
                "problem_id": problem.problem_id,
                "passed_all": False,
                "error": str(e),
            })
    
    accuracy = total_passed_all / len(problems) if problems else 0
    test_pass_rate = total_tests_passed / total_test_cases if total_test_cases else 0
    
    return {
        "accuracy": accuracy,
        "test_pass_rate": test_pass_rate,
        "total_solved": total_passed_all,
        "total_problems": len(problems),
        "results": results,
    }

print("evaluate_model function defined!")

## 7.2 Run Initial Evaluation

In [None]:
print("="*60)
print("INITIAL EVALUATION (Before Training)")
print("="*60)

print(f"\n--- Training Set ({len(train_problems)} problems) ---")
train_eval = evaluate_model(model, tokenizer, train_problems)

print(f"\n--- Validation Set ({len(val_problems)} problems) ---")
val_eval = evaluate_model(model, tokenizer, val_problems)

print(f"\n--- Test Set ({len(test_problems)} problems) ---")
test_eval = evaluate_model(model, tokenizer, test_problems)

## 7.3 Evaluation Summary

In [None]:
print("="*60)
print("ITERATION 0 - EVALUATION SUMMARY")
print("="*60)

print(f"\n{'Set':<12} {'Accuracy':>12} {'Test Pass Rate':>16} {'Solved':>10}")
print("-"*52)
print(f"{'Train':<12} {train_eval['accuracy']*100:>11.1f}% {train_eval['test_pass_rate']*100:>15.1f}% {train_eval['total_solved']:>6}/{train_eval['total_problems']}")
print(f"{'Validation':<12} {val_eval['accuracy']*100:>11.1f}% {val_eval['test_pass_rate']*100:>15.1f}% {val_eval['total_solved']:>6}/{val_eval['total_problems']}")
print(f"{'Test':<12} {test_eval['accuracy']*100:>11.1f}% {test_eval['test_pass_rate']*100:>15.1f}% {test_eval['total_solved']:>6}/{test_eval['total_problems']}")

# Store for comparison
iteration_metrics = [{
    "iteration": 0,
    "train_accuracy": train_eval['accuracy'],
    "val_accuracy": val_eval['accuracy'],
    "test_accuracy": test_eval['accuracy'],
}]

---
# Part 8: Solution Collection
---

Collect verified solutions for training.

## 8.1 Collect Verified Solutions

In [None]:
def collect_solutions(model, tokenizer, problems: List[AlgorithmicProblem],
                      samples_per_problem: int = 1) -> List[dict]:
    """
    Collect verified solutions from problems.
    Only solutions that pass ALL test cases are kept.
    """
    solutions = []
    
    for i, problem in enumerate(problems):
        print(f"  [{i+1}/{len(problems)}] {problem.problem_id}...", end=" ")
        
        for sample_idx in range(samples_per_problem):
            try:
                response = generate_solution(model, tokenizer, problem, temperature=0.7)
                code = extract_code(response)
                
                if code:
                    passed_all, passed, total, _ = verify_solution(code, problem, verbose=False)
                    
                    if passed_all:
                        solutions.append({
                            "problem_id": problem.problem_id,
                            "problem_type": problem.problem_type,
                            "prompt": problem.to_prompt(),
                            "solution_code": code,
                            "passed_tests": passed,
                            "total_tests": total,
                        })
                        print(f"COLLECTED ({passed}/{total})")
                        break  # Got a good solution, move to next problem
            except Exception as e:
                pass
        else:
            print("SKIPPED (no valid solution)")
    
    return solutions

print("collect_solutions function defined!")

## 8.2 Collect Training Solutions

In [None]:
print("="*60)
print("COLLECTING VERIFIED SOLUTIONS")
print("="*60)

print(f"\nCollecting from {len(train_problems)} training problems...")
collected_solutions = collect_solutions(model, tokenizer, train_problems)

print(f"\n" + "="*60)
print(f"COLLECTION SUMMARY")
print("="*60)
print(f"Collected: {len(collected_solutions)} verified solutions")
print(f"From: {len(train_problems)} problems")
print(f"Collection rate: {len(collected_solutions)/len(train_problems)*100:.1f}%")

## 8.3 Inspect Collected Solutions

In [None]:
print("="*60)
print("SAMPLE COLLECTED SOLUTIONS")
print("="*60)

for i, sol in enumerate(collected_solutions[:2], 1):
    print(f"\n--- Solution {i} ---")
    print(f"Problem: {sol['problem_id']}")
    print(f"Type: {sol['problem_type']}")
    print(f"Tests: {sol['passed_tests']}/{sol['total_tests']}")
    print(f"\nCode:")
    print(sol['solution_code'][:500])
    if len(sol['solution_code']) > 500:
        print("... (truncated)")

---
# Part 9: LoRA Training
---

Fine-tune the model on collected solutions using LoRA.

## 9.1 Prepare Training Data

In [None]:
from datasets import Dataset

def prepare_training_data(solutions: List[dict], tokenizer) -> Dataset:
    """
    Convert solutions to training format.
    """
    training_examples = []
    
    for sol in solutions:
        # Format as chat
        system_msg = "You are an expert Python programmer. Implement the requested function."
        
        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": sol["prompt"]},
            {"role": "assistant", "content": f"```python\n{sol['solution_code']}\n```"},
        ]
        
        # Apply chat template
        text = tokenizer.apply_chat_template(messages, tokenize=False)
        training_examples.append({"text": text})
    
    return Dataset.from_list(training_examples)

# Skip training if no solutions
if len(collected_solutions) == 0:
    print("WARNING: No solutions collected. Skipping training.")
    SKIP_TRAINING = True
else:
    print("="*60)
    print("PREPARING TRAINING DATA")
    print("="*60)
    
    train_dataset = prepare_training_data(collected_solutions, tokenizer)
    print(f"\nTraining examples: {len(train_dataset)}")
    
    print("\nSample training text (first 500 chars):")
    print(train_dataset[0]["text"][:500])
    
    SKIP_TRAINING = False

## 9.2 Configure LoRA

In [None]:
if not SKIP_TRAINING:
    from peft import LoraConfig, get_peft_model, TaskType

    print("="*60)
    print("CONFIGURING LoRA")
    print("="*60)

    lora_config = LoraConfig(
        r=CONFIG["lora_r"],
        lora_alpha=CONFIG["lora_alpha"],
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    print(f"\nLoRA Configuration:")
    print(f"  Rank (r): {lora_config.r}")
    print(f"  Alpha: {lora_config.lora_alpha}")
    print(f"  Target modules: {lora_config.target_modules}")
    print(f"  Dropout: {lora_config.lora_dropout}")

    # Apply LoRA to model
    model = get_peft_model(model, lora_config)
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"\nTrainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")

## 9.3 Train the Model

In [None]:
if not SKIP_TRAINING:
    from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

    print("="*60)
    print("TRAINING")
    print("="*60)

    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=1024,
            padding="max_length",
        )

    tokenized_dataset = train_dataset.map(tokenize_function, batched=True)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./lora_output",
        num_train_epochs=2,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=CONFIG["learning_rate"],
        fp16=True,
        logging_steps=1,
        save_strategy="no",
        report_to="none",
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )

    print(f"\nStarting training...")
    print(f"  Examples: {len(tokenized_dataset)}")
    print(f"  Epochs: {training_args.num_train_epochs}")
    print(f"  Learning rate: {training_args.learning_rate}")
    
    trainer.train()
    
    print("\nTraining complete!")

## 9.4 Merge LoRA Weights

In [None]:
if not SKIP_TRAINING:
    print("="*60)
    print("MERGING LoRA WEIGHTS")
    print("="*60)

    # Merge LoRA weights into base model
    model = model.merge_and_unload()
    
    print("LoRA weights merged into base model!")
    print("Model is now ready for evaluation.")

---
# Part 10: Post-Training Evaluation
---

Evaluate the model after training to see improvement.

## 10.1 Re-evaluate on All Sets

In [None]:
print("="*60)
print("ITERATION 1 - POST-TRAINING EVALUATION")
print("="*60)

print(f"\n--- Training Set ({len(train_problems)} problems) ---")
train_eval_1 = evaluate_model(model, tokenizer, train_problems)

print(f"\n--- Validation Set ({len(val_problems)} problems) ---")
val_eval_1 = evaluate_model(model, tokenizer, val_problems)

print(f"\n--- Test Set ({len(test_problems)} problems) ---")
test_eval_1 = evaluate_model(model, tokenizer, test_problems)

# Store metrics
iteration_metrics.append({
    "iteration": 1,
    "train_accuracy": train_eval_1['accuracy'],
    "val_accuracy": val_eval_1['accuracy'],
    "test_accuracy": test_eval_1['accuracy'],
})

## 10.2 Compare Before vs After

In [None]:
print("="*60)
print("COMPARISON: BEFORE vs AFTER TRAINING")
print("="*60)

print(f"\n{'Metric':<20} {'Before':>12} {'After':>12} {'Change':>12}")
print("-"*58)

# Train accuracy
before = iteration_metrics[0]['train_accuracy'] * 100
after = iteration_metrics[1]['train_accuracy'] * 100
change = after - before
print(f"{'Train Accuracy':<20} {before:>11.1f}% {after:>11.1f}% {change:>+11.1f}%")

# Val accuracy
before = iteration_metrics[0]['val_accuracy'] * 100
after = iteration_metrics[1]['val_accuracy'] * 100
change = after - before
print(f"{'Val Accuracy':<20} {before:>11.1f}% {after:>11.1f}% {change:>+11.1f}%")

# Test accuracy
before = iteration_metrics[0]['test_accuracy'] * 100
after = iteration_metrics[1]['test_accuracy'] * 100
change = after - before
print(f"{'Test Accuracy':<20} {before:>11.1f}% {after:>11.1f}% {change:>+11.1f}%")

---
# Part 11: Results Analysis
---

## 11.1 Final Summary

In [None]:
print("="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\nConfiguration:")
print(f"  Model: {CONFIG['model_name']}")
print(f"  Problem types: {CONFIG['problem_types']}")
print(f"  Train problems: {len(train_problems)}")
print(f"  Test cases per problem: {CONFIG['test_cases']}")

print(f"\nTraining:")
print(f"  Solutions collected: {len(collected_solutions)}")
print(f"  Learning rate: {CONFIG['learning_rate']}")

print(f"\nResults:")
print(f"  {'Iteration':<12} {'Train':>10} {'Val':>10} {'Test':>10}")
print(f"  {'-'*44}")
for m in iteration_metrics:
    print(f"  {m['iteration']:<12} {m['train_accuracy']*100:>9.1f}% {m['val_accuracy']*100:>9.1f}% {m['test_accuracy']*100:>9.1f}%")

# Calculate improvement
train_improvement = (iteration_metrics[-1]['train_accuracy'] - iteration_metrics[0]['train_accuracy']) * 100
val_improvement = (iteration_metrics[-1]['val_accuracy'] - iteration_metrics[0]['val_accuracy']) * 100

print(f"\nImprovement:")
print(f"  Train: {train_improvement:+.1f}%")
print(f"  Val:   {val_improvement:+.1f}%")

## 11.2 Key Takeaways

In [None]:
print("="*60)
print("KEY TAKEAWAYS")
print("="*60)

print("""
1. V2 PROBLEM DESIGN WORKS
   - Multiple test cases prevent memorization
   - Model must implement actual algorithms
   - Training accuracy IMPROVES (not degrades like V1)

2. SELF-IMPROVEMENT LOOP
   - Model generates solutions
   - Solutions verified against ALL test cases
   - Only fully correct solutions used for training
   - Model learns from its own successful outputs

3. EXPERT ITERATION
   - Model N generates solutions
   - Verified solutions train Model N+1
   - Model N+1 is better at generating solutions
   - Repeat for continuous improvement

4. NEXT STEPS
   - Run more iterations (5-10+)
   - Add more problem types
   - Scale to more problems per type
   - Implement replay buffer to prevent forgetting
""")

print("="*60)
print("EXPERIMENT COMPLETE!")
print("="*60)