# Axiom-RL: Self-Improving Code Generation

This notebook implements a complete self-improvement loop for code generation using Expert Iteration:

1. **Generate** - Model attempts to solve problems
2. **Verify** - Solutions are tested in a sandbox
3. **Train** - Model is fine-tuned on correct solutions using LoRA
4. **Repeat** - Iterate to continuously improve

All code is self-contained - no external imports required.

## Setup & Dependencies

In [1]:
import json
import random
import re
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

print("✓ Libraries imported")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


✓ Libraries imported
✓ PyTorch version: 2.5.1+cu121
✓ CUDA available: True
✓ GPU: NVIDIA GeForce RTX 3080


## Part 1: Problem Data Structures

In [2]:
from dataclasses import dataclass
from typing import List, Dict, Any, Optional

# These are the core data structures from axiom/core/data_structures.py

@dataclass
class TestCase:
    """A single test case for a problem.
    
    Think of this as one example that the generated code must pass.
    Example: For a fibonacci function, a test case might be:
        input: {"n": 5}
        expected_output: 5
    """
    input: Dict[str, Any]      # The arguments to pass to the function
    expected_output: Any        # What the function should return

@dataclass  
class Problem:
    """A programming problem that the model must solve.
    
    This is what we ask the model to generate code for.
    """
    problem_id: str            # Unique identifier (e.g., "fibonacci_v2_001")
    problem_type: str          # Category (e.g., "fibonacci", "fizzbuzz")
    description: str           # Natural language description
    function_signature: str    # The function signature to implement
    test_cases: List[TestCase] # List of test cases to verify correctness

# Let's create an example problem manually
example_problem = Problem(
    problem_id="fibonacci_example",
    problem_type="fibonacci",
    description="Write a function that returns the nth Fibonacci number. F(0)=0, F(1)=1, F(n)=F(n-1)+F(n-2).",
    function_signature="def fibonacci(n: int) -> int:",
    test_cases=[
        TestCase(input={"n": 0}, expected_output=0),
        TestCase(input={"n": 1}, expected_output=1),
        TestCase(input={"n": 5}, expected_output=5),
        TestCase(input={"n": 10}, expected_output=55),
    ]
)

print("=== Example Problem ===")
print(f"ID: {example_problem.problem_id}")
print(f"Type: {example_problem.problem_type}")
print(f"Description: {example_problem.description}")
print(f"Signature: {example_problem.function_signature}")
print(f"\nTest Cases ({len(example_problem.test_cases)} total):")
for i, tc in enumerate(example_problem.test_cases):
    print(f"  {i+1}. Input: {tc.input} → Expected: {tc.expected_output}")

=== Example Problem ===
ID: fibonacci_example
Type: fibonacci
Description: Write a function that returns the nth Fibonacci number. F(0)=0, F(1)=1, F(n)=F(n-1)+F(n-2).
Signature: def fibonacci(n: int) -> int:

Test Cases (4 total):
  1. Input: {'n': 0} → Expected: 0
  2. Input: {'n': 1} → Expected: 1
  3. Input: {'n': 5} → Expected: 5
  4. Input: {'n': 10} → Expected: 55


## Part 2: Procedural Problem Generators

In [None]:
# ============================================================================
# PROCEDURAL PROBLEM GENERATORS
# ============================================================================

class ProblemGenerator(ABC):
    """Abstract base class for generators."""
    
    @abstractmethod
    def generate(self, difficulty: int = 1) -> ProceduralProblem:
        pass
    
    def generate_batch(self, n: int, difficulty: int = 1) -> List[ProceduralProblem]:
        return [self.generate(difficulty) for _ in range(n)]


