# Notebook 04: Causal Intervention

**Research Question:** Is the action-grounding representation causally relevant to behavior?

This notebook:
1. Extracts probe direction from trained reality probe
2. Runs steering experiments (add/subtract direction)
3. Computes dose-response curves
4. Tests control (random direction)

**Success criteria:** Steering changes tool call rate by >20%

**Note:** This is the most challenging experiment. Null results are acceptable if reported honestly.

## Setup

In [5]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.utils.logging import setup_logging
from src.config import get_config
from src.data.io import load_episodes
from src.analysis.probes import load_probe, get_probe_direction
from src.intervention.steering import run_steering_experiment, compute_dose_response, plot_dose_response

setup_logging(level="INFO")
config = get_config()

print("Causal Intervention Experiments")
print(f"Steering config:")
print(f"  Alphas: {config.steering.alphas}")
print(f"  Target layer: {config.steering.target_layer}")
print(f"  Samples per alpha: {config.steering.n_samples_per_alpha}")

Causal Intervention Experiments
Steering config:
  Alphas: [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]
  Target layer: 16
  Samples per alpha: 50


## 1. Load Probe and Episodes

In [6]:
# Load trained reality probe
reality_probe = load_probe(config.data.processed_dir / "reality_probe.pkl")

# Get probe direction
probe_direction = get_probe_direction(reality_probe)

print(f"Loaded reality probe")
print(f"  Direction shape: {probe_direction.shape}")
print(f"  Direction norm: {np.linalg.norm(probe_direction):.3f}")

2025-12-24 10:54:58,870 - src.analysis.probes - INFO - Loaded probe from: data/processed/reality_probe.pkl
Loaded reality probe
  Direction shape: (4096,)
  Direction norm: 13.468


In [8]:
# Load episodes
episodes_collection = load_episodes("../data/processed/episodes_v2.parquet")

# Filter to interesting cases for steering
# Test 1: Can we steer fake_action (claims but no tool) to actually call tool?
fake_episodes = episodes_collection.get_fake_episodes()

# Test 2: Can we steer true_action to not call tool?
true_episodes = episodes_collection.filter_by_category('true_action').episodes

print(f"\nEpisodes for steering:")
print(f"  Fake actions: {len(fake_episodes)}")
print(f"  True actions: {len(true_episodes)}")

2025-12-24 10:56:18,676 - src.data.io - INFO - Loading episodes from: ../data/processed/episodes_v2.parquet


ValidationError: 3 validation errors for Episode
category
  Input should be 'true_action', 'fake_action', 'honest_no_action' or 'silent_action' [type=enum, input_value='wrong_tool', input_type=str]
    For further information visit https://errors.pydantic.dev/2.12/v/enum
tool_used_any
  Extra inputs are not permitted [type=extra_forbidden, input_value=True, input_type=bool]
    For further information visit https://errors.pydantic.dev/2.12/v/extra_forbidden
wrong_tool_name
  Extra inputs are not permitted [type=extra_forbidden, input_value='sendMessage', input_type=str]
    For further information visit https://errors.pydantic.dev/2.12/v/extra_forbidden

In [None]:
# Sample episodes for steering (to keep compute manageable)
n_samples = config.steering.n_samples_per_alpha

np.random.seed(42)
if len(fake_episodes) > n_samples:
    fake_sample = np.random.choice(fake_episodes, n_samples, replace=False).tolist()
else:
    fake_sample = fake_episodes

if len(true_episodes) > n_samples:
    true_sample = np.random.choice(true_episodes, n_samples, replace=False).tolist()
else:
    true_sample = true_episodes

print(f"\nSampled for steering:")
print(f"  Fake: {len(fake_sample)}")
print(f"  True: {len(true_sample)}")

## 2. Steering Experiment: Fake → True

**Test:** Add probe direction to fake_action episodes. Does this cause tool calls?

In [None]:
# Run steering on fake episodes with positive alphas
# WARNING: This takes ~1-2 hours depending on GPU
# Checkpointing enabled: progress is saved incrementally and will resume if interrupted

checkpoint_path = config.data.processed_dir / "fake_steering_checkpoint.json"

fake_steering_results = run_steering_experiment(
    probe_direction=probe_direction,
    episodes=fake_sample,
    alphas=config.steering.alphas,
    model_id=config.model.id,
    target_layer=config.steering.target_layer,
    verbose=True,
    checkpoint_path=checkpoint_path,
)

print(f"\nCompleted {len(fake_steering_results)} steering experiments on fake episodes")
print(f"Checkpoint saved to: {checkpoint_path}")

In [None]:
# Compute dose-response for fake episodes
fake_dose_response = compute_dose_response(fake_steering_results)

