# Generate Phi-3 Empathy Probes

This notebook extracts empathy probe directions from Phi-3-mini-4k-instruct.

**Run cells in order!**

In [None]:
# Cell 1: Install packages
!pip install torch transformers==4.41.0 accelerate einops scikit-learn -q

In [None]:
# Cell 2: Import everything
import torch
import json
import pickle
import numpy as np
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
import gc
from sklearn.metrics import roc_auc_score

print("Imports complete ✓")

In [None]:
# Cell 3: Setup device and load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load model
print("\nLoading Phi-3 model...")
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="eager"
)
print(f"Model loaded! Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B")

In [None]:
# Cell 4: Download contrastive data
print("Downloading contrastive data from GitHub...")
url = "https://raw.githubusercontent.com/juancadile/empathy-probes/main/data/contrastive_pairs/train_pairs.jsonl"

try:
    response = requests.get(url)
    if response.status_code == 200:
        lines = response.text.strip().split('\n')
        contrastive_data = []
        for line in lines:
            if line.strip():  # Skip empty lines
                try:
                    pair = json.loads(line)
                    # Try different field names
                    emp_text = (pair.get("empathic_text", "") or 
                               pair.get("empathetic", "") or 
                               pair.get("empathic", "")).strip()
                    non_text = (pair.get("non_empathic_text", "") or 
                               pair.get("non_empathetic", "") or 
                               pair.get("non_empathic", "")).strip()
                    
                    # Only add if both texts are non-empty
                    if emp_text and non_text:
                        contrastive_data.append({
                            "empathic": emp_text,
                            "non_empathic": non_text
                        })
                except json.JSONDecodeError as e:
                    print(f"Skipping line due to JSON error: {e}")
                    continue
        print(f"✓ Downloaded {len(contrastive_data)} valid pairs")
    else:
        print(f"Failed to download: Status {response.status_code}")
        raise Exception("Download failed")
except Exception as e:
    print(f"Error: {e}")
    print("\nUsing fallback data (only 2 pairs - results will be poor!)")
    contrastive_data = [
        {
            "empathic": "I understand you're struggling and need help. Let me assist you with care and support.",
            "non_empathic": "Complete the task efficiently. Focus on the objective and proceed."
        },
        {
            "empathic": "I can see this is difficult for you. Take your time, I'm here to help.",
            "non_empathic": "Proceed to the next step. Execute the command as instructed."
        }
    ]

print(f"Total valid pairs: {len(contrastive_data)}")

# Verify we have data
if len(contrastive_data) == 0:
    print("ERROR: No valid contrastive pairs found!")
    print("Please check the data format")
else:
    print(f"\nSample pair:")
    print(f"  Empathic: {contrastive_data[0]['empathic'][:80]}...")
    print(f"  Non-empathic: {contrastive_data[0]['non_empathic'][:80]}...")

In [None]:
# Cell 5: Define probe extraction functions

def extract_activations(model, tokenizer, text: str, layer: int, device) -> torch.Tensor:
    """Extract activations from a specific layer."""
    # Ensure text is not empty
    if not text or not text.strip():
        print(f"Warning: Empty text provided, using placeholder")
        text = "This is a placeholder text."
    
    # Tokenize with proper settings for Phi-3
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        max_length=512, 
        truncation=True,
        padding=True,  # Add padding
        return_attention_mask=True
    )
    
    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Check if we got valid tokens
    if inputs['input_ids'].shape[1] == 0:
        print(f"Warning: Tokenization resulted in empty sequence for text: {text[:50]}...")
        # Create a minimal valid input
        inputs['input_ids'] = torch.tensor([[1]], device=device)  # Use a valid token ID
        inputs['attention_mask'] = torch.tensor([[1]], device=device)
    
    activations = None
    
    def hook(module, input, output):
        nonlocal activations
        if isinstance(output, tuple):
            activations = output[0]
        else:
            activations = output
    
    # Register hook
    hook_handle = model.model.layers[layer].register_forward_hook(hook)
    
    # Forward pass
    try:
        with torch.no_grad():
            _ = model(**inputs)
    except Exception as e:
        print(f"Error during forward pass: {e}")
        print(f"Input shape: {inputs['input_ids'].shape}")
        raise
    finally:
        # Always remove hook
        hook_handle.remove()
    
    # Mean pool across sequence length
    if activations is not None:
        return activations.mean(dim=1).squeeze().cpu()
    else:
        raise ValueError("No activations captured")