class ArithmeticGenerator(ProblemGenerator):
    """Generates arithmetic expression problems."""
    
    def __init__(self, seed: Optional[int] = None):
        self.rng = random.Random(seed)
        self.ops = ['+', '-', '*']
        
    def generate(self, difficulty: int = 1) -> ProceduralProblem:
        expr = self._generate_expression(difficulty)
        result = eval(expr)
        
        description = f"""Write a Python function `solve()` that returns the result of the following arithmetic expression:

{expr}

The function should take no arguments and return an integer."""
        
        solution_code = f"def solve():\n    return {result}"
        pid = f"arithmetic_{abs(hash(expr)) % 10**10}"
        
        return ProceduralProblem(
            id=pid,
            title=f"Arithmetic: {expr}",
            description=description,
            difficulty=difficulty,
            solution_code=solution_code,
            test_cases=[{"input": [], "output": result}],
            function_signature="def solve():",
            problem_type="arithmetic",
        )

    def _generate_expression(self, depth: int) -> str:
        if depth == 1:
            return f"{self.rng.randint(1, 20)} {self.rng.choice(self.ops)} {self.rng.randint(1, 20)}"
            
        left = self._generate_expression(depth - 1) if self.rng.random() > 0.5 else str(self.rng.randint(1, 20))
        right = self._generate_expression(depth - 1) if self.rng.random() > 0.5 else str(self.rng.randint(1, 20))
        op = self.rng.choice(self.ops)
        
        if self.rng.random() > 0.5:
            return f"({left} {op} {right})"
        return f"{left} {op} {right}"


class RPNGenerator(ProblemGenerator):
    """Generates Reverse Polish Notation evaluation problems."""
    
    def __init__(self, seed: Optional[int] = None):
        self.rng = random.Random(seed)
        self.ops = {
            '+': lambda x, y: x + y,
            '-': lambda x, y: x - y,
            '*': lambda x, y: x * y
        }
        
    def generate(self, difficulty: int = 1) -> ProceduralProblem:
        length = 3 + difficulty * 2
        expression, result = self._generate_rpn(length)
        
        description = f"""Write a Python function `solve()` that evaluates the following Reverse Polish Notation (RPN) expression and returns the result:

Expression: "{expression}"

In RPN, operators come after their operands. For example:
- "3 4 +" means 3 + 4 = 7
- "3 4 + 5 *" means (3 + 4) * 5 = 35

The function should take no arguments and return the integer result."""
        
        solution_code = f"def solve():\n    # RPN: {expression}\n    return {result}"
        pid = f"rpn_{abs(hash(expression)) % 10**10}"
        
        return ProceduralProblem(
            id=pid,
            title=f"RPN: {expression}",
            description=description,
            difficulty=difficulty,
            solution_code=solution_code,
            test_cases=[{"input": [], "output": result}],
            function_signature="def solve():",
            problem_type="rpn",
        )

    def _generate_rpn(self, target_length: int):
        stack = []
        expression = []
        current_depth = 0
        
        while len(expression) < target_length or current_depth > 1:
            if current_depth < 2:
                num = self.rng.randint(1, 9)
                stack.append(num)
                expression.append(str(num))
                current_depth += 1
            elif len(expression) >= target_length:
                op_sym = self.rng.choice(list(self.ops.keys()))
                b = stack.pop()
                a = stack.pop()
                res = self.ops[op_sym](a, b)
                stack.append(res)
                expression.append(op_sym)
                current_depth -= 1
            else:
                if self.rng.random() > 0.6:
                    num = self.rng.randint(1, 9)
                    stack.append(num)
                    expression.append(str(num))
                    current_depth += 1
                else:
                    op_sym = self.rng.choice(list(self.ops.keys()))
                    b = stack.pop()
                    a = stack.pop()
                    res = self.ops[op_sym](a, b)
                    stack.append(res)
                    expression.append(op_sym)
                    current_depth -= 1
                    
        return " ".join(expression), stack[0]


