# Contrastive training (new Trainer + evaluation)

This notebook mirrors the flow of `a_preparing_training_precise_qa_halu_as_outlier(1).ipynb`, but uses:
- `activation_research.trainer.ContrastiveTrainer` for contrastive training
- `activation_research.metric_evaluator.HallucinationEvaluator` for clean OOD evaluation (Mahalanobis / cosine)

It assumes activations + eval results are already on disk (e.g. a `.zarr` store).

In [None]:
import os
import sys

from torch.utils.data import DataLoader
from loguru import logger

from activation_logging.activation_parser import ActivationParser
from activation_research.model import ProgressiveCompressor
from activation_research.trainer import ContrastiveTrainer, ContrastiveTrainerConfig
from activation_research.metric_evaluator import MultiMetricHallucinationEvaluator

logger.remove()
logger.add(sys.stdout, level="INFO")


In [None]:
# ---- Paths (edit these for your environment) ----
inference_json = 'shared/goodwiki_jsonv2/generation.jsonl'
eval_json = 'shared/goodwiki.zarr/eval_results.json'
activations_path = 'shared/goodwiki.zarr/activations.zarr'

# ---- Dataset parameters ----
backend = 'zarr'  # 'zarr' or 'wds' or 'auto'
relevant_layers = list(range(14, 30))
target_layers = [22, 26]  # used for embedding-based OOD evaluation

# Treat one class as outlier for certain evaluations.
# If outlier_class=1, we use non-halu samples as baseline (ID).
outlier_class = 1

# Optional: if set, one of the two views is always from this index-in-relevant_layers
# (matches ActivationParser semantics).
fixed_layer = None

# ---- Model / training hyperparams ----
device = 'auto'  # 'auto', 'cuda', 'cpu'
input_dim = 4096
final_dim = 512

max_epochs = 50
batch_size = 512
lr = 1e-5
temperature = 0.25
steps_per_epoch_override = None  # e.g., 1000 for fixed steps/epoch

# For no worker restart behavior, keep num_workers > 0 and persistent_workers=True.
# Lower this if your machine is memory constrained.
num_workers = 30
persistent_workers = True

checkpoint_dir = os.path.join('checkpoints', 'contrastive_regular')


In [None]:
# ---- Load metadata + build train/test datasets ----
ap = ActivationParser(
    inference_json=inference_json,
    eval_json=eval_json,
    activations_path=activations_path,
    verbose=False,
)

train_dataset = ap.get_dataset(
    'train',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)
test_dataset = ap.get_dataset(
    'test',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)

print('train:', len(train_dataset))
print('test :', len(test_dataset))


In [None]:
# ---- Regular contrastive encoder ----
model = ProgressiveCompressor(
    input_dim=input_dim,
    final_dim=final_dim,
    input_dropout=0.3,
)
model

In [None]:
# ---- Train with the new Trainer API ----
config = ContrastiveTrainerConfig(
    max_epochs=max_epochs,
    batch_size=batch_size,
    lr=lr,
    temperature=temperature,
    steps_per_epoch_override=steps_per_epoch_override,
    device=device,
    num_workers=num_workers,
    persistent_workers=persistent_workers,
    checkpoint_dir=checkpoint_dir,
    save_every=1,
    snapshot_every=10,
    snapshot_keep_last=5,

    # Supervised contrastive: uses `halu` labels; ignore the configured outlier class.
    use_labels=True,
    ignore_label=outlier_class,

    # Keeps DataLoader workers alive for map-style datasets (disabled automatically for IterableDataset).
    use_infinite_index_stream=True,
    use_infinite_index_stream_eval=True,
)
print(config)

trainer = ContrastiveTrainer(model, config=config)
trainer.fit(train_dataset=train_dataset, val_dataset=test_dataset)


In [None]:
# ---- OOD evaluation with the new evaluator abstraction ----
# Build an ID-only baseline dataset for embeddings (mirrors the old notebook logic).
ap_id = ActivationParser(
    inference_json=inference_json,
    eval_json=eval_json,
    activations_path=activations_path,
    verbose=False,
)
if outlier_class == 0:
    ap_id.df = ap_id.df[ap_id.df['halu']]
elif outlier_class == 1:
    ap_id.df = ap_id.df[~ap_id.df['halu']]

train_dataset_for_inference = ap_id.get_dataset(
    'train',
    relevant_layers=target_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)
eval_dataset = ap.get_dataset(
    'test',
    relevant_layers=target_layers,
    fixed_layer=fixed_layer,
    backend=backend,
)

# DataLoaders are used by the evaluator primarily for the `.dataset` attribute.
train_loader_for_baseline = DataLoader(train_dataset_for_inference, batch_size=64, shuffle=False)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

model_for_eval = trainer.model  # already on the right device

# Multi-metric OOD evaluation using shared embeddings (computed once).
ood_eval = MultiMetricHallucinationEvaluator(
    activation_parser_df=ap.df,
    train_data_loader=train_loader_for_baseline,
    layers=None,
    batch_size=256,
    sub_batch_size=64,
    device=str(trainer.device),
    num_workers=num_workers,
    persistent_workers=False,
    outlier_class=outlier_class,
    metrics=[
        'cosine',
        'mds',
        {'metric': 'knn', 'kwargs': {'k': 5, 'metric': 'euclidean'}},
    ],
)
ood_stats = ood_eval.compute(eval_loader, model_for_eval)
print('OOD metrics:', ood_stats)
