# Step A — Train baseline + alignment regimes

Trains 3 regimes on the **set-trending-sinusoids** dataset with `abc2` knowledge:
- `baseline` — no alignment
- `aggressive_align_rT` — InfoNCE on posterior mean q(z|C,T)
- `safe_align_rC` — InfoNCE on prior mean q(z|C)

Each regime is trained for every seed in `SEEDS`.
Outputs go to `outputs/{run_name}/` with `config.toml`, `metrics.jsonl`, `model_best.pt`.

In [24]:
import sys, os

# Find repo root (parent of notebooks/)
_nb_dir = os.path.dirname(os.path.abspath('__file__'))
REPO_ROOT = os.path.abspath(os.path.join(_nb_dir, '..'))
# Fallback: walk up until we find config.py
_d = os.getcwd()
while _d != os.path.dirname(_d):
    if os.path.isfile(os.path.join(_d, 'config.py')):
        REPO_ROOT = _d
        break
    _d = os.path.dirname(_d)

os.chdir(REPO_ROOT)
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

print(f'Working directory: {os.getcwd()}')

Working directory: /Users/louishayot/MVA/VdS_Submission_Final


In [25]:
OUTPUT_DIR = 'outputs'
FORCE_RETRAIN = False

# ---- FAST_DEV: quick sanity check (1 seed, 50 epochs) ----
FAST_DEV = False

if FAST_DEV:
    SEEDS = [0]
    NUM_EPOCHS = 50
else:
    SEEDS = [0, 1, 2]
    NUM_EPOCHS = 1000

In [26]:
from config import Config

def make_base_config(seed=0):
    """Base config matching run_sinusoids.sh INP_abc2 settings."""
    return Config(**{
        # project
        'project_name': 'INPs_sinusoids',
        'seed': seed,
        # training
        'batch_size': 64,
        'num_epochs': NUM_EPOCHS,
        'lr': 1e-3,
        'beta': 1.0,
        'sort_context': False,
        'n_trials': 1,
        'train_split': 'train',
        'val_split': 'val',
        'decay_lr': 10,
        # dataset
        'dataset': 'set-trending-sinusoids',
        'knowledge_type': 'abc2',
        'min_num_context': 0,
        'max_num_context': 10,
        'num_targets': 100,
        'noise': 0.2,
        'x_sampler': 'uniform',
        # knowledge
        'use_knowledge': True,
        'text_encoder': 'set',
        'freeze_llm': True,
        'tune_llm_layer_norms': False,
        'knowledge_dropout': 0.3,
        'knowledge_merge': 'sum',
        'knowledge_dim': 128,
        'knowledge_extractor_num_hidden': 2,
        'knowledge_extractor_hidden_dim': 128,
        # model architecture
        'input_dim': 1,
        'output_dim': 1,
        'hidden_dim': 128,
        'x_transf_dim': 128,
        'x_encoder_num_hidden': 1,
        'xy_encoder_num_hidden': 2,
        'xy_encoder_hidden_dim': 384,
        'xy_self_attention': 'none',
        'xy_self_attention_num_layers': 1,
        'data_agg_func': 'mean',
        'latent_encoder_num_hidden': 1,
        'decoder_hidden_dim': 128,
        'decoder_num_hidden': 3,
        'decoder_activation': 'gelu',
        'train_num_z_samples': 1,
        'test_num_z_samples': 32,
        'path': 'latent',
        # alignment (overridden per regime)
        'alignment_mode': 'none',
        'alignment_lambda': 0.0,
        'alignment_temperature': 0.1,
    })

In [27]:
REGIMES = {
    'baseline': {
        'alignment_mode': 'none',
        'alignment_lambda': 0.0,
    },
    'aggressive_align_rT': {
        'alignment_mode': 'rT',
        'alignment_lambda': 0.1,
        'alignment_temperature': 0.1,
    },
    'safe_align_rC': {
        'alignment_mode': 'rC',
        'alignment_lambda': 0.01,
        'alignment_temperature': 0.2,
    },
}

In [None]:
import json
from models.train import train_from_config

