# Data Splitting for GEPA Optimization

This notebook covers strategies for splitting benchmarks into train/validation/test sets for GEPA optimization. Proper data splitting is crucial for:

1. **Training**: Questions used to optimize prompts
2. **Validation**: Questions used to evaluate candidates during optimization
3. **Testing**: Held-out questions for final evaluation (prevents overfitting)

---

## Setup

In [None]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent.parent.parent / "src"))

from karenina import Benchmark
from karenina.integrations.gepa import (
    BenchmarkSplit,
    questions_to_data_insts,
    split_benchmark,
    split_by_attribute,
)

# Load AIME benchmark
benchmark_path = Path.home() / "Projects/karenina-monorepo/local_data/data/checkpoints/aime_2025.jsonld"
benchmark = Benchmark.load(benchmark_path)

print(f"Loaded: {benchmark.name}")
print(f"Total questions: {len(benchmark.get_question_ids())}")

---

## split_benchmark(): Random Splitting

The primary function for splitting benchmarks with configurable ratios.

### Basic 80/20 Split (Default)

In [None]:
# Default 80% train, 20% validation
split = split_benchmark(benchmark)

print(split.summary())
print(f"\nTrain questions: {len(split.train)}")
print(f"Val questions: {len(split.val)}")
print(f"Test questions: {len(split.test) if split.test else 0}")

### 70/20/10 Split (With Test Set)

In [None]:
# 70% train, 20% val, 10% test
split_with_test = split_benchmark(
    benchmark,
    train_ratio=0.7,
    val_ratio=0.2,
    test_ratio=0.1,
)

print(split_with_test.summary())
print(f"\nTest set size: {len(split_with_test.test)} questions")

### Reproducible Splitting with Seed

In [None]:
# Use seed for reproducibility
split_a = split_benchmark(benchmark, seed=42)
split_b = split_benchmark(benchmark, seed=42)

# Same seed = same split
print(f"Same seed produces same split: {split_a.train_ids == split_b.train_ids}")

# Different seed = different split
split_c = split_benchmark(benchmark, seed=123)
print(f"Different seed produces different split: {split_a.train_ids != split_c.train_ids}")

---

## BenchmarkSplit: Working with Splits

The `BenchmarkSplit` dataclass provides convenient access to split data.

In [None]:
split = split_benchmark(
    benchmark,
    train_ratio=0.7,
    val_ratio=0.2,
    test_ratio=0.1,
    seed=42,
)

# Access train/val/test lists
print(f"Training set: {len(split.train)} KareninaDataInst objects")
print(f"Validation set: {len(split.val)} KareninaDataInst objects")
print(f"Test set: {len(split.test)} KareninaDataInst objects")

# Get question IDs directly
print(f"\nTrain IDs (first 3): {split.train_ids[:3]}")
print(f"Val IDs (first 2): {split.val_ids[:2]}")
print(f"Test IDs: {split.test_ids}")

# Check the seed
print(f"\nRandom seed: {split.seed}")

---

## KareninaDataInst: Question Data

Each split contains `KareninaDataInst` objects with all question data needed for GEPA.

In [None]:
# Inspect a data instance
inst = split.train[0]

print("KareninaDataInst fields:")
print(f"  question_id: {inst.question_id[:60]}...")
print(f"  question_text: {inst.question_text[:80]}...")
print(f"  raw_answer: {inst.raw_answer}")
print(f"  template_code: {len(inst.template_code)} chars")
print(f"  rubric: {inst.rubric}")
print(f"  few_shot_examples: {inst.few_shot_examples}")
print(f"  metadata: {inst.metadata}")

In [None]:
# View the template code
print("Template code:")
print(inst.template_code)

In [None]:
# Convert to dict (for GEPA)
inst_dict = inst.to_dict()

print("As dict (keys):")
for key in inst_dict:
    print(f"  {key}")

---

## Stratified Splitting

Preserve the distribution of a metadata attribute across splits.

In [None]:
# AIME questions have 'custom_part' metadata (AIME_I or AIME_II)
# Check the metadata
question = benchmark.get_question(benchmark.get_question_ids()[0])
print("Sample question metadata:")
for key, value in question.items():
    if key not in ["question", "raw_answer", "template_code"]:
        print(f"  {key}: {value}")

In [None]:
# Stratify by 'custom_part' (AIME_I vs AIME_II)
# This ensures both AIME I and AIME II problems appear in each split
stratified_split = split_benchmark(
    benchmark,
    train_ratio=0.7,
    val_ratio=0.2,
    test_ratio=0.1,
    seed=42,
    stratify_by="custom_part",  # Stratify by this metadata field
)

print(stratified_split.summary())


# Check distribution in each split
def count_by_part(insts):
    counts = {}
    for inst in insts:
        part = inst.metadata.get("custom_part", "unknown")
        counts[part] = counts.get(part, 0) + 1
    return counts


