In [None]:



# CELL 1:  Imports and Setup
# CELL 2:  Configuration and Constants
# CELL 3:  CNN3Layer Model Class
# CELL 4:  TopK Sparse Autoencoder Class
# CELL 5:  Contrastive Feature Learner Class
# CELL 6:  Contrastive Loss Function
# CELL 7:  Concept Steering Vectors Class
# CELL 8:  Causal Neuron Analyzer Class
# CELL 9:  Feature Clustering Function
# CELL 10: Feature Classifier Class (Phase 1 - Traitors & Heroes)
# CELL 11: Grad-CAM Class (Phase 2)
# CELL 12: Structural Analyzer Class (Phase 3 - Circuits)
# CELL 13: Causal Intervener Class (Phase 4 - The Cure)
# CELL 14: Validation Function - Traitors with Grad-CAM
# CELL 15: Evaluation Function - SAE Reconstruction Quality
# CELL 16: Training Function - Concept Probes
# CELL 17: Main Execution Part 1 - Data Loading
# CELL 18: Main Execution Part 2 - Load Pre-trained Model
# CELL 19: Main Execution Part 3 - Train TopK Sparse Autoencoder
# CELL 20: Main Execution Part 4 - Concept Steering Vectors
# CELL 21: Main Execution Part 5 - Causal Neuron Ablation Analysis
# CELL 22: Main Execution Part 6 - Cluster SAE Features
# CELL 23: Main Execution Part 7 - PHASE 1 Feature Classification
# CELL 24: Main Execution Part 8 - PHASE 2 Grad-CAM Validation
# CELL 25: Main Execution Part 9 - Train Concept Probes
# CELL 26: Main Execution Part 10 - Visualization: Analysis Plots
# CELL 27: Main Execution Part 11 - Visualization: Steering Effect Plot
# CELL 28: Main Execution Part 12 - PHASE 3 Structural Analysis
# CELL 29: Main Execution Part 13 - PHASE 4 Causal Intervention
# CELL 30: Run Main Function




In [None]:
# ============================================================================
# CELL 1: Imports and Setup
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import os
import gc
import time
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Any, Union
from functools import wraps
from tqdm import tqdm

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Timing decorator
def timing_decorator(func):
    """Log execution time of functions."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        elapsed = time.time() - start
        logger.info(f"{func.__name__} completed in {elapsed:.2f}s")
        return result
    return wrapper



In [None]:
# ============================================================================
# CELL 2: Configuration and Constants
# ============================================================================

@dataclass
class Config:
    """Centralized configuration for SAE analysis."""
    
    # Paths
    train_data_path: str = '/kaggle/input/cmnistneo1/train_data_rg95z.npz'
    test_data_path: str = '/kaggle/input/cmnistneo1/test_data_gr95z.npz'
    model_path: str = '/kaggle/input/task1app3models/pytorch/default/2/task1approach3sc1_modelv1.pth'
    
    # SAE Hyperparameters
    topk_k: int = 32  # Number of active features in TopK SAE
    sae_hidden_dim: int = 512
    sae_epochs: int = 20
    sae_batch_size: int = 64
    sae_learning_rate: float = 0.001
    
    # Model Architecture
    num_classes: int = 10
    conv1_channels: int = 32
    conv2_channels: int = 64
    conv3_channels: int = 64
    fc1_units: int = 128
    dropout_rate: float = 0.1
    
    # Analysis Thresholds
    traitor_threshold: float = 2.0  # Features with >2x activation on color vs BW
    hero_lower_bound: float = 0.8   # Shape-invariant features lower bound
    hero_upper_bound: float = 1.2   # Shape-invariant features upper bound
    activation_epsilon: float = 1e-4  # Prevent division by zero
    min_activation: float = 0.01     # Ignore effectively dead neurons
    correlation_threshold: float = 0.3  # Circuit detection threshold
    dead_feature_threshold: float = 1e-5  # Variance threshold for dead feature detection
    circuit_threshold: float = 0.3  # Alias for correlation_threshold (backward compatibility)
    
    # Surgery/Intervention parameters
    traitor_suppression: float = 0.09  # Multiply traitors by this (0.1 = 90% suppression)
    max_hero_boost: float = 4.0  # Maximum multiplier for hero features
    
    # Memory Management
    memory_limit_samples: int = 500
    batch_size_inference: int = 64
    
    # Device
    device: str = field(default_factory=lambda: 'cuda' if torch.cuda.is_available() else 'cpu')
    
    def to_dict(self) -> Dict[str, Any]:
        """Export configuration as dictionary."""
        return {k: v for k, v in self.__dict__.items()}

# Global config instance
config = Config()
device = torch.device(config.device)
logger.info(f"Using device: {device}")

# Legacy constants for backward compatibility (will be removed)
TRAIN_DATA_PATH = config.train_data_path
TEST_DATA_PATH = config.test_data_path
MODEL_PATH = config.model_path
TOPK_K = config.topk_k
SAE_HIDDEN_DIM = config.sae_hidden_dim
SAE_EPOCHS = config.sae_epochs
SAE_BATCH_SIZE = config.sae_batch_size
SAE_LEARNING_RATE = config.sae_learning_rate
MEMORY_LIMIT_SAMPLES = config.memory_limit_samples
NUM_CLASSES = config.num_classes
CONV1_CHANNELS = config.conv1_channels
CONV2_CHANNELS = config.conv2_channels
CONV3_CHANNELS = config.conv3_channels
FC1_UNITS = config.fc1_units
DROPOUT_RATE = config.dropout_rate


In [None]:

# ============================================================================
# CELL 2.5: Utility Classes
# ============================================================================

class ActivationUtils:
    """Centralized activation computation utilities."""
    
    @staticmethod
    def get_activations_batched(
        model: nn.Module,
        images: np.ndarray,
        device: torch.device,
        batch_size: int = 64,
        layer: str = 'fc1',
        return_tensor: bool = False,
        show_progress: bool = False
    ) -> Union[np.ndarray, torch.Tensor]:
        """
        Compute activations with consistent API.
        
        Args:
            model: Neural network model
            images: numpy array (N, H, W, C)
            device: torch.device
            batch_size: Batch size for processing
            layer: 'fc1' or 'conv3'
            return_tensor: If True, return torch.Tensor; else numpy
            show_progress: Show progress bar
        
        Returns:
            Activations as torch.Tensor (if return_tensor=True) or numpy.ndarray
        """
        all_acts = []
        model.eval()
        
        iterator = range(0, len(images), batch_size)
        if show_progress:
            iterator = tqdm(iterator, desc=f"Computing {layer} activations", leave=False)
        
        with torch.no_grad():
            for i in iterator:
                batch = torch.FloatTensor(images[i:i+batch_size])
                batch = batch.permute(0, 3, 1, 2).to(device)
                
                if layer == 'fc1':
                    acts = model.get_fc1_activations(batch)
                elif layer == 'conv3':
                    # Get output after conv3
                    x = model.pool1(F.relu(model.conv1(batch)))
                    x = model.pool2(F.relu(model.conv2(x)))
                    acts = model.pool3(F.relu(model.conv3(x)))
                else:
                    raise ValueError(f"Unknown layer: {layer}")
                
                if not return_tensor:
                    acts = acts.cpu()
                
                all_acts.append(acts)
                del batch
                
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        if return_tensor:
            return torch.cat(all_acts, dim=0)
        else:
            return torch.cat(all_acts, dim=0).numpy()


class ActivationCache:
    """
    Cache model activations to avoid recomputation.
    Saves ~80% of forward pass time in analysis phase.
    """
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.cache: Dict[int, np.ndarray] = {}
    
    def get_activations(
        self, 
        images: np.ndarray, 
        layer: str = 'fc1',
        batch_size: int = 64
    ) -> np.ndarray:
        """Get cached activations or compute if not cached."""
        key = self._hash_images(images)
        
        if key not in self.cache:
            self.cache[key] = ActivationUtils.get_activations_batched(
                self.model, images, self.device, batch_size, layer
            )
        
        return self.cache[key]
    
    def _hash_images(self, images: np.ndarray) -> int:
        """Fast hash using first/last image checksums and length."""
        return hash((images[0].tobytes(), images[-1].tobytes(), len(images)))
    
    def clear(self):
        """Free cached memory."""
        self.cache.clear()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


class PlotContext:
    """Context manager to ensure matplotlib figures are properly closed."""
    
    def __init__(self, *args, **kwargs):
        self.fig = plt.figure(*args, **kwargs)
    
    def __enter__(self):
        return self.fig
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        plt.close(self.fig)
        gc.collect()


def get_stratified_subset(
    images: np.ndarray, 
    labels: np.ndarray, 
    n_samples: int = 500, 
    random_state: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Get balanced subset with equal class representation.
    Better than arbitrary [:500] which might miss classes.
    """
    from sklearn.model_selection import train_test_split
    
    if len(images) <= n_samples:
        return images, labels
    
    _, subset_images, _, subset_labels = train_test_split(
        images, labels, 
        test_size=n_samples,
        stratify=labels,
        random_state=random_state
    )
    
    return subset_images, subset_labels