print("\nDose-response (Fake Episodes):")
for alpha, rate in zip(fake_dose_response['alphas'], fake_dose_response['tool_rates']):
    print(f"  α = {alpha:+.1f}: tool_rate = {rate:.1%}")

## 3. Steering Experiment: True → Fake

**Test:** Subtract probe direction from true_action episodes. Does this suppress tool calls?

In [None]:
# Run steering on true episodes with negative alphas
# WARNING: This takes ~1-2 hours depending on GPU
# Checkpointing enabled: progress is saved incrementally and will resume if interrupted

checkpoint_path = config.data.processed_dir / "true_steering_checkpoint.json"

true_steering_results = run_steering_experiment(
    probe_direction=probe_direction,
    episodes=true_sample,
    alphas=config.steering.alphas,
    model_id=config.model.id,
    target_layer=config.steering.target_layer,
    verbose=True,
    checkpoint_path=checkpoint_path,
)

print(f"\nCompleted {len(true_steering_results)} steering experiments on true episodes")
print(f"Checkpoint saved to: {checkpoint_path}")

In [None]:
# Compute dose-response for true episodes
true_dose_response = compute_dose_response(true_steering_results)

print("\nDose-response (True Episodes):")
for alpha, rate in zip(true_dose_response['alphas'], true_dose_response['tool_rates']):
    print(f"  α = {alpha:+.1f}: tool_rate = {rate:.1%}")

## 4. Control: Random Direction

**Test:** Steering with a random direction should have no effect.

In [None]:
# Generate random direction (same dimensionality)
np.random.seed(42)
random_direction = np.random.randn(len(probe_direction))
random_direction = random_direction / np.linalg.norm(random_direction)

print(f"Random direction shape: {random_direction.shape}")
print(f"Random direction norm: {np.linalg.norm(random_direction):.3f}")

# Cosine similarity with probe direction (should be ~0)
cosine_sim = np.dot(probe_direction / np.linalg.norm(probe_direction), random_direction)
print(f"Cosine similarity with probe: {cosine_sim:.3f}")

In [None]:
# Sample smaller subset for control (to save compute)
control_sample = fake_sample[:20]

# Run steering with random direction
# Checkpointing enabled: progress is saved incrementally and will resume if interrupted
checkpoint_path = config.data.processed_dir / "control_steering_checkpoint.json"

control_results = run_steering_experiment(
    probe_direction=random_direction,
    episodes=control_sample,
    alphas=config.steering.alphas,
    model_id=config.model.id,
    target_layer=config.steering.target_layer,
    verbose=True,
    checkpoint_path=checkpoint_path,
)

control_dose_response = compute_dose_response(control_results)

print("\nControl dose-response (Random Direction):")
for alpha, rate in zip(control_dose_response['alphas'], control_dose_response['tool_rates']):
    print(f"  α = {alpha:+.1f}: tool_rate = {rate:.1%}")
print(f"Checkpoint saved to: {checkpoint_path}")

## 5. Visualization

**Figure 6:** Dose-response curves

In [None]:
# Combined dose-response plot
fig, ax = plt.subplots(figsize=(12, 7))

# Fake episodes (adding direction should increase tool calls)
ax.plot(
    fake_dose_response['alphas'],
    fake_dose_response['tool_rates'],
    'o-',
    linewidth=2,
    markersize=8,
    label='Fake Episodes (baseline: no tool)',
    color='red',
)

# True episodes (subtracting direction should decrease tool calls)
ax.plot(
    true_dose_response['alphas'],
    true_dose_response['tool_rates'],
    's-',
    linewidth=2,
    markersize=8,
    label='True Episodes (baseline: tool used)',
    color='green',
)

# Control (should be flat)
ax.plot(
    control_dose_response['alphas'],
    control_dose_response['tool_rates'],
    '^--',
    linewidth=1,
    markersize=6,
    label='Control (random direction)',
    color='gray',
    alpha=0.7,
)

ax.axhline(y=0.5, color='k', linestyle=':', linewidth=1, alpha=0.5)
ax.axvline(x=0, color='k', linestyle=':', linewidth=1, alpha=0.5)

ax.set_xlabel('Steering Strength (α)', fontsize=14)
ax.set_ylabel('Tool Call Rate', fontsize=14)
ax.set_title('Steering Vector Dose-Response Curves', fontsize=16)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.0)

plt.tight_layout()
plt.savefig(config.figures_dir / "figure6_steering_dose_response.png", dpi=300, bbox_inches='tight')
plt.savefig(config.figures_dir / "figure6_steering_dose_response.pdf", bbox_inches='tight')
plt.show()

## 6. Compute Effect Sizes

