# Layer-aware contrastive training (new Trainer + evaluation)

This notebook mirrors `b_contrastive_training_with_new_trainer.ipynb`, but uses layer-aware components:
- `activation_research.model.LayerAwareProgressiveCompressor`
- `activation_research.trainer.LayerAwareContrastiveTrainer`

It assumes activations + eval results are already on disk (for example a `.zarr` store).

In [None]:
import os
import sys

from torch.utils.data import DataLoader
from loguru import logger
from tqdm.auto import tqdm

from activation_logging.activation_parser import ActivationParser
from activation_research.model import LayerAwareProgressiveCompressor
from activation_research.trainer import LayerAwareContrastiveTrainer, LayerAwareContrastiveTrainerConfig
from activation_research.metric_evaluator import MultiMetricHallucinationEvaluator

logger.remove()
logger.add(sys.stdout, level="INFO")

In [None]:
# ---- Paths (edit these for your environment) ----
inference_json = 'shared/goodwiki_jsonv2/generation.jsonl'
eval_json = 'shared/goodwiki.zarr/eval_results.json'
activations_path = 'shared/goodwiki.zarr/activations.zarr'

# ---- Dataset parameters ----
backend = 'zarr'  # 'zarr' or 'wds' or 'auto'
relevant_layers = list(range(14, 30))
target_layers = [22, 26]  # used for embedding-based OOD evaluation

# Treat one class as outlier for certain evaluations.
# If outlier_class=1, we use non-halu samples as baseline (ID).
outlier_class = 1

# Optional: if set, one of the two views is always from this layer index.
fixed_layer = None

# ---- Layer-aware model hyperparams ----
device = 'auto'  # 'auto', 'cuda', 'cpu'
input_dim = 4096
final_dim = 512
layer_embed_dim = 128
conditioning = 'film_both'  # film_in | film_out | film_both | positional | concatenate

# If parser emits absolute layer ids, this should cover max index.
# Works for both absolute and compact indices in practice.
num_layers = max(relevant_layers) + 1

# ---- Training hyperparams ----
max_epochs = 50
batch_size = 512
lr = 1e-5
temperature = 0.25
steps_per_epoch_override = 200  # e.g., 1000 for fixed steps/epoch

num_workers = 30
persistent_workers = True

checkpoint_dir = os.path.join('checkpoints', 'contrastive_layer_aware')

In [None]:
# ---- Load metadata + build train/test datasets ----
ap = ActivationParser(
    inference_json=inference_json,
    eval_json=eval_json,
    activations_path=activations_path,
    verbose=False,
)

train_dataset = ap.get_dataset(
    'train',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)
test_dataset = ap.get_dataset(
    'test',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)

print('train:', len(train_dataset))
print('test :', len(test_dataset))

In [None]:
# ---- Layer-aware contrastive encoder ----
model = LayerAwareProgressiveCompressor(
    num_layers=num_layers,
    input_dim=input_dim,
    final_dim=final_dim,
    layer_embed_dim=layer_embed_dim,
    conditioning=conditioning,
    input_dropout=0.3,
)
model

In [None]:
# ---- Train with the layer-aware Trainer API ----
config = LayerAwareContrastiveTrainerConfig(
    max_epochs=max_epochs,
    batch_size=batch_size,
    lr=lr,
    temperature=temperature,
    steps_per_epoch_override=steps_per_epoch_override,
    device=device,
    num_workers=num_workers,
    persistent_workers=persistent_workers,
    checkpoint_dir=checkpoint_dir,
    save_every=1,
    snapshot_every=10,
    snapshot_keep_last=5,

    # Supervised contrastive: uses `halu` labels; ignore the configured outlier class.
    use_labels=True,
    ignore_label=outlier_class,

    # Keeps DataLoader workers alive for map-style datasets (disabled automatically for IterableDataset).
    use_infinite_index_stream=True,
    use_infinite_index_stream_eval=True,
)
print(config)

trainer = LayerAwareContrastiveTrainer(model, config=config)
trainer.fit(train_dataset=train_dataset, val_dataset=test_dataset)

In [None]:
# ---- OOD evaluation with the new evaluator abstraction ----
train_dataset_for_inference = ap.get_dataset(
    'train',
    relevant_layers=target_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)
eval_dataset = ap.get_dataset(
    'test',
    relevant_layers=target_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)

train_loader_for_baseline = DataLoader(train_dataset_for_inference, batch_size=64, shuffle=False)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

model_for_eval = trainer.model

ood_eval = MultiMetricHallucinationEvaluator(
    activation_parser_df=ap.df,
    train_data_loader=train_loader_for_baseline,
    layers=None,
    batch_size=256,
    sub_batch_size=64,
    device=str(trainer.device),
    num_workers=num_workers,
    persistent_workers=False,
    outlier_class=outlier_class,
    metrics=[
        'cosine',
        'mds',
        {
            'metric': 'knn',
            'kwargs': {
                'k': 50,
                'metric': 'euclidean',
                'calibrate_k': True,
                'k_candidates': [50, 100, 200, 500, 1000],
                'max_train_size': 200000,
                'sample_seed': 0,
            },
            'train_selection': 'all',
        },
    ],
)
ood_stats = ood_eval.compute(eval_loader, model_for_eval)
print('OOD metrics:', ood_stats)