In [None]:
# ============================================================================
# CELL 3: CNN3Layer Model Class
# ============================================================================



class CNN3Layer(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CNN3Layer, self).__init__()
        self.conv1 = nn.Conv2d(3, CONV1_CHANNELS, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(CONV1_CHANNELS, CONV2_CHANNELS, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(CONV2_CHANNELS, CONV3_CHANNELS, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(CONV3_CHANNELS * 3 * 3, FC1_UNITS)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.fc2 = nn.Linear(FC1_UNITS, num_classes)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_fc1_activations(self, x):
        """Get FC1 activations (the hidden state we'll analyze)"""
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x



In [None]:
# ============================================================================
# CELL 4: TopK Sparse Autoencoder Class
# ============================================================================

class TopKSparseAutoencoder(nn.Module):

    def __init__(self, input_dim, hidden_dim, k):
        super(TopKSparseAutoencoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.k = k
        
        # Learned encoder and decoder
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        
        # Normalize decoder weights (helps with feature interpretability)
        with torch.no_grad():
            self.decoder.weight.data = F.normalize(self.decoder.weight.data, dim=0)
    
    def encode(self, x):
        """Encode with TopK sparsity constraint"""
        pre_act = self.encoder(x)
        
        # TopK: only keep the top K activations, set rest to 0
        topk_values, topk_indices = torch.topk(pre_act, self.k, dim=-1)
        
        # Create sparse activation
        sparse_act = torch.zeros_like(pre_act)
        sparse_act.scatter_(-1, topk_indices, F.relu(topk_values))
        
        return sparse_act, topk_indices
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z, indices = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z, indices
      

In [None]:
# ============================================================================
# CELL 5: Contrastive Feature Learner Class
# ============================================================================

class ContrastiveFeatureLearner(nn.Module):

    def __init__(self, input_dim, feature_dim=64):
        super(ContrastiveFeatureLearner, self).__init__()
        self.feature_dim = feature_dim
        
        # Shape-invariant feature extractor
        self.shape_encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim),
        )
        
        # Color-specific feature extractor
        self.color_encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim),
        )
    
    def forward(self, x):
        shape_features = self.shape_encoder(x)
        color_features = self.color_encoder(x)
        return shape_features, color_features



In [None]:
# ============================================================================
# CELL 6: Contrastive Loss Function
# ============================================================================

def contrastive_loss(shape_features, color_features, labels, margin=1.0):
    """
    Custom contrastive loss:
    - Shape features of same class should be similar
    - Color features should capture the remaining variance
    """
    batch_size = shape_features.size(0)
    
    # Normalize features
    shape_norm = F.normalize(shape_features, dim=1)
    color_norm = F.normalize(color_features, dim=1)
    
    # Shape similarity matrix
    shape_sim = torch.mm(shape_norm, shape_norm.t())
    
    # Create label mask (1 if same class, 0 otherwise)
    labels = labels.view(-1, 1)
    label_mask = (labels == labels.t()).float()
    
    # Contrastive loss for shape features
    # Pull same-class together, push different-class apart
    pos_loss = (1 - shape_sim) * label_mask  # Same class should be similar
    neg_loss = F.relu(shape_sim - margin) * (1 - label_mask)  # Different class apart
    
    shape_loss = (pos_loss.sum() + neg_loss.sum()) / (batch_size * batch_size)
    
    # Reconstruction: shape + color should reconstruct original
    # (This is optional but helps training)
    
    return shape_loss



In [None]:

# ============================================================================
# CELL 7: Concept Steering Vectors Class
# ============================================================================

class ConceptSteeringVectors:
    """Compute and apply concept steering vectors for interventions."""

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.color_vectors: Dict[int, np.ndarray] = {}
        self.shape_vectors: Dict[int, np.ndarray] = {}
        self.mean_activations: Dict[str, np.ndarray] = {}
        self.global_color_vector: Optional[np.ndarray] = None
    
    def compute_steering_vectors(
        self, 
        env1_images: np.ndarray, 
        env1_labels: np.ndarray, 
        env2_images: np.ndarray, 
        env2_labels: np.ndarray
    ) -> np.ndarray:
        """
        Compute steering vectors from two different color environments.
        
        Args:
            env1_images: Images from environment 1 (N, H, W, C)
            env1_labels: Labels for env1 (N,)
            env2_images: Images from environment 2 (M, H, W, C)
            env2_labels: Labels for env2 (M,)
            
        Returns:
            Global color steering vector (FC1_UNITS,)
        """
        self.model.eval()
        
        # Get activations for both environments using centralized utility
        logger.info("Computing activations for environment 1...")
        env1_acts = ActivationUtils.get_activations_batched(
            self.model, env1_images, self.device, show_progress=True
        )
        logger.info("Computing activations for environment 2...")
        env2_acts = ActivationUtils.get_activations_batched(
            self.model, env2_images, self.device, show_progress=True
        )
        
        # Compute per-class means
        self.color_vectors = {}
        self.mean_activations = {}
        
        for class_idx in range(10):
            env1_mask = env1_labels == class_idx
            env2_mask = env2_labels == class_idx
            
            if env1_mask.sum() > 0 and env2_mask.sum() > 0:
                env1_mean = env1_acts[env1_mask].mean(axis=0)
                env2_mean = env2_acts[env2_mask].mean(axis=0)
                
                # Color vector: captures color changes for same shape
                self.color_vectors[class_idx] = env1_mean - env2_mean
                
                self.mean_activations[f'env1_class{class_idx}'] = env1_mean
                self.mean_activations[f'env2_class{class_idx}'] = env2_mean
        
        # Global color vector: average across classes
        if self.color_vectors:
            all_color_vecs = np.stack(list(self.color_vectors.values()))
            self.global_color_vector = all_color_vecs.mean(axis=0)
            self.global_color_vector = self.global_color_vector / np.linalg.norm(self.global_color_vector)
        else:
            self.global_color_vector = np.zeros(FC1_UNITS)
        
        logger.info(f"Computed color steering vectors for {len(self.color_vectors)} classes")
        logger.info(f"Global color vector norm: {np.linalg.norm(self.global_color_vector):.4f}")
        
        return self.global_color_vector

    
    def steer_activations(
        self, 
        activations: torch.Tensor, 
        direction: str = 'remove_color', 
        strength: float = 1.0
    ) -> torch.Tensor:
        """Apply steering intervention to activations."""
        color_vec = torch.FloatTensor(self.global_color_vector).to(self.device)
        
        # Project activations onto color direction
        color_projection = torch.sum(activations * color_vec, dim=-1, keepdim=True)
        
        if direction == 'remove_color':
            steered = activations - strength * color_projection * color_vec
        else:
            steered = activations + strength * color_projection * color_vec
        
        return steered
    
    def intervention_experiment(
        self, 
        images: np.ndarray, 
        labels: np.ndarray, 
        strengths: List[float] = None
    ) -> List[Dict]:
        """Test how removing color direction affects predictions."""
        if strengths is None:
            strengths = [0.0, 0.5, 1.0, 2.0]
        
        self.model.eval()
        results = []
        
        with torch.no_grad():
            images_tensor = torch.FloatTensor(images).permute(0, 3, 1, 2).to(self.device)
            original_acts = self.model.get_fc1_activations(images_tensor)
            
            for strength in tqdm(strengths, desc="Testing intervention strengths"):
                # Steer activations
                steered_acts = self.steer_activations(original_acts, 'remove_color', strength)
                
                # Continue forward pass
                output = self.model.fc2(self.model.dropout(steered_acts))
                predictions = torch.argmax(output, dim=1).cpu().numpy()
                
                accuracy = (predictions == labels).mean() * 100
                
                results.append({
                    'strength': strength,
                    'accuracy': accuracy,
                    'predictions': predictions
                })
                
                logger.info(f"Steering strength {strength:.1f}: Accuracy = {accuracy:.2f}%")
        
        return results