In [None]:
# Effect on fake episodes
baseline_fake_rate = fake_dose_response['tool_rates'][fake_dose_response['alphas'].index(0.0)]
max_alpha_idx = fake_dose_response['alphas'].index(max(config.steering.alphas))
max_fake_rate = fake_dose_response['tool_rates'][max_alpha_idx]

fake_effect_size = max_fake_rate - baseline_fake_rate

print(f"\n**Effect on Fake Episodes:**")
print(f"  Baseline (α=0): {baseline_fake_rate:.1%}")
print(f"  Max steering (α={max(config.steering.alphas)}): {max_fake_rate:.1%}")
print(f"  Effect size: {fake_effect_size:+.1%}")

if abs(fake_effect_size) > 0.20:
    print(f"  ✓ Effect > 20% - Causal evidence!")
else:
    print(f"  ✗ Effect < 20% - Weak or no causal effect")

In [None]:
# Effect on true episodes
baseline_true_rate = true_dose_response['tool_rates'][true_dose_response['alphas'].index(0.0)]
min_alpha_idx = true_dose_response['alphas'].index(min(config.steering.alphas))
min_true_rate = true_dose_response['tool_rates'][min_alpha_idx]

true_effect_size = baseline_true_rate - min_true_rate

print(f"\n**Effect on True Episodes:**")
print(f"  Baseline (α=0): {baseline_true_rate:.1%}")
print(f"  Min steering (α={min(config.steering.alphas)}): {min_true_rate:.1%}")
print(f"  Effect size: {true_effect_size:+.1%}")

if abs(true_effect_size) > 0.20:
    print(f"  ✓ Effect > 20% - Causal evidence!")
else:
    print(f"  ✗ Effect < 20% - Weak or no causal effect")

In [None]:
# Control check (should be flat)
control_rates = control_dose_response['tool_rates']
control_variance = np.var(control_rates)

print(f"\n**Control (Random Direction):**")
print(f"  Mean rate: {np.mean(control_rates):.1%}")
print(f"  Variance: {control_variance:.4f}")

if control_variance < 0.01:
    print(f"  ✓ Control is flat (low variance)")
else:
    print(f"  ⚠ Control shows variation (unexpected)")

## 7. Example Steered Generations

Show concrete examples of steering effects.

In [None]:
# Find examples where steering caused tool call
successful_steers = [
    r for r in fake_steering_results
    if r.alpha > 0 and not r.original_tool_used and r.steered_tool_used
]

print(f"\nFound {len(successful_steers)} successful steering cases (fake → tool call)")

if successful_steers:
    # Show first example
    example = successful_steers[0]
    
    print(f"\n{'='*60}")
    print(f"EXAMPLE: Steering Induced Tool Call")
    print(f"{'='*60}")
    print(f"Episode ID: {example.episode_id}")
    print(f"Steering strength: α = {example.alpha}")
    print(f"\nOriginal reply (no tool):")
    print(example.original_reply[:300] + "...")
    print(f"\nSteered reply (tool call added):")
    print(example.steered_reply[:300] + "...")
    print(f"{'='*60}")

## Summary

In [None]:
print("=" * 60)
print("PHASE 4 RESULTS: CAUSAL INTERVENTION")
print("=" * 60)

print(f"\nSteering on Fake Episodes:")
print(f"  Baseline tool rate (α=0): {baseline_fake_rate:.1%}")
print(f"  Max steering tool rate: {max_fake_rate:.1%}")
print(f"  Effect: {fake_effect_size:+.1%}")

print(f"\nSteering on True Episodes:")
print(f"  Baseline tool rate (α=0): {baseline_true_rate:.1%}")
print(f"  Min steering tool rate: {min_true_rate:.1%}")
print(f"  Effect: {true_effect_size:+.1%}")

print(f"\nControl (Random Direction):")
print(f"  Variance: {control_variance:.4f}")

# Overall assessment
if abs(fake_effect_size) > 0.20 or abs(true_effect_size) > 0.20:
    print("\n✓ CAUSAL EVIDENCE: Steering changes behavior by >20%")
    print("✓ The probe direction is causally relevant!")
else:
    print("\n⚠ WEAK CAUSAL EVIDENCE: Effect < 20%")
    print("  Representation may be correlational, not causal")
    print("  OR: Steering method needs refinement")

print("=" * 60)

## Honest Assessment

**Note for write-up:**

If steering works:
- Strong causal evidence that probe detects action-grounding
- Can induce/suppress tool calls by adding/subtracting direction

If steering doesn't work:
- Still have strong correlational evidence from Notebooks 01-03
- Steering failure could mean:
  - Representation is predictive but not causally determining
  - Steering method needs refinement (wrong layer, wrong strength)
  - Model has redundant circuits that compensate
- This is STILL a valuable finding - shows limits of linear probing

**Either outcome is publishable if reported honestly.**