class ParenthesesGenerator(ProblemGenerator):
    """Generates parentheses validation problems."""
    
    def __init__(self, seed: Optional[int] = None):
        self.rng = random.Random(seed)
        self.bracket_pairs = [("(", ")"), ("[", "]"), ("{", "}")]
        
    def generate(self, difficulty: int = 5) -> ProceduralProblem:
        # Determine parameters based on difficulty
        if difficulty <= 3:
            length = self.rng.randint(4, 10)
            num_types = 1
            valid_prob = 0.6
        elif difficulty <= 6:
            length = self.rng.randint(8, 18)
            num_types = 2
            valid_prob = 0.5
        else:
            length = self.rng.randint(14, 28)
            num_types = 3
            valid_prob = 0.4
        
        pairs = self.bracket_pairs[:num_types]
        generate_valid = self.rng.random() < valid_prob
        
        if generate_valid:
            s = self._generate_valid(length, pairs)
            expected = True
        else:
            s = self._generate_invalid(length, pairs)
            expected = False
        
        description = f"""Determine if the following string containing brackets is valid.

A string is valid if:
1. Open brackets must be closed by the same type of brackets
2. Open brackets must be closed in the correct order
3. Every close bracket has a corresponding open bracket

Input string: "{s}"

Write a Python function `solve()` that returns True if valid, False otherwise."""
        
        solution_code = f"def solve():\n    return {expected}"
        pid = f"parens_{abs(hash(s)) % 10**10}"
        
        return ProceduralProblem(
            id=pid,
            title=f"Parentheses: {s[:20]}{'...' if len(s) > 20 else ''}",
            description=description,
            difficulty=difficulty,
            solution_code=solution_code,
            test_cases=[{"input": [], "output": expected}],
            function_signature="def solve():",
            problem_type="parentheses",
        )

    def _generate_valid(self, target_length: int, pairs: List[Tuple[str, str]]) -> str:
        result = []
        stack = []
        
        while len(result) < target_length:
            remaining = target_length - len(result)
            can_open = remaining >= 2
            can_close = len(stack) > 0
            
            if can_open and can_close:
                if len(stack) >= remaining // 2:
                    action = "close"
                else:
                    action = self.rng.choice(["open", "close"])
            elif can_open:
                action = "open"
            elif can_close:
                action = "close"
            else:
                break
            
            if action == "open":
                pair = self.rng.choice(pairs)
                result.append(pair[0])
                stack.append(pair[1])
            else:
                result.append(stack.pop())
        
        while stack:
            result.append(stack.pop())
        
        return "".join(result)

    def _generate_invalid(self, target_length: int, pairs: List[Tuple[str, str]]) -> str:
        # Simple strategy: generate valid-ish then corrupt
        result = []
        for _ in range((target_length - 1) // 2):
            pair = self.rng.choice(pairs)
            result.append(pair[0])
            result.append(pair[1])
        
        # Add corruption
        corruption = self.rng.choice(["extra_close", "unclosed", "mismatch"])
        if corruption == "extra_close":
            result.insert(0, self.rng.choice(pairs)[1])
        elif corruption == "unclosed":
            result.append(self.rng.choice(pairs)[0])
        else:  # mismatch
            if len(pairs) > 1 and len(result) > 1:
                idx = self.rng.randint(0, len(result) - 1)
                if result[idx] in [p[1] for p in pairs]:
                    other = [p[1] for p in pairs if p[1] != result[idx]]
                    if other:
                        result[idx] = self.rng.choice(other)
        
        return "".join(result)[:target_length] if len(result) > target_length else "".join(result)


# Test generators
print("Testing generators...")
arith = ArithmeticGenerator(seed=42)
rpn = RPNGenerator(seed=42)
parens = ParenthesesGenerator(seed=42)

p1 = arith.generate(difficulty=3)
print(f"\nArithmetic: {p1.title}")
print(f"  Expected: {p1.test_cases[0]['output']}")

p2 = rpn.generate(difficulty=3)
print(f"\nRPN: {p2.title}")
print(f"  Expected: {p2.test_cases[0]['output']}")

p3 = parens.generate(difficulty=5)
print(f"\nParentheses: {p3.title}")
print(f"  Expected: {p3.test_cases[0]['output']}")

## Part 3: Code Verifier (Sandbox)

In [None]:
# ============================================================================
# VERIFICATION SYSTEM
# ============================================================================

class VerificationStatus(Enum):
    PASSED = "passed"
    FAILED = "failed"
    ERROR = "error"
    TIMEOUT = "timeout"


@dataclass
class VerificationResult:
    status: VerificationStatus
    passed_count: int
    total_count: int
    error_message: Optional[str] = None
    
    @property
    def passed(self) -> bool:
        return self.status == VerificationStatus.PASSED


class PythonSandbox:
    """Sandboxed Python execution using subprocess."""
    
    def __init__(self, timeout: float = 5.0):
        self.timeout = timeout
    
    def execute(self, code: str) -> Tuple[str, str, int, bool]:
        """Execute code and return (stdout, stderr, returncode, timed_out)."""
        fd, temp_path = tempfile.mkstemp(suffix=".py", prefix="axiom_")
        try:
            with os.fdopen(fd, "w", encoding="utf-8") as f:
                f.write(code)
            
            result = subprocess.run(
                ["python", temp_path],
                capture_output=True,
                text=True,
                timeout=self.timeout,
                cwd=tempfile.gettempdir(),
            )
            return result.stdout, result.stderr, result.returncode, False
        except subprocess.TimeoutExpired:
            return "", f"Timeout after {self.timeout}s", -1, True
        except Exception as e:
            return "", str(e), -1, False
        finally:
            try:
                os.unlink(temp_path)
            except:
                pass


class TestHarness:
    """Runs generated code against test cases."""
    
    def __init__(self, timeout: float = 5.0):
        self.sandbox = PythonSandbox(timeout=timeout)
    
    def verify(self, solution_code: str, problem: Problem) -> VerificationResult:
        """Verify a solution against all test cases."""
        test_script = self._build_test_script(solution_code, problem)
        stdout, stderr, returncode, timed_out = self.sandbox.execute(test_script)
        
        if timed_out:
            return VerificationResult(
                status=VerificationStatus.TIMEOUT,
                passed_count=0,
                total_count=len(problem.test_cases),
                error_message=stderr,
            )
        
        if returncode != 0 and not stdout.strip():
            return VerificationResult(
                status=VerificationStatus.ERROR,
                passed_count=0,
                total_count=len(problem.test_cases),
                error_message=stderr or "Unknown error",
            )
        
        try:
            data = json.loads(stdout)
            if not data.get("success", False) and "error" in data:
                return VerificationResult(
                    status=VerificationStatus.ERROR,
                    passed_count=0,
                    total_count=len(problem.test_cases),
                    error_message=data["error"],
                )
            
            passed_count = sum(1 for r in data["results"] if r["passed"])
            total_count = len(data["results"])
            
            return VerificationResult(
                status=VerificationStatus.PASSED if passed_count == total_count else VerificationStatus.FAILED,
                passed_count=passed_count,
                total_count=total_count,
            )
        except json.JSONDecodeError as e:
            return VerificationResult(
                status=VerificationStatus.ERROR,
                passed_count=0,
                total_count=len(problem.test_cases),
                error_message=f"Parse error: {stderr or stdout or str(e)}",
            )
    
    def _build_test_script(self, solution_code: str, problem: Problem) -> str:
        test_cases_json = json.dumps([
            {"input": tc.input, "expected": tc.expected_output}
            for tc in problem.test_cases
        ])
        func_name = problem.function_name
        
        return f'''# -*- coding: utf-8 -*-
import json
from typing import List, Optional, Tuple, Dict, Any, Set

# === SOLUTION CODE ===
{solution_code}
# === END SOLUTION ===

def run_tests():
    test_cases = json.loads(\'{test_cases_json}\')
    results = []
    
    for i, tc in enumerate(test_cases):
        inp = tc["input"]
        expected = tc["expected"]
        
        try:
            if isinstance(inp, list):
                actual = {func_name}(*inp)
            else:
                actual = {func_name}(inp)
            
            passed = actual == expected
            results.append({{"index": i, "passed": passed, "actual": actual, "expected": expected}})
        except Exception as e:
            results.append({{"index": i, "passed": False, "error": str(e)}})
    
    print(json.dumps({{"results": results, "success": True}}))

if __name__ == "__main__":
    try:
        run_tests()
    except Exception as e:
        print(json.dumps({{"results": [], "success": False, "error": str(e)}}))
'''


# Test the verifier
print("Testing verifier...")
harness = TestHarness(timeout=5.0)

test_problem = arith.generate(difficulty=2).to_problem()
correct_code = "def solve():\n    return " + str(test_problem.test_cases[0].expected_output)
wrong_code = "def solve():\n    return 999999"

result1 = harness.verify(correct_code, test_problem)
result2 = harness.verify(wrong_code, test_problem)

print(f"Correct solution: {result1.status.value}")
print(f"Wrong solution: {result2.status.value}")

## Part 4: Code Generator (Model)

In [None]:
# ============================================================================
# CODE GENERATOR
# ============================================================================

SYSTEM_PROMPT = """You are an expert Python programmer. Your task is to solve algorithmic problems by writing clean, efficient, and correct Python code.

Rules:
1. Write ONLY the function implementation - no explanations, no test code
2. The function signature is provided - implement the function body
3. Use standard Python libraries only
4. Write clear, readable code
5. Handle edge cases appropriately
6. Return the result as specified"""


def build_user_prompt(problem: Problem) -> str:
    return f"""Solve the following problem by implementing the function.

## Problem: {problem.title}

{problem.description}

## Function Signature
```python
{problem.function_signature}
    # Your implementation here
```

Write ONLY the complete function implementation. Do not include any explanations, examples, or test code."""


def build_messages(problem: Problem) -> List[dict]:
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": build_user_prompt(problem)},
    ]


