# Causal Ablation Demo

This notebook demonstrates causal ablation experiments to identify circuits responsible for deceptive behavior.

**Goal**: Ablate (remove) specific model components and observe how deceptive behavior changes.

In [None]:
import sys
sys.path.insert(0, '..')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from deception_detector_jax.config import ModelConfig, DatasetConfig
from deception_detector_jax.models.tiny_transformer import init_model
from deception_detector_jax.data.deception_tasks import generate_task
from deception_detector_jax.interp.ablations import sweep_head_ablations, sweep_layer_ablations
from deception_detector_jax.viz.plots import plot_head_ablation_impact

## 1. Setup: Load Model and Data

In [None]:
# Generate test data
data_config = DatasetConfig(
    task_name="hidden_check",
    num_train=500,
    deception_rate=0.3,
    seed=42
)

data = generate_task("hidden_check", data_config, 500)

print(f"Generated {len(data['input_ids'])} examples")
print(f"Forbidden rate: {data['forbidden'].mean():.2%}")

In [None]:
# Initialize model
model_config = ModelConfig(
    seq_len=32,
    d_model=64,
    n_heads=4,
    n_layers=2,
    vocab_size=128,
    collect_intermediates=False  # Disable for speed during ablation
)

rng = jax.random.PRNGKey(0)
model, params = init_model(model_config, rng)

print("Model initialized!")
# TODO: Load trained parameters

## 2. Sweep: Ablate Each Attention Head

In [None]:
# Prepare inputs and targets
input_ids = jnp.array(data['input_ids'][:100])
target_ids = jnp.array(data['target_ids'][:100])

print("Running head ablation sweep...")
impact_matrix = sweep_head_ablations(
    model,
    params,
    input_ids,
    target_ids,
    n_layers=model_config.n_layers,
    n_heads=model_config.n_heads
)

print("Ablation sweep complete!")
print(f"Impact matrix shape: {impact_matrix.shape}")

In [None]:
# Visualize impact matrix
plot_head_ablation_impact(impact_matrix, figsize=(10, 6))

## 3. Identify Critical Heads

In [None]:
# Find heads with highest ablation impact
n_layers, n_heads = impact_matrix.shape
flat_impacts = impact_matrix.flatten()
flat_indices = np.argsort(flat_impacts)[::-1]

print("Most important heads (highest ablation impact):\n")
for i, idx in enumerate(flat_indices[:5], 1):
    layer_idx = idx // n_heads
    head_idx = idx % n_heads
    impact = flat_impacts[idx]
    print(f"  {i}. Layer {layer_idx}, Head {head_idx}: impact = {impact:.4f}")

## 4. Focused Analysis: Deceptive vs Non-Deceptive

In [None]:
# Split into forbidden and clean examples
forbidden_mask = data['forbidden'] == 1
clean_mask = data['forbidden'] == 0

forbidden_inputs = jnp.array(data['input_ids'][forbidden_mask][:50])
forbidden_targets = jnp.array(data['target_ids'][forbidden_mask][:50])

clean_inputs = jnp.array(data['input_ids'][clean_mask][:50])
clean_targets = jnp.array(data['target_ids'][clean_mask][:50])

print(f"Forbidden examples: {len(forbidden_inputs)}")
print(f"Clean examples: {len(clean_inputs)}")

In [None]:
# Ablate heads on forbidden vs clean separately
print("Ablating on forbidden examples...")
forbidden_impact = sweep_head_ablations(
    model, params, forbidden_inputs, forbidden_targets,
    n_layers=model_config.n_layers, n_heads=model_config.n_heads
)

print("Ablating on clean examples...")
clean_impact = sweep_head_ablations(
    model, params, clean_inputs, clean_targets,
    n_layers=model_config.n_layers, n_heads=model_config.n_heads
)

In [None]:
# Compare impact on forbidden vs clean
differential_impact = forbidden_impact - clean_impact

plt.figure(figsize=(10, 6))
plt.imshow(differential_impact, cmap='RdBu_r', aspect='auto')
plt.colorbar(label='Impact Difference (Forbidden - Clean)')
plt.xlabel('Head Index')
plt.ylabel('Layer Index')
plt.title('Differential Ablation Impact\n(Red = More important for forbidden cases)')
plt.grid(True, alpha=0.3, color='white')
plt.show()

## 5. Layer-wise Ablation

In [None]:
# Ablate entire layers
print("Running layer ablation sweep...")
layer_results = sweep_layer_ablations(
    model, params, input_ids, target_ids,
    n_layers=model_config.n_layers
)

print("Layer ablation complete!")
print(layer_results)

In [None]:
# Plot layer impacts
fig, ax = plt.subplots(figsize=(10, 6))

layers = range(model_config.n_layers)
ax.plot(layers, layer_results['attn_impact'], marker='o', label='Attention', linewidth=2)
ax.plot(layers, layer_results['mlp_impact'], marker='s', label='MLP', linewidth=2)
ax.plot(layers, layer_results['all_impact'], marker='^', label='Both', linewidth=2)

ax.set_xlabel('Layer')
ax.set_ylabel('Loss Increase')
ax.set_title('Layer Ablation Impact')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

## 6. Interpretation

### What to look for:

1. **Critical Heads**: Heads with high ablation impact are important for the task
2. **Deception Circuits**: Heads with higher impact on forbidden cases may implement deception
3. **Layer Patterns**: Which layers are most important for detecting CHECK_FLAG?

### Next steps:
- Visualize attention patterns for critical heads
- Patch activations from clean to forbidden examples
- Test if ablating critical heads eliminates deceptive behavior

In [None]:
# TODO: Implement targeted ablation of deception circuit
# TODO: Measure behavioral change after ablation