# Activation Probing Demo

This notebook demonstrates how to use linear probes to decode hidden variables from transformer activations.

**Goal**: Train a linear classifier to predict whether an example is "deceptive" based solely on internal activations.

In [None]:
import sys
sys.path.insert(0, '..')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve

from deception_detector_jax.config import ModelConfig, DatasetConfig
from deception_detector_jax.models.tiny_transformer import init_model
from deception_detector_jax.data.deception_tasks import generate_task
from deception_detector_jax.interp.activation_cache import run_with_cache

## 1. Generate Dataset with Hidden Variables

In [None]:
# Generate Hidden Check task
data_config = DatasetConfig(
    task_name="hidden_check",
    num_train=2000,
    deception_rate=0.3,
    seed=42
)

data = generate_task("hidden_check", data_config, 2000)

print(f"Generated {len(data['input_ids'])} examples")
print(f"Forbidden rate: {data['forbidden'].mean():.2%}")

## 2. Initialize Model (Pretrained or Random)

In [None]:
# Initialize model
model_config = ModelConfig(
    seq_len=32,
    d_model=64,
    n_heads=4,
    n_layers=2,
    vocab_size=128,
    collect_intermediates=True
)

rng = jax.random.PRNGKey(0)
model, params = init_model(model_config, rng)

print("Model initialized!")
# TODO: Load pretrained parameters if available

## 3. Extract Activations from All Examples

In [None]:
# Run forward pass on all examples
inputs = jnp.array(data['input_ids'])
logits, cache = run_with_cache(model, params, inputs)

print(f"Logits shape: {logits.shape}")
print(f"Cache collected for {len(data['input_ids'])} examples")

In [None]:
# Extract MLP activations from last layer
layer_idx = model_config.n_layers - 1
mlp_acts = cache.get_mlp_activations(layer_idx)

if 'mlp_post_act' in mlp_acts:
    activations = np.array(mlp_acts['mlp_post_act'])
    print(f"MLP activations shape: {activations.shape}")
    
    # Average over sequence dimension
    activations_pooled = activations.mean(axis=1)
    print(f"Pooled activations shape: {activations_pooled.shape}")
else:
    print("No MLP activations found!")

## 4. Train Linear Probe to Detect Forbidden Cases

In [None]:
# Prepare data for probe
X = activations_pooled
y = data['forbidden']

# Split train/test
n_train = int(0.8 * len(X))
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

print(f"Training set: {len(X_train)} examples")
print(f"Test set: {len(X_test)} examples")

In [None]:
# Train logistic regression probe
probe = LogisticRegression(max_iter=1000, random_state=42)
probe.fit(X_train, y_train)

# Evaluate
y_pred = probe.predict(X_test)
y_proba = probe.predict_proba(X_test)[:, 1]

accuracy = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_proba)

print(f"Probe Accuracy: {accuracy:.3f}")
print(f"Probe AUC: {auc:.3f}")

## 5. Visualize Probe Performance

In [None]:
# ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_proba)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC={auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve: Forbidden Case Detection')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 6. Analyze Probe Weights

In [None]:
# Get probe weights
weights = probe.coef_[0]

# Plot weight magnitudes
plt.figure(figsize=(12, 4))
plt.bar(range(len(weights)), np.abs(weights))
plt.xlabel('Feature Index (MLP Neuron)')
plt.ylabel('Absolute Weight')
plt.title('Probe Weight Magnitudes (Which neurons are most informative?)')
plt.grid(True, alpha=0.3, axis='y')
plt.show()

# Find most important features
top_k = 10
top_indices = np.argsort(np.abs(weights))[-top_k:][::-1]
print(f"\nTop {top_k} most important features:")
for i, idx in enumerate(top_indices, 1):
    print(f"  {i}. Feature {idx}: weight = {weights[idx]:.3f}")

## 7. Interpretation

If the probe achieves high accuracy (>0.7), it means:
- The model's activations **encode** the hidden forbidden condition
- This information is linearly accessible
- We can potentially intervene on these features to change behavior

## TODO:
- Probe different layers to see where information emerges
- Compare attention vs MLP activations
- Visualize activation patterns for high-weight neurons