def extract_code_from_response(response: str) -> str:
    """Extract Python code from model response."""
    # Try ```python blocks
    matches = re.findall(r"```python\s*(.*?)```", response, re.DOTALL)
    if matches:
        return matches[0].strip()
    
    # Try generic ``` blocks
    matches = re.findall(r"```\s*(.*?)```", response, re.DOTALL)
    if matches:
        code = matches[0].strip()
        lines = code.split("\n")
        if lines and lines[0].strip().lower() in ["python", "py", ""]:
            code = "\n".join(lines[1:])
        return code.strip()
    
    # Return as-is
    return response.strip()


class CodeGenerator:
    """HuggingFace-based code generator."""
    
    def __init__(
        self,
        model_name: str = "Qwen/Qwen2.5-Coder-1.5B-Instruct",
        device: str = "auto",
        temperature: float = 0.7,
        max_new_tokens: int = 512,
    ):
        self.model_name = model_name
        self.device = device
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.model = None
        self.tokenizer = None
        self._loaded = False
    
    def load(self):
        if self._loaded:
            return
        
        print(f"Loading model: {self.model_name}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map=self.device,
            trust_remote_code=True,
        )
        
        self._loaded = True
        device = next(self.model.parameters()).device
        print(f"Model loaded on: {device}")
    
    def generate(self, messages: List[dict], temperature: Optional[float] = None) -> str:
        if not self._loaded:
            self.load()
        
        temp = temperature if temperature is not None else self.temperature
        
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        input_length = model_inputs.input_ids.shape[1]
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **model_inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=temp,
                top_p=0.95,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        new_tokens = generated_ids[0][input_length:]
        return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
    
    def generate_solution(self, problem: Problem) -> str:
        """Generate a solution for the given problem."""
        messages = build_messages(problem)
        response = self.generate(messages)
        return extract_code_from_response(response)