for regime_name, overrides in REGIMES.items():
    for seed in SEEDS:
        run_name = f'{regime_name}_seed{seed}'
        ckpt_path = os.path.join(OUTPUT_DIR, run_name, 'model_best.pt')

        if os.path.exists(ckpt_path) and not FORCE_RETRAIN:
            print(f'SKIP {run_name} (checkpoint exists)')
            continue

        print(f'\n{"="*60}')
        print(f'Training: {run_name}')
        print(f'{"="*60}')

        config = make_base_config(seed=seed)
        for k, v in overrides.items():
            setattr(config, k, v)

        best_loss, save_dir = train_from_config(
            config, output_dir=OUTPUT_DIR, run_name=run_name, use_wandb=False
        )

        # Print summary from last metrics line
        metrics_path = os.path.join(save_dir, 'metrics.jsonl')
        if os.path.exists(metrics_path):
            with open(metrics_path) as f:
                lines = f.readlines()
            last_train = None
            last_eval = None
            for line in reversed(lines):
                d = json.loads(line)
                if last_eval is None and 'eval_loss' in d:
                    last_eval = d
                if last_train is None and 'train_predictive_nll' in d:
                    last_train = d
                if last_train and last_eval:
                    break
            if last_train:
                msg = f'  train_predictive_nll={last_train["train_predictive_nll"]:.4f}'
                if 'train_alignment_loss' in last_train:
                    msg += f'  align_loss={last_train["train_alignment_loss"]:.4f}'
                print(msg)
            if last_eval:
                print(f'  eval_loss={last_eval["eval_loss"]:.4f}')

        print(f'  best_val_loss={best_loss:.4f}  save_dir={save_dir}')


Training: baseline_seed0
Using device: cpu
Trainable parameters:
xy_encoder.pairer.layers.0.weight
xy_encoder.pairer.layers.0.bias
xy_encoder.pairer.layers.1.weight
xy_encoder.pairer.layers.1.bias
xy_encoder.pairer.layers.2.weight
xy_encoder.pairer.layers.2.bias
latent_encoder.knowledge_encoder.text_encoder.h1.layers.0.weight
latent_encoder.knowledge_encoder.text_encoder.h1.layers.0.bias
latent_encoder.knowledge_encoder.text_encoder.h1.layers.1.weight
latent_encoder.knowledge_encoder.text_encoder.h1.layers.1.bias
latent_encoder.knowledge_encoder.text_encoder.h2.layers.0.weight
latent_encoder.knowledge_encoder.text_encoder.h2.layers.0.bias
latent_encoder.knowledge_encoder.text_encoder.h2.layers.1.weight
latent_encoder.knowledge_encoder.text_encoder.h2.layers.1.bias
latent_encoder.knowledge_encoder.knowledge_extractor.layers.0.weight
latent_encoder.knowledge_encoder.knowledge_extractor.layers.0.bias
latent_encoder.knowledge_encoder.knowledge_extractor.layers.1.weight
latent_encoder.know

In [None]:
# Verify outputs
print('Run outputs:')
for regime_name in REGIMES:
    for seed in SEEDS:
        run_name = f'{regime_name}_seed{seed}'
        run_dir = os.path.join(OUTPUT_DIR, run_name)
        has_ckpt = os.path.exists(os.path.join(run_dir, 'model_best.pt'))
        has_metrics = os.path.exists(os.path.join(run_dir, 'metrics.jsonl'))
        has_config = os.path.exists(os.path.join(run_dir, 'config.toml'))
        status = 'OK' if (has_ckpt and has_metrics and has_config) else 'INCOMPLETE'
        print(f'  {run_name:40s}  ckpt={has_ckpt}  metrics={has_metrics}  config={has_config}  [{status}]')

Run outputs:
  baseline_seed0                            ckpt=False  metrics=False  config=False  [INCOMPLETE]
  baseline_seed1                            ckpt=False  metrics=False  config=False  [INCOMPLETE]
  baseline_seed2                            ckpt=False  metrics=False  config=False  [INCOMPLETE]
  aggressive_align_rT_seed0                 ckpt=False  metrics=False  config=False  [INCOMPLETE]
  aggressive_align_rT_seed1                 ckpt=False  metrics=False  config=False  [INCOMPLETE]
  aggressive_align_rT_seed2                 ckpt=False  metrics=False  config=False  [INCOMPLETE]
  safe_align_rC_seed0                       ckpt=False  metrics=False  config=False  [INCOMPLETE]
  safe_align_rC_seed1                       ckpt=False  metrics=False  config=False  [INCOMPLETE]
  safe_align_rC_seed2                       ckpt=False  metrics=False  config=False  [INCOMPLETE]
