# High-stakes Probe

This notebook trains a deception probe as described in [McKenzie et al. 2025](https://arxiv.org/abs/2506.10805v2) (using a different training dataset).

In [None]:
# Basic Configuration
from potato.config import DATA_DIR, LOCAL_MODELS


model_name = LOCAL_MODELS["llama-1b"]
layer = 11

CONCEPT = "high-stakes"
CONCEPT_DIR = DATA_DIR / CONCEPT
PROBE_PATH = CONCEPT_DIR / f"{CONCEPT}_probe.pkl"

pos_class_label = "high-stakes"
neg_class_label = "low-stakes"
probe_description = f"A linear probe on {model_name} detecting whether the conversation is high-stakes."

## Dataset Creation

In [None]:
from potato.utils import create_train_test_split
from potato.interfaces.dataset import (
    LabelledDataset,
    Message,
)

stakes_dataset = CONCEPT_DIR / "combined_deployment_22_04_25.jsonl"
dataset = LabelledDataset.load_from(stakes_dataset)
train_dataset, validation_dataset = create_train_test_split(
    dataset, split_field="pair_id"
)

print(f"Read {len(train_dataset)} samples for training and {len(validation_dataset)} samples for validation.")

## Training

In [None]:
from potato.training import train_probe
from potato.interfaces.probes import ProbeSpec, ProbeType


probe = train_probe(
    train_dataset,
    validation_dataset,
    model_name,
    layer,
    start_turn_index=0,  # Include system and user message
    apply_transformations_to_validation_dataset=True,
    pos_class_label=pos_class_label,
    neg_class_label=neg_class_label,
    probe_description=probe_description,
    probe_spec=ProbeSpec(
        # name=ProbeType.sklearn,
        # hyperparams={},
        name=ProbeType.linear_then_mean,
        hyperparams={
            "batch_size": 8,
            "epochs": 200,
            "optimizer_args": {"lr": 1e-3, "weight_decay": 1e-2},
            "final_lr": 1e-4,
            "gradient_accumulation_steps": 1,
            "patience": 100,
            "temperature": 0.1,
        },
    ),
)

In [None]:
# Store the probe
import pickle

pickle.dump(probe, open(PROBE_PATH, "wb"))

## Loading and Evaluating

In [None]:
from potato.model import LLMModel

probe = pickle.load(open(PROBE_PATH, "rb"))
assert probe.model_name is not None
assert probe.layer is not None
print("Probe initialized:")
print(probe.description)

# Initialize the model so we can compute activations
model = LLMModel.load(probe.model_name)

In [None]:
import yaml

# Load test inputs from YAML
with open(CONCEPT_DIR / "test_inputs.yaml") as f:
    raw_inputs = yaml.safe_load(f)
inputs = [[Message(**msg) for msg in pair] for pair in raw_inputs]

preds = probe.predict_proba_from_inputs(inputs, model=model)
for i in range(len(preds)):
    print(f"Sample {i}: {preds[i]}")
    print(f"Input: {inputs[i]}")
    print()

In [None]:
# Verifying that the probe works with activation tensors

# NOTE To apply the probe to a HF transformer directly, get the activations tensor
# from activations before layer norm. Either process one item at a time or make sure
# to apply the attention mask.
from potato.probes.pytorch_probes import filter_activations_by_turns

for inp in inputs:
    activations = model.get_activations([inp], layer=probe.layer)

    activations = filter_activations_by_turns(
        activations=activations,
        inputs=[inp],
        model=model,
        start_turn_index=probe.start_turn_index,
        end_turn_index=probe.end_turn_index,
    )
    print(probe.predict_proba_from_activations_tensor(activations.activations[0]))