In [None]:
# ============================================================================
# CELL 8: Causal Neuron Analyzer Class
# ============================================================================

class CausalNeuronAnalyzer:
    """Analyze causal effects of individual neurons through ablation."""

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.neuron_effects: Dict[int, Dict[str, float]] = {}
    
    def ablate_neuron(
        self, 
        activations: torch.Tensor, 
        labels: np.ndarray, 
        neuron_idx: int, 
        ablation_type: str = 'zero'
    ) -> float:
        """
        Test model accuracy when a specific FC1 neuron is ablated.
        Uses pre-computed activations for speed.
        
        Args:
            activations: Pre-computed FC1 activations (N, FC1_UNITS)
            labels: Ground truth labels (N,)
            neuron_idx: Index of neuron to ablate
            ablation_type: 'zero' or 'mean'
            
        Returns:
            Accuracy after ablation (percentage)
        """
        self.model.eval()
        
        with torch.no_grad():
            # Ablate the specific neuron
            ablated_acts = activations.clone()
            if ablation_type == 'zero':
                ablated_acts[:, neuron_idx] = 0
            else:  # mean
                ablated_acts[:, neuron_idx] = activations[:, neuron_idx].mean()
            
            # Continue forward pass (only FC2 needed)
            output = self.model.fc2(self.model.dropout(ablated_acts))
            predictions = torch.argmax(output, dim=1).cpu().numpy()
            
            accuracy = (predictions == labels).mean() * 100
        
        return accuracy
    
    @timing_decorator
    def find_causal_neurons(
        self, 
        env1_images: np.ndarray, 
        env1_labels: np.ndarray, 
        env2_images: np.ndarray, 
        env2_labels: np.ndarray, 
        num_neurons: int = 128
    ) -> Tuple[List, List]:
        """
        Find neurons with causal effects through systematic ablation.
        
        Args:
            env1_images: Images from environment 1
            env1_labels: Labels for env1
            env2_images: Images from environment 2
            env2_labels: Labels for env2  
            num_neurons: Number of neurons to analyze
            
        Returns:
            (color_neurons, shape_neurons) tuple of neuron indices
        """
        logger.info("Analyzing causal effect of each neuron...")
        
        # Baseline accuracies
        self.model.eval()
        
        with torch.no_grad():
            # Pre-compute activations ONCE using centralized utility
            logger.info("Pre-computing activations for efficiency...")
            env1_acts = ActivationUtils.get_activations_batched(
                self.model, env1_images, self.device, return_tensor=True, show_progress=True
            )
            env2_acts = ActivationUtils.get_activations_batched(
                self.model, env2_images, self.device, return_tensor=True, show_progress=True
            )
            
            # Get baseline accuracy from activations
            def get_acc_from_acts(acts, labels):
                out = self.model.fc2(self.model.dropout(acts))
                preds = torch.argmax(out, dim=1).cpu().numpy()
                return (preds == labels).mean() * 100

            env1_baseline = get_acc_from_acts(env1_acts, env1_labels)
            env2_baseline = get_acc_from_acts(env2_acts, env2_labels)
        
            logger.info(f"Baseline accuracies: Env1={env1_baseline:.2f}%, Env2={env2_baseline:.2f}%")
            
            # Test each neuron with progress bar
            self.neuron_effects = {}
            color_neurons = []
            shape_neurons = []
            
            for neuron_idx in tqdm(range(num_neurons), desc="Analyzing neurons"):
                # Pass activations instead of images
                env1_ablated = self.ablate_neuron(env1_acts, env1_labels, neuron_idx)
                env2_ablated = self.ablate_neuron(env2_acts, env2_labels, neuron_idx)
                
                env1_effect = env1_baseline - env1_ablated
                env2_effect = env2_baseline - env2_ablated
                
                self.neuron_effects[neuron_idx] = {
                    'env1_effect': env1_effect,
                    'env2_effect': env2_effect,
                    'asymmetry': env1_effect - env2_effect
                }
                
                if env1_effect > 2 and env2_effect < 0:
                    color_neurons.append(neuron_idx)
                elif abs(env1_effect - env2_effect) < 1 and (env1_effect > 0.5 or env2_effect > 0.5):
                    shape_neurons.append(neuron_idx)
            
            # Sort and print results
            sorted_neurons = sorted(self.neuron_effects.items(), 
                                key=lambda x: x[1]['asymmetry'], reverse=True)
            
            logger.info(f"\nTop 10 COLOR-SPECIFIC neurons:")
            for neuron_idx, effects in sorted_neurons[:10]:
                logger.info(f"  Neuron {neuron_idx}: Env1 effect={effects['env1_effect']:.2f}%, "
                          f"Asymmetry={effects['asymmetry']:.2f}")
            
            return sorted_neurons[:10], sorted_neurons[-10:]




In [None]:
# ============================================================================
# CELL 9: Feature Clustering Function
# ============================================================================

@timing_decorator
def cluster_features_by_environment(model, sae, env1_data, env2_data, n_clusters=4):
    """
    Cluster SAE features based on cross-environment behavior.
    
    Args:
        model: CNN model
        sae: Sparse autoencoder
        env1_data: (images, labels) tuple for environment 1
        env2_data: (images, labels) tuple for environment 2
        n_clusters: Number of clusters
        
    Returns:
        (clusters, cluster_info, feature_characteristics) tuple
    """
    model.eval()
    sae.eval()
    
    env1_images, env1_labels = env1_data
    env2_images, env2_labels = env2_data
    
    logger.info("Computing SAE encodings for both environments...")
    
    with torch.no_grad():
        # Use centralized utility
        env1_acts = ActivationUtils.get_activations_batched(
            model, env1_images, device, show_progress=True, return_tensor=True
        )
        env2_acts = ActivationUtils.get_activations_batched(
            model, env2_images, device, show_progress=True, return_tensor=True
        )
        
        # Get SAE encodings
        env1_z, _ = sae.encode(env1_acts)
        env2_z, _ = sae.encode(env2_acts)
        
        env1_z = env1_z.cpu().numpy()
        env2_z = env2_z.cpu().numpy()
    
    # Compute per-class mean activations for each SAE feature
    n_features = env1_z.shape[1]
    env1_class_means = np.zeros((10, n_features))
    env2_class_means = np.zeros((10, n_features))
    
    for c in range(10):
        env1_mask = env1_labels == c
        env2_mask = env2_labels == c
        
        if env1_mask.sum() > 0:
            env1_class_means[c] = env1_z[env1_mask].mean(axis=0)
        if env2_mask.sum() > 0:
            env2_class_means[c] = env2_z[env2_mask].mean(axis=0)
    
    # Feature characteristics: how each feature differs across environments
    feature_characteristics = np.zeros((n_features, 3))
    
    for feat_idx in range(n_features):
        # Mean activation difference across classes
        mean_diff = np.abs(env1_class_means[:, feat_idx] - env2_class_means[:, feat_idx]).mean()
        
        # Variance within environment (selectivity)
        env1_var = env1_class_means[:, feat_idx].var()
        env2_var = env2_class_means[:, feat_idx].var()
        
        feature_characteristics[feat_idx] = [mean_diff, env1_var, env2_var]
    
    # Cluster features
    logger.info(f"Clustering {n_features} features into {n_clusters} groups...")
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    clusters = kmeans.fit_predict(feature_characteristics)
    
    # Analyze clusters
    logger.info("\n=== Feature Clustering by Environment Behavior ===")
    cluster_info = []
    
    for cluster_idx in range(n_clusters):
        cluster_mask = clusters == cluster_idx
        cluster_features = np.where(cluster_mask)[0]
        
        mean_env_diff = feature_characteristics[cluster_mask, 0].mean()
        mean_selectivity = (feature_characteristics[cluster_mask, 1] + 
                           feature_characteristics[cluster_mask, 2]).mean() / 2
        
        cluster_info.append({
            'cluster': cluster_idx,
            'n_features': len(cluster_features),
            'mean_env_diff': mean_env_diff,
            'mean_selectivity': mean_selectivity,
            'features': cluster_features[:10]  # First 10
        })
        
        cluster_type = "COLOR" if mean_env_diff > np.median(feature_characteristics[:, 0]) else "SHAPE"
        logger.info(f"\nCluster {cluster_idx} ({cluster_type}):")
        logger.info(f"  {len(cluster_features)} features")
        logger.info(f"  Mean env difference: {mean_env_diff:.4f}")
        logger.info(f"  Mean selectivity: {mean_selectivity:.4f}")
        logger.info(f"  Example features: {cluster_features[:10]}")
    
    return clusters, cluster_info, feature_characteristics



