# High-stakes Probe

This notebook trains a high-stakes probe as described in [McKenzie et al. 2025](https://arxiv.org/abs/2506.10805v2).

In [None]:
# Basic Configuration
from tuberlens.config import DATA_DIR, LOCAL_MODELS


model_name = LOCAL_MODELS["llama-1b"]
layer = 11

CONCEPT = "high-stakes"
CONCEPT_DIR = DATA_DIR / CONCEPT
PROBE_PATH = CONCEPT_DIR / f"{CONCEPT}_probe.pkl"

pos_class_label = "high-stakes"
neg_class_label = "low-stakes"
probe_description = f"A linear probe on {model_name} detecting whether the conversation is high-stakes."

## Dataset Creation

In [None]:
# Simplified dataset (included in this repo)
from tuberlens.utils import create_train_test_split
from tuberlens.interfaces.dataset import LabelledDataset

stakes_dataset = CONCEPT_DIR / "combined_deployment_22_04_25.jsonl"
dataset = LabelledDataset.load_from(stakes_dataset, pos_class_label=pos_class_label, neg_class_label=neg_class_label)
train_dataset, validation_dataset = create_train_test_split(
    dataset, split_field="pair_id"
)

print(f"Read {len(train_dataset)} samples for training and {len(validation_dataset)} samples for validation.")

In [None]:
# Alternative: Use the original dataset
# (link can be found in https://github.com/Arrrlex/models-under-pressure/tree/main)
from tuberlens.interfaces.dataset import download_and_load_dataset

# URLs for the datasets
train_url = "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/training/prompts_4x/train.jsonl"
test_url = "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/training/prompts_4x/test.jsonl"

train_dataset = download_and_load_dataset(train_url, pos_class_label=pos_class_label, neg_class_label=neg_class_label)
validation_dataset = download_and_load_dataset(test_url, pos_class_label=pos_class_label, neg_class_label=neg_class_label)

print(f"Read {len(train_dataset)} samples for training and {len(validation_dataset)} samples for validation.")

## Training

In [None]:
from tuberlens.training import train_probe
from tuberlens.interfaces.probes import ProbeSpec, ProbeType


probe = train_probe(
    train_dataset,
    validation_dataset,
    model_name,
    layer,
    # start_turn_index=0,  # Include system and user message
    pos_class_label=pos_class_label,
    neg_class_label=neg_class_label,
    probe_description=probe_description,
    probe_spec=ProbeSpec(
        name=ProbeType.linear_then_softmax,
        hyperparams={},
    ),
)

In [None]:
# Store the probe
import pickle

pickle.dump(probe, open(PROBE_PATH, "wb"))

## Loading and Evaluating

In [None]:
import pickle

from tuberlens.model import LLMModel

probe = pickle.load(open(PROBE_PATH, "rb"))
assert probe.model_name is not None
assert probe.layer is not None
print("Probe initialized:")
print(probe.description)

# Initialize the model so we can compute activations
model = LLMModel.load(probe.model_name)

In [None]:
import yaml
from tuberlens.interfaces.dataset import Message

# Load test inputs from YAML
with open(CONCEPT_DIR / "test_inputs.yaml") as f:
    raw_inputs = yaml.safe_load(f)
inputs = [[Message(**msg) for msg in pair] for pair in raw_inputs]

preds = probe.predict_proba_from_inputs(inputs, model=model)
for i in range(len(preds)):
    print(f"Sample {i}: {preds[i]}")
    print(f"Input: {inputs[i]}")
    print()

### Using Proper Evaluation Datasets

Here we use the evaluation datasets from the original paper. This lets us compare performance of our trained probe to the one from the paper.

In [None]:
from tuberlens.interfaces.dataset import download_and_load_dataset

USE_TEST = False

# URLs as specified in the readme of https://github.com/Arrrlex/models-under-pressure
# (We are selecting the balanced versions as in the paper)
EVAL_DATASETS_URLS = {
    "dev": {
        "anthropic": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/dev/anthropic_balanced_apr_23.jsonl",
        "mt": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/dev/mt_balanced_apr_30.jsonl",
        "mts": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/dev/mts_balanced_apr_22.jsonl",
        "toolace": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/dev/toolace_balanced_apr_22.jsonl",
    },
    "test": {
        "anthropic": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/anthropic_test_balanced_apr_23.jsonl",
        "mt": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/mt_test_balanced_apr_30.jsonl",
        "mts": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/mts_test_balanced_apr_22.jsonl",
        "toolace": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/toolace_test_balanced_apr_22.jsonl",
        "mental_health": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/mental_health_test_balanced_apr_22.jsonl",
        "redteaming": "https://pub-fd16e959a4f14ca48765b437c9425ba6.r2.dev/evals/test/aya_redteaming_balanced.csv",
    }
} 

# First we download all evaluation datasets
split = "test" if USE_TEST else "dev"
eval_datasets = {name: download_and_load_dataset(url, pos_class_label=pos_class_label, neg_class_label=neg_class_label)
                 for name, url in EVAL_DATASETS_URLS[split].items()}

# Print dataset sizes
for name, dataset in eval_datasets.items():
    print(f"{name}: {len(dataset)} samples")


In [None]:
from tuberlens.evaluation import get_performances
from tuberlens.interfaces.dataset import subsample_balanced_subset

max_samples = 100  # Downsample for faster evaluation (set to None for full evaluation)
performances = get_performances(probe, {name: subsample_balanced_subset(dataset, n_per_class=max_samples//2) if max_samples is not None else dataset
                                        for name, dataset in eval_datasets.items()})
performances


### Direct Probing of HuggingFace Models

In [None]:
# Verifying that the probe works with activation tensors

# NOTE To apply the probe to a HF transformer directly, get the activations tensor
# from activations before layer norm. Either process one item at a time or make sure
# to apply the attention mask.
from tuberlens.probes.pytorch_probes import filter_activations_by_turns

for inp in inputs:
    activations = model.get_activations([inp], layer=probe.layer)

    activations = filter_activations_by_turns(
        activations=activations,
        inputs=[inp],
        model=model,
        start_turn_index=probe.start_turn_index,
        end_turn_index=probe.end_turn_index,
    )
    print(probe.predict_proba_from_activations_tensor(activations.activations[0]))