print(f"\nTrain distribution: {count_by_part(stratified_split.train)}")
print(f"Val distribution: {count_by_part(stratified_split.val)}")
print(f"Test distribution: {count_by_part(stratified_split.test)}")

---

## split_by_attribute(): Attribute-Based Splitting

Split based on specific attribute values (e.g., train on AIME I, test on AIME II).

In [None]:
# Split by AIME part: train on AIME I, validate on some of AIME II
# Note: This requires questions to have distinct attribute values

try:
    attribute_split = split_by_attribute(
        benchmark,
        attribute="custom_part",
        train_values=["AIME_I"],
        val_values=["AIME_II"],
    )

    print(attribute_split.summary())
    print(f"\nAll train questions are from: {set(inst.metadata.get('custom_part') for inst in attribute_split.train)}")
    print(f"All val questions are from: {set(inst.metadata.get('custom_part') for inst in attribute_split.val)}")
except ValueError as e:
    print(f"Note: {e}")
    print("(This is expected if the metadata field doesn't exist or has no matching values)")

---

## questions_to_data_insts(): Manual Conversion

Convert specific question IDs to `KareninaDataInst` objects.

In [None]:
# Get all question IDs
all_ids = benchmark.get_question_ids()

# Convert first 5 to data instances
sample_insts = questions_to_data_insts(benchmark, all_ids[:5])

print(f"Converted {len(sample_insts)} questions to KareninaDataInst")
for inst in sample_insts:
    print(f"  - {inst.question_id[:50]}... -> answer: {inst.raw_answer}")

In [None]:
# Manual split using questions_to_data_insts
train_ids = all_ids[:20]
val_ids = all_ids[20:25]
test_ids = all_ids[25:]

train_insts = questions_to_data_insts(benchmark, train_ids)
val_insts = questions_to_data_insts(benchmark, val_ids)
test_insts = questions_to_data_insts(benchmark, test_ids)

# Create BenchmarkSplit manually
manual_split = BenchmarkSplit(
    train=train_insts,
    val=val_insts,
    test=test_insts,
    seed=None,  # No random seed since we used explicit IDs
)

print(manual_split.summary())

---

## Best Practices

### 1. Always Use Seeds for Reproducibility

In [None]:
# Good: Use a seed
reproducible_split = split_benchmark(benchmark, seed=42)
print(f"Reproducible split seed: {reproducible_split.seed}")

# Save the seed with your optimization results for reproducibility

### 2. Use Test Sets for Final Evaluation

In [None]:
# Good: Reserve a test set
proper_split = split_benchmark(
    benchmark,
    train_ratio=0.7,
    val_ratio=0.2,
    test_ratio=0.1,
    seed=42,
)

print("Proper split with held-out test set:")
print(f"  Train: {len(proper_split.train)} (for optimization)")
print(f"  Val: {len(proper_split.val)} (for candidate selection)")
print(f"  Test: {len(proper_split.test)} (for final evaluation only)")

### 3. Stratify When Distribution Matters

In [None]:
# Good: Stratify by important attributes
stratified = split_benchmark(
    benchmark,
    train_ratio=0.7,
    val_ratio=0.2,
    test_ratio=0.1,
    seed=42,
    stratify_by="custom_part",  # Ensure both AIME I and II in each split
)

print("Stratified split maintains distribution across splits")

### 4. Validate Splits Before Optimization

In [None]:
# Good: Validate the split
split = split_benchmark(benchmark, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1, seed=42)

# Check no overlap
train_set = set(split.train_ids)
val_set = set(split.val_ids)
test_set = set(split.test_ids)

assert len(train_set & val_set) == 0, "Train and val overlap!"
assert len(train_set & test_set) == 0, "Train and test overlap!"
assert len(val_set & test_set) == 0, "Val and test overlap!"

# Check coverage
all_ids = set(benchmark.get_question_ids())
split_ids = train_set | val_set | test_set
assert split_ids == all_ids, "Some questions missing from split!"

print("Split validation passed!")
print("  No overlaps between sets")
print(f"  All {len(all_ids)} questions accounted for")

---

## Summary

| Function | Use Case |
|----------|----------|
| `split_benchmark()` | Random/stratified splitting with ratios |
| `split_by_attribute()` | Split by specific attribute values |
| `questions_to_data_insts()` | Manual conversion for custom splits |

| Parameter | Description |
|-----------|-------------|
| `train_ratio` | Fraction for training (default: 0.8) |
| `val_ratio` | Fraction for validation (default: 0.2) |
| `test_ratio` | Fraction for testing (default: None) |
| `seed` | Random seed for reproducibility |
| `stratify_by` | Metadata field to stratify by |

## Next Steps

- [04_scoring_deep_dive.ipynb](04_scoring_deep_dive.ipynb) - Understanding score computation