# Content Policy Probe

Training a probe to detect whether the assistant's response is violating the content policy.

In [None]:
# Basic Configuration
from tuberlens.config import DATA_DIR, LOCAL_MODELS


model_name = LOCAL_MODELS["llama-1b"]
layer = 11

CONCEPT = "policy"
CONCEPT_DIR = DATA_DIR / CONCEPT
PROBE_PATH = CONCEPT_DIR / f"{CONCEPT}_probe.pkl"

pos_class_label = "violation"
neg_class_label = "compliant"
probe_description = f"A linear probe on {model_name} detecting whether the assistant's response is compliant with the content policy."

## Dataset Creation

In [None]:
from tuberlens.utils import create_train_test_split
from tuberlens.interfaces.dataset import (
    LabelledDataset,
    Message,
)

dataset_path = CONCEPT_DIR / "training.csv"
dataset = LabelledDataset.load_from(dataset_path)
train_dataset, validation_dataset = create_train_test_split(
    dataset,
)

print(f"Read {len(train_dataset)} samples for training and {len(validation_dataset)} samples for validation.")

## Training

In [None]:
from tuberlens.training import train_probe
from tuberlens.interfaces.probes import ProbeSpec, ProbeType


probe = train_probe(
    train_dataset,
    validation_dataset,
    model_name,
    layer,
    start_turn_index=1,  # Exclude user message (Note that we don't have a system message in the dataset!)
    apply_transformations_to_validation_dataset=True,
    pos_class_label=pos_class_label,
    neg_class_label=neg_class_label,
    probe_description=probe_description,
    probe_spec=ProbeSpec(
        # name=ProbeType.sklearn,
        # hyperparams={},
        name=ProbeType.linear_then_mean,
        hyperparams={
            "batch_size": 8,
            "epochs": 200,
            "optimizer_args": {"lr": 1e-3, "weight_decay": 1e-2},
            "final_lr": 1e-4,
            "gradient_accumulation_steps": 1,
            "patience": 100,
            "temperature": 0.1,
        },
    ),
)

In [None]:
# Store the probe
import pickle

pickle.dump(probe, open(PROBE_PATH, "wb"))

## Loading and Evaluating

In [None]:
from tuberlens.model import LLMModel

probe = pickle.load(open(PROBE_PATH, "rb"))
assert probe.model_name is not None
assert probe.layer is not None
print("Probe initialized:")
print(probe.description)

# Initialize the model so we can compute activations
model = LLMModel.load(probe.model_name)

In [None]:
import yaml

# Load test inputs from YAML
with open(CONCEPT_DIR / "test_inputs.yaml") as f:
    raw_inputs = yaml.safe_load(f)
inputs = [[Message(**msg) for msg in pair] for pair in raw_inputs]

preds = probe.predict_proba_from_inputs(inputs, model=model)
for i in range(len(preds)):
    print(f"Sample {i}: {preds[i]}")
    print(f"Input: {inputs[i]}")
    print()