In [None]:
# ============================================================================
# CELL 10: Feature Classifier Class (Phase 1 - Traitors & Heroes)
# ============================================================================

# ==========================================
# PHASE 1: Feature Identification (Traitors & Heroes)
# ==========================================

class FeatureClassifier:
    def __init__(self, model, sae, device):
        self.model = model
        self.sae = sae
        self.device = device
        self.traitors = []
        self.heroes = []
        self.sensitivity_scores = None

    def classify_features(self, images: np.ndarray, batch_size: int = 64) -> Tuple[List, List]:
        """
        Classify SAE features based on sensitivity to color.
        Generates B&W counterparts in-memory to ensure perfect pairing.
        
        Args:
            images: RGB images (N, H, W, 3)
            batch_size: Batch size for processing
            
        Returns:
            (traitors, heroes) tuple of (feature_idx, score) lists
        """
        logger.info("Running Feature Sensitivity Analysis...")
        self.model.eval()
        self.sae.eval()

        # 1. Prepare Data Pairs (Color vs BW)
        images_bw = images.mean(axis=3, keepdims=True).repeat(3, axis=3)
        
        # 2. Extract SAE Latents for both
        def get_latents(img_data):
            z_list = []
            with torch.no_grad():
                for i in tqdm(range(0, len(img_data), batch_size), 
                             desc="Extracting latents", leave=False):
                    batch = torch.FloatTensor(img_data[i:i+batch_size]).permute(0, 3, 1, 2).to(self.device)
                    acts = self.model.get_fc1_activations(batch)
                    _, z, _ = self.sae(acts)
                    z_list.append(z.cpu().numpy())
            return np.concatenate(z_list, axis=0)

        logger.info("Extracting latents for colored images...")
        z_color = get_latents(images)
        logger.info("Extracting latents for B&W counterparts...")
        z_bw = get_latents(images_bw)

        # 3. Calculate Sensitivity Index
        mean_act_color = z_color.mean(axis=0)
        mean_act_bw = z_bw.mean(axis=0)
        
        self.sensitivity_scores = (mean_act_color + config.activation_epsilon) / \
                                 (mean_act_bw + config.activation_epsilon)
        
        # 4. Classification (Vectorized)
        active_mask = mean_act_color > config.min_activation
        traitor_mask = (self.sensitivity_scores > config.traitor_threshold) & active_mask
        hero_mask = ((self.sensitivity_scores >= config.hero_lower_bound) & 
                     (self.sensitivity_scores <= config.hero_upper_bound) & active_mask)
        
        # Convert to list of tuples
        self.traitors = list(zip(
            np.where(traitor_mask)[0].tolist(), 
            self.sensitivity_scores[traitor_mask].tolist()
        ))
        self.heroes = list(zip(
            np.where(hero_mask)[0].tolist(), 
            self.sensitivity_scores[hero_mask].tolist()
        ))

        # Sort by score
        self.traitors.sort(key=lambda x: x[1], reverse=True)
        self.heroes.sort(key=lambda x: abs(1-x[1]))

        logger.info(f"Found {len(self.traitors)} Traitors (Color-Obsessed)")
        logger.info(f"Found {len(self.heroes)} Heroes (Shape-Invariant)")
        
        if self.traitors:
            logger.info(f"Top Traitor: Feature {self.traitors[0][0]} (Score {self.traitors[0][1]:.2f})")
        if self.heroes:
            logger.info(f"Top Hero: Feature {self.heroes[0][0]} (Score {self.heroes[0][1]:.2f})")
            
        return self.traitors, self.heroes

    def plot_sensitivity_analysis(self, save_path: str = 'feature_sensitivity.png'):
        """Generate sensitivity analysis visualization."""
        if self.sensitivity_scores is None:
            logger.warning("No sensitivity scores to plot")
            return
        
        scores = self.sensitivity_scores
        n_feats = len(scores)
        
        # Reshape for grid (Target 16x32 for 512 features, else square)
        if n_feats == 512:
            grid = scores.reshape(16, 32)
        else:
            side = int(np.ceil(np.sqrt(n_feats)))
            grid = np.zeros(side*side)
            grid[:n_feats] = scores
            grid = grid.reshape(side, side)
            
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        try:
            # 1. Pixel Grid (Red=Traitor, Blue=Hero)
            im = axes[0].imshow(grid, cmap='coolwarm', vmin=0, vmax=3.0, aspect='auto')
            axes[0].set_title(f"Feature Sensitivity Map ({n_feats} Features)\nRed = Traitors (>{config.traitor_threshold}) | Blue = Heroes (~1.0)")
            axes[0].axis('off')
            plt.colorbar(im, ax=axes[0], label='Sensitivity Index')
            
            # 2. Histogram with Thresholds
            n_traitors = (scores > config.traitor_threshold).sum()
            n_heroes = ((scores > config.hero_lower_bound) & (scores < config.hero_upper_bound)).sum()
            
            axes[1].hist(scores, bins=50, color='gray', alpha=0.7, log=True)
            
            # Thresholds
            axes[1].axvline(config.traitor_threshold, color='red', linestyle='--', linewidth=2)
            axes[1].axvline(1.0, color='blue', linestyle='--', linewidth=2)
            
            # Add Text Annotations
            ymin, ymax = axes[1].get_ylim()
            text_y = ymax * 0.5 
            
            axes[1].text(config.traitor_threshold + 0.1, text_y, 
                        f'Traitors\n(>{config.traitor_threshold})\nn={n_traitors}', 
                        color='red', fontweight='bold', ha='left')
            axes[1].text(0.9, text_y, f'Heroes\n(~1.0)\nn={n_heroes}', 
                        color='blue', fontweight='bold', ha='right')

            axes[1].set_title("Sensitivity Distribution")
            axes[1].set_xlabel("Sensitivity Index (Color/BW)")
            axes[1].set_ylabel("Count (Log)")
            
            plt.tight_layout()
            plt.savefig(save_path)
            logger.info(f"Saved sensitivity plot to {save_path}")
            plt.show()
        finally:
            plt.close(fig)
            gc.collect()



In [None]:
# ============================================================================
# CELL 11: Grad-CAM Class (Phase 2)
# ============================================================================

