<a href="https://colab.research.google.com/github/mahadikprasad15/Efficacy-of-ensemble-of-attention-probes/blob/main/Efficacy_of_Ensembles_and_attention_probes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


%cd /content/drive/MyDrive

# Clone repository (if not already cloned)
!git clone https://github.com/mahadikprasad15/Efficacy-of-ensemble-of-attention-probes.git
%cd Efficacy-of-ensemble-of-attention-probes

In [None]:
!git pull

In [None]:
!pip install -q torch torchvision transformers safetensors pyyaml requests tqdm scikit-learn matplotlib pandas
!pip install -q cerebras-cloud-sdk


import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
import os
from getpass import getpass

# HuggingFace Token (for Llama 3.2)
# Get from: https://huggingface.co/settings/tokens
hf_token = getpass("Enter your HuggingFace token: ")
os.environ['HF_TOKEN'] = hf_token

# Cerebras API Key (for labeling)
# Get from: https://cloud.cerebras.ai/
cerebras_key = getpass("Enter your Cerebras API key: ")
os.environ['CEREBRAS_API_KEY'] = cerebras_key

print("‚úì Tokens set!")

In [None]:
# Download roleplaying dataset
!python scripts/download_apollo_data.py \
    --datasets roleplaying \
    --output_dir data/apollo_raw

# Verify download
!ls -lh data/apollo_raw/roleplaying/

# Optional: Preview the dataset
import yaml
with open('data/apollo_raw/roleplaying/dataset.yaml', 'r') as f:
    data = yaml.safe_load(f)
    print(f"Total scenarios: {len(data)}")
    print("\nFirst scenario example:")
    print(f"Scenario: {data[0]['scenario'][:200]}...")
    print(f"Question: {data[0]['question']}")
    print(f"Answer prefix: {data[0]['answer_prefix']}")


In [None]:
# Cache training set (100 examples for quick testing)
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --split train \
    --limit 100 \
    --batch_size 4 \
    --L_prime 28 \
    --T_prime 64 \
    --hf_token $HF_TOKEN \
    --labeling_model llama3.1-8b \
    --requests_per_minute 25

# This will:
# - Load 100 scenarios
# - Generate completions using Llama-3.2-3B (5-10 min)
# - Label using Cerebras Llama-8B (4-5 min)
# - Extract activations (only from generated tokens)
# - Resample to (28, 64, 3072)
# - Save to data/activations/...


In [None]:
!python scripts/validate_deception_data.py \
    --activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/train

In [None]:
# Train split (full dataset, ~180 examples)
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --split train \
    --batch_size 4 \
    --hf_token $HF_TOKEN

# Validation split
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --split validation \
    --batch_size 4 \
    --hf_token $HF_TOKEN

# Test split
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --split test \
    --batch_size 4 \
    --hf_token $HF_TOKEN



In [None]:
# Validate the cached activations
!python scripts/validate_deception_data.py \
    --activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/train

# Check what was saved
!ls -lh data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/train/

# Preview manifest
!head -n 3 data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/train/manifest.jsonl

In [None]:
# Train mean pooling probes on all layers
!python scripts/train_deception_probes.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --pooling mean \
    --batch_size 32 \
    --epochs 10 \
    --patience 5 \
    --lr 0.001 \
    --weight_decay 0.0001

# Probes saved to:
# data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/

In [None]:
# Analyze mean pooling results
!python scripts/analyze_probes.py \
    --probes_dir data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean \
    --save_plots \
    --save_report

# View analysis report
!cat data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/analysis_report.txt

# View best probe info
!cat data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/best_probe.json

# Display the per-layer AUC plot
from IPython.display import Image, display
display(Image('data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/per_layer_analysis.png'))


In [None]:
# Evaluate best probe on test split
!python scripts/eval_ood.py \
    --best_probe_json data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/best_probe.json \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --eval_dataset Deception-Roleplaying \
    --eval_split test

# View test results
!cat data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/eval_Deception-Roleplaying_test.json


In [None]:
# Max pooling
!python scripts/train_deception_probes.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --pooling max \
    --epochs 10 \
    --batch_size 32


# Last token pooling
!python scripts/train_deception_probes.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --pooling last \
    --epochs 10 \
    --batch_size 32

# Attention pooling (learned)
!python scripts/train_deception_probes.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --pooling attn \
    --batch_size 32 \
    --epochs 20


In [None]:
# ============================================================================
# Re-run compare_results.py with verbose output
# ============================================================================

!python scripts/compare_results.py \
    --experiments_dir data/probes \
    --output_dir results/comparisons \
    --save_csv 2>&1 | tee compare_output.txt

print("\n" + "=" * 60)
print("üìÅ Files generated:")
!ls -la results/comparisons/



