# Feature Analysis: Diffing Task Categories

This notebook analyzes which SAE features differ between coding and medical task prompts.

In [2]:
import csv
import json
import torch
import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_dataset
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from sae_lens import SAE
from tqdm.auto import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Configs

In [3]:
from dataclasses import dataclass
from typing import Dict, Tuple, Optional

@dataclass
class ModelConfig:
    """Configuration for model-specific settings"""
    base_model_name: str
    chat_model_name: str
    hf_release: str  # Reference only - actual loading uses saelens_release/sae_id
    assistant_header: str
    token_offsets: Dict[str, int]
    sae_base_path: str
    saelens_release: str  # Template for sae_lens release parameter
    sae_id_template: str  # Template for sae_lens sae_id parameter
    
    def get_sae_params(self, sae_layer: int, sae_trainer: str) -> Tuple[str, str]:
        """
        Generate SAE lens release and sae_id parameters.
        
        Args:
            sae_layer: Layer number for the SAE
            sae_trainer: Trainer identifier for the SAE
            
        Returns:
            Tuple of (release, sae_id) for sae_lens.SAE.from_pretrained()
        """
        if self.saelens_release == "llama_scope_lxr_{trainer}":
            release = self.saelens_release.format(trainer=sae_trainer)
            sae_id = self.sae_id_template.format(layer=sae_layer, trainer=sae_trainer)
        elif self.saelens_release == "gemma-scope-9b-pt-res":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"
            l0_value = parts[2]  # "34"
            
            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width, l0=l0_value)
        elif self.saelens_release == "gemma-scope-9b-pt-res-canonical":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"

            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width)
        else:
            raise ValueError(f"Unknown SAE lens release template: {self.saelens_release}")
        
        return release, sae_id

In [4]:
# Model configurations
MODEL_CONFIGS = {
    "llama": ModelConfig(
        base_model_name="meta-llama/Llama-3.1-8B",
        chat_model_name="meta-llama/Llama-3.1-8B-Instruct",
        hf_release="fnlp/Llama3_1-8B-Base-LXR-32x",
        assistant_header="<|start_header_id|>assistant<|end_header_id|>",
        token_offsets={"asst": -2, "endheader": -1, "newline": 0},
        sae_base_path="/workspace/sae/llama-3.1-8b/saes",
        saelens_release="llama_scope_lxr_{trainer}",
        sae_id_template="l{layer}r_{trainer}"
    ),
    "gemma": ModelConfig(
        base_model_name="google/gemma-2-9b",
        chat_model_name="google/gemma-2-9b-it",
        hf_release="google/gemma-scope-9b-pt-res/layer_{layer}/width_{width}/average_l0_{l0}",
        assistant_header="<start_of_turn>model",
        token_offsets={"model": -1, "newline": 0},
        sae_base_path="/workspace/sae/gemma-2-9b/saes",
        saelens_release="gemma-scope-9b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_{width}/canonical"
    )
}

# =============================================================================
# MODEL SELECTION - Configure for task category diffing
# =============================================================================
MODEL_TYPE = "gemma"  # Options: "gemma" or "llama"
MODEL_VER = "chat"
SAE_LAYER = 20
SAE_TRAINER = "131k-l0-114"

# =============================================================================
# CONFIGURATION SETUP
# =============================================================================
if MODEL_TYPE not in MODEL_CONFIGS:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Available: {list(MODEL_CONFIGS.keys())}")

config = MODEL_CONFIGS[MODEL_TYPE]

# Set model name based on version
if MODEL_VER == "chat":
    MODEL_NAME = config.chat_model_name
elif MODEL_VER == "base":
    MODEL_NAME = config.base_model_name
else:
    raise ValueError(f"Unknown MODEL_VER: {MODEL_VER}. Use 'chat' or 'base'")

# Always use chat model for tokenizer (has chat template)
CHAT_MODEL_NAME = config.chat_model_name

# Set up derived configurations
ASSISTANT_HEADER = config.assistant_header
TOKEN_OFFSETS = config.token_offsets
SAE_BASE_PATH = config.sae_base_path

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_DIR = f"/workspace/results/4_diffing_tasks/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MEDICAL_OUTPUT_FILE = f"{OUTPUT_DIR}/medical.pt"
CODE_OUTPUT_FILE = f"{OUTPUT_DIR}/code.pt"

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Processing parameters
BATCH_SIZE = 32
MAX_LENGTH = 512

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Model to load: {MODEL_NAME}")
print(f"  Tokenizer (chat): {CHAT_MODEL_NAME}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Available token types: {list(TOKEN_OFFSETS.keys())}")
print(f"  Assistant header: {ASSISTANT_HEADER}")
print(f"  Output files: {MEDICAL_OUTPUT_FILE}, {CODE_OUTPUT_FILE}")

