# Training â€” Grid Search

Runs the full DSPy/GEPA training pipeline:

1. Load & normalise dataset.
2. Split per label (shared locked selection holdout).
3. Build `dspy.Example` objects.
4. Grid search over all (generation_model, reflection_model) pairs.
5. Evaluate best program per label on locked holdout set.
6. Save optimised programs to disk.

---
## Setup

In [None]:
import json
import os
import sys

from loguru import logger

from training.config import (
    GENERATION_MODELS,
    GEPA_AUTO,
    LABELS,
    PROGRAMS_MANIFEST,
    PROGRAMS_SUBDIR,
    REFLECTION_MODELS,
    RESULTS_DIR,
)
from training.data import dataset_builder, split_pipeline
from training.evaluation import evaluate
from training.main import load_dataset, _dataset_hash
from training.optimization import grid_search

logger.remove()
logger.add(sys.stderr, level="INFO")
os.makedirs(RESULTS_DIR, exist_ok=True)

n_pairs = len(GENERATION_MODELS) * len(REFLECTION_MODELS)
print(
    f"Grid: {len(GENERATION_MODELS)} gen x {len(REFLECTION_MODELS)} reflection "
    f"= {n_pairs} pairs x {len(LABELS)} labels = {n_pairs * len(LABELS)} trials"
)
print(f"GEPA budget: {GEPA_AUTO}")

---
## 1. Load Dataset

In [None]:
df = load_dataset()
ds_hash = _dataset_hash(df)
print(f"Dataset hash: {ds_hash}")
df.head(3)

---
## 2. Split Dataset

In [None]:
splits = split_pipeline.run(df)
split_pipeline.save(splits)

---
## 3. Build dspy.Example Datasets

In [None]:
datasets = dataset_builder.build(splits)

for label in LABELS:
    d = datasets[label]
    print(f"{label}: train={len(d['trainset'])}, val={len(d['valset'])}")

---
## 4. Grid Search

In [None]:
best_per_label = grid_search.run(datasets=datasets, auto=GEPA_AUTO, dataset_hash=ds_hash)

---
## 5. Evaluate on Holdout Set

In [None]:
for label in LABELS:
    result = best_per_label[label]
    print(f"--- {label} (baseline) ---")
    evaluate.run_baseline_label(
        label=label,
        gen_model=result.gen_model,
        datasets=datasets,
        dataset_hash=ds_hash,
    )
    print(f"--- {label} (optimized) ---")
    evaluate.run_dspy_label(
        label=label,
        result=result,
        datasets=datasets,
        dataset_hash=ds_hash,
    )

---
## 6. Save Programs

In [None]:
programs_dir = os.path.join(RESULTS_DIR, PROGRAMS_SUBDIR)
manifest = {}

for label in LABELS:
    result = best_per_label[label]
    label_dir = os.path.join(programs_dir, label)
    os.makedirs(label_dir, exist_ok=True)
    result.optimized_program.save(label_dir, save_program=True)
    manifest[label] = {
        "gen_model": result.gen_model,
        "reflection_model": result.reflection_model,
        "dataset_hash": ds_hash,
    }
    print(f"[{label}] saved to {label_dir}")

manifest_path = os.path.join(programs_dir, PROGRAMS_MANIFEST)
with open(manifest_path, "w") as f:
    json.dump(manifest, f, indent=2)

print(f"\nManifest saved to {manifest_path}")

---
## 7. Test the API

Loads the saved programs via FastAPI's `TestClient` and runs a few sample queries.

In [None]:
import pandas as pd
from fastapi.testclient import TestClient
from serving.app import app, _load_programs

# Manually load programs (TestClient only triggers lifespan inside a `with` block,
# which cannot span notebook cells).
_load_programs()

client = TestClient(app)

# Verify the API is healthy
resp = client.get("/health")
print(resp.json())

In [None]:
from tqdm import tqdm

test_queries = [
    "My wife and I celebrated our anniversary with a quiet dinner at home.",
    "My brother helped me move into my new apartment last weekend.",
    "We've been best friends since college and still talk every day.",
    "The CEO appointed her as the new VP of Engineering.",
    "Two people were seen talking near the park bench.",
    "The weather forecast predicts rain for the rest of the week.",
]

rows = []
for text in tqdm(test_queries, desc="Classifying", unit="query"):
    resp = client.post("/classify", json={"text": text})
    data = resp.json()
    row = {"text": text[:80], "labels": ", ".join(data["labels"]) or "(none)"}
    row.update(data["details"])
    rows.append(row)

pd.DataFrame(rows)