In [None]:
# ============================================================================
# Generate Layerwise Comparison Plot (Inline - No Script Needed)
# ============================================================================
import json
import os
import numpy as np
import matplotlib.pyplot as plt

PROBES_BASE = "data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying"
OUTPUT_DIR = "results/comparisons"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load all pooling results
colors = {
    'mean': '#2E86AB',
    'max': '#A23B72',
    'last': '#F18F01',
    'attn': '#06A77D'
}

all_results = {}
for pooling in ['mean', 'max', 'last', 'attn']:
    results_file = f"{PROBES_BASE}/{pooling}/layer_results.json"
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            all_results[pooling] = json.load(f)
        print(f"‚úì Loaded {pooling}")

# Check if accuracy data exists
has_accuracy = all_results and 'val_acc' in all_results[list(all_results.keys())[0]][0]

# Create figure
if has_accuracy:
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
else:
    fig, ax1 = plt.subplots(figsize=(14, 6))
    ax2 = None

overall_best_auc = 0
overall_best_info = None

# Plot each pooling strategy
for pooling, layer_results in all_results.items():
    layers = [r['layer'] for r in layer_results]
    aucs = [r['val_auc'] for r in layer_results]
    color = colors.get(pooling, '#666666')

    # Plot AUC
    ax1.plot(layers, aucs, marker='o', linewidth=2.5, markersize=6,
             color=color, label=f'{pooling.upper()}', alpha=0.85)

    # Mark best layer
    best = max(layer_results, key=lambda x: x['val_auc'])
    ax1.scatter([best['layer']], [best['val_auc']],
                color=color, s=200, zorder=5, edgecolors='black',
                linewidths=2.5, marker='*')

    # Track overall best
    if best['val_auc'] > overall_best_auc:
        overall_best_auc = best['val_auc']
        overall_best_info = (pooling, best['layer'], best['val_auc'])

    # Plot accuracy if available
    if ax2 is not None and 'val_acc' in layer_results[0]:
        accs = [r.get('val_acc', 0.5) for r in layer_results]
        ax2.plot(layers, accs, marker='s', linewidth=2.5, markersize=6,
                 color=color, label=f'{pooling.upper()}', alpha=0.85)

# Annotate overall best
if overall_best_info:
    pooling, layer, auc = overall_best_info
    color = colors.get(pooling, '#666666')
    ax1.annotate(
        f'BEST: {pooling.upper()}\nLayer {layer}\nAUC: {auc:.3f}',
        xy=(layer, auc),
        xytext=(15, 15),
        textcoords='offset points',
        bbox=dict(boxstyle='round,pad=0.8', facecolor=color, alpha=0.3,
                 edgecolor='black', linewidth=2),
        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.3',
                       color='black', lw=2),
        fontsize=11, fontweight='bold', ha='left'
    )

# Style AUC plot
ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.4, linewidth=1.5, label='Random')
ax1.axhline(y=0.7, color='green', linestyle=':', alpha=0.4, linewidth=1.5, label='Strong (0.7)')
ax1.set_ylabel('Validation AUC', fontsize=13, fontweight='bold')
ax1.set_title('Layerwise Validation AUC Comparison\nAll 4 Pooling Strategies', fontsize=14, fontweight='bold')
ax1.legend(loc='best', fontsize=11, framealpha=0.9)
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.set_ylim(0.45, 1.0)

# Style accuracy plot if present
if ax2 is not None:
    ax2.axhline(y=0.5, color='red', linestyle='--', alpha=0.4, linewidth=1.5)
    ax2.set_xlabel('Layer', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Validation Accuracy', fontsize=13, fontweight='bold')
    ax2.set_title('Layerwise Validation Accuracy Comparison', fontsize=14, fontweight='bold')
    ax2.legend(loc='best', fontsize=11, framealpha=0.9)
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.set_ylim(0.45, 1.0)
else:
    ax1.set_xlabel('Layer', fontsize=13, fontweight='bold')

plt.tight_layout()

# Save and display
save_path = f"{OUTPUT_DIR}/layerwise_pooling_comparison.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n‚úì Saved: {save_path}")

# Display
from IPython.display import Image, display
display(Image(save_path, width=900))

# Print summary
print("\n" + "=" * 60)
print("üìä Summary: Best Layer for Each Pooling Strategy")
print("=" * 60)
for pooling, layer_results in all_results.items():
    best = max(layer_results, key=lambda x: x['val_auc'])
    marker = " ‚≠ê BEST" if overall_best_info and pooling == overall_best_info[0] else ""
    print(f"  {pooling.upper():6s}: Layer {best['layer']:2d} | AUC: {best['val_auc']:.4f}{marker}")

In [None]:
# Load and visualize per-layer results
import json
import matplotlib.pyplot as plt

# Load results
with open('data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/mean/layer_results.json', 'r') as f:
    results = json.load(f)

