# Step B — Evaluate under knowledge corruption

Loads Step A checkpoints (sinusoids, abc2) and evaluates Predictive NLL (IS-NLL)
under 4 corruption regimes: `clean`, `noisy_low`, `noisy_high`, `permuted`.

For each regime, evaluation is done at context sizes C ∈ {0, 3, 5, 10}.

Exports:
- Per-run: `outputs/{run_name}/stepB_{mode}.json`
- Aggregated: `outputs/stepB_eval.json` + `outputs/stepB_eval.tsv`

In [9]:
import sys, os

_d = os.getcwd()
while _d != os.path.dirname(_d):
    if os.path.isfile(os.path.join(_d, 'config.py')):
        break
    _d = os.path.dirname(_d)
REPO_ROOT = _d

os.chdir(REPO_ROOT)
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

print(f'Working directory: {os.getcwd()}')

Working directory: /Users/louishayot/MVA/VdS_Submission_Final


In [10]:
import torch
import numpy as np
import json
import csv
import random

from config import Config
from models.inp import INP
from models.loss import NLL
from models.knowledge_corruption import corrupt_knowledge
from dataset.utils import setup_dataloaders

In [11]:
# ---------- Configuration ----------
OUTPUT_DIR = 'outputs'
MAX_EVAL_BATCHES = 50
EVAL_SEED = 42
CONTEXT_SIZES = [0, 3, 5, 10]

CORRUPTION_REGIMES = {
    'clean':      {'regime_str': 'clean',      'sigma_rel': 0.0},
    'noisy_low':  {'regime_str': 'noisy_0.1',  'sigma_rel': 0.1},
    'noisy_high': {'regime_str': 'noisy_0.3',  'sigma_rel': 0.3},
    'permuted':   {'regime_str': 'permuted',   'sigma_rel': 0.0},
}

# Discover all Step A run dirs that have a checkpoint
RUN_NAMES = sorted([
    d for d in os.listdir(OUTPUT_DIR)
    if os.path.isfile(os.path.join(OUTPUT_DIR, d, 'model_best.pt'))
       and os.path.isfile(os.path.join(OUTPUT_DIR, d, 'config.toml'))
])
print(f'Found {len(RUN_NAMES)} runs: {RUN_NAMES}')

Found 9 runs: ['aggressive_align_rT_seed0', 'aggressive_align_rT_seed1', 'aggressive_align_rT_seed2', 'baseline_seed0', 'baseline_seed1', 'baseline_seed2', 'safe_align_rC_seed0', 'safe_align_rC_seed1', 'safe_align_rC_seed2']


