# Overfit Training Comparison (Small Dataset)

This notebook runs a matrix of training configs, evaluates each run on fixed guidance scales, and compares:
- decode validity (`decoded / samples`)
- SRV exact match (on decoded subset)
- training loss summary

It is designed to quickly find a *working* small-dataset configuration before long runs.

In [None]:
import os
import re
import shlex
import subprocess
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Paths
REPO_ROOT = Path('..').resolve()
TRAIN_SCRIPT = REPO_ROOT / 'scripts' / 'train_model.py'
EVAL_SCRIPT = REPO_ROOT / 'scripts' / 'evaluate_model.py'
RESULTS_DIR = REPO_ROOT / 'notebooks' / 'results' / 'overfit_training_comparison'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Control flags
RUN_TRAINING = True
RUN_EVALUATION = True
DRY_RUN = False

# Base configs
TRAIN_BASE = 'overfit_debug_srv'
EVAL_BASE = 'overfit_debug_srv'

# Evaluation sweep
GUIDANCE_SCALES = [0.5, 1.0, 1.5, 3.0, 5.0, 10.0]
NUM_EVAL_SAMPLES = 128

STAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
print('Results dir:', RESULTS_DIR)
print('Timestamp  :', STAMP)

In [None]:
# Experiment matrix (edit freely)
EXPERIMENTS = [
    {
        'name': 'exp_a_lr3e4_s512_e100_ip0p0',
        'train_overrides': {
            'training.learning_rate': '3e-4',
            'training.max_samples': '512',
            'training.num_epochs': '100',
            'training.batch_size': '32',
            'scheduler.params.input_perturbation': '0.0',
            'training.enable_guidance_train': 'true',
            'training.guidance_train_p': '0.1',
        },
    },
    {
        'name': 'exp_b_lr1e4_s512_e100_ip0p0',
        'train_overrides': {
            'training.learning_rate': '1e-4',
            'training.max_samples': '512',
            'training.num_epochs': '100',
            'training.batch_size': '32',
            'scheduler.params.input_perturbation': '0.0',
            'training.enable_guidance_train': 'true',
            'training.guidance_train_p': '0.1',
        },
    },
    {
        'name': 'exp_c_lr3e4_s1024_e80_ip0p0',
        'train_overrides': {
            'training.learning_rate': '3e-4',
            'training.max_samples': '1024',
            'training.num_epochs': '80',
            'training.batch_size': '64',
            'scheduler.params.input_perturbation': '0.0',
            'training.enable_guidance_train': 'true',
            'training.guidance_train_p': '0.1',
        },
    },
]

for e in EXPERIMENTS:
    print('-', e['name'])

In [None]:
def run_cmd(cmd, cwd=REPO_ROOT, env=None, dry_run=False):
    printable = ' '.join(shlex.quote(str(x)) for x in cmd)
    print('\n$ ' + printable)
    if dry_run:
        return 0, '', ''

    proc = subprocess.run(
        cmd,
        cwd=str(cwd),
        env=env,
        text=True,
        capture_output=True,
    )
    if proc.returncode != 0:
        print(proc.stdout[-2000:])
        print(proc.stderr[-2000:])
    return proc.returncode, proc.stdout, proc.stderr


def parse_eval_stdout(text):
    out = {
        'samples_requested': np.nan,
        'decoded_circuits': np.nan,
        'decode_failures': np.nan,
        'valid_rate': np.nan,
        'srv_exact_match': np.nan,
    }

    m = re.search(r"Samples requested:\s*(\d+)", text)
    if m:
        out['samples_requested'] = int(m.group(1))

    m = re.search(r"Decoded circuits\s*:\s*(\d+)", text)
    if m:
        out['decoded_circuits'] = int(m.group(1))

    m = re.search(r"Decode failures\s*:\s*(\d+)", text)
    if m:
        out['decode_failures'] = int(m.group(1))

    if not np.isnan(out['samples_requested']) and not np.isnan(out['decoded_circuits']):
        out['valid_rate'] = out['decoded_circuits'] / out['samples_requested']

    m = re.search(r"SRV exact-match rate\s*:\s*([0-9.]+)", text)
    if m:
        out['srv_exact_match'] = float(m.group(1))

    return out


