In [None]:
import os
import random
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from collections import defaultdict
from sklearn.metrics import mutual_info_score

# =============================================================================
# PHASE 0: CONFIGURATION & UTILITIES
# =============================================================================

class ProberConfig:
    # --- Paths ---
    MODEL_PATH = r"/kaggle/input/task1app3models/pytorch/default/2/task4_irm_modelv1.pth"
    # Using RG95 training data for Color bias discovery
    DATA_RG_PATH = r"/kaggle/input/cmnistneo1/train_data_rg95z.npz"
    # Using BW100 test data for Shape bias discovery (clean signal)
    DATA_BW_PATH = r"/kaggle/input/cmnistneo1/test_data_bw100z.npz"

    OUTPUT_DIR = r"results/"
    
    # --- Optimization Engine ---
    LR = 0.05
    ITERATIONS = 512
    TV_WEIGHT = 1e-4
    L2_WEIGHT = 1e-5
    JITTER = 2
    BLUR_FREQ = 40
    
    # --- Diagnostic Settings ---
    BATCH_SIZE = 256
    NUM_SAMPLES_PROFILE = 2000  # Number of images to use for neuron profiling
    TOP_K = 5                  # Top K neurons to identify/visualize
    POLY_TOP_N = 5             # Number of example images to show for polysemantic neurons
    POLY_MIN_SCORE = 0.2       # Minimum poly_score threshold (entropy × strength)
    POLY_STRENGTH_PERCENTILE = 95  # Percentile for computing neuron activation strength
    VERBOSE_VIZ = False        # Print optimization diagnostics for each neuron
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# =============================================================================
# PHASE 1: MODEL DEFINITION & LOADING
# =============================================================================