# Extract data
layers = [r['layer'] for r in results]
aucs = [r['val_auc'] for r in results]
epochs = [r['epoch'] for r in results]

# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: AUC per layer
ax1.plot(layers, aucs, marker='o', linewidth=2, markersize=6, color='#2E86AB')
ax1.axhline(y=0.5, color='red', linestyle='--', label='Random Chance', alpha=0.5)
best_layer = max(results, key=lambda x: x['val_auc'])
ax1.scatter([best_layer['layer']], [best_layer['val_auc']],
            color='orange', s=200, zorder=5, label=f"Best: Layer {best_layer['layer']}")
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('Validation AUC', fontsize=12)
ax1.set_title('Deception Detection by Layer', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Plot 2: Training epochs per layer
ax2.bar(layers, epochs, alpha=0.7, color='#A23B72')
ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('Training Epochs', fontsize=12)
ax2.set_title('Early Stopping Epochs by Layer', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('custom_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úì Best Layer: {best_layer['layer']} (AUC: {best_layer['val_auc']:.4f})")


In [None]:
%cd /content/drive/MyDrive/Efficacy-of-ensemble-of-attention-probes
!git pull origin main

In [None]:
# Clear old data
!rm -rf data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/

# Re-cache
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-InsiderTrading \
    --split test \
    --limit 200 \
    --batch_size 4 \
    --hf_token $HF_TOKEN

In [None]:
# ============================================================================
# CELL O4: Validate OOD Activations - Should Show Balanced Labels
# ============================================================================
!python scripts/validate_deception_data.py \
    --activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/test

In [None]:
# ============================================================================
# CELL O5: Evaluate Probes on OOD - INLINE VERSION (Bypasses Loader Bug)
# ============================================================================
!git pull origin main


import os
import json
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from safetensors.torch import load_file
from sklearn.metrics import roc_auc_score, accuracy_score
from tqdm import tqdm
import sys

sys.path.append(os.path.join(os.getcwd(), 'actprobe', 'src'))
from actprobe.probes.models import LayerProbe

# Paths
OOD_DIR = "data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/test"
PROBES_BASE = "data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying"
OUTPUT_DIR = "results/ood_evaluation"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load OOD data
print("Loading OOD data...")
with open(f"{OOD_DIR}/manifest.jsonl", 'r') as f:
    manifest = [json.loads(line) for line in f]

# Load all shards
shards = sorted(glob.glob(f"{OOD_DIR}/shard_*.safetensors"))
all_tensors = {}
for shard_path in shards:
    all_tensors.update(load_file(shard_path))

samples = []
labels = []
for entry in manifest:
    eid = entry['id']
    if eid in all_tensors:
        samples.append(all_tensors[eid])
        labels.append(entry['label'])

X = torch.stack(samples).float()
y = np.array(labels)
print(f"‚úì Loaded {len(X)} OOD samples")
print(f"  Labels: {sum(y==0)} honest, {sum(y==1)} deceptive")

# Colors
COLORS = {'mean': '#2E86AB', 'max': '#A23B72', 'last': '#F18F01', 'attn': '#06A77D'}

# Evaluate all pooling strategies
all_results = {}

for pooling in ['mean', 'max', 'last', 'attn']:
    probe_dir = f"{PROBES_BASE}/{pooling}"
    probe_files = sorted(glob.glob(f"{probe_dir}/probe_layer_*.pt"))

    if not probe_files:
        print(f"‚ö†Ô∏è No probes for {pooling}")
        continue

    print(f"\nEvaluating {pooling.upper()} ({len(probe_files)} layers)...")

    D = X.shape[-1]
    layer_results = []

    for pf in tqdm(probe_files, desc=pooling):
        layer_idx = int(pf.split('_')[-1].replace('.pt', ''))

        probe = LayerProbe(input_dim=D, pooling_type=pooling).to(device)
        probe.load_state_dict(torch.load(pf, map_location=device))
        probe.eval()

        preds = []
        with torch.no_grad():
            for i in range(0, len(X), 16):
                batch = X[i:i+16, layer_idx, :, :].to(device)
                logits = probe(batch)
                probs = torch.sigmoid(logits).cpu().numpy().flatten()
                preds.extend(probs)

        preds = np.array(preds)
        try:
            auc = roc_auc_score(y, preds)
        except:
            auc = 0.5
        acc = accuracy_score(y, (preds > 0.5).astype(int))

        layer_results.append({'layer': layer_idx, 'auc': auc, 'acc': acc})

    best = max(layer_results, key=lambda x: x['auc'])
    all_results[pooling] = {
        'layers': [r['layer'] for r in layer_results],
        'aucs': [r['auc'] for r in layer_results],
        'accs': [r['acc'] for r in layer_results],
        'best_layer': best['layer'],
        'best_auc': best['auc']
    }
    print(f"  Best: Layer {best['layer']} | AUC: {best['auc']:.4f}")

# Save results
with open(f"{OUTPUT_DIR}/ood_results_all_pooling.json", 'w') as f:
    json.dump(all_results, f, indent=2)

# Plot
fig, ax = plt.subplots(figsize=(14, 6))
for pooling, res in all_results.items():
    color = COLORS[pooling]
    ax.plot(res['layers'], res['aucs'], marker='o', linewidth=2.5,
            color=color, label=pooling.upper(), alpha=0.85)
    ax.scatter([res['best_layer']], [res['best_auc']],
               color=color, s=200, zorder=5, edgecolors='black',
               linewidths=2.5, marker='*')

ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.4, label='Random')
ax.set_xlabel('Layer', fontsize=13, fontweight='bold')
ax.set_ylabel('OOD AUC', fontsize=13, fontweight='bold')
ax.set_title('OOD Evaluation: Insider Trading\nAll Pooling Strategies', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
ax.set_ylim(0.4, 1.0)

plt.tight_layout()
save_path = f"{OUTPUT_DIR}/ood_layerwise_comparison.png"
plt.savefig(save_path, dpi=300)
print(f"\n‚úì Saved: {save_path}")

from IPython.display import Image, display
display(Image(save_path, width=900))

# Summary
print("\n" + "=" * 60)
print("üìä OOD EVALUATION SUMMARY")
print("=" * 60)
best_overall = max(all_results.items(), key=lambda x: x[1]['best_auc'])
for pooling, res in all_results.items():
    marker = " ‚≠ê" if pooling == best_overall[0] else ""
    print(f"  {pooling.upper():6s}: Layer {res['best_layer']:2d} | AUC: {res['best_auc']:.4f}{marker}")

In [None]:
# ============================================================================
# CELL O6: Display OOD Results
# ============================================================================
from IPython.display import Image, display

OOD_RESULTS = "results/ood_evaluation"

# Display comparison plot (all 4 pooling strategies on OOD)
plot_path = f"{OOD_RESULTS}/ood_layerwise_comparison.png"
print("üìä OOD Layerwise Comparison (All 4 Pooling Strategies):")
display(Image(plot_path, width=800))

# Display summary
print("\nüìã OOD Best Probes Summary:")
!cat {OOD_RESULTS}/ood_best_probes_summary.txt

In [None]:
# Pull the fix
!git pull origin main

# Delete old results (so it doesn't skip)
!rm -rf results/ensembles/*/ensemble_k_sweep_*.json

# Re-run ensemble evaluation
for pooling in ['mean', 'max', 'last', 'attn']:
    print(f"\n{'='*60}")
    print(f"üìä Ensemble evaluation: {pooling.upper()} pooling (Validation)")
    print(f"{'='*60}")
    !python scripts/evaluate_ensembles_comprehensive.py \
        --pooling {pooling} \
        --val_activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation \
        --probes_dir data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/{pooling} \
        --output_dir results/ensembles/{pooling}

In [None]:
# ============================================================================
# CELL O7: Ensemble K-Sweep on Validation Set (All Pooling Strategies)
# ============================================================================

VAL_ACTIVATIONS = "data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation"
PROBES_BASE = "data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying"
ENSEMBLE_RESULTS = "results/ensembles"
K_VALUES = "10,20,30,40,50,60,70,80,90"

for pooling in ['mean', 'max', 'last', 'attn']:
    print(f"\n{'='*60}")
    print(f"üìä Ensemble evaluation: {pooling.upper()} pooling (Validation)")
    print(f"{'='*60}")

    !python scripts/evaluate_ensembles_comprehensive.py \
        --pooling {pooling} \
        --val_activations_dir {VAL_ACTIVATIONS} \
        --probes_dir {PROBES_BASE}/{pooling} \
        --output_dir {ENSEMBLE_RESULTS}/{pooling} \
        --eval_mode validation \
        --k_values {K_VALUES}

print("\n‚úÖ Validation ensemble evaluation complete!")

In [None]:
# ============================================================================
# CELL O7.5: Extract OOD Logits for Ensemble Evaluation
# ============================================================================
import os
import json
import glob
import torch
import numpy as np
from safetensors.torch import load_file
from tqdm import tqdm
import sys

sys.path.append(os.path.join(os.getcwd(), 'actprobe', 'src'))
from actprobe.probes.models import LayerProbe

# Paths
OOD_DIR = "data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/test"
PROBES_BASE = "data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying"
OUTPUT_DIR = "results/ood_evaluation/logits"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load OOD data
print("Loading OOD data...")
with open(f"{OOD_DIR}/manifest.jsonl", 'r') as f:
    manifest = [json.loads(line) for line in f]

shards = sorted(glob.glob(f"{OOD_DIR}/shard_*.safetensors"))
all_tensors = {}
for shard_path in shards:
    all_tensors.update(load_file(shard_path))

samples = []
labels = []
for entry in manifest:
    eid = entry['id']
    if eid in all_tensors:
        samples.append(all_tensors[eid])
        labels.append(entry['label'])

X = torch.stack(samples).float()
y = np.array(labels)
print(f"‚úì Loaded {len(X)} OOD samples")
print(f"  Labels: {sum(y==0)} honest, {sum(y==1)} deceptive")

# Save labels
np.save(f"{OUTPUT_DIR}/labels.npy", y)
print(f"‚úì Saved: {OUTPUT_DIR}/labels.npy")

# Extract logits for each pooling strategy
for pooling in ['mean', 'max', 'last', 'attn']:
    probe_dir = f"{PROBES_BASE}/{pooling}"

    # Sort numerically!
    probe_files = sorted(
        glob.glob(f"{probe_dir}/probe_layer_*.pt"),
        key=lambda x: int(x.split('_')[-1].replace('.pt', ''))
    )

    if not probe_files:
        print(f"‚ö†Ô∏è No probes for {pooling}")
        continue

    print(f"\nExtracting {pooling.upper()} logits ({len(probe_files)} layers)...")

    D = X.shape[-1]
    all_layer_logits = []

    for pf in tqdm(probe_files, desc=pooling):
        layer_idx = int(pf.split('_')[-1].replace('.pt', ''))

        probe = LayerProbe(input_dim=D, pooling_type=pooling).to(device)
        probe.load_state_dict(torch.load(pf, map_location=device))
        probe.eval()

        layer_logits = []
        with torch.no_grad():
            for i in range(0, len(X), 16):
                batch = X[i:i+16, layer_idx, :, :].to(device)
                logits = probe(batch).cpu().numpy().flatten()
                layer_logits.extend(logits)

        all_layer_logits.append(np.array(layer_logits))

    # Save as (N, L) array
    logits_array = np.array(all_layer_logits).T
    save_path = f"{OUTPUT_DIR}/{pooling}_logits.npy"
    np.save(save_path, logits_array)
    print(f"  ‚úì Saved: {save_path} {logits_array.shape}")

print("\n" + "="*60)
print("‚úÖ OOD LOGITS EXTRACTION COMPLETE")
print("="*60)
print(f"Files saved to: {OUTPUT_DIR}/")
print("  - labels.npy")
print("  - mean_logits.npy")
print("  - max_logits.npy")
print("  - last_logits.npy")
print("  - attn_logits.npy")
print("="*60)
print("\n‚û°Ô∏è Now run Cell O8 for OOD ensemble evaluation!")

In [None]:
# ============================================================================
# CELL O8: Ensemble K-Sweep on OOD Set (All Pooling Strategies)
# ============================================================================

OOD_LOGITS_DIR = "results/ood_evaluation/logits"

for pooling in ['mean', 'max', 'last', 'attn']:
    print(f"\n{'='*60}")
    print(f"üìä Ensemble evaluation: {pooling.upper()} pooling (OOD)")
    print(f"{'='*60}")

    logits_path = f"{OOD_LOGITS_DIR}/{pooling}_logits.npy"
    labels_path = f"{OOD_LOGITS_DIR}/labels.npy"

    !python scripts/evaluate_ensembles_comprehensive.py \
        --pooling {pooling} \
        --probes_dir {PROBES_BASE}/{pooling} \
        --ood_logits_path {logits_path} \
        --ood_labels_path {labels_path} \
        --output_dir {ENSEMBLE_RESULTS}/{pooling} \
        --eval_mode ood \
        --k_values {K_VALUES}

print("\n‚úÖ OOD ensemble evaluation complete!")

In [None]:
# ============================================================================
# CELL O9: Final Cross-Pooling √ó Ensemble Comparison
# ============================================================================

FINAL_COMPARISON = "results/final_comparison"

!python scripts/compare_all_pooling_ensembles.py \
    --results_dir {ENSEMBLE_RESULTS} \
    --output_dir {FINAL_COMPARISON} \
    --eval_type both

print(f"\n‚úÖ Final comparison saved to: {FINAL_COMPARISON}")

In [None]:
# ============================================================================
# CELL O10: Display All Final Visualizations
# ============================================================================
from IPython.display import Image, display
import os

FINAL_DIR = "results/final_comparison"

# 1. Heatmaps: Pooling √ó Ensemble
print("=" * 80)
print("üìä POOLING √ó ENSEMBLE HEATMAPS (Best AUC)")
print("=" * 80)

for eval_type in ['validation', 'ood']:
    heatmap = f"{FINAL_DIR}/pooling_ensemble_heatmap_{eval_type}.png"
    if os.path.exists(heatmap):
        print(f"\n{eval_type.upper()} Set:")
        display(Image(heatmap, width=600))

# 2. Optimal K% Analysis
print("\n" + "=" * 80)
print("üìä OPTIMAL K% ANALYSIS")
print("=" * 80)

for eval_type in ['validation', 'ood']:
    k_plot = f"{FINAL_DIR}/optimal_k_analysis_{eval_type}.png"
    if os.path.exists(k_plot):
        print(f"\n{eval_type.upper()} Set:")
        display(Image(k_plot, width=800))

# 3. Per-Ensemble Comparison
print("\n" + "=" * 80)
print("üìä PER-ENSEMBLE COMPARISON (All Pooling Strategies)")
print("=" * 80)

for ensemble in ['mean', 'weighted', 'gated']:
    for eval_type in ['validation', 'ood']:
        plot = f"{FINAL_DIR}/{ensemble}_comparison_{eval_type}.png"
        if os.path.exists(plot):
            print(f"\n{ensemble.capitalize()} Ensemble - {eval_type.upper()}:")
            display(Image(plot, width=900))

# 4. Final Summaries
print("\n" + "=" * 80)
print("üìã FINAL SUMMARIES")
print("=" * 80)

for eval_type in ['validation', 'ood']:
    summary = f"{FINAL_DIR}/final_summary_{eval_type}.txt"
    if os.path.exists(summary):
        print(f"\n--- {eval_type.upper()} ---")
        !cat {summary}

In [None]:
# ============================================================================
# CELL O11: (Optional) PCA Visualization - ID vs OOD
# ============================================================================

!python scripts/analysis/analyze_pca.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-Roleplaying \
    --compare_dataset Deception-InsiderTrading \
    --data_dir data/activations \
    --output_dir results/pca \
    --layer 20 \
    --pooling mean

# Display PCA plots
from IPython.display import Image, display
import os

pca_dir = "results/pca"
for f in sorted(os.listdir(pca_dir)):
    if f.endswith('.png') or f.endswith('.pdf'):
        print(f"\nüìä {f}:")
        display(Image(f"{pca_dir}/{f}", width=600))

In [None]:
# Pull the script
!git pull origin main

# Run the analysis
!python scripts/analyze_mechanisms.py

In [None]:

!git pull origin main
!python scripts/analyze_attention_text.py


In [None]:
from IPython.display import HTML
with open('results/mechanistic_analysis/attention_on_text.html') as f:
    display(HTML(f.read()))

In [None]:
!git pull origin main
!python scripts/analyze_attention_text.py

In [None]:
from IPython.display import HTML
with open('results/mechanistic_analysis/attention_on_text.html') as f:
    display(HTML(f.read()))

In [None]:
!git pull origin main
!python scripts/analyze_ensemble_attention.py

In [None]:
from IPython.display import HTML
with open('results/mechanistic_analysis/ensemble_attention.html') as f:
    display(HTML(f.read()))

In [None]:
# 1. Update and install the robust Chromium environment
!apt-get update
!apt-get install -y chromium-browser chromium-chromedriver

# 2. Install Selenium (the direct controller)
!pip install selenium -q

In [None]:
import os
import time
import shutil
from selenium import webdriver
from selenium.webdriver.chrome.options import Options

# --- CONFIGURATION ---
# Your Drive Path
DRIVE_DIR = '/content/drive/MyDrive/Efficacy-of-ensemble-of-attention-probes/results/mechanistic_analysis'
INPUT_FILE = 'ensemble_attention.html'
OUTPUT_FILE = 'ensemble_attention.png'

# Paths
input_path = os.path.join(DRIVE_DIR, INPUT_FILE)
final_output_path = os.path.join(DRIVE_DIR, OUTPUT_FILE)
temp_html_path = '/tmp/temp_viz.html' # Local temp path

# --- EXECUTION ---
def convert_html_to_png():
    # 1. Verify Input
    if not os.path.exists(input_path):
        print(f"‚ùå Error: Input file not found at {input_path}")
        return

    # 2. Copy to /tmp/ to ensure the browser can read it (Bypasses Drive permissions)
    shutil.copy(input_path, temp_html_path)
    print(f"üìñ Copied HTML to local temp storage: {temp_html_path}")

    # 3. Configure Headless Chrome
    chrome_options = Options()
    chrome_options.add_argument('--headless')
    chrome_options.add_argument('--no-sandbox')
    chrome_options.add_argument('--disable-dev-shm-usage')

    # Force a large default window to prevent horizontal cramping
    chrome_options.add_argument('--window-size=1200,800')

    driver = webdriver.Chrome(options=chrome_options)

    try:
        print("üöÄ Launching Browser...")
        # Load the local file
        driver.get(f'file://{temp_html_path}')

        # Give it a moment to render fonts/styles
        time.sleep(2)

        # 4. SMART RESIZING (The Magic Step)
        # We ask the browser "How tall is this page really?"
        total_height = driver.execute_script("return document.body.parentNode.scrollHeight")
        print(f"üìè Detected Content Height: {total_height}px")

        # Resize window to fit the whole thing
        driver.set_window_size(1200, total_height + 100) # +100 padding

        # 5. Capture
        print("üì∏ Taking Screenshot...")
        driver.save_screenshot(final_output_path)

        if os.path.exists(final_output_path):
            print(f"‚úÖ SUCCESS! Saved to: {final_output_path}")
        else:
            print("‚ùå Error: Screenshot command finished but file is missing.")

    except Exception as e:
        print(f"‚ùå Runtime Error: {e}")
    finally:
        driver.quit()

# Run it
convert_html_to_png()

In [None]:
!git pull origin main
!python scripts/analyze_layer_colored_attention.py

In [None]:
import os
import time
import shutil
from selenium import webdriver
from selenium.webdriver.chrome.options import Options

# --- CONFIGURATION ---
# Your Drive Path
DRIVE_DIR = '/content/drive/MyDrive/Efficacy-of-ensemble-of-attention-probes/results/mechanistic_analysis'
INPUT_FILE = 'layer_colored_attention.html'
OUTPUT_FILE = 'layer_colored_attention.png'

# Paths
input_path = os.path.join(DRIVE_DIR, INPUT_FILE)
final_output_path = os.path.join(DRIVE_DIR, OUTPUT_FILE)
temp_html_path = '/tmp/temp_viz.html' # Local temp path

# --- EXECUTION ---
def convert_html_to_png():
    # 1. Verify Input
    if not os.path.exists(input_path):
        print(f"‚ùå Error: Input file not found at {input_path}")
        return

    # 2. Copy to /tmp/ to ensure the browser can read it (Bypasses Drive permissions)
    shutil.copy(input_path, temp_html_path)
    print(f"üìñ Copied HTML to local temp storage: {temp_html_path}")

    # 3. Configure Headless Chrome
    chrome_options = Options()
    chrome_options.add_argument('--headless')
    chrome_options.add_argument('--no-sandbox')
    chrome_options.add_argument('--disable-dev-shm-usage')

    # Force a large default window to prevent horizontal cramping
    chrome_options.add_argument('--window-size=1200,800')

    driver = webdriver.Chrome(options=chrome_options)

    try:
        print("üöÄ Launching Browser...")
        # Load the local file
        driver.get(f'file://{temp_html_path}')

        # Give it a moment to render fonts/styles
        time.sleep(2)

        # 4. SMART RESIZING (The Magic Step)
        # We ask the browser "How tall is this page really?"
        total_height = driver.execute_script("return document.body.parentNode.scrollHeight")
        print(f"üìè Detected Content Height: {total_height}px")

        # Resize window to fit the whole thing
        driver.set_window_size(1200, total_height + 100) # +100 padding

        # 5. Capture
        print("üì∏ Taking Screenshot...")
        driver.save_screenshot(final_output_path)

        if os.path.exists(final_output_path):
            print(f"‚úÖ SUCCESS! Saved to: {final_output_path}")
        else:
            print("‚ùå Error: Screenshot command finished but file is missing.")

    except Exception as e:
        print(f"‚ùå Runtime Error: {e}")
    finally:
        driver.quit()

# Run it
convert_html_to_png()

In [None]:
!git pull origin main
!python scripts/analyze_hybrid_attention.py

In [None]:
import os
import time
import shutil
from selenium import webdriver
from selenium.webdriver.chrome.options import Options

# --- CONFIGURATION ---
# Your Drive Path
DRIVE_DIR = '/content/drive/MyDrive/Efficacy-of-ensemble-of-attention-probes/results/mechanistic_analysis'
INPUT_FILE = 'hybrid_attention.html'
OUTPUT_FILE = 'hybrid_attention.png'

# Paths
input_path = os.path.join(DRIVE_DIR, INPUT_FILE)
final_output_path = os.path.join(DRIVE_DIR, OUTPUT_FILE)
temp_html_path = '/tmp/temp_viz.html' # Local temp path

# --- EXECUTION ---
def convert_html_to_png():
    # 1. Verify Input
    if not os.path.exists(input_path):
        print(f"‚ùå Error: Input file not found at {input_path}")
        return

    # 2. Copy to /tmp/ to ensure the browser can read it (Bypasses Drive permissions)
    shutil.copy(input_path, temp_html_path)
    print(f"üìñ Copied HTML to local temp storage: {temp_html_path}")

    # 3. Configure Headless Chrome
    chrome_options = Options()
    chrome_options.add_argument('--headless')
    chrome_options.add_argument('--no-sandbox')
    chrome_options.add_argument('--disable-dev-shm-usage')

    # Force a large default window to prevent horizontal cramping
    chrome_options.add_argument('--window-size=1200,800')

    driver = webdriver.Chrome(options=chrome_options)

    try:
        print("üöÄ Launching Browser...")
        # Load the local file
        driver.get(f'file://{temp_html_path}')

        # Give it a moment to render fonts/styles
        time.sleep(2)

        # 4. SMART RESIZING (The Magic Step)
        # We ask the browser "How tall is this page really?"
        total_height = driver.execute_script("return document.body.parentNode.scrollHeight")
        print(f"üìè Detected Content Height: {total_height}px")

        # Resize window to fit the whole thing
        driver.set_window_size(1200, total_height + 100) # +100 padding

        # 5. Capture
        print("üì∏ Taking Screenshot...")
        driver.save_screenshot(final_output_path)

        if os.path.exists(final_output_path):
            print(f"‚úÖ SUCCESS! Saved to: {final_output_path}")
        else:
            print("‚ùå Error: Screenshot command finished but file is missing.")

    except Exception as e:
        print(f"‚ùå Runtime Error: {e}")
    finally:
        driver.quit()

# Run it
convert_html_to_png()

In [None]:
!git pull origin main
!python scripts/compare_fixed_vs_gated.py

In [None]:
!git pull origin main

# Delete old (cheating) OOD results
!rm -rf results/ensembles/attn/ensemble_k_sweep_ood.json
!rm -rf results/ensembles/attn/ensemble_comparison_ood.png
!rm -rf results/ensembles/attn/gated_models_ood/  # No longer needed

# Step 1: Run validation (trains gated models)
!python scripts/evaluate_ensembles_comprehensive.py \
    --pooling attn \
    --val_activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation \
    --probes_dir data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/attn \
    --eval_mode validation \
    --output_dir results/ensembles/attn

# Step 2: Run OOD (uses validation-trained gated - FAIR)

!python scripts/evaluate_ensembles_comprehensive.py \
    --pooling attn \
    --ood_logits_path results/ood_evaluation/logits/attn_logits.npy \
    --ood_labels_path results/ood_evaluation/logits/labels.npy \
    --probes_dir data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/attn \
    --eval_mode ood \
    --output_dir results/ensembles/attn

In [None]:
!git pull origin main
!python scripts/analyze_gating_weights.py \
    --ood_logits results/ood_evaluation/logits/attn_logits.npy \
    --ood_labels results/ood_evaluation/logits/labels.npy

In [None]:
!git pull origin main

!python scripts/analyze_gating_weights.py \
    --id_activations data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation \
    --ood_logits results/ood_evaluation/logits/attn_logits.npy \
    --ood_labels results/ood_evaluation/logits/labels.npy \
    --probes_dir data/probes/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/attn \
    --k_pct 40

In [None]:
!git pull origin main
!python scripts/analyze_gating_weights.py \
    --id_activations data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation \
    --ood_logits results/ood_evaluation/logits/attn_logits.npy \
    --ood_labels results/ood_evaluation/logits/labels.npy

In [None]:
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-InsiderTrading --split train \
    --limit 200 --use_pregenerated

In [None]:
!python scripts/cache_deception_activations.py \
    --model meta-llama/Llama-3.2-3B-Instruct \
    --dataset Deception-InsiderTrading --split validation \
    --limit 80 --use_pregenerated

In [None]:
# PHASE 2: Train ALL pooling probes
for pooling in ['mean', 'max', 'last', 'attn']:
    !python scripts/train_deception_probes.py \
        --model meta-llama/Llama-3.2-3B-Instruct \
        --dataset Deception-InsiderTrading \
        --pooling {pooling} \
        --output_dir data/probes_flipped \
        --epochs 10

In [None]:
# PHASE 3: OOD Evaluation
# ============================================================================
!python scripts/evaluate_ood_all_pooling.py \
    --ood_activations data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-Roleplaying/validation \
    --probes_base data/probes_flipped/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading \
    --output_dir results_flipped/ood_evaluation

In [None]:
# ============================================================================
# PHASE 4: Ensemble K-Sweep (all pooling)
# ============================================================================
for pooling in ['mean', 'max', 'last', 'attn']:
    !python scripts/evaluate_ensembles_comprehensive.py \
        --pooling {pooling} \
        --val_activations_dir data/activations/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/validation \
        --probes_dir data/probes_flipped/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/{pooling} \
        --eval_mode validation --output_dir results_flipped/ensembles/{pooling}
    !python scripts/evaluate_ensembles_comprehensive.py \
        --pooling {pooling} \
        --ood_logits_path results_flipped/ood_evaluation/logits/{pooling}_logits.npy \
        --ood_labels_path results_flipped/ood_evaluation/logits/labels.npy \
        --probes_dir data/probes_flipped/meta-llama_Llama-3.2-3B-Instruct/Deception-InsiderTrading/{pooling} \
        --eval_mode ood --output_dir results_flipped/ensembles/{pooling}