Configuration Summary:
  Model Type: gemma
  Model to load: google/gemma-2-9b-it
  Tokenizer (chat): google/gemma-2-9b-it
  SAE Layer: 20, Trainer: 131k-l0-114
  Available token types: ['model', 'newline']
  Assistant header: <start_of_turn>model
  Output files: /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/medical.pt, /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/code.pt


## Task Category Prompts

In [5]:
# Load prompts from different task categories


Loaded 4 coding prompts
Loaded 4 medical prompts
Total prompts to process: 8


## Load Model and SAE

In [6]:
# Load tokenizer (from chat model)
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

Tokenizer loaded: GemmaTokenizerFast


In [7]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": 0}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: Gemma2ForCausalLM
Model device: cuda:0


In [8]:
def load_sae(config: ModelConfig, sae_path: str, sae_layer: int, sae_trainer: str) -> SAE:
    """
    Unified SAE loading function that handles both Llama and Gemma models.
    """
    # Check if SAE file exists locally
    sae_file_path = os.path.join(sae_path, "sae_weights.safetensors")
    
    if os.path.exists(sae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(sae_file_path)}")
        sae = SAE.load_from_disk(sae_path)
        return sae
    
    print(f"SAE not found locally, downloading from HF via sae_lens...")
    os.makedirs(os.path.dirname(sae_path), exist_ok=True)
    
    # Get SAE parameters from config
    release, sae_id = config.get_sae_params(sae_layer, sae_trainer)
    print(f"Loading SAE with release='{release}', sae_id='{sae_id}'")
    
    # Load the SAE using sae_lens
    sae, _, sparsity = SAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device="cuda"
    )
    
    # Save the SAE locally for future use
    sae.save_model(sae_path, sparsity)
    return sae

# Load SAE
sae = load_sae(config, SAE_PATH, SAE_LAYER, SAE_TRAINER)
sae = sae.to(device)

print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

✓ Found SAE files at: /workspace/sae/gemma-2-9b/saes/resid_post_layer_20/trainer_131k-l0-114
SAE loaded with 131072 features
SAE device: cuda:0


## Activation Extraction Functions

In [9]:
class StopForward(Exception):
    """Exception to stop forward pass after target layer."""
    pass

def find_assistant_position(input_ids: torch.Tensor, attention_mask: torch.Tensor, 
                          assistant_header: str, token_offset: int, tokenizer, device) -> int:
    """Find the position of the assistant token based on the given offset."""
    assistant_tokens = tokenizer.encode(assistant_header, add_special_tokens=False)
    
    # Find where assistant section starts
    assistant_pos = None
    for k in range(len(input_ids) - len(assistant_tokens) + 1):
        if torch.equal(input_ids[k:k+len(assistant_tokens)], torch.tensor(assistant_tokens).to(device)):
            assistant_pos = k + len(assistant_tokens) + token_offset
            break
    
    if assistant_pos is None:
        assistant_pos = attention_mask.sum().item() - 1
    
    # Ensure position is within bounds
    max_pos = attention_mask.sum().item() - 1
    assistant_pos = min(assistant_pos, max_pos)
    assistant_pos = max(assistant_pos, 0)
    
    return int(assistant_pos)