In [12]:
def load_model(run_dir):
    """Load trained INP model from a Step A run directory."""
    config = Config.from_toml(os.path.join(run_dir, 'config.toml'))
    config.device = 'cpu'
    # Disable knowledge dropout at eval time
    config.knowledge_dropout = 0.0

    # We need knowledge_input_dim from dataset; for sinusoids abc2 it's always 4
    config.knowledge_input_dim = 4

    model = INP(config)
    state_dict = torch.load(os.path.join(run_dir, 'model_best.pt'), map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model, config


def get_test_dataloader(config):
    """Build test dataloader matching training config."""
    _, _, test_dl, _ = setup_dataloaders(config)
    return test_dl

In [13]:
def evaluate_run(run_name):
    """Evaluate one run under all corruption regimes. Returns list of result dicts."""
    run_dir = os.path.join(OUTPUT_DIR, run_name)
    model, config = load_model(run_dir)
    test_dl = get_test_dataloader(config)

    nll_func = NLL(reduction='none')  # per-sample IS-NLL

    # Set seeds for reproducibility
    torch.manual_seed(EVAL_SEED)
    np.random.seed(EVAL_SEED)
    random.seed(EVAL_SEED)

    # Collect batches once (to reuse across corruption regimes)
    batches = []
    for i, batch in enumerate(test_dl):
        if i >= MAX_EVAL_BATCHES:
            break
        context, target, knowledge, extras = batch
        x_target, y_target = target
        batches.append((x_target, y_target, knowledge))

    # Pre-sample context indices for each batch and each C
    # (shared across corruption regimes for variance control)
    rng = np.random.RandomState(EVAL_SEED)
    batch_context_indices = []
    for x_target, _, _ in batches:
        n_points = x_target.shape[1]
        indices_per_C = {}
        for C in CONTEXT_SIZES:
            if C == 0:
                indices_per_C[C] = np.array([], dtype=int)
            else:
                indices_per_C[C] = rng.choice(n_points, C, replace=False)
        batch_context_indices.append(indices_per_C)

    # Evaluate
    run_results = []
    for mode_name, mode_cfg in CORRUPTION_REGIMES.items():
        regime_str = mode_cfg['regime_str']
        sigma_rel = mode_cfg['sigma_rel']

        nll_per_C = {C: [] for C in CONTEXT_SIZES}

        for batch_idx, (x_target, y_target, knowledge) in enumerate(batches):
            # Corrupt knowledge ONCE per batch (reuse for all C)
            K_corrupted = corrupt_knowledge(knowledge, regime=regime_str, seed=EVAL_SEED + batch_idx)

            for C in CONTEXT_SIZES:
                idx = batch_context_indices[batch_idx][C]

                if C == 0:
                    x_ctx = x_target[:, :0, :]  # empty context
                    y_ctx = y_target[:, :0, :]
                else:
                    x_ctx = x_target[:, idx, :]
                    y_ctx = y_target[:, idx, :]

                with torch.no_grad():
                    outputs = model(
                        x_ctx, y_ctx, x_target, y_target=None,
                        knowledge=K_corrupted
                    )
                    p_yCc, z_samples, q_z_Cc, q_zCct = outputs[:4]
                    nll_val, _, _ = nll_func.get_loss(
                        p_yCc, z_samples, q_z_Cc, q_zCct, y_target
                    )
                    # nll_val: [batch_size] per-sample NLL
                    nll_per_C[C].append(nll_val.mean().item())

        # Aggregate
        result = {
            'run_name': run_name,
            'run_dir': run_dir,
            'seed': int(run_name.split('seed')[-1]) if 'seed' in run_name else -1,
            'eval_corruption_mode': mode_name,
            'sigma_rel': sigma_rel,
        }
        for C in CONTEXT_SIZES:
            result[f'eval_nll_{C}'] = float(np.mean(nll_per_C[C]))
        result['mean_eval_nll'] = float(np.mean(
            [result[f'eval_nll_{C}'] for C in CONTEXT_SIZES]
        ))

        run_results.append(result)

        # Save per-run per-mode JSON
        json_path = os.path.join(run_dir, f'stepB_{mode_name}.json')
        with open(json_path, 'w') as f:
            json.dump(result, f, indent=2)

    return run_results

In [14]:
# ---------- Run evaluation for all checkpoints ----------
all_results = []
for run_name in RUN_NAMES:
    print(f'\nEvaluating: {run_name}')
    results = evaluate_run(run_name)
    for r in results:
        mode = r['eval_corruption_mode']
        print(f"  {mode:12s}  mean_nll={r['mean_eval_nll']:.4f}  "
              f"nll@0={r['eval_nll_0']:.4f}  nll@3={r['eval_nll_3']:.4f}  "
              f"nll@5={r['eval_nll_5']:.4f}  nll@10={r['eval_nll_10']:.4f}")
    all_results.extend(results)

print(f'\nTotal result rows: {len(all_results)}')


Evaluating: aggressive_align_rT_seed0
  clean         mean_nll=42.0457  nll@0=117.9802  nll@3=37.4406  nll@5=14.7310  nll@10=-1.9688
  noisy_low     mean_nll=85.4011  nll@0=200.5888  nll@3=82.2076  nll@5=49.1143  nll@10=9.6936
  noisy_high    mean_nll=268.1713  nll@0=528.6855  nll@3=269.0284  nll@5=183.3592  nll@10=91.6122
  permuted      mean_nll=318.6687  nll@0=676.4164  nll@3=304.2128  nll@5=197.8510  nll@10=96.1947

Evaluating: aggressive_align_rT_seed1
  clean         mean_nll=35.4595  nll@0=108.6789  nll@3=25.5044  nll@5=10.1731  nll@10=-2.5183
  noisy_low     mean_nll=89.6862  nll@0=230.7248  nll@3=71.2310  nll@5=42.6149  nll@10=14.1743
  noisy_high    mean_nll=378.4654  nll@0=855.9799  nll@3=310.4139  nll@5=204.6845  nll@10=142.7834
  permuted      mean_nll=360.5978  nll@0=798.5223  nll@3=317.2819  nll@5=208.0097  nll@10=118.5772

Evaluating: aggressive_align_rT_seed2
  clean         mean_nll=39.3868  nll@0=117.9600  nll@3=30.0178  nll@5=11.0919  nll@10=-1.5227
  noisy_low    

In [15]:
# ---------- Export aggregated JSON + TSV ----------
agg_json_path = os.path.join(OUTPUT_DIR, 'stepB_eval.json')
agg_tsv_path = os.path.join(OUTPUT_DIR, 'stepB_eval.tsv')

# JSON
with open(agg_json_path, 'w') as f:
    json.dump(all_results, f, indent=2)
print(f'Wrote {agg_json_path}')

# TSV
if all_results:
    fieldnames = list(all_results[0].keys())
    with open(agg_tsv_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter='\t')
        writer.writeheader()
        for row in all_results:
            writer.writerow(row)
    print(f'Wrote {agg_tsv_path}')

# Quick sanity check
print(f'\nSample row:')
print(json.dumps(all_results[0], indent=2))

Wrote outputs/stepB_eval.json
Wrote outputs/stepB_eval.tsv

Sample row:
{
  "run_name": "aggressive_align_rT_seed0",
  "run_dir": "outputs/aggressive_align_rT_seed0",
  "seed": 0,
  "eval_corruption_mode": "clean",
  "sigma_rel": 0.0,
  "eval_nll_0": 117.98021793365479,
  "eval_nll_3": 37.440579771995544,
  "eval_nll_5": 14.730976581573486,
  "eval_nll_10": -1.9687858819961548,
  "mean_eval_nll": 42.045747101306915
}


In [16]:
# ---------- Verification ----------
expected_keys = {
    'run_name', 'run_dir', 'seed', 'eval_corruption_mode', 'sigma_rel',
    'eval_nll_0', 'eval_nll_3', 'eval_nll_5', 'eval_nll_10', 'mean_eval_nll',
}
for r in all_results:
    missing = expected_keys - set(r.keys())
    assert not missing, f'Missing keys in {r["run_name"]}: {missing}'

assert os.path.isfile(agg_json_path), 'stepB_eval.json not found'
assert os.path.isfile(agg_tsv_path), 'stepB_eval.tsv not found'
print(f'All {len(all_results)} rows have correct keys.')
print(f'Files exist: {agg_json_path}, {agg_tsv_path}')
print('Step B DONE.')

All 36 rows have correct keys.
Files exist: outputs/stepB_eval.json, outputs/stepB_eval.tsv
Step B DONE.