class CNN3Layer(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN3Layer, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # We need hooks to grab activations, so standard forward pass is fine.
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def load_model():
    model = CNN3Layer().to(ProberConfig.DEVICE)
    try:
        state_dict = torch.load(ProberConfig.MODEL_PATH, map_location=ProberConfig.DEVICE)
        
        # KEY REMAPPING FOR IRM MODELS
        new_state_dict = {}
        key_map = {
            # Features (Convs)
            "features.0.weight": "conv1.weight",
            "features.0.bias":   "conv1.bias",
            "features.3.weight": "conv2.weight",
            "features.3.bias":   "conv2.bias",
            "features.6.weight": "conv3.weight",
            "features.6.bias":   "conv3.bias",
            # Classifier (FCs)
            "classifier.0.weight": "fc1.weight",
            "classifier.0.bias":   "fc1.bias",
            "classifier.3.weight": "fc2.weight",
            "classifier.3.bias":   "fc2.bias"
        }

        for k, v in state_dict.items():
            if k in key_map:
                new_state_dict[key_map[k]] = v
            else:
                new_state_dict[k] = v
                
        # Load with strict=False to allow flexibility, but check keys
        missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
        print(f"[Phase 1] Model loaded from {ProberConfig.MODEL_PATH}")
        if missing: print(f"  Missing keys: {missing}")
        if unexpected: print(f"  Unexpected keys: {unexpected}")
        
    except Exception as e:
        print(f"[Phase 1] Error loading model: {e}")
        print("[Phase 1] Running with random weights (WARNING: Results meant for testing pipeline only)")
    model.eval()
    return model

def load_data(path, limit=None):
    """Loads a subset of the dataset for profiling neurons."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Data file not found: {path}")

    data = np.load(path)
    images = data['images']
    labels = data['labels']
    
    if limit:
        indices = np.random.choice(len(images), min(limit, len(images)), replace=False)
        images = images[indices]
        labels = labels[indices]
        
    # Preprocess: [0, 255] -> [0, 1], NHWC -> NCHW
    images = images.astype('float32') / 255.0
    images = torch.from_numpy(images).permute(0, 3, 1, 2)
    labels = torch.from_numpy(labels).long()
    
    return TensorDataset(images, labels)

# =============================================================================
# PHASE 1: OPTIMIZATION ENGINE (VISUALIZATION)
# =============================================================================

class FeatureVisualizer:
    def __init__(self, model):
        self.model = model
        self.device = ProberConfig.DEVICE

    def total_variation_loss(self, img):
        b, c, h, w = img.shape
        tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum()
        tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum()
        return (tv_h + tv_w) / (c * h * w)

    def jitter_transform(self, img, lim=ProberConfig.JITTER):
        ox, oy = random.randint(-lim, lim), random.randint(-lim, lim)
        return torch.roll(img, shifts=(ox, oy), dims=(2, 3))

    def generate_ideal_image(self, layer, channel_idx, verbose=False):
        """
        Generates the image that maximizes a specific neuron's activation.
        
        Common reasons for empty/noisy visualizations:
        1. Dead neurons: Neuron outputs near-zero for all inputs
        2. Weak gradients: Optimization can't find strong activating patterns
        3. Over-regularization: TV/L2 penalties suppress the signal
        4. Poor initialization: Starting point too far from optimal
        """
        input_img = torch.rand(1, 3, 28, 28, device=self.device) * 0.1 + 0.45
        input_img.requires_grad = True
        optimizer = optim.Adam([input_img], lr=ProberConfig.LR)

        activations = {}
        def hook_fn(m, i, o): activations['act'] = o
        handle = layer.register_forward_hook(hook_fn)

        best_activation = -float('inf')
        best_img = None
        
        for i in range(ProberConfig.ITERATIONS):
            optimizer.zero_grad()
            
            # Apply jitter for robustness
            img_jit = self.jitter_transform(input_img)
            
            self.model(img_jit)
            act = activations['act']
            
            # Maximize mean activation of target channel
            current_activation = act[0, channel_idx].mean()
            obj_loss = -current_activation
            
            # Track best result
            if current_activation.item() > best_activation:
                best_activation = current_activation.item()
                best_img = input_img.detach().clone()
            
            # Regularizers
            tv_loss = self.total_variation_loss(input_img) * ProberConfig.TV_WEIGHT
            l2_loss = torch.norm(input_img) * ProberConfig.L2_WEIGHT
            
            loss = obj_loss + tv_loss + l2_loss
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                input_img.clamp_(0, 1)
                
            # Diagnostic output
            if verbose and (i % 100 == 0 or i == ProberConfig.ITERATIONS - 1):
                print(f"  Iter {i}: Activation={current_activation.item():.4f}, Loss={loss.item():.4f}")

        handle.remove()
        
        # If optimization failed (very weak activation), return best attempt
        if best_activation < 0.01:
            if verbose:
                print(f"  WARNING: Weak neuron (max activation={best_activation:.4f}). May be dead/inactive.")
        
        return (best_img if best_img is not None else input_img).detach().cpu()

# =============================================================================
# PHASE 2: TRAITOR DETECTOR (NEURON PROFILING)
# =============================================================================

class NeuronProfiler:
    def __init__(self, model):
        self.model = model
        self.device = ProberConfig.DEVICE
        self.activations = defaultdict(list)
        
    def _get_hooks(self):
        hooks = []
        layers = {'conv1': self.model.conv1, 'conv2': self.model.conv2, 'conv3': self.model.conv3}
        
        for name, layer in layers.items():
            def get_hook(n):
                return lambda m, i, o: self.activations[n].append(o.detach().cpu())
            hooks.append(layer.register_forward_hook(get_hook(name)))
        return hooks

    def collect_activations(self, dataset):
        """Passes dataset through model and records activations."""
        self.activations.clear()
        loader = DataLoader(dataset, batch_size=ProberConfig.BATCH_SIZE, shuffle=False)
        hooks = self._get_hooks()
        
        extracted_features = defaultdict(list)
        
        with torch.no_grad():
            for images, _ in loader:
                images = images.to(self.device)
                self.model(images)
                
        for h in hooks: h.remove()
        
        # Aggregate activations
        for layer_name, batches in self.activations.items():
            # [N, C, H, W] -> Global Average Pooling -> [N, C]
            full_acts = torch.cat(batches, dim=0)
            extracted_features[layer_name] = full_acts.mean(dim=(2, 3)).numpy()
            
        return extracted_features

    def score_neurons(self, rg_features, bw_features, rg_colors, bw_labels):
        """
        Computes Color_Score (on RG data) and Shape_Score (on BW data).
        """
        print("[Phase 2] Scoring neurons...")
        scores = {}
        
        for layer_name in rg_features.keys():
            rg_acts = rg_features[layer_name] # [N_rg, C]
            bw_acts = bw_features[layer_name] # [N_bw, C]
            n_channels = rg_acts.shape[1]
            
            layer_scores = {}
            for ch in range(n_channels):
                # 1. Color Score (on RG data): MI(Act, Color)
                # Binarize using median split
                act_rg = rg_acts[:, ch]
                thr = np.median(act_rg)
                bin_rg = (act_rg > thr).astype(int)
                
                # Assume RG labels: 0-4 Red(0), 5-9 Green(1)
                colors = (rg_colors >= 5).numpy().astype(int)
                color_score = mutual_info_score(bin_rg, colors)
                
                # 2. Shape Score (on BW data): MI(Act, Digit)
                # If neuron fires on BW images, it MUST be seeing shape.
                act_bw = bw_acts[:, ch]
                # Use same threshold? Or dataset specific? Let's use dataset specifics to be fair.
                thr_bw = np.median(act_bw) if np.var(act_bw) > 1e-5 else 0
                bin_bw = (act_bw > thr_bw).astype(int)
                
                shape_score = mutual_info_score(bin_bw, bw_labels.numpy())
                
                # Bias Ratio: High Color / Low Shape
                ratio = color_score / (shape_score + 1e-6)
                
                layer_scores[ch] = {
                    'shape': shape_score,
                    'color': color_score,
                    'ratio': ratio
                }
            scores[layer_name] = layer_scores
        return scores

    def identify_roles(self, scores):
        """
        Identifies neuron roles based on mutual information scores.
        
        Selection Criteria:
        - TRAITORS: Top-K neurons with highest Color/Shape ratio (color-biased)
        - HEROES: Top-K neurons with highest Shape MI score (shape-focused)
        """
        roles = {'traitors': [], 'heroes': [], 'all': []}
        all_neurons = []
        
        for layer, channels in scores.items():
            for ch, metrics in channels.items():
                neuron_data = {
                    'layer': layer, 'channel': ch, **metrics
                }
                all_neurons.append(neuron_data)
        
        # Store all for plotting
        roles['all'] = all_neurons

        # Traitors: High Ratio (Color > Shape) - sorted by ratio descending
        roles['traitors'] = sorted(all_neurons, key=lambda x: x['ratio'], reverse=True)[:ProberConfig.TOP_K]
        
        # Heroes: High Shape Score - sorted by shape MI descending
        # Filter out neurons with very low activity
        active_neurons = [n for n in all_neurons if n['shape'] > 0.05]
        roles['heroes'] = sorted(active_neurons, key=lambda x: x['shape'], reverse=True)[:ProberConfig.TOP_K]
        
        return roles

# =============================================================================
# PHASE 3: POLYSEMANTICITY DETECTOR (DATA DRIVEN)
# =============================================================================

class PolysemanticDetector:
    def __init__(self, model, dataset):
        self.model = model
        self.device = ProberConfig.DEVICE
        # Create loader with shuffling disabled to track indices if needed, 
        # but here we just need the images and labels.
        self.loader = DataLoader(dataset, batch_size=ProberConfig.BATCH_SIZE, shuffle=False)
        
    def find_polysemantic_neurons(self):
        """
        Finds neurons that activate for multiple different digit classes.
        
        Uses principled entropy-based scoring:
            PolyScore = Activation_Strength × Label_Entropy
        
        This automatically:
        - Penalizes blank neurons (low activation → low score)
        - Rewards neurons that fire strongly AND ambiguously
        - Avoids the "inactive neuron illusion" where dead neurons appear polysemantic
        """
        print("[Phase 3] Scanning for Polysemantic Neurons...")
        
        # 1. Collect all activations and labels
        # Structure: layer -> channel -> [(activation, label), ...]
        # We need a way to store top K without keeping everything in memory if dataset is huge.
        # But for 2000 samples (ProberConfig), we can store all.
        
        activations = defaultdict(list)
        all_labels = []
        all_images = [] # Keep ref to images for visualization
        
        hooks = []
        layers = {'conv1': self.model.conv1, 'conv2': self.model.conv2, 'conv3': self.model.conv3}
        
        for name, layer in layers.items():
            def get_hook(n):
                return lambda m, i, o: activations[n].append(o.detach().cpu())
            hooks.append(layer.register_forward_hook(get_hook(name)))
            
        with torch.no_grad():
            for imgs, lbls in self.loader:
                imgs = imgs.to(self.device)
                self.model(imgs)
                all_labels.extend(lbls.numpy())
                # Store images on CPU to save GPU RAM
                all_images.append(imgs.cpu())
                
        for h in hooks: h.remove()
        
        all_labels = np.array(all_labels)
        full_images = torch.cat(all_images)
        
        # 2. Analyze each neuron with entropy-based scoring
        poly_candidates = []
        total_neurons_checked = 0
        weak_neurons_filtered = 0
        
        for layer_name, batches in activations.items():
            full_acts = torch.cat(batches, dim=0)  # [N, C, H, W]
            
            # Robust spatial activation: top-k mean instead of single max
            flat = full_acts.flatten(2)  # [N, C, H*W]
            vals, _ = flat.topk(5, dim=2)  # Top-5 spatial activations
            act = vals.mean(dim=2).numpy()  # [N, C] - robust activation per image
            
            n_channels = act.shape[1]
            total_neurons_checked += n_channels
            
            for ch in range(n_channels):
                # Neuron-level activation strength: 95th percentile
                # This ignores noise and focuses on genuine firing regime
                strength = np.percentile(act[:, ch], ProberConfig.POLY_STRENGTH_PERCENTILE)
                
                # FILTER: Skip nearly dead neurons
                if strength < 0.05:
                    weak_neurons_filtered += 1
                    continue
                
                # Focus ONLY on genuinely strong activations (70% of neuron's peak)
                strong_threshold = 0.7 * strength
                strong_idx = act[:, ch] > strong_threshold
                strong_labels = all_labels[strong_idx]
                
                # Need enough samples for meaningful entropy
                if len(strong_labels) < 5:
                    weak_neurons_filtered += 1
                    continue
                
                # Compute label entropy on STRONG activations only
                counts = np.bincount(strong_labels, minlength=10)
                probs = counts / counts.sum()
                entropy = -np.sum(probs * np.log(probs + 1e-8))
                
                # Polysemantic Score = Entropy × Strength
                # High score → neuron fires strongly AND ambiguously
                poly_score = entropy * strength
                
                # Only keep genuinely polysemantic neurons
                if poly_score > ProberConfig.POLY_MIN_SCORE:
                    # Get top-N examples for visualization
                    strong_act_values = act[:, ch][strong_idx]
                    top_indices = np.where(strong_idx)[0]
                    sorted_order = np.argsort(strong_act_values)[::-1]
                    topn_idx = top_indices[sorted_order[:ProberConfig.POLY_TOP_N]]
                    topn_labels = all_labels[topn_idx]
                    
                    poly_candidates.append({
                        'layer': layer_name,
                        'channel': ch,
                        'topn_idx': topn_idx,
                        'topn_labels': topn_labels,
                        'entropy': entropy,
                        'strength': strength,
                        'poly_score': poly_score,
                        'n_strong': len(strong_labels),
                        'label_dist': counts
                    })
        
        # Sort by polysemantic score (entropy × strength)
        # Highest = strongest AND most ambiguous neurons
        poly_candidates.sort(key=lambda x: x['poly_score'], reverse=True)
        
        print(f"[Phase 3] Analyzed {total_neurons_checked} neurons, filtered {weak_neurons_filtered} weak/low-entropy neurons")
        print(f"[Phase 3] Found {len(poly_candidates)} neurons with poly_score > {ProberConfig.POLY_MIN_SCORE}")
        
        return poly_candidates, full_images

# =============================================================================
# PHASE 4: DASHBOARD & EXECUTION
# =============================================================================

def run_suite():
    """
    Runs the complete interpretability suite.
    
    NOTE ON EMPTY/NOISY VISUALIZATIONS:
    Some neurons may show noisy or empty ideal images because:
    - Dead/Inactive neurons: Output near-zero for all inputs (common after dropout/regularization)
    - Polysemantic neurons: Respond to multiple unrelated features, creating incoherent visualizations
    - Inhibitory neurons: Activate by suppressing signals rather than detecting patterns
    - High-level abstract features: Require specific contexts that single-image optimization can't capture
    
    NOTE ON POLYSEMANTIC DETECTION:
    Uses entropy-based scoring: PolyScore = Entropy × Activation_Strength
    - Entropy: computed on labels of STRONG activations only (avoids noise)
    - Strength: 95th percentile of neuron's activations (robust measure)
    - This automatically filters dead neurons whose random activations falsely appear polysemantic
    - Only neurons with poly_score > POLY_MIN_SCORE are considered genuinely polysemantic
    
    Set VERBOSE_VIZ=True to see optimization diagnostics and identify weak neurons.
    """
    ensure_dir(ProberConfig.OUTPUT_DIR)
    set_seed()
    
    # 1. Setup
    model = load_model()
    vis = FeatureVisualizer(model)
    profiler = NeuronProfiler(model)
    # mixer = FeatureMixer(vis) # Removed
    
    # 2. Load Data
    print(f"Loading RG Data: {ProberConfig.DATA_RG_PATH}")
    ds_rg = load_data(ProberConfig.DATA_RG_PATH, limit=ProberConfig.NUM_SAMPLES_PROFILE)
    print(f"Loading BW Data: {ProberConfig.DATA_BW_PATH}")
    ds_bw = load_data(ProberConfig.DATA_BW_PATH, limit=ProberConfig.NUM_SAMPLES_PROFILE)
    
    # 3. Phase 2: Profile & Identify
    feats_rg = profiler.collect_activations(ds_rg)
    feats_bw = profiler.collect_activations(ds_bw)
    
    # RG contains labels, BW contains labels. We pass labels for scoring.
    # ds_rg.tensors[1] is labels. 
    scores = profiler.score_neurons(feats_rg, feats_bw, ds_rg.tensors[1], ds_bw.tensors[1])
    roles = profiler.identify_roles(scores)
    
    print(f"\n[Phase 2] Identified Roles:")
    print(f"  Total Neurons analyzed: {len(roles['all'])}")
    # Simple threshold heuristics for counting 'general' population
    n_traitors = sum(1 for n in roles['all'] if n['ratio'] > 2.0)
    n_heroes   = sum(1 for n in roles['all'] if n['ratio'] < 0.5 and n['shape'] > 0.1)
    print(f"  Approx. Traitors (Ratio > 2.0): {n_traitors}")
    print(f"  Approx. Heroes (Ratio < 0.5): {n_heroes}")

    print("\n=== TOP 5 TRAITORS (Color Biased) ===")
    for n in roles['traitors']:
        print(f"{n['layer']} Ch{n['channel']} | ColorMI: {n['color']:.3f} | ShapeMI: {n['shape']:.3f} | Ratio: {n['ratio']:.1f}")

    print("\n=== TOP 5 HEROES (Shape Focused) ===")
    for n in roles['heroes']:
        print(f"{n['layer']} Ch{n['channel']} | ColorMI: {n['color']:.3f} | ShapeMI: {n['shape']:.3f}")
        
    # --- 2D SCATTER PLOT ---
    plt.figure(figsize=(10, 8))
    
    # Extract data
    shapes = [n['shape'] for n in roles['all']]
    colors = [n['color'] for n in roles['all']]
    ratios = [n['ratio'] for n in roles['all']]
    
    # Normalize ratio for color map: 0 = Blue (Hero), 1 = Red (Traitor)
    # Using log scale for ratio because it can vary wildly
    # Clipping for better visualization
    ratios = np.array(ratios)
    ratios = np.clip(ratios, 0.1, 10)
    norm = plt.Normalize(0.1, 10)
    
    scatter = plt.scatter(shapes, colors, c=ratios, cmap='coolwarm', norm=norm, alpha=0.7, edgecolors='k')
    plt.colorbar(scatter, label='Bias Ratio (Color/Shape)')
    
    plt.xlabel('Shape MI (Shape Score)')
    plt.ylabel('Color MI (Color Score)')
    plt.title('Neuron Landscape: Heroes vs Traitors')
    plt.grid(True, alpha=0.3)
    
    # Annotate Top Traitors & Heroes
    for n in roles['traitors']:
        plt.annotate(f"{n['layer']}:{n['channel']}", (n['shape'], n['color']), fontsize=8, color='red')
    for n in roles['heroes']:
        plt.annotate(f"{n['layer']}:{n['channel']}", (n['shape'], n['color']), fontsize=8, color='blue')

    plt.savefig(os.path.join(ProberConfig.OUTPUT_DIR, "neuron_landscape.png"))
    print(f"[Phase 2] Scatter plot saved to {ProberConfig.OUTPUT_DIR}/neuron_landscape.png")

    # 4. Phase 1 & 4: Viz Dashboard
    # Dynamic grid based on TOP_K
    n_cols = ProberConfig.TOP_K
    fig, axes = plt.subplots(2, n_cols, figsize=(3 * n_cols, 6))
    fig.suptitle(f"Task 2: The Prober - Top {n_cols} Neuron Roles", fontsize=16)
    
    # Ensure axes is 2D array even if n_cols=1
    if n_cols == 1:
        axes = axes.reshape(2, 1)
    
    # Traitors
    for i, node in enumerate(roles['traitors']):
        layer = getattr(model, node['layer'])
        if ProberConfig.VERBOSE_VIZ:
            print(f"\nGenerating ideal image for Traitor {node['layer']}:{node['channel']}")
        img = vis.generate_ideal_image(layer, node['channel'], verbose=ProberConfig.VERBOSE_VIZ)
        img_np = img.squeeze().permute(1, 2, 0).numpy()
        ax = axes[0, i]
        ax.imshow(img_np)
        ax.set_title(f"Traitor {node['layer']}:{node['channel']}\nColor MI: {node['color']:.2f}")
        ax.axis('off')

    # Heroes
    for i, node in enumerate(roles['heroes']):
        layer = getattr(model, node['layer'])
        if ProberConfig.VERBOSE_VIZ:
            print(f"\nGenerating ideal image for Hero {node['layer']}:{node['channel']}")
        img = vis.generate_ideal_image(layer, node['channel'], verbose=ProberConfig.VERBOSE_VIZ)
        img_np = img.squeeze().permute(1, 2, 0).numpy()
        ax = axes[1, i]
        ax.imshow(img_np)
        ax.set_title(f"Hero {node['layer']}:{node['channel']}\nShape MI: {node['shape']:.2f}")
        ax.axis('off')
        
    p = os.path.join(ProberConfig.OUTPUT_DIR, "neuron_roles_dashboard.png")
    plt.tight_layout()
    plt.savefig(p)
    print(f"\n[Phase 4] Dashboard saved to {p}")
    
    # 5. Phase 3: Polysemantic Detector
    # Using the BW dataset because we care about SHAPE polysemanticity (e.g. 7 and 1)
    # RG dataset is naturally "polysemantic" for color+shape, which is trivial.
    # We want to find neurons confused about *digits* in clean data.
    
    # Reload BW dataset purely for this analysis (or reuse ds_bw)
    poly_detector = PolysemanticDetector(model, ds_bw)
    poly_neurons, val_images = poly_detector.find_polysemantic_neurons()
    
    print(f"\n[Phase 3] Total polysemantic neurons found: {len(poly_neurons)} (activating for >1 digit class)")
    print(f"[Phase 3] Visualizing top {ProberConfig.TOP_K} most polysemantic neurons")
    
    # Print stats for top polysemantic neurons
    if poly_neurons:
        print(f"\n=== TOP {min(5, len(poly_neurons))} POLYSEMANTIC NEURONS ===", )
        for i, n in enumerate(poly_neurons[:5]):
            print(f"{i+1}. {n['layer']} Ch{n['channel']} | Entropy: {n['entropy']:.3f} | Strength: {n['strength']:.3f} | PolyScore: {n['poly_score']:.3f}")
    
    # Visualize Top K Polysemantic Neurons
    n_viz = min(len(poly_neurons), ProberConfig.TOP_K)
    top_poly = poly_neurons[:n_viz]
    
    if top_poly:
        # Grid: Rows = Neurons, Cols = 1 Ideal Image + 1 Text Info + N Top Activating Images
        fig_p, axes_p = plt.subplots(n_viz, ProberConfig.POLY_TOP_N + 2, figsize=(2 * (ProberConfig.POLY_TOP_N + 2), 2.5 * n_viz))
        fig_p.suptitle(f"Top {n_viz} Polysemantic Neurons (Total: {len(poly_neurons)}, Entropy×Strength Scoring)", fontsize=14)
        
        # Ensure axes_p is 2D array even if n_viz=1
        if n_viz == 1:
            axes_p = axes_p.reshape(1, -1)
            
        for row, neuron in enumerate(top_poly):
            # Generate Ideal Image for this neuron
            layer_obj = getattr(model, neuron['layer'])
            if ProberConfig.VERBOSE_VIZ:
                print(f"\nGenerating ideal image for Polysemantic {neuron['layer']}:{neuron['channel']}")
            ideal_img = vis.generate_ideal_image(layer_obj, neuron['channel'], verbose=ProberConfig.VERBOSE_VIZ)
            ideal_img_np = ideal_img.squeeze().permute(1, 2, 0).numpy()
            
            ax_ideal = axes_p[row, 0]
            ax_ideal.imshow(ideal_img_np)
            ax_ideal.set_title(f"Ideal\n{neuron['layer']}:{neuron['channel']}\nPS={neuron['poly_score']:.2f}", fontsize=9)
            ax_ideal.axis('off')
            
            # Info Text
            info = f"Entropy: {neuron['entropy']:.2f}\nStrength: {neuron['strength']:.2f}\nLabels: {neuron['topn_labels']}"
            ax_text = axes_p[row, 1]
            ax_text.text(0.5, 0.5, info, ha='center', va='center', fontsize=9, wrap=True)
            ax_text.axis('off')
            
            # Top N Activating Images
            for i, idx in enumerate(neuron['topn_idx']):
                img_t = val_images[idx].permute(1, 2, 0).numpy()
                lbl = neuron['topn_labels'][i]
                
                ax_img = axes_p[row, i+2]
                ax_img.imshow(img_t)
                ax_img.set_title(f"Lbl: {lbl}", fontsize=9)
                ax_img.axis('off')
                
        p_poly = os.path.join(ProberConfig.OUTPUT_DIR, "polysemantic_examples.png")
        plt.tight_layout()
        plt.savefig(p_poly)
        print(f"[Phase 3] Polysemantic examples saved to {p_poly}")

if __name__ == "__main__":
    run_suite()