In [10]:
@torch.no_grad()
def extract_activations_and_metadata(prompts: List[str], layer_idx: int) -> Tuple[torch.Tensor, List[Dict], List[str]]:
    """Extract activations and prepare metadata for all prompts."""
    all_activations = []
    all_metadata = []
    formatted_prompts_list = []
    
    target_layer = model.model.layers[layer_idx]
    
    for i in tqdm(range(0, len(prompts), BATCH_SIZE), desc="Processing batches"):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        
        # Format prompts as chat messages
        formatted_prompts = []
        for prompt in batch_prompts:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            formatted_prompts.append(formatted_prompt)
        
        formatted_prompts_list.extend(formatted_prompts)
        
        # Tokenize batch
        batch_inputs = tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        )
        
        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        
        # Hook to capture activations
        activations = None
        
        def hook_fn(module, input, output):
            nonlocal activations
            activations = output[0] if isinstance(output, tuple) else output
            raise StopForward()
        
        handle = target_layer.register_forward_hook(hook_fn)
        
        try:
            _ = model(**batch_inputs)
        except StopForward:
            pass
        finally:
            handle.remove()
        
        # Process each prompt in the batch
        for j, formatted_prompt in enumerate(formatted_prompts):
            attention_mask = batch_inputs["attention_mask"][j]
            input_ids = batch_inputs["input_ids"][j]
            
            # Calculate positions for all token types
            positions = {}
            for token_type, token_offset in TOKEN_OFFSETS.items():
                positions[token_type] = find_assistant_position(
                    input_ids, attention_mask, ASSISTANT_HEADER, token_offset, tokenizer, device
                )
            
            all_activations.append(activations[j].cpu())
            all_metadata.append({
                'prompt_idx': i + j,
                'positions': positions,
                'attention_mask': attention_mask.cpu(),
                'input_ids': input_ids.cpu()
            })
    
    # Pad activations to same length
    max_seq_len = max(act.shape[0] for act in all_activations)
    hidden_dim = all_activations[0].shape[1]
    
    padded_activations = []
    for act in all_activations:
        if act.shape[0] < max_seq_len:
            padding = torch.zeros(max_seq_len - act.shape[0], hidden_dim)
            padded_act = torch.cat([act, padding], dim=0)
        else:
            padded_act = act
        padded_activations.append(padded_act)
    
    return torch.stack(padded_activations, dim=0), all_metadata, formatted_prompts_list

In [11]:
@torch.no_grad()
def extract_token_activations(full_activations: torch.Tensor, metadata: List[Dict]) -> Dict[str, torch.Tensor]:
    """Extract activations for specific token positions."""
    results = {}
    
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = []
    
    for i, meta in enumerate(metadata):
        for token_type, position in meta['positions'].items():
            activation = full_activations[i, position, :]
            results[token_type].append(activation)
    
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = torch.stack(results[token_type], dim=0)
    
    return results

## SAE Processing Functions

In [12]:
@torch.no_grad()
def get_sae_features_batched(activations: torch.Tensor) -> torch.Tensor:
    """Apply SAE to get feature activations with proper batching."""
    activations = activations.to(device)
    
    feature_activations = []
    
    for i in range(0, activations.shape[0], BATCH_SIZE):
        batch = activations[i:i+BATCH_SIZE]
        features = sae.encode(batch)
        feature_activations.append(features.cpu())
    
    return torch.cat(feature_activations, dim=0)

In [13]:
def save_as_pt_cpu(token_features, category_name: str):
    """Save results as PyTorch tensors using CPU computation."""
    source_name = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_{category_name}"
    
    print(f"Processing results for PyTorch format using CPU, source: {source_name}")
    
    results_dict = {}
    
    for token_type in TOKEN_OFFSETS.keys():
        print(f"\nProcessing token type: {token_type}")
        
        features_tensor = token_features[token_type].float()
        
        print(f"Processing all {features_tensor.shape[1]} features for token_type='{token_type}' on CPU")
        
        # Calculate statistics
        all_mean = features_tensor.mean(dim=0)
        all_std = features_tensor.std(dim=0)
        max_vals = features_tensor.max(dim=0)[0]
        
        # Active statistics (only non-zero values)
        active_mask = features_tensor > 0
        num_active = active_mask.sum(dim=0)
        sparsity = num_active.float() / features_tensor.shape[0]
        
        results_dict[token_type] = {
            'all_mean': all_mean,
            'all_std': all_std,
            'max': max_vals,
            'num_active': num_active,
            'sparsity': sparsity,
        }
        
        print(f"Processed all {features_tensor.shape[1]} features for token_type='{token_type}'")
    
    # Add metadata
    results_dict['metadata'] = {
        'source': source_name,
        'model_type': MODEL_TYPE,
        'category': category_name,
        'sae_layer': SAE_LAYER,
        'sae_trainer': SAE_TRAINER,
        'num_prompts': features_tensor.shape[0],
        'num_features': features_tensor.shape[1],
        'token_types': list(TOKEN_OFFSETS.keys())
    }
    
    print(f"\nTotal token types processed: {len(results_dict) - 1}")
    return results_dict

## Process Coding Prompts

In [14]:
print("\n" + "="*60)
print("PROCESSING CODING PROMPTS")
print("="*60)

# Extract activations for coding prompts
print("Extracting activations for all positions...")
coding_full_activations, coding_metadata, coding_formatted_prompts = extract_activations_and_metadata(
    coding_prompts, LAYER_INDEX
)
print(f"Coding full activations shape: {coding_full_activations.shape}")

# Extract activations for specific token positions
print("\nExtracting activations for specific token positions...")
coding_token_activations = extract_token_activations(coding_full_activations, coding_metadata)