def parse_training_losses(model_dir):
    model_dir = Path(model_dir)
    out = {
        'train_loss_last': np.nan,
        'train_loss_min': np.nan,
        'valid_loss_last': np.nan,
        'valid_loss_min': np.nan,
    }

    tr_path = model_dir / 'fit_losses.txt'
    va_path = model_dir / 'fit_valid_losses.txt'

    if tr_path.exists():
        tr = np.loadtxt(tr_path)
        if tr.size > 0:
            out['train_loss_last'] = float(tr[-1])
            out['train_loss_min'] = float(np.min(tr))

    if va_path.exists():
        va = np.loadtxt(va_path)
        if va.ndim == 1 and va.size == 2:
            va = va.reshape(1, 2)
        if va.size > 0:
            out['valid_loss_last'] = float(va[-1, 1])
            out['valid_loss_min'] = float(np.min(va[:, 1]))

    return out

In [None]:
rows = []

for exp in EXPERIMENTS:
    exp_name = exp['name']
    model_name = f"{exp_name}_{STAMP}"
    model_dir = REPO_ROOT / 'models' / 'trained' / model_name

    # ----- Train -----
    if RUN_TRAINING:
        train_cmd = [
            'python', str(TRAIN_SCRIPT),
            f'training={TRAIN_BASE}',
            f'general.model_name={model_name}',
            f'training.model_name={model_name}',
            f'general.experiment_name={model_name}',
            f'training.experiment_name={model_name}',
        ]
        for k, v in exp['train_overrides'].items():
            train_cmd.append(f'{k}={v}')

        rc, out, err = run_cmd(train_cmd, dry_run=DRY_RUN)
        (RESULTS_DIR / f'{model_name}_train_stdout.txt').write_text(out)
        (RESULTS_DIR / f'{model_name}_train_stderr.txt').write_text(err)

        if rc != 0:
            rows.append({
                'experiment': exp_name,
                'model_name': model_name,
                'guidance_scale': np.nan,
                'status': 'train_failed',
            })
            continue

    loss_stats = parse_training_losses(model_dir)

    # ----- Eval sweep -----
    if RUN_EVALUATION:
        for g in GUIDANCE_SCALES:
            eval_cmd = [
                'python', str(EVAL_SCRIPT),
                f'evaluation={EVAL_BASE}',
                f'evaluation.model_dir={model_dir}',
                f'evaluation.num_samples={NUM_EVAL_SAMPLES}',
                f'evaluation.model_params.guidance_scale={g}',
            ]

            rc, out, err = run_cmd(eval_cmd, dry_run=DRY_RUN)
            (RESULTS_DIR / f'{model_name}_eval_g{g}_stdout.txt').write_text(out)
            (RESULTS_DIR / f'{model_name}_eval_g{g}_stderr.txt').write_text(err)

            metrics = parse_eval_stdout(out)
            rows.append({
                'experiment': exp_name,
                'model_name': model_name,
                'guidance_scale': g,
                'status': 'ok' if rc == 0 else 'eval_failed',
                **loss_stats,
                **metrics,
            })

results_df = pd.DataFrame(rows)
results_df

In [None]:
if not results_df.empty:
    out_csv = RESULTS_DIR / f'results_{STAMP}.csv'
    results_df.to_csv(out_csv, index=False)
    print('Saved:', out_csv)

In [None]:
if not results_df.empty:
    display_cols = [
        'experiment', 'model_name', 'guidance_scale', 'status',
        'valid_rate', 'srv_exact_match',
        'train_loss_last', 'train_loss_min', 'valid_loss_last', 'valid_loss_min'
    ]
    display(results_df[display_cols].sort_values(['experiment', 'guidance_scale']))

In [None]:
if not results_df.empty:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    for exp_name, grp in results_df.groupby('experiment'):
        grp = grp.sort_values('guidance_scale')
        axes[0].plot(grp['guidance_scale'], grp['valid_rate'], marker='o', label=exp_name)
        axes[1].plot(grp['guidance_scale'], grp['srv_exact_match'], marker='o', label=exp_name)

    axes[0].set_title('Decode Valid Rate vs Guidance')
    axes[0].set_xlabel('guidance_scale')
    axes[0].set_ylabel('valid_rate')
    axes[0].set_ylim(0, 1)
    axes[0].grid(True, alpha=0.3)

    axes[1].set_title('SRV Exact Match vs Guidance')
    axes[1].set_xlabel('guidance_scale')
    axes[1].set_ylabel('exact_match')
    axes[1].set_ylim(0, 1)
    axes[1].grid(True, alpha=0.3)

    axes[1].legend(loc='best', fontsize=8)
    plt.tight_layout()
    plt.show()

## Notes
- Keep `RUN_TRAINING=False` if you only want to re-run evaluation on existing `model_name`s.
- If a run crashes, inspect `notebooks/results/overfit_training_comparison/*stderr.txt`.
- Start with 2-3 experiments, then expand once validity moves above ~50% on 3q.