# ==========================================
# PHASE 2: Grad-CAM Implementation
# ==========================================

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.handle_fwd = None
        self.handle_bwd = None
        self._register_hooks()

    def _register_hooks(self):
        def save_activation(module, input, output):
            self.activations = output.detach()

        def save_gradient(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        self.handle_fwd = self.target_layer.register_forward_hook(save_activation)
        # Use register_full_backward_hook for newer pytorch, or register_backward_hook for older
        try:
            self.handle_bwd = self.target_layer.register_full_backward_hook(save_gradient)
        except AttributeError:
            self.handle_bwd = self.target_layer.register_backward_hook(save_gradient)

    def remove_hooks(self):
        if self.handle_fwd: self.handle_fwd.remove()
        if self.handle_bwd: self.handle_bwd.remove()

    def generate_heatmap(self, input_tensor, target_class_idx):
        """
        Generate Grad-CAM heatmap for a specific target class.
        input_tensor: (1, 3, 28, 28)
        """
        self.model.eval()
        self.model.zero_grad()
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Target specific class
        if target_class_idx is None:
            target_class_idx = output.argmax(dim=1).item()
            
        score = output[:, target_class_idx]
        
        # Backward pass
        score.backward()
        
        # Generate CAM
        # Global average pooling of gradients
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        
        # Weighted combination of activations
        cam = torch.sum(weights * self.activations, dim=1).squeeze()
        
        # ReLU
        cam = F.relu(cam)
        
        # Resize using torch interpolation (replaces cv2)
        # Assuming cam is (H, W), we need (1, 1, H, W) for interpolate
        if len(cam.shape) == 2:
            cam = cam.unsqueeze(0).unsqueeze(0)
            
        cam = F.interpolate(cam, size=(28, 28), mode='bilinear', align_corners=False)
        
        # Normalize
        cam = cam.squeeze().cpu().numpy()
        cam = cam - np.min(cam)
        cam = cam / (np.max(cam) + 1e-8)
        
        return cam, target_class_idx



In [None]:

# ============================================================================
# CELL 12: Structural Analyzer Class (Phase 3 - Circuits)
# ============================================================================

# ==========================================
# PHASE 3: Structural Analysis (Circuits)
# ==========================================

class StructuralAnalyzer:
    def __init__(self, sae, device):
        self.sae = sae
        self.device = device
        self.correlation_matrix = None

    @timing_decorator
    def analyze_correlations(self, images: np.ndarray, model: nn.Module, 
                           batch_size: int = 64) -> np.ndarray:
        """Compute correlation matrix of SAE features across dataset."""
        logger.info("Running Structural Analysis (Polysemanticity check)")
        model.eval()
        self.sae.eval()
        
        # Get all latent activations
        all_z = []
        with torch.no_grad():
            for i in tqdm(range(0, len(images), batch_size), desc="Computing correlations"):
                batch = torch.FloatTensor(images[i:i+batch_size]).permute(0, 3, 1, 2).to(self.device)
                acts = model.get_fc1_activations(batch)
                _, z, _ = self.sae(acts)
                all_z.append(z.cpu().numpy())
        
        Z = np.concatenate(all_z, axis=0)  # Shape: (N, HiddenDim)
        
        # Filter dead neurons to avoid NaN correlations
        active_indices = np.where(Z.var(axis=0) > config.dead_feature_threshold)[0]
        Z_active = Z[:, active_indices]
        
        logger.info(f"Computing correlation matrix for {len(active_indices)} active features")
        # Add small epsilon noise to break perfect symmetries if any
        Z_active = Z_active + np.random.normal(0, 1e-9, Z_active.shape)
        
        self.correlation_matrix = np.corrcoef(Z_active, rowvar=False)
        self.active_indices = active_indices
        return self.correlation_matrix

    def find_circuits(self, traitors: List[Tuple[int, float]], 
                     heroes: List[Tuple[int, float]]) -> List[Dict]:
        """Find high correlations between Traitors (Color) and Heroes (Shape)."""
        if self.correlation_matrix is None:
            logger.warning("Run analyze_correlations first")
            return []

        logger.info("Scanning for 'Circuits' (Traitor-Hero Pairs)")
        
        # Create map from original index to active matrix index
        idx_map = {orig: new for new, orig in enumerate(self.active_indices)}
        
        circuits = []
        threshold = config.circuit_threshold
        
        for t_idx, t_score in traitors:
            if t_idx not in idx_map: continue
            
            for h_idx, h_score in heroes:
                if h_idx not in idx_map: continue
                
                corr = self.correlation_matrix[idx_map[t_idx], idx_map[h_idx]]
                
                if corr > threshold:
                    circuits.append({
                        'Traitor': t_idx, 'Traitor_Score': t_score,
                        'Hero': h_idx, 'Hero_Score': h_score,
                        'Correlation': corr
                    })
        
        # Sort by correlation strength
        circuits.sort(key=lambda x: x['Correlation'], reverse=True)
        
        logger.info(f"Found {len(circuits)} potential circuits (Correlation > {threshold})")
        for i, c in enumerate(circuits[:5]):
            logger.info(f"  Circuit {i}: Traitor #{c['Traitor']} (Color) <--> Hero #{c['Hero']} (Shape) | r={c['Correlation']:.4f}")
            
        return circuits

    def plot_circuit_analysis(self, traitors: List[Tuple[int, float]], 
                             heroes: List[Tuple[int, float]], 
                             save_path: str = 'circuit_analysis.png'):
        """
        Visualize the interaction between Top Traitors and Top Heroes.
        Plots a correlation heatmap for the subset of features.
        """
        if self.correlation_matrix is None:
            logger.warning("No correlation matrix to plot")
            return
        
        # Filter for Top 20 of each for visibility
        t_indices = [t[0] for t in traitors[:20]] 
        h_indices = [h[0] for h in heroes[:20]]
        
        if not t_indices or not h_indices:
            logger.warning("No traitor or hero features to plot")
            return

        # Map original feature indices to the active_indices used in correlation matrix
        idx_map = {orig: new for new, orig in enumerate(self.active_indices)}
        
        matrix_subset = np.zeros((len(t_indices), len(h_indices)))
        
        # Fill subset matrix
        for r, t_idx in enumerate(t_indices):
            if t_idx in idx_map:
                for c, h_idx in enumerate(h_indices):
                    if h_idx in idx_map:
                        matrix_subset[r, c] = self.correlation_matrix[idx_map[t_idx], idx_map[h_idx]]
        
        fig, ax = plt.subplots(figsize=(10, 8))
        
        try:
            im = ax.imshow(matrix_subset, cmap='coolwarm', vmin=-1, vmax=1)
            plt.colorbar(im, label='Pearson Correlation')
            
            # Axis labels
            ax.set_xticks(np.arange(len(h_indices)))
            ax.set_yticks(np.arange(len(t_indices)))
            ax.set_xticklabels([f"H{h}" for h in h_indices], rotation=45, ha='right')
            ax.set_yticklabels([f"T{t}" for t in t_indices])
            
            ax.set_xlabel("Heroes (Shape)")
            ax.set_ylabel("Traitors (Color)")
            ax.set_title("Circuit Map: Color-Shape Dependencies\nRed = Positive Correlation (Co-occurence)")
            
            plt.tight_layout()
            plt.savefig(save_path)
            logger.info(f"Saved Circuit Analysis to {save_path}")
            plt.show()
        finally:
            plt.close(fig)
            gc.collect()


In [None]:

# ============================================================================
# CELL 13: Causal Intervener Class (Phase 4 - The Cure)
# ============================================================================

# ==========================================
# PHASE 4: Causal Intervention (The Cure)
# ==========================================

class CausalIntervener:
    def __init__(self, model, sae, device):
        self.model = model
        self.sae = sae
        self.device = device

    def find_failure_cases(self, images: np.ndarray, labels: np.ndarray, 
                          batch_size: int = 64) -> List[Dict]:
        """Identify images where the model fails (Predictions != Labels)."""
        self.model.eval()
        failures = []
        
        with torch.no_grad():
            for i in tqdm(range(0, len(images), batch_size), desc="Finding failures"):
                batch_imgs = images[i:i+batch_size]
                batch_lbls = labels[i:i+batch_size]
                
                tensor = torch.FloatTensor(batch_imgs).permute(0, 3, 1, 2).to(self.device)
                output = self.model(tensor)
                preds = output.argmax(dim=1).cpu().numpy()
                
                # Check for errors
                for idx, (p, l) in enumerate(zip(preds, batch_lbls)):
                    if p != l:
                        failures.append({
                            'index': i + idx,
                            'image': batch_imgs[idx],
                            'label': l,
                            'pred': p
                        })
        
        logger.info(f"Found {len(failures)} failure cases in {len(images)} samples")
        return failures

    def perform_surgery(self, failure_cases: List[Dict], 
                       traitors: List[Tuple[int, float]], 
                       heroes: List[Tuple[int, float]], 
                       boost_factor: float = 2.0, 
                       verbose: bool = True) -> Tuple[int, List[Dict]]:
        """Attempt to cure failure cases by suppressing Traitors and boosting Heroes."""
        if verbose:
            logger.info(f"Performing Surgery with Boost Factor {boost_factor}x")
        success_count = 0
        cured_cases_list = []
        
        # Helper lists
        traitor_indices = [t[0] for t in traitors]
        hero_indices = [h[0] for h in heroes]
        
        # Convert list of indices to tensors for faster indexing
        t_tensor = torch.LongTensor(traitor_indices).to(self.device)
        h_tensor = torch.LongTensor(hero_indices).to(self.device)

        for case in tqdm(failure_cases, desc="Performing surgery", disable=not verbose):
            img_tensor = torch.FloatTensor(case['image']).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            target_label = case['label']
            original_pred = case['pred']
            
            # 1. Custom Forward Pass with Intervention
            with torch.no_grad():
                # A. Get FC1 activations
                fc1_act = self.model.get_fc1_activations(img_tensor)
                
                # B. Encode into SAE Latents
                z, _ = self.sae.encode(fc1_act)
                
                # Store original magnitude for preservation
                original_norm = z.norm(p=2, dim=1, keepdim=True)
                
                # C. THE SURGERY - Softer intervention to prevent artifacts
                z_edited = z.clone()
                
                # Soft Suppression of Traitors (Color) -> 0.1x instead of 0.0x
                # Complete zeroing creates unnatural activation patterns
                if len(traitor_indices) > 0:
                    z_edited[:, t_tensor] *= 0.1  # Reduce by 90% instead of 100%
                
                # Controlled Boost of Heroes (Shape)
                # Clamp boost factor to prevent extreme values
                effective_boost = min(boost_factor, 3.0)  # Max 3x boost
                if len(hero_indices) > 0:
                    z_edited[:, h_tensor] *= effective_boost
                
                # Magnitude Preservation - rescale to maintain similar L2 norm
                # This prevents out-of-distribution activations
                edited_norm = z_edited.norm(p=2, dim=1, keepdim=True)
                if edited_norm.item() > 0:
                    z_edited = z_edited * (original_norm / (edited_norm + 1e-8))
                
                # D. Decode back to FC1
                fc1_recon = self.sae.decode(z_edited)
                
                # E. Finish Network Pass (Dropout + FC2)
                # Note: CNN3Layer forward is: pool/conv -> fc1 -> relu -> dropout -> fc2
                # get_fc1_activations returns "F.relu(self.fc1(x))"
                # So we simulate the rest:
                out = self.model.fc2(self.model.dropout(fc1_recon))
                new_pred = out.argmax(dim=1).item()
            
            # Check success
            if new_pred == target_label:
                success_count += 1
                cured_cases_list.append(case)
                if success_count <= 3 and verbose:
                    logger.info(f"[CURED] Corrected Label {target_label} (Was {original_pred} -> Now {new_pred})")
        
        cure_rate = success_count/len(failure_cases)*100 if failure_cases else 0
        if verbose:
            logger.info(f"Surgery Results: Cured {success_count}/{len(failure_cases)} ({cure_rate:.1f}%)")
        return success_count, cured_cases_list

    def sweep_intervention_strength(self, failures: List[Dict], 
                                   traitors: List[Tuple[int, float]], 
                                   heroes: List[Tuple[int, float]], 
                                   factors: List[float] = None) -> List[int]:
        """Sweep multiple boost factors to find optimal intervention strength."""
        if factors is None:
            # Use more conservative range with softer interventions
            factors = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
        
        logger.info("Sweeping Intervention Strengths (Boost Factors)")
        results = []
        for factor in tqdm(factors, desc="Testing boost factors"):
            n_cured, _ = self.perform_surgery(failures, traitors, heroes, boost_factor=factor, verbose=False)
            results.append(n_cured)
            logger.info(f"Factor {factor}x: Cured {n_cured}/{len(failures)}")
            
        fig = plt.figure(figsize=(8, 5))
        try:
            plt.plot(factors, results, 'o-', linewidth=2)
            plt.xlabel("Hero Boost Factor (Multiplier)")
            plt.ylabel("Number of Cured Cases")
            plt.title("Intervention Efficacy vs Strength")
            plt.grid(True, alpha=0.3)
            plt.savefig('intervention_sweep.png')
            logger.info("Saved sweep plot to intervention_sweep.png")
            plt.show()
        finally:
            plt.close(fig)
            gc.collect()
        
        return results

    def visualize_cures(self, cured_cases: List[Dict], model: nn.Module, 
                       sae: nn.Module, traitors: List[Tuple[int, float]], 
                       heroes: List[Tuple[int, float]], device: torch.device, 
                       save_path: str = 'surgery_validation.png'):
        """Visualize the effect of surgery using Grad-CAM on the Intervened Model."""
        if not cured_cases:
            return

        logger.info(f"Visualizing surgery effects on {min(5, len(cured_cases))} cured cases")
        
        # 1. Define Intervened Model Wrapper
        # This allows Grad-CAM to forward pass through the surgery logic
        class IntervenedModel(nn.Module):
            def __init__(self, original_model, sae_model, t_idxs, h_idxs, boost):
                super().__init__()
                self.model = original_model
                self.sae = sae_model
                self.t_idxs = torch.LongTensor(t_idxs).to(device)
                self.h_idxs = torch.LongTensor(h_idxs).to(device)
                self.boost = boost
                
            def forward(self, x):
                # Manual Forward of CNN3Layer until FC1
                x = self.model.pool1(F.relu(self.model.conv1(x)))
                x = self.model.pool2(F.relu(self.model.conv2(x)))
                x = self.model.pool3(F.relu(self.model.conv3(x))) # Conv3 Hook fires here
                x = x.reshape(x.size(0), -1)
                fc1_act = F.relu(self.model.fc1(x))
                
                # SAE Surgery - Non-inplace version for gradient compatibility
                z, _ = self.sae.encode(fc1_act)
                
                # Store original magnitude
                original_norm = z.norm(p=2, dim=1, keepdim=True)
                
                # Clone to avoid inplace operations that break gradients
                z_edited = z.clone()
                
                # Soft suppression and controlled boost (non-inplace)
                if len(self.t_idxs) > 0: 
                    z_edited[:, self.t_idxs] = z[:, self.t_idxs] * 0.1  # 90% suppression
                
                effective_boost = min(self.boost, 3.0)  # Cap at 3x
                if len(self.h_idxs) > 0: 
                    z_edited[:, self.h_idxs] = z[:, self.h_idxs] * effective_boost
                
                # Magnitude preservation
                edited_norm = z_edited.norm(p=2, dim=1, keepdim=True)
                if edited_norm.item() > 0:
                    z_edited = z_edited * (original_norm / (edited_norm + 1e-8))
                
                fc1_recon = self.sae.decode(z_edited)
                
                # Finish
                return self.model.fc2(self.model.dropout(fc1_recon))

        # Setup Models
        t_indices = [t[0] for t in traitors]
        h_indices = [h[0] for h in heroes]
        
        wrapped_model = IntervenedModel(model, sae, t_indices, h_indices, 2.0).to(device)
        wrapped_model.eval()
        
        # Setup Grad-CAMs
        # One for original model (to see error), One for new model (to see cure)
        # Note: Both share 'model.conv3' so we must manage hooks carefully.
        # GradCAM class adds hooks to the layer. Since both models use the SAME layer object,
        # we can just use one GradCAM instance and swap the .model attribute for context.
        
        grad_cam = GradCAM(model, model.conv3) # Attached to conv3
        
        n_show = min(len(cured_cases), 5)
        fig, axes = plt.subplots(n_show, 3, figsize=(12, 4*n_show))
        if n_show == 1: axes = axes.reshape(1, -1)
        
        for i in range(n_show):
            case = cured_cases[i]
            img = case['image']
            img_tensor = torch.FloatTensor(img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
            
            # A. PRE-SURGERY (Original Model)
            grad_cam.model = model # Point to original
            # TARGET: The WRONG prediction (we want to see why it was fooled)
            hm_bad, _ = grad_cam.generate_heatmap(img_tensor, case['pred'])
            
            # B. POST-SURGERY (Intervened Model)
            grad_cam.model = wrapped_model # Point to wrapper (Hooks still on conv3)
            # TARGET: The CORRECT label (which is now the prediction)
            hm_good, _ = grad_cam.generate_heatmap(img_tensor, case['label'])
            
            # Plot
            ax0, ax1, ax2 = axes[i]
            
            ax0.imshow(img)
            ax0.set_title(f"Input (True: {case['label']})")
            ax0.axis('off')
            
            ax1.imshow(img)
            ax1.imshow(hm_bad, cmap='jet', alpha=0.5)
            ax1.set_title(f"Original (Pred: {case['pred']})\nConfusion")
            ax1.axis('off')
            
            ax2.imshow(img)
            ax2.imshow(hm_good, cmap='jet', alpha=0.5)
            ax2.set_title(f"Cured (Pred: {case['label']})\nFocus")
            ax2.axis('off')
            
        plt.tight_layout()
        plt.savefig(save_path)
        logger.info(f"Saved Surgery Validation to {save_path}")
        plt.show()
        grad_cam.remove_hooks()
        plt.close(fig)
        gc.collect()

In [None]:

# ============================================================================
# CELL 14: Validation Function - Traitors with Grad-CAM
# ============================================================================

def validate_traitors_with_gradcam(model: nn.Module, sae: nn.Module, 
                                   traitors: List[Tuple[int, float]], 
                                   images: np.ndarray, labels: np.ndarray, 
                                   device: torch.device, n_examples: int = 3):
    """Phase 2 Validation: Run Grad-CAM on images where Traitor features are active."""
    if not traitors:
        logger.warning("No traitors to validate")
        return

    logger.info("Phase 2: Verifying Traitor Features with Grad-CAM")
    
    # Initialize GradCAM on the last conv layer
    grad_cam = GradCAM(model, model.conv3)
    
    # Select top few traitors
    top_traitors = [t[0] for t in traitors[:3]]
    
    # Pre-compute activations to find max activating images for these traitors
    subset_size = min(len(images), 1000)
    subset_images = images[:subset_size]
    subset_labels = labels[:subset_size]
    
    subset_tensor = torch.FloatTensor(subset_images).permute(0, 3, 1, 2).to(device)
    with torch.no_grad():
        fc1_acts = model.get_fc1_activations(subset_tensor)
        _, sae_acts, _ = sae(fc1_acts)
        sae_acts = sae_acts.cpu().numpy()
    
    fig, axes = plt.subplots(len(top_traitors), n_examples * 2, figsize=(3 * n_examples * 2, 3 * len(top_traitors)))
    if len(top_traitors) == 1: axes = axes.reshape(1, -1)
    
    try:
        for row_idx, trait_idx in enumerate(top_traitors):
            # Find images where this traitor is most active
            feature_acts = sae_acts[:, trait_idx]
            top_img_indices = np.argsort(feature_acts)[-n_examples:][::-1]
            
            for col_idx, img_idx in enumerate(top_img_indices):
                img = subset_images[img_idx]
                label = subset_labels[img_idx]
                activation_val = feature_acts[img_idx]
                
                img_tensor = torch.FloatTensor(img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
                
                # Run Grad-CAM
                heatmap, pred_class = grad_cam.generate_heatmap(img_tensor, None)
                
                # Visualization
                ax_orig = axes[row_idx, col_idx * 2]
                ax_cam = axes[row_idx, col_idx * 2 + 1]
                
                ax_orig.imshow(img)
                ax_orig.set_title(f"Traitor {trait_idx}\nAct: {activation_val:.2f}\nPred: {pred_class} (True {label})")
                ax_orig.axis('off')
                
                ax_cam.imshow(img)
                ax_cam.imshow(heatmap, cmap='jet', alpha=0.5)
                ax_cam.set_title(f"Grad-CAM\nDoes it look at color?")
                ax_cam.axis('off')

        plt.tight_layout()
        plt.savefig('traitor_gradcam_validation.png')
        logger.info("Saved Grad-CAM validation to traitor_gradcam_validation.png")
        plt.show()
    finally:
        grad_cam.remove_hooks()
        plt.close(fig)
        gc.collect()


In [None]:
# ============================================================================
# CELL 15: Evaluation Function - SAE Reconstruction Quality
# ============================================================================

def evaluate_reconstruction(model: nn.Module, sae: nn.Module, 
                           data: Tuple[np.ndarray, np.ndarray], 
                           device: torch.device):
    """Evaluate SAE reconstruction quality."""
    images, _ = data
    model.eval()
    sae.eval()
    
    subset_size = min(len(images), 1000)
    subset = torch.FloatTensor(images[:subset_size]).permute(0, 3, 1, 2).to(device)
    
    with torch.no_grad():
        fc1_acts = model.get_fc1_activations(subset)
        recon_acts, _, _ = sae(fc1_acts)
        
        mse = F.mse_loss(recon_acts, fc1_acts).item()
        l2_norm = torch.norm(fc1_acts, p=2).mean().item()
        
    logger.info(f"SAE Quality Check - Reconstruction MSE: {mse:.6f} (vs Avg Activation Norm: {l2_norm:.4f})")
    if mse > 0.1:
        logger.warning("SAE Reconstruction is poor. Interventions may be unreliable.")
    else:
        logger.info("✓ SAE Reconstruction looks reasonable")



In [None]:
# ============================================================================
# CELL 16: Training Function - Concept Probes
# ============================================================================

@timing_decorator
def train_concept_probes(model: nn.Module, images: np.ndarray, 
                        labels: np.ndarray, device: torch.device) -> Tuple[float, float]:
    """Train concept probes to measure shape vs color bias."""
    model.eval()
    
    # Get activations
    with torch.no_grad():
        images_tensor = torch.FloatTensor(images).permute(0, 3, 1, 2).to(device)
        activations = model.get_fc1_activations(images_tensor).cpu().numpy()
    
    # Detect dominant color from images
    mean_colors = images.mean(axis=(1, 2))  # Shape: (N, 3)
    dominant_color_idx = mean_colors.argmax(axis=1)  # 0=R, 1=G, 2=B
    
    # Train probes
    # 1. Shape probe (predict digit)
    shape_probe = LogisticRegression(max_iter=1000, random_state=42)
    shape_probe.fit(activations, labels)
    shape_acc = shape_probe.score(activations, labels)
    
    # 2. Color probe (predict dominant color)
    color_probe = LogisticRegression(max_iter=1000, random_state=42)
    color_probe.fit(activations, dominant_color_idx)
    color_acc = color_probe.score(activations, dominant_color_idx)
    
    logger.info("=== Linear Probe Analysis ===")
    logger.info(f"Shape (digit) probe accuracy: {shape_acc*100:.2f}%")
    logger.info(f"Color probe accuracy: {color_acc*100:.2f}%")
    
    if color_acc > shape_acc:
        logger.warning("Color is MORE linearly separable than shape!")
        logger.info("→ The model represents color more explicitly than shape")
    else:
        logger.info("✓ Shape is more linearly separable than color")
        logger.info("→ The model represents shape more explicitly")
    
    return shape_acc, color_acc




In [None]:
# ============================================================================
# CELL 17: Main Execution Part 1 - Data Loading
# ============================================================================

def main():
    logger.info("="*70)
    logger.info("Task 6: ADVANCED DECOMPOSITION - SOTA Interpretability Techniques")
    logger.info("="*70)
    
    # Load data
    logger.info("[1/7] Loading data")
    
    def load_data(path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        if not os.path.exists(path):
            logger.warning(f"{path} not found")
            return None, None
        data = np.load(path)
        return data['images'].astype('float32') / 255.0, data['labels']
    
    env1_images, env1_labels = load_data(TRAIN_DATA_PATH)
    env2_images, env2_labels = load_data(TEST_DATA_PATH)
    
    if env1_images is None:
        logger.error("Data not found. Please update paths")
        return
    
    logger.info(f"Environment 1: {env1_images.shape}")
    logger.info(f"Environment 2: {env2_images.shape}")

# ============================================================================
# CELL 18: Main Execution Part 2 - Load Pre-trained Model
# ============================================================================

    # Load model
    logger.info("[2/7] Loading biased model")
    model = CNN3Layer().to(device)
    if os.path.exists(MODEL_PATH):
        # Load the state dict
        state_dict = torch.load(MODEL_PATH, map_location=device)
        
        # Create key mapping for different naming conventions
        # Old model: features.0, features.3, features.6, classifier.0, classifier.3
        # New model: conv1, conv2, conv3, fc1, fc2
        key_mapping = {
            'features.0.weight': 'conv1.weight',
            'features.0.bias': 'conv1.bias',
            'features.3.weight': 'conv2.weight',
            'features.3.bias': 'conv2.bias',
            'features.6.weight': 'conv3.weight',
            'features.6.bias': 'conv3.bias',
            'classifier.0.weight': 'fc1.weight',
            'classifier.0.bias': 'fc1.bias',
            'classifier.3.weight': 'fc2.weight',
            'classifier.3.bias': 'fc2.bias',
        }
        
        # Remap keys if needed
        new_state_dict = {}
        for old_key, value in state_dict.items():
            new_key = key_mapping.get(old_key, old_key)
            new_state_dict[new_key] = value
        
        model.load_state_dict(new_state_dict)
        logger.info(f"Loaded model from {MODEL_PATH}")
    else:
        logger.error("Model not found!")
        return
    model.eval()
# ============================================================================
# CELL 19: Main Execution Part 3 - Train TopK Sparse Autoencoder
# ============================================================================
    
    # TECHNIQUE 1: TopK Sparse Autoencoder
    logger.info("[3/7] Training TopK Sparse Autoencoder")
    logger.info(f"Architecture: 128 -> {SAE_HIDDEN_DIM} (TopK={TOPK_K}) -> 128")
    
    with torch.no_grad():
        activations = ActivationUtils.get_activations_batched(
            model, env1_images, device, return_tensor=True, show_progress=True
        )
    
    gc.collect()
    torch.cuda.empty_cache()
    
    topk_sae = TopKSparseAutoencoder(FC1_UNITS, SAE_HIDDEN_DIM, TOPK_K).to(device)
    optimizer = optim.Adam(topk_sae.parameters(), lr=SAE_LEARNING_RATE)
    
    dataset = TensorDataset(activations)
    loader = DataLoader(dataset, batch_size=SAE_BATCH_SIZE, shuffle=True)
    
    for epoch in tqdm(range(SAE_EPOCHS), desc="Training SAE"):
        epoch_loss = 0
        for batch in loader:
            x = batch[0].to(device)
            x_recon, z, _ = topk_sae(x)
            loss = F.mse_loss(x_recon, x)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                topk_sae.decoder.weight.data = F.normalize(topk_sae.decoder.weight.data, dim=0)
            
            epoch_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            logger.info(f"Epoch {epoch+1}/{SAE_EPOCHS}: Loss = {epoch_loss/len(loader):.6f}")
    
# ============================================================================
# CELL 20: Main Execution Part 4 - Concept Steering Vectors
# ============================================================================

    logger.info("[4/7] Computing Concept Steering Vectors")
    steering_vectors = ConceptSteeringVectors(model, device)
    color_vector = steering_vectors.compute_steering_vectors(env1_images, env1_labels, 
                                            env2_images, env2_labels)

    
    logger.info("Testing color removal intervention")
    steering_vectors.intervention_experiment(env2_images[:MEMORY_LIMIT_SAMPLES], env2_labels[:MEMORY_LIMIT_SAMPLES])
    

    
# ============================================================================
# CELL 21: Main Execution Part 5 - Causal Neuron Ablation Analysis
# ============================================================================

    logger.info("[5/7] Performing Causal Neuron Ablation Analysis")
    gc.collect()
    torch.cuda.empty_cache()
    neuron_analyzer = CausalNeuronAnalyzer(model, device)
    color_neurons, shape_neurons = neuron_analyzer.find_causal_neurons(
        env1_images[:MEMORY_LIMIT_SAMPLES], env1_labels[:MEMORY_LIMIT_SAMPLES],
        env2_images[:MEMORY_LIMIT_SAMPLES], env2_labels[:MEMORY_LIMIT_SAMPLES]
    )
# ============================================================================
# CELL 22: Main Execution Part 6 - Cluster SAE Features
# ============================================================================

    logger.info("[6/7] Clustering SAE Features by Cross-Environment Behavior")
    clusters, cluster_info, feat_chars = cluster_features_by_environment(
        model, topk_sae, 
        (env1_images, env1_labels), 
        (env2_images, env2_labels)
    )
# ============================================================================
# CELL 23: Main Execution Part 7 - PHASE 1 Feature Classification
# ============================================================================

    # ==========================================
    # PHASE 1 & 2 Execution
    # ==========================================
    logger.info("="*50)
    logger.info("PHASE 1: Feature Identification & Classification")
    logger.info("="*50)
    
    # Initialize Classifier
    feature_clf = FeatureClassifier(model, topk_sae, device)
    
    # Run Classification on Env1 (Colored) vs Generated B&W
    traitors, heroes = feature_clf.classify_features(env1_images[:1000])
    feature_clf.plot_sensitivity_analysis()
# ============================================================================
# CELL 24: Main Execution Part 8 - PHASE 2 Grad-CAM Validation
# ============================================================================
    
    logger.info("="*50)
    logger.info("PHASE 2: Feature Validation with Grad-CAM")
    logger.info("="*50)
    
    # Validate the discovered Traitors
    validate_traitors_with_gradcam(
        model, topk_sae, traitors, 
        env1_images, env1_labels, 
        device
    )

    

    
# ============================================================================
# CELL 25: Main Execution Part 9 - Train Concept Probes
# ============================================================================

    logger.info("[7/7] Training Linear Concept Probes")
    shape_acc, color_acc = train_concept_probes(
        model, env1_images, env1_labels, device
    )
    


# ============================================================================
# CELL 26: Main Execution Part 10 - Visualization: Analysis Plots
# ============================================================================

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    try:
        effects = [neuron_analyzer.neuron_effects[i] for i in range(FC1_UNITS)]
        env1_effects = [e['env1_effect'] for e in effects]
        env2_effects = [e['env2_effect'] for e in effects]
        
        axes[0].scatter(env1_effects, env2_effects, alpha=0.6, c='blue', edgecolors='k')
        axes[0].axhline(0, color='gray', linestyle='--', alpha=0.5)
        axes[0].axvline(0, color='gray', linestyle='--', alpha=0.5)
        axes[0].plot([-10, 10], [-10, 10], 'r--', alpha=0.5, label='Equal effect')
        axes[0].set_xlabel('Effect on Env1 (Original Colors) %')
        axes[0].set_ylabel('Effect on Env2 (Reversed Colors) %')
        axes[0].set_title('Causal Neuron Analysis\nNeurons above line = COLOR-specific')
        axes[0].legend()
        axes[0].set_xlim(-5, 5)
        axes[0].set_ylim(-5, 5)
        
        axes[1].scatter(feat_chars[:, 0], feat_chars[:, 1], c=clusters, cmap='viridis', alpha=0.6)
        axes[1].set_xlabel('Cross-Environment Activation Difference')
        axes[1].set_ylabel('Within-Environment Variance')
        axes[1].set_title('SAE Feature Clustering\nHigher X = More Color-Sensitive')
        
        plt.tight_layout()
        plt.savefig('advanced_sae_analysis.png', dpi=150, bbox_inches='tight')
        logger.info("Saved analysis plots to advanced_sae_analysis.png")
        plt.show()
    finally:
        plt.close(fig)
        gc.collect()
    
# ============================================================================
# CELL 27: Main Execution Part 11 - Visualization: Steering Effect Plot
# ============================================================================

    fig, ax = plt.subplots(figsize=(8, 5))
    
    try:
        strengths = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
        results = steering_vectors.intervention_experiment(env2_images[:MEMORY_LIMIT_SAMPLES], env2_labels[:MEMORY_LIMIT_SAMPLES], strengths)
        accs = [r['accuracy'] for r in results]
        
        ax.plot(strengths, accs, 'bo-', linewidth=2, markersize=8)
        ax.axhline(accs[0], color='gray', linestyle='--', label='Baseline (no steering)')
        ax.set_xlabel('Color Removal Strength')
        ax.set_ylabel('Accuracy on Reversed-Color Data (%)')
        ax.set_title('Effect of Removing "Color Direction" from Activations')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('color_steering_effect.png', dpi=150, bbox_inches='tight')
        logger.info("Saved steering effect plot to color_steering_effect.png")
        plt.show()
    finally:
        plt.close(fig)
        gc.collect()
# ============================================================================
# CELL 28: Main Execution Part 12 - PHASE 3 Structural Analysis
# ============================================================================

    logger.info("="*50)
    logger.info("PHASE 3: Structural Analysis")
    logger.info("="*50)
    
    analyzer = StructuralAnalyzer(topk_sae, device)
    analyzer.analyze_correlations(env1_images[:1000], model)
    circuits = analyzer.find_circuits(traitors, heroes)
    analyzer.plot_circuit_analysis(traitors, heroes)
    
# ============================================================================
# CELL 29: Main Execution Part 13 - PHASE 4 Causal Intervention
# ============================================================================
    
    # Evaluate SAE Reconstruction Quality first
    evaluate_reconstruction(model, topk_sae, (env1_images, env1_labels), device)
    
    logger.info("="*50)
    logger.info("PHASE 4: Causal Intervention")
    logger.info("="*50)
    
    intervener = CausalIntervener(model, topk_sae, device)
    
    # 1. Identify failures in the Reversed (Hard) Environment
    logger.info("Identifying failures in Env2 (Reversed Data)")
    failures = intervener.find_failure_cases(env2_images[:1000], env2_labels[:1000])
    
    # 2. Perform Surgery
    if failures:
        # Sweep first
        intervener.sweep_intervention_strength(failures, traitors, heroes)
        
        # Then perform standard surgery with moderate boost factor
        n_cured, cured_list = intervener.perform_surgery(
            failures, traitors, heroes, boost_factor=2.0  # Use 2.0x instead of 2.5x
        )
        # 3. Visualize Cures
        if n_cured > 0:
            intervener.visualize_cures(cured_list, model, topk_sae, traitors, heroes, device)
    else:
        logger.info("No failures found to audit (Model is too good?)")

    # Save Models
    torch.save(topk_sae.state_dict(), 'topk_sae_model.pth')
    np.save('color_steering_vector.npy', color_vector)
    logger.info("Saved SAE model and steering vector")
    
    logger.info("Task 6 Complete - All Phases Executed")



In [None]:
# ============================================================================
# CELL 30: Run Main Function
# ============================================================================

if __name__ == "__main__":
    main()