## Evaluate across epoch snapshots

Load each saved snapshot checkpoint, run OOD evaluation, and compare metrics across training epochs.

In [None]:
import re
import glob
import torch
import pandas as pd

# ---- Discover available snapshots ----
snapshot_pattern = os.path.join(checkpoint_dir, 'layer_aware_contrastive_snapshot_epoch_*.pt')
snapshot_files = sorted(glob.glob(snapshot_pattern))

last_ckpt = os.path.join(checkpoint_dir, 'layer_aware_contrastive_last.pt')
if os.path.exists(last_ckpt):
    snapshot_files.append(last_ckpt)

def _parse_epoch(path: str) -> str:
    basename = os.path.basename(path)
    m = re.search(r'epoch_(\d+)', basename)
    if m:
        return f"epoch_{int(m.group(1))}"
    if 'last' in basename:
        return 'last'
    return basename

snapshot_info = [(p, _parse_epoch(p)) for p in snapshot_files]
print(f'Found {len(snapshot_info)} checkpoint(s) to evaluate:')
for path, label in snapshot_info:
    print(f'  {label:>12s}  ->  {path}')

In [None]:
# ---- Evaluate each snapshot ----
if 'train_dataset_for_inference' not in dir():
    train_dataset_for_inference = ap.get_dataset(
        'train', relevant_layers=target_layers, fixed_layer=fixed_layer, backend=backend,
    )
if 'eval_dataset' not in dir():
    eval_dataset = ap.get_dataset(
        'test', relevant_layers=target_layers, fixed_layer=fixed_layer, backend=backend,
    )

train_loader_for_baseline = DataLoader(train_dataset_for_inference, batch_size=64, shuffle=False)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

all_results = []

eval_device = str(trainer.device) if 'trainer' in dir() else ('cuda' if torch.cuda.is_available() else 'cpu')
map_location = eval_device

for ckpt_path, epoch_label in tqdm(snapshot_info, desc='Evaluating snapshots'):
    ckpt = torch.load(ckpt_path, map_location=map_location)

    snapshot_model = LayerAwareProgressiveCompressor(
        num_layers=num_layers,
        input_dim=input_dim,
        final_dim=final_dim,
        layer_embed_dim=layer_embed_dim,
        conditioning=conditioning,
        input_dropout=0.3,
    )
    snapshot_model.load_state_dict(ckpt['model_state_dict'])
    snapshot_model = snapshot_model.to(eval_device)
    snapshot_model.eval()

    ood_eval = MultiMetricHallucinationEvaluator(
        activation_parser_df=ap.df,
        train_data_loader=train_loader_for_baseline,
        layers=None,
        batch_size=256,
        sub_batch_size=64,
        device=eval_device,
        num_workers=num_workers,
        persistent_workers=False,
        outlier_class=outlier_class,
        metrics=[
            'cosine',
            'mds',
            {
                'metric': 'knn',
                'kwargs': {
                    'k': 50,
                    'metric': 'euclidean',
                    'calibrate_k': True,
                    'k_candidates': [50, 100, 200, 500, 1000],
                    'max_train_size': 200000,
                    'sample_seed': 0,
                },
                'train_selection': 'all',
            },
        ],
    )
    stats = ood_eval.compute(eval_loader, snapshot_model)

    epoch_num = int(re.search(r'\d+', epoch_label).group()) if re.search(r'\d+', epoch_label) else 9999
    result_row = {'checkpoint': epoch_label, 'epoch': epoch_num, 'path': ckpt_path}
    result_row.update(stats)
    all_results.append(result_row)

    logger.info(f'[{epoch_label}] {stats}')

results_df = pd.DataFrame(all_results).sort_values('epoch').reset_index(drop=True)
print('\n=== Results across snapshots ===')
results_df

In [None]:
# ---- Visualize metrics across epochs ----
import matplotlib.pyplot as plt

meta_cols = {'checkpoint', 'epoch', 'path'}
metric_cols = [c for c in results_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(results_df[c])]

if not metric_cols:
    print('No numeric metric columns found to plot.')
else:
    n_metrics = len(metric_cols)
    fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4), squeeze=False)
    axes = axes.flatten()

    for ax, col in zip(axes, metric_cols):
        ax.plot(results_df['epoch'], results_df[col], marker='o', linewidth=1.5)
        ax.set_xlabel('Epoch')
        ax.set_ylabel(col)
        ax.set_title(col)
        ax.grid(True, alpha=0.3)

    plt.suptitle('OOD Metrics vs Training Epoch (Layer-aware)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()