# Causal Intervention to Test for a Shortcut
This notebook tests the hypothesis that a specific direction in activation space causally influences the model's truthfulness. We will:
1. **Calculate the 'shortcut vector'** by finding the difference between mean true and mean false activations at a late layer.
2. **Intervene** during a forward pass by adding or subtracting this vector.
3. **Observe** if the intervention flips the model's output from true to false, or vice versa.

In [1]:
import torch as t

# Import our custom functions and classes
from utils import DataManager, load_model
from intervention import run_intervention_experiment

# --- Configuration ---
MODEL_FAMILY = "Llama3"
MODEL_SIZE = "8B"
MODEL_TYPE = "base"
TARGET_LAYER = 22 # A layer in the identified 20-26 range
DEVICE = 'cuda' if t.cuda.is_available() else 'cpu'
DEVICE = "mps" if t.mps.is_available() else DEVICE

t.set_grad_enabled(False)
print(f"Using device: {DEVICE}")



Using device: mps


## Step 1: Calculate the Shortcut Vector

In [2]:
print(f"Calculating the shortcut vector for Layer {TARGET_LAYER}...")

# Load pre-computed activations using the DataManager
dm = DataManager()
dm.add_dataset('cities', MODEL_FAMILY, MODEL_SIZE, MODEL_TYPE, TARGET_LAYER, center=False, device=DEVICE)
dm.add_dataset('neg_cities', MODEL_FAMILY, MODEL_SIZE, MODEL_TYPE, TARGET_LAYER, center=False, device=DEVICE)

# --- CORRECTED LOGIC ---
# Get activations and labels from both datasets
cities_acts, cities_labels = dm.data['cities']
neg_cities_acts, neg_cities_labels = dm.data['neg_cities']

# 1. Collect all TRUE activations (label == 1) from BOTH datasets
true_acts_from_cities = cities_acts[cities_labels == 1]
true_acts_from_neg_cities = neg_cities_acts[neg_cities_labels == 1]
all_true_acts = t.cat([true_acts_from_cities, true_acts_from_neg_cities], dim=0)

# 2. Collect all FALSE activations (label == 0) from BOTH datasets
false_acts_from_cities = cities_acts[cities_labels == 0]
false_acts_from_neg_cities = neg_cities_acts[neg_cities_labels == 0]
all_false_acts = t.cat([false_acts_from_cities, false_acts_from_neg_cities], dim=0)

print(f"Found {len(all_true_acts)} total true statements.")
print(f"Found {len(all_false_acts)} total false statements.")

# 3. Calculate means based on the correctly aggregated groups
mean_true_acts = all_true_acts.mean(dim=0)
mean_false_acts = all_false_acts.mean(dim=0)

# This vector points from the "average lie" to the "average truth"
shortcut_vector = mean_true_acts - mean_false_acts

print(f"Shortcut vector calculated. Norm: {shortcut_vector.norm().item():.2f}")

Calculating the shortcut vector for Layer 22...
Found 45 total true statements.
Found 45 total false statements.
Shortcut vector calculated. Norm: 3.62


## Step 2: Load Model and Run Intervention Experiments

In [5]:
print("Loading model...")
tokenizer, model = load_model(MODEL_FAMILY, MODEL_SIZE, MODEL_TYPE, DEVICE)

# --- Experiment 1: Try to flip a TRUE statement to FALSE ---
test_statement_true = "The Eiffel Tower is in the city of" # Model should complete 'Paris'
# We subtract the vector (negative alpha) to push it from the 'truth' region towards the 'lie' region
alpha_for_falsehood = -4.0
run_intervention_experiment(model, tokenizer, shortcut_vector, test_statement_true, TARGET_LAYER, alpha=alpha_for_falsehood)

# --- Experiment 2: Try to flip a FALSE statement to TRUE ---
# Let's set up a false context. The model 'knows' the capital of Italy is Rome.
test_statement_false = "The capital of Italy is the city of London, not the city of"
# We add the vector to push its state towards 'truth', hoping it corrects itself to 'Rome'
alpha_for_truth = 4.0
run_intervention_experiment(model, tokenizer, shortcut_vector, test_statement_false, TARGET_LAYER, alpha=alpha_for_truth)

Loading model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 36.28 GiB, other allocations: 800.00 KiB, max allowed: 36.27 GiB). Tried to allocate 8.00 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).