print("CodeGenerator class defined.")

## Part 5: Training Dataset & LoRA Config

In [None]:
# ============================================================================
# TRAINING DATASET & LORA CONFIG
# ============================================================================

@dataclass
class TrainingSample:
    """A single training sample."""
    problem_id: str
    problem_title: str
    problem_description: str
    function_signature: str
    solution_code: str
    model_name: str = "student"
    
    def to_prompt_completion(self) -> dict:
        user_prompt = f"""Solve the following problem by implementing the function.

## Problem: {self.problem_title}

{self.problem_description}

## Function Signature
```python
{self.function_signature}
    # Your implementation here
```

Write ONLY the complete function implementation."""
        
        completion = f"```python\n{self.solution_code}\n```"
        return {"prompt": user_prompt, "completion": completion}


class SFTDataset(Dataset):
    """Dataset for SFT training."""
    
    def __init__(self, samples: List[TrainingSample], tokenizer, max_length: int = 1024):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed = [self._process(s) for s in samples]
    
    def _process(self, sample: TrainingSample) -> dict:
        pc = sample.to_prompt_completion()
        
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": pc["prompt"]},
            {"role": "assistant", "content": pc["completion"]},
        ]
        
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        
        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        
        return {
            "input_ids": encodings["input_ids"].squeeze(0),
            "attention_mask": encodings["attention_mask"].squeeze(0),
            "labels": encodings["input_ids"].squeeze(0).clone(),
        }
    
    def __len__(self):
        return len(self.processed)
    
    def __getitem__(self, idx):
        return self.processed[idx]


def get_lora_config(
    r: int = 16,
    alpha: int = 32,
    dropout: float = 0.05,
) -> LoraConfig:
    """Create LoRA config for training."""
    return LoraConfig(
        r=r,
        lora_alpha=alpha,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_dropout=dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )


print("Training dataset and LoRA config defined.")

## Part 6: Self-Improvement Loop

In [None]:
# ============================================================================
# SELF-IMPROVEMENT EXPERIMENT
# ============================================================================

@dataclass
class SelfImproveConfig:
    """Configuration for self-improvement experiment."""
    # Model
    base_model: str = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
    
    # Problems
    problem_types: List[str] = field(default_factory=lambda: ["arithmetic", "rpn", "parentheses"])
    train_size: int = 50
    val_size: int = 10
    test_size: int = 10
    difficulty_min: int = 3
    difficulty_max: int = 7
    seed: int = 42
    
    # Training
    num_iterations: int = 3
    learning_rate: float = 1e-4
    epochs_per_iteration: int = 1
    batch_size: int = 1
    gradient_accumulation: int = 4
    lora_r: int = 16
    lora_alpha: int = 32
    
    # Generation
    temperature: float = 0.7
    max_new_tokens: int = 512


class SelfImproveExperiment:
    """Self-improvement experiment using Expert Iteration."""
    
    def __init__(self, config: SelfImproveConfig):
        self.config = config
        self.rng = random.Random(config.seed)
        
        # Initialize generators
        self.generators = {
            "arithmetic": ArithmeticGenerator(seed=config.seed),
            "rpn": RPNGenerator(seed=config.seed),
            "parentheses": ParenthesesGenerator(seed=config.seed),
        }
        
        # Components
        self.harness = TestHarness(timeout=5.0)
        self.model = None
        self.tokenizer = None
        self.peft_model = None
        
        # History
        self.metrics_history = []
        self.historical_solutions = []
    
    def generate_problems(self, n: int) -> List[Problem]:
        """Generate n problems with mixed types and difficulties."""
        problems = []
        types = self.config.problem_types
        
        for _ in range(n):
            ptype = self.rng.choice(types)
            difficulty = self.rng.randint(self.config.difficulty_min, self.config.difficulty_max)
            proc_problem = self.generators[ptype].generate(difficulty=difficulty)
            problems.append(proc_problem.to_problem())
        
        return problems
    
    def load_model(self):
        """Load base model and tokenizer."""
        print(f"Loading model: {self.config.base_model}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.base_model,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.base_model,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )
        
        device = next(self.model.parameters()).device
        print(f"Model loaded on: {device}")
    
    def generate_solution(self, problem: Problem) -> str:
        """Generate a solution for the problem."""
        model = self.peft_model if self.peft_model else self.model
        model.eval()
        
        messages = build_messages(problem)
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        inputs = self.tokenizer([text], return_tensors="pt").to(model.device)
        input_length = inputs.input_ids.shape[1]
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=self.config.max_new_tokens,
                temperature=self.config.temperature,
                top_p=0.95,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        new_tokens = outputs[0][input_length:]
        response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
        return extract_code_from_response(response)
    
    def evaluate(self, problems: List[Problem], desc: str = "Evaluating") -> Tuple[int, int]:
        """Evaluate model on problems. Returns (correct, total)."""
        correct = 0
        total = len(problems)
        
        for i, problem in enumerate(tqdm(problems, desc=desc)):
            code = self.generate_solution(problem)
            result = self.harness.verify(code, problem)
            if result.passed:
                correct += 1
        
        return correct, total
    
    def collect_solutions(self, problems: List[Problem]) -> List[dict]:
        """Generate and verify solutions, return successful ones."""
        solutions = []
        
        for problem in tqdm(problems, desc="Collecting solutions"):
            code = self.generate_solution(problem)
            result = self.harness.verify(code, problem)
            
            if result.passed:
                solutions.append({
                    "problem_id": problem.id,
                    "problem_title": problem.title,
                    "problem_description": problem.description,
                    "function_signature": problem.function_signature,
                    "solution_code": code,
                    "model_name": "student",
                })
        
        return solutions
    
    def train_on_solutions(self, solutions: List[dict], iteration: int):
        """Train on collected solutions using LoRA."""
        if not solutions:
            print("  No solutions to train on, skipping")
            return
        
        print(f"  Training on {len(solutions)} solutions...")
        
        # Add to historical buffer (for future iterations)
        self.historical_solutions.extend(solutions)
        
        # Mix with historical solutions (replay buffer)
        if len(self.historical_solutions) > len(solutions):
            n_historical = min(len(solutions), len(self.historical_solutions) - len(solutions))
            historical_sample = self.rng.sample(
                self.historical_solutions[:-len(solutions)],  # Exclude current
                n_historical
            )
            all_solutions = solutions + historical_sample
            print(f"    Mixed: {len(solutions)} new + {len(historical_sample)} historical")
        else:
            all_solutions = solutions
        
        # Create training samples
        samples = [
            TrainingSample(
                problem_id=s["problem_id"],
                problem_title=s["problem_title"],
                problem_description=s["problem_description"],
                function_signature=s["function_signature"],
                solution_code=s["solution_code"],
                model_name=s["model_name"],
            )
            for s in all_solutions
        ]
        
        dataset = SFTDataset(samples, self.tokenizer, max_length=1024)
        
        # Setup LoRA if first iteration
        if self.peft_model is None:
            lora_config = get_lora_config(
                r=self.config.lora_r,
                alpha=self.config.lora_alpha,
            )
            self.peft_model = get_peft_model(self.model, lora_config)
        
        self.peft_model.train()
        
        # Training args
        training_args = TrainingArguments(
            output_dir=f"./checkpoints/iter_{iteration}",
            num_train_epochs=self.config.epochs_per_iteration,
            per_device_train_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation,
            learning_rate=self.config.learning_rate,
            warmup_ratio=0.1,
            fp16=True,
            logging_steps=10,
            save_strategy="no",
            report_to="none",
            remove_unused_columns=False,
        )
        
        trainer = Trainer(
            model=self.peft_model,
            args=training_args,
            train_dataset=dataset,
        )
        
        trainer.train()
        
        # Merge LoRA weights for faster inference
        print("  Merging LoRA weights...")
        self.peft_model = self.peft_model.merge_and_unload()
        
        # Re-wrap for next iteration
        lora_config = get_lora_config(
            r=self.config.lora_r,
            alpha=self.config.lora_alpha,
        )
        self.peft_model = get_peft_model(self.peft_model, lora_config)
        
        print("  Training complete.")
    
    def run(self):
        """Run the self-improvement loop."""
        print("="*60)
        print("SELF-IMPROVEMENT EXPERIMENT")
        print("="*60)
        print(f"Model: {self.config.base_model}")
        print(f"Problems: {self.config.train_size} train, {self.config.val_size} val, {self.config.test_size} test")
        print(f"Iterations: {self.config.num_iterations}")
        print("="*60)
        
        # Generate fixed problem sets
        print("\nGenerating problems...")
        train_problems = self.generate_problems(self.config.train_size)
        val_problems = self.generate_problems(self.config.val_size)
        test_problems = self.generate_problems(self.config.test_size)
        print(f"  Train: {len(train_problems)}, Val: {len(val_problems)}, Test: {len(test_problems)}")
        
        # Load model
        self.load_model()
        
        # Run iterations
        for iteration in range(self.config.num_iterations):
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration}")
            print("="*60)
            
            # Evaluate
            train_correct, train_total = self.evaluate(train_problems, "Train")
            val_correct, val_total = self.evaluate(val_problems, "Val")
            test_correct, test_total = self.evaluate(test_problems, "Test")
            
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total
            test_acc = test_correct / test_total
            
            print(f"\n  Results:")
            print(f"    Train: {train_acc:.1%} ({train_correct}/{train_total})")
            print(f"    Val:   {val_acc:.1%} ({val_correct}/{val_total})")
            print(f"    Test:  {test_acc:.1%} ({test_correct}/{test_total})")
            
            self.metrics_history.append({
                "iteration": iteration,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "test_acc": test_acc,
            })
            
            # Collect solutions
            print(f"\n  Collecting solutions...")
            solutions = self.collect_solutions(train_problems)
            print(f"    Collected {len(solutions)} correct solutions")
            
            # Train
            if solutions:
                self.train_on_solutions(solutions, iteration)
        
        # Final summary
        print(f"\n{'='*60}")
        print("FINAL SUMMARY")
        print("="*60)
        for m in self.metrics_history:
            print(f"Iter {m['iteration']}: Train {m['train_acc']:.1%}, Val {m['val_acc']:.1%}, Test {m['test_acc']:.1%}")
        
        return self.metrics_history