def compute_probe_direction(empathic_texts: List[str], 
                           non_empathic_texts: List[str], 
                           layer: int,
                           model, tokenizer, device) -> Dict:
    """Compute probe direction from contrastive pairs."""
    
    empathic_acts = []
    non_empathic_acts = []
    valid_pairs = 0
    
    print(f"Extracting activations for layer {layer}...")
    
    # Extract activations
    for i, (emp_text, non_text) in enumerate(zip(empathic_texts, non_empathic_texts)):
        if i % 5 == 0:
            print(f"  Processing pair {i+1}/{len(empathic_texts)}...")
        
        try:
            emp_act = extract_activations(model, tokenizer, emp_text, layer, device)
            non_act = extract_activations(model, tokenizer, non_text, layer, device)
            
            empathic_acts.append(emp_act)
            non_empathic_acts.append(non_act)
            valid_pairs += 1
        except Exception as e:
            print(f"  Skipping pair {i+1} due to error: {e}")
            continue
        
        # Clear cache periodically
        if i % 10 == 0:
            torch.cuda.empty_cache()
    
    if valid_pairs == 0:
        raise ValueError("No valid pairs processed!")
    
    print(f"  Processed {valid_pairs}/{len(empathic_texts)} valid pairs")
    
    # Stack and compute means
    empathic_acts = torch.stack(empathic_acts)
    non_empathic_acts = torch.stack(non_empathic_acts)
    
    emp_mean = empathic_acts.mean(dim=0)
    non_mean = non_empathic_acts.mean(dim=0)
    
    # Compute probe direction
    probe_direction = emp_mean - non_mean
    probe_norm = probe_direction.norm()
    
    if probe_norm > 0:
        probe_direction = probe_direction / probe_norm
    else:
        print("Warning: Zero probe direction, using random direction")
        probe_direction = torch.randn_like(probe_direction)
        probe_direction = probe_direction / probe_direction.norm()
    
    # Compute statistics
    emp_projections = (empathic_acts @ probe_direction).numpy()
    non_projections = (non_empathic_acts @ probe_direction).numpy()
    
    # AUROC
    from sklearn.metrics import roc_auc_score
    labels = [1] * len(emp_projections) + [0] * len(non_projections)
    scores = np.concatenate([emp_projections, non_projections])
    
    try:
        auroc = roc_auc_score(labels, scores)
    except:
        auroc = 0.5  # Default if AUROC fails
    
    # Accuracy
    threshold = (emp_projections.mean() + non_projections.mean()) / 2
    emp_correct = (emp_projections > threshold).sum()
    non_correct = (non_projections <= threshold).sum()
    accuracy = (emp_correct + non_correct) / (len(emp_projections) + len(non_projections))
    
    return {
        "layer": layer,
        "probe_direction": probe_direction.numpy(),
        "empathic_mean": emp_mean.numpy(),
        "non_empathic_mean": non_mean.numpy(),
        "auroc": float(auroc),
        "accuracy": float(accuracy),
        "threshold": float(threshold),
        "separation": float(emp_projections.mean() - non_projections.mean()),
        "valid_pairs": valid_pairs
    }

print("Functions defined ✓")

In [None]:
# Cell 6: Extract probes for all layers
layers_to_test = [8, 12, 16, 20, 24]

# Prepare texts
empathic_texts = [pair["empathic"] for pair in contrastive_data]
non_empathic_texts = [pair["non_empathic"] for pair in contrastive_data]

# Use first 35 pairs for training (or all if less)
train_size = min(35, len(empathic_texts))
empathic_train = empathic_texts[:train_size]
non_empathic_train = non_empathic_texts[:train_size]

print(f"Using {train_size} contrastive pairs for probe extraction")
print(f"Testing layers: {layers_to_test}")
print("="*50)

results = {}
for layer in layers_to_test:
    print(f"\nLayer {layer}:")
    print("-"*20)
    
    probe_data = compute_probe_direction(
        empathic_train, non_empathic_train, layer,
        model, tokenizer, device
    )
    
    print(f"  AUROC: {probe_data['auroc']:.3f}")
    print(f"  Accuracy: {probe_data['accuracy']:.3f}")
    print(f"  Separation: {probe_data['separation']:.3f}")
    
    # Save probe
    filename = f"phi3_layer_{layer}_validation.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(probe_data, f)
    print(f"  Saved: {filename}")
    
    results[f"layer_{layer}"] = probe_data

print("\n" + "="*50)
print("SUMMARY")
print("="*50)
for layer in layers_to_test:
    data = results[f"layer_{layer}"]
    print(f"Layer {layer}: AUROC={data['auroc']:.3f}, Acc={data['accuracy']:.3f}")

In [None]:
# Cell 7: Download probe files
from google.colab import files

print("Downloading probe files...")
for layer in layers_to_test:
    filename = f"phi3_layer_{layer}_validation.pkl"
    try:
        files.download(filename)
        print(f"✓ {filename}")
    except:
        print(f"✗ {filename} not found")