for token_type, activations in coding_token_activations.items():
    print(f"Token type '{token_type}' activations shape: {activations.shape}")

# Get SAE features for coding prompts
print("\nComputing SAE features for specific token positions...")
coding_token_features = {}
for token_type, activations in coding_token_activations.items():
    print(f"Processing SAE features for token type '{token_type}'...")
    features = get_sae_features_batched(activations)
    coding_token_features[token_type] = features
    print(f"Features shape for '{token_type}': {features.shape}")

print(f"\nCompleted SAE feature extraction for coding prompts")


PROCESSING CODING PROMPTS
Extracting activations for all positions...


Processing batches:   0%|          | 0/1 [00:00<?, ?it/s]

Coding full activations shape: torch.Size([4, 21, 3584])

Extracting activations for specific token positions...
Token type 'model' activations shape: torch.Size([4, 3584])
Token type 'newline' activations shape: torch.Size([4, 3584])

Computing SAE features for specific token positions...
Processing SAE features for token type 'model'...
Features shape for 'model': torch.Size([4, 131072])
Processing SAE features for token type 'newline'...
Features shape for 'newline': torch.Size([4, 131072])

Completed SAE feature extraction for coding prompts


## Process Medical Prompts

In [15]:
print("\n" + "="*60)
print("PROCESSING MEDICAL PROMPTS")
print("="*60)

# Extract activations for medical prompts
print("Extracting activations for all positions...")
medical_full_activations, medical_metadata, medical_formatted_prompts = extract_activations_and_metadata(
    medical_prompts, LAYER_INDEX
)
print(f"Medical full activations shape: {medical_full_activations.shape}")

# Extract activations for specific token positions
print("\nExtracting activations for specific token positions...")
medical_token_activations = extract_token_activations(medical_full_activations, medical_metadata)

for token_type, activations in medical_token_activations.items():
    print(f"Token type '{token_type}' activations shape: {activations.shape}")

# Get SAE features for medical prompts
print("\nComputing SAE features for specific token positions...")
medical_token_features = {}
for token_type, activations in medical_token_activations.items():
    print(f"Processing SAE features for token type '{token_type}'...")
    features = get_sae_features_batched(activations)
    medical_token_features[token_type] = features
    print(f"Features shape for '{token_type}': {features.shape}")

print(f"\nCompleted SAE feature extraction for medical prompts")


PROCESSING MEDICAL PROMPTS
Extracting activations for all positions...


Processing batches:   0%|          | 0/1 [00:00<?, ?it/s]

Medical full activations shape: torch.Size([4, 20, 3584])

Extracting activations for specific token positions...
Token type 'model' activations shape: torch.Size([4, 3584])
Token type 'newline' activations shape: torch.Size([4, 3584])

Computing SAE features for specific token positions...
Processing SAE features for token type 'model'...
Features shape for 'model': torch.Size([4, 131072])
Processing SAE features for token type 'newline'...
Features shape for 'newline': torch.Size([4, 131072])

Completed SAE feature extraction for medical prompts


## Save Results

In [16]:
# Process and save coding results
print("\n" + "="*60)
print("SAVING CODING RESULTS")
print("="*60)

coding_results_dict = save_as_pt_cpu(coding_token_features, "code")
torch.save(coding_results_dict, CODE_OUTPUT_FILE)
print(f"\nCoding results saved to: {CODE_OUTPUT_FILE}")

# Show preview of coding data structure
print(f"\nCoding file structure:")
print(f"Keys: {list(coding_results_dict.keys())}")
print(f"Metadata: {coding_results_dict['metadata']}")

for token_type in TOKEN_OFFSETS.keys():
    print(f"\n{token_type} statistics shapes:")
    for stat_name, tensor in coding_results_dict[token_type].items():
        print(f"  {stat_name}: {tensor.shape}")
    
    print(f"\n{token_type} sample statistics:")
    print(f"  all_mean - min: {coding_results_dict[token_type]['all_mean'].min():.6f}, max: {coding_results_dict[token_type]['all_mean'].max():.6f}")
    print(f"  sparsity - min: {coding_results_dict[token_type]['sparsity'].min():.6f}, max: {coding_results_dict[token_type]['sparsity'].max():.6f}")
    print(f"  num_active - min: {coding_results_dict[token_type]['num_active'].min():.0f}, max: {coding_results_dict[token_type]['num_active'].max():.0f}")


SAVING CODING RESULTS
Processing results for PyTorch format using CPU, source: gemma_trainer131k-l0-114_layer20_code

Processing token type: model
Processing all 131072 features for token_type='model' on CPU
Processed all 131072 features for token_type='model'

Processing token type: newline
Processing all 131072 features for token_type='newline' on CPU
Processed all 131072 features for token_type='newline'

Total token types processed: 2

Coding results saved to: /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/code.pt

Coding file structure:
Keys: ['model', 'newline', 'metadata']
Metadata: {'source': 'gemma_trainer131k-l0-114_layer20_code', 'model_type': 'gemma', 'category': 'code', 'sae_layer': 20, 'sae_trainer': '131k-l0-114', 'num_prompts': 4, 'num_features': 131072, 'token_types': ['model', 'newline']}

model statistics shapes:
  all_mean: torch.Size([131072])
  all_std: torch.Size([131072])
  max: torch.Size([131072])
  num_active: torch.Size([131072])
  spar

In [17]:
# Process and save medical results
print("\n" + "="*60)
print("SAVING MEDICAL RESULTS")
print("="*60)

medical_results_dict = save_as_pt_cpu(medical_token_features, "medical")
torch.save(medical_results_dict, MEDICAL_OUTPUT_FILE)
print(f"\nMedical results saved to: {MEDICAL_OUTPUT_FILE}")

# Show preview of medical data structure
print(f"\nMedical file structure:")
print(f"Keys: {list(medical_results_dict.keys())}")
print(f"Metadata: {medical_results_dict['metadata']}")

for token_type in TOKEN_OFFSETS.keys():
    print(f"\n{token_type} statistics shapes:")
    for stat_name, tensor in medical_results_dict[token_type].items():
        print(f"  {stat_name}: {tensor.shape}")
    
    print(f"\n{token_type} sample statistics:")
    print(f"  all_mean - min: {medical_results_dict[token_type]['all_mean'].min():.6f}, max: {medical_results_dict[token_type]['all_mean'].max():.6f}")
    print(f"  sparsity - min: {medical_results_dict[token_type]['sparsity'].min():.6f}, max: {medical_results_dict[token_type]['sparsity'].max():.6f}")
    print(f"  num_active - min: {medical_results_dict[token_type]['num_active'].min():.0f}, max: {medical_results_dict[token_type]['num_active'].max():.0f}")


SAVING MEDICAL RESULTS
Processing results for PyTorch format using CPU, source: gemma_trainer131k-l0-114_layer20_medical

Processing token type: model
Processing all 131072 features for token_type='model' on CPU
Processed all 131072 features for token_type='model'

Processing token type: newline
Processing all 131072 features for token_type='newline' on CPU
Processed all 131072 features for token_type='newline'

Total token types processed: 2

Medical results saved to: /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/medical.pt

Medical file structure:
Keys: ['model', 'newline', 'metadata']
Metadata: {'source': 'gemma_trainer131k-l0-114_layer20_medical', 'model_type': 'gemma', 'category': 'medical', 'sae_layer': 20, 'sae_trainer': '131k-l0-114', 'num_prompts': 4, 'num_features': 131072, 'token_types': ['model', 'newline']}

model statistics shapes:
  all_mean: torch.Size([131072])
  all_std: torch.Size([131072])
  max: torch.Size([131072])
  num_active: torch.Size([

In [18]:
print("\n" + "="*80)
print("TASK CATEGORY DIFFING ANALYSIS COMPLETE!")
print("="*80)
print(f"✓ Processed {len(coding_prompts)} coding prompts")
print(f"✓ Processed {len(medical_prompts)} medical prompts")
print(f"✓ Extracted SAE features from layer {SAE_LAYER}")
print(f"✓ Saved coding results to: {CODE_OUTPUT_FILE}")
print(f"✓ Saved medical results to: {MEDICAL_OUTPUT_FILE}")
print(f"\nBoth .pt files contain average SAE feature activations for assistant header token positions.")
print(f"Files are compatible with the existing 4_diffing analysis pipeline.")
print(f"\nNext steps: Use these files to identify features that differ between coding and medical tasks.")


TASK CATEGORY DIFFING ANALYSIS COMPLETE!
✓ Processed 4 coding prompts
✓ Processed 4 medical prompts
✓ Extracted SAE features from layer 20
✓ Saved coding results to: /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/code.pt
✓ Saved medical results to: /workspace/results/4_diffing_tasks/gemma_trainer131k-l0-114_layer20/medical.pt

Both .pt files contain average SAE feature activations for assistant header token positions.
Files are compatible with the existing 4_diffing analysis pipeline.

Next steps: Use these files to identify features that differ between coding and medical tasks.