print("SelfImproveExperiment class defined.")

## Part 7: Run the Experiment

In [None]:
# ============================================================================
# RUN EXPERIMENT
# ============================================================================

# Configure the experiment
config = SelfImproveConfig(
    base_model="Qwen/Qwen2.5-Coder-1.5B-Instruct",
    problem_types=["arithmetic", "rpn", "parentheses"],
    train_size=30,      # Number of training problems
    val_size=10,        # Number of validation problems
    test_size=10,       # Number of test problems
    difficulty_min=3,
    difficulty_max=7,
    num_iterations=3,   # Number of self-improvement iterations
    learning_rate=1e-4,
    epochs_per_iteration=1,
    batch_size=1,
    gradient_accumulation=4,
    seed=42,
)

# Create and run experiment
experiment = SelfImproveExperiment(config)
results = experiment.run()

In [None]:
# ============================================================================
# VISUALIZE RESULTS
# ============================================================================

import matplotlib.pyplot as plt

iterations = [m["iteration"] for m in results]
train_accs = [m["train_acc"] * 100 for m in results]
val_accs = [m["val_acc"] * 100 for m in results]
test_accs = [m["test_acc"] * 100 for m in results]

plt.figure(figsize=(10, 6))
plt.plot(iterations, train_accs, 'b-o', label='Train', linewidth=2, markersize=8)
plt.plot(iterations, val_accs, 'g-s', label='Validation', linewidth=2, markersize=8)
plt.plot(iterations, test_accs, 'r-^', label='Test', linewidth=2, markersize=8)

plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Self-Improvement Learning Curve', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xticks(iterations)
plt.ylim(0, 100)

plt.tight_layout()
plt.show()

# Print summary
print("\nResults Summary:")
print("-" * 50)
for m in results:
    print(f"Iteration {m['iteration']}: "
          f"Train={m['train_acc']:.1%}, "
          f"Val={m['val_acc']:.1%}, "
          f"Test={m['test_acc']:.1%}")

## Part 8: Test Individual Components (Optional)

In [None]:
# ============================================================================
# TEST INDIVIDUAL COMPONENTS
# ============================================================================

# Test problem generation
print("Testing problem generation...")
for ptype in ["arithmetic", "rpn", "parentheses"]:
    gen = experiment.generators[ptype]
    prob = gen.generate(difficulty=5)
    print(f"\n{ptype.upper()}:")
    print(f"  Title: {prob.title}")
    print(f"  Expected: {prob.test_cases[0]['output']}")

In [None]:
# Test model generation on a single problem
print("Testing model generation...")

test_prob = experiment.generators["arithmetic"].generate(difficulty=3).to_problem()
print(f"\nProblem: {test_prob.title}")
print(f"Expected: {test_prob.test_cases[0].expected_output}")

code = experiment.generate_solution(test_prob)
print(f"\nGenerated code:\n{code}")

result = experiment.harness.verify(code, test_prob)
print(f"\nVerification: {result.status.value}")