[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nawidayima/IPHR_Direction/blob/main/notebooks/03_extract_activations.ipynb)

# Extract Activations with TransformerLens

**Goal:** Extract residual stream activations from Llama-3-8B-Instruct for labeled trajectory data.

**Project Plan Reference:** Phase 2, Hours 6-8

**Key decisions:**
- **Token position:** Last token of prompt (decision point before generation)
- **Layers:** Upper third - layers 24, 28, 31 (of 32 total)
- **Output:** `[n_samples, d_model]` per layer

**Setup:** Add `HF_TOKEN` to Colab Secrets (key icon in sidebar), then Run All.

In [None]:
# Cell 0: Setup - Clone repo and install dependencies
import os

# Clone repo (only if not already cloned)
if not os.path.exists('/content/IPHR_Direction'):
    !git clone https://github.com/nawidayima/IPHR_Direction.git
    %cd /content/IPHR_Direction
else:
    %cd /content/IPHR_Direction
    !git pull  # Get latest changes

# Install dependencies
!pip install torch transformers accelerate pandas -q
!pip install transformer_lens -q

# Install package in editable mode
!pip install -e . -q

print("Setup complete!")

In [None]:
# Cell 1: Imports
import torch
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
from huggingface_hub import login

from transformer_lens import HookedTransformer

# Import from our package
from src.data_generation import Domain, SYSTEM_PROMPTS
from src.experiment_utils import load_config, load_results

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: HuggingFace Authentication
import os
from huggingface_hub import login

hf_token = None

# Method 1: Colab Secrets
try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    print("Found HF_TOKEN in Colab Secrets")
except:
    pass

# Method 2: Environment variable
if not hf_token and "HF_TOKEN" in os.environ:
    hf_token = os.environ["HF_TOKEN"]
    print("Found HF_TOKEN in environment")

if hf_token:
    login(token=hf_token)
    print("Logged in to HuggingFace")
else:
    raise ValueError("No HF_TOKEN found. Add to Colab Secrets or environment.")

## Load Trajectory Data

In [None]:
# Cell 3: Load trajectory data from previous experiment
RUN_DIR = Path("experiments/run_20251228_204835_expand_dataset")

# Load CSVs
geo_df = pd.read_csv(RUN_DIR / "trajectories/geography.csv")
dates_df = pd.read_csv(RUN_DIR / "trajectories/dates.csv")

# Combine all trajectories
all_df = pd.concat([geo_df, dates_df], ignore_index=True)

print(f"Loaded {len(all_df)} question pairs")
print(f"  - Geography: {len(geo_df)}")
print(f"  - Dates: {len(dates_df)}")
print(f"\nContradiction distribution:")
print(all_df["is_contradiction"].value_counts())

In [None]:
# Cell 4: Quick data inspection
print("Sample question pair:")
sample = all_df.iloc[0]
print(f"Domain: {sample['domain']}")
print(f"Question A: {sample['question_a']}")
print(f"Question B: {sample['question_b']}")
print(f"Answer A: {sample['answer_a']}, Answer B: {sample['answer_b']}")
print(f"Is Contradiction: {sample['is_contradiction']}")

## Load Model with TransformerLens

In [None]:
# Cell 5: Load Llama-3-8B-Instruct with TransformerLens
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

print(f"Loading {MODEL_NAME} with TransformerLens...")
print("This may take a few minutes...")

model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    fold_ln=False,           # Keep LayerNorm separate for interpretability
    center_writing_weights=False,
    center_unembed=False,
    device="cuda",
    dtype=torch.bfloat16,    # Use bfloat16 for memory efficiency
)

print(f"\nModel loaded!")
print(f"  - Layers: {model.cfg.n_layers}")
print(f"  - d_model: {model.cfg.d_model}")
print(f"  - Heads: {model.cfg.n_heads}")
print(f"  - d_head: {model.cfg.d_head}")

In [None]:
# Cell 6: Quick model test
test_prompt = "The capital of France is"
print(f"Test prompt: {test_prompt}")

# Generate a few tokens to verify model works
output = model.generate(test_prompt, max_new_tokens=5, temperature=0)
print(f"Model output: {output}")

# Test cache access
tokens = model.to_tokens(test_prompt)
_, cache = model.run_with_cache(tokens)
print(f"\nCache keys (sample): {list(cache.keys())[:5]}")
print(f"Residual stream shape at layer 31: {cache['resid_post', 31].shape}")

## Define Activation Extraction

In [None]:
# Cell 7: Configuration
# Per project-plan.md: "Start with layers in the upper third (layers 24-32 for Llama-3-8B's 32 layers)"
LAYERS_TO_PROBE = [24, 28, 31]  # Upper third, sample at different depths

print(f"Extracting activations from layers: {LAYERS_TO_PROBE}")
print(f"Token position: Last token of prompt (decision point)")
print(f"Expected output shape per layer: [d_model={model.cfg.d_model}]")

In [None]:
# Cell 8: Activation extraction function
def format_chat_prompt(question: str, system_prompt: str) -> str:
    """Format question with system prompt for Llama-3 chat template."""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    # Use the tokenizer's chat template
    formatted = model.tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    return formatted


def extract_activations(
    question: str,
    system_prompt: str,
    layers: list[int],
) -> dict[int, torch.Tensor]:
    """Extract residual stream activations at the last token position.

    Args:
        question: The question to process
        system_prompt: System prompt for the domain
        layers: List of layer indices to extract from

    Returns:
        Dict mapping layer index to activation tensor of shape [d_model]
    """
    # Format the prompt
    formatted_prompt = format_chat_prompt(question, system_prompt)

    # Tokenize
    tokens = model.to_tokens(formatted_prompt)

    # Run with cache
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)

    # Extract activations at last token position (decision point)
    activations = {}
    for layer in layers:
        # resid_post shape: [batch, seq_len, d_model]
        resid = cache["resid_post", layer]
        # Get last token, remove batch dim, move to CPU
        activations[layer] = resid[0, -1, :].cpu().to(torch.float32)

    return activations


# Test the function
test_system = SYSTEM_PROMPTS[Domain.GEOGRAPHY]
test_question = "Is Paris located south of Cairo? Think step by step, then answer YES or NO."

test_acts = extract_activations(test_question, test_system, LAYERS_TO_PROBE)
print("Test extraction successful!")
for layer, act in test_acts.items():
    print(f"  Layer {layer}: shape={act.shape}, norm={act.norm().item():.2f}")

## Extract Activations for All Trajectories

In [None]:
# Cell 9: Process all trajectories
print(f"Processing {len(all_df)} question pairs...")
print(f"This will extract {len(all_df) * 2} activations (A and B for each pair)")
print(f"Layers: {LAYERS_TO_PROBE}")
print()

results = []
errors = []

for idx, row in tqdm(all_df.iterrows(), total=len(all_df), desc="Extracting"):
    try:
        # Get system prompt for this domain
        domain = Domain(row["domain"])
        system_prompt = SYSTEM_PROMPTS[domain]

        # Extract activations for question A
        acts_a = extract_activations(row["question_a"], system_prompt, LAYERS_TO_PROBE)

        # Extract activations for question B
        acts_b = extract_activations(row["question_b"], system_prompt, LAYERS_TO_PROBE)

        results.append({
            "pair_id": row["pair_id"],
            "domain": row["domain"],
            "is_contradiction": row["is_contradiction"],
            "answer_a": row["answer_a"],
            "answer_b": row["answer_b"],
            "ground_truth_a": row["ground_truth_a"],
            "ground_truth_b": row["ground_truth_b"],
            "activations_a": acts_a,
            "activations_b": acts_b,
        })

        # Clear CUDA cache periodically to prevent OOM
        if idx % 50 == 0:
            torch.cuda.empty_cache()

    except Exception as e:
        errors.append({"idx": idx, "pair_id": row["pair_id"], "error": str(e)})
        print(f"\nError at idx {idx}: {e}")

print(f"\nExtraction complete!")
print(f"  Successful: {len(results)}")
print(f"  Errors: {len(errors)}")

## Save Activations

In [None]:
# Cell 10: Prepare data for saving
# Convert results to a more efficient format for probing

# Stack activations into tensors by layer
# Shape will be [n_samples * 2, d_model] for each layer (A and B questions)

def stack_activations(results: list, layers: list[int]):
    """Stack activations from all samples into tensors."""
    stacked = {layer: [] for layer in layers}
    labels = []  # 1 = contradiction, 0 = honest
    metadata = []

    for r in results:
        # Add question A activation
        for layer in layers:
            stacked[layer].append(r["activations_a"][layer])
        labels.append(1 if r["is_contradiction"] else 0)
        metadata.append({
            "pair_id": r["pair_id"],
            "domain": r["domain"],
            "question_type": "A",
            "is_contradiction": r["is_contradiction"],
        })

        # Add question B activation
        for layer in layers:
            stacked[layer].append(r["activations_b"][layer])
        labels.append(1 if r["is_contradiction"] else 0)
        metadata.append({
            "pair_id": r["pair_id"],
            "domain": r["domain"],
            "question_type": "B",
            "is_contradiction": r["is_contradiction"],
        })

    # Stack into tensors
    for layer in layers:
        stacked[layer] = torch.stack(stacked[layer])

    labels = torch.tensor(labels)

    return stacked, labels, metadata


activations_stacked, labels, metadata = stack_activations(results, LAYERS_TO_PROBE)

print("Stacked activations:")
for layer, acts in activations_stacked.items():
    print(f"  Layer {layer}: {acts.shape}")
print(f"Labels: {labels.shape}")
print(f"Label distribution: {labels.sum().item()} contradictions, {(~labels.bool()).sum().item()} honest")

In [None]:
# Cell 11: Save to disk
activations_dir = RUN_DIR / "activations"
activations_dir.mkdir(exist_ok=True)

save_path = activations_dir / "residual_stream_activations.pt"

save_data = {
    "model_name": MODEL_NAME,
    "layers": LAYERS_TO_PROBE,
    "token_position": "last",
    "d_model": model.cfg.d_model,
    "n_layers": model.cfg.n_layers,
    "activations": activations_stacked,  # Dict[layer, Tensor[n_samples*2, d_model]]
    "labels": labels,                     # Tensor[n_samples*2], 1=contradiction, 0=honest
    "metadata": metadata,                 # List of dicts with pair_id, domain, question_type
    "extraction_timestamp": datetime.now().isoformat(),
    "n_pairs": len(results),
    "n_samples": len(labels),
}

torch.save(save_data, save_path)
print(f"Saved activations to: {save_path}")
print(f"File size: {save_path.stat().st_size / 1e6:.1f} MB")

## Validation

In [None]:
# Cell 12: Validate saved data
loaded = torch.load(save_path)

print("Loaded data validation:")
print(f"  Model: {loaded['model_name']}")
print(f"  Layers: {loaded['layers']}")
print(f"  d_model: {loaded['d_model']}")
print(f"  n_pairs: {loaded['n_pairs']}")
print(f"  n_samples: {loaded['n_samples']}")
print(f"\nActivation shapes:")
for layer, acts in loaded['activations'].items():
    print(f"  Layer {layer}: {acts.shape}")
print(f"\nLabels: {loaded['labels'].shape}")
print(f"  Contradictions: {loaded['labels'].sum().item()}")
print(f"  Honest: {(~loaded['labels'].bool()).sum().item()}")

In [None]:
# Cell 13: Quick statistics
print("Activation statistics by layer:")
print()

for layer in LAYERS_TO_PROBE:
    acts = loaded['activations'][layer]
    labels_bool = loaded['labels'].bool()

    # Split by label
    contradiction_acts = acts[labels_bool]
    honest_acts = acts[~labels_bool]

    print(f"Layer {layer}:")
    print(f"  All - mean norm: {acts.norm(dim=1).mean():.2f}, std: {acts.norm(dim=1).std():.2f}")
    print(f"  Contradiction - mean norm: {contradiction_acts.norm(dim=1).mean():.2f}")
    print(f"  Honest - mean norm: {honest_acts.norm(dim=1).mean():.2f}")
    print()

In [None]:
# Cell 14: Domain breakdown
print("Samples by domain:")
domain_counts = {}
for m in loaded['metadata']:
    domain = m['domain']
    is_contra = m['is_contradiction']
    key = (domain, is_contra)
    domain_counts[key] = domain_counts.get(key, 0) + 1

for (domain, is_contra), count in sorted(domain_counts.items()):
    label = "contradiction" if is_contra else "honest"
    print(f"  {domain} - {label}: {count}")

## Summary

Activation extraction complete! The data is saved to:
```
experiments/run_20251228_204835_expand_dataset/activations/residual_stream_activations.pt
```

**Contents:**
- `activations`: Dict mapping layer index to tensor of shape `[n_samples, d_model]`
- `labels`: Tensor of shape `[n_samples]` (1=contradiction, 0=honest)
- `metadata`: List of dicts with pair_id, domain, question_type

**Next steps (Phase 2, Hours 8-10):**
1. Compute difference-in-means: `rationalization_dir = mean(contradiction) - mean(honest)`
2. Train logistic regression probe
3. Evaluate ROC-AUC on held-out set

In [None]:
# Cell 15: (Optional) Push to GitHub
# Uncomment to save activations to repo
# Note: activations file may be large, consider using git-lfs

# !git add experiments/
# !git commit -m "Add extracted activations from Phase 2"
# !git push