In [None]:
# Task 6: The Decomposition - Sparse Autoencoders for Feature Discovery
# ======================================================================
# This script trains Sparse Autoencoders (SAEs) to decompose the intermediate
# hidden states of our biased CNN into an overcomplete representation.
# We'll explore whether we can find meaningful features, especially color-related ones.

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import os

# ===== CONFIGURATION =====
# Data paths - UPDATE THESE TO YOUR ACTUAL DATA LOCATIONS
# For Kaggle, use paths like: '/kaggle/input/cmnistneo1/train_data_rg95z.npz'
# For local, provide the full path to your training data

# You may need to download/copy the training data from Kaggle
# If you have training data at a different location, update this path:
TRAIN_DATA_PATH = '/kaggle/input/cmnistneo1/train_data_rg95z.npz'  # Using test data if training data unavailable
TEST_DATA_PATH = '/kaggle/input/cmnistneo1/test_data_gr95z.npz'  # Test data with different color mapping

# Model path (your trained biased model)
MODEL_PATH = '/kaggle/input/task1app3models/pytorch/default/2/task1approach3sc1_modelv1.pth'

# SAE hyperparameters
SAE_EXPANSION_FACTOR = 4  # Overcomplete representation multiplier
SAE_SPARSITY_WEIGHT = 0.001  # L1 penalty for sparsity
SAE_LEARNING_RATE = 0.001
SAE_EPOCHS = 50
SAE_BATCH_SIZE = 256

# Model hyperparameters (from your biased model)
NUM_CLASSES = 10
CONV1_CHANNELS = 32
CONV2_CHANNELS = 64
CONV3_CHANNELS = 64
FC1_UNITS = 128
DROPOUT_RATE = 0.1

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ===== MODEL DEFINITION =====
# Same CNN architecture as in task1_app3
class CNN3Layer(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CNN3Layer, self).__init__()
        
        # First Convolutional Layer
        self.conv1 = nn.Conv2d(3, CONV1_CHANNELS, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        # Second Convolutional Layer
        self.conv2 = nn.Conv2d(CONV1_CHANNELS, CONV2_CHANNELS, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # Third Convolutional Layer
        self.conv3 = nn.Conv2d(CONV2_CHANNELS, CONV3_CHANNELS, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(CONV3_CHANNELS * 3 * 3, FC1_UNITS)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.fc2 = nn.Linear(FC1_UNITS, num_classes)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_intermediate_activations(self, x):
        """Extract intermediate activations from various layers"""
        activations = {}
        
        # Conv1 output (before pooling)
        conv1_out = F.relu(self.conv1(x))
        activations['conv1'] = conv1_out.clone()
        x = self.pool1(conv1_out)
        
        # Conv2 output (before pooling)
        conv2_out = F.relu(self.conv2(x))
        activations['conv2'] = conv2_out.clone()
        x = self.pool2(conv2_out)
        
        # Conv3 output (before pooling)
        conv3_out = F.relu(self.conv3(x))
        activations['conv3'] = conv3_out.clone()
        x = self.pool3(conv3_out)
        
        # Flattened output
        x_flat = x.reshape(x.size(0), -1)
        activations['flattened'] = x_flat.clone()
        
        # FC1 output
        fc1_out = F.relu(self.fc1(x_flat))
        activations['fc1'] = fc1_out.clone()
        
        return activations


# ===== SPARSE AUTOENCODER DEFINITION =====
class SparseAutoencoder(nn.Module):
    """
    Sparse Autoencoder for decomposing hidden states into overcomplete representation.
    Uses L1 regularization to encourage sparsity in the latent space.
    """
    def __init__(self, input_dim, hidden_dim):
        super(SparseAutoencoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Encoder: maps to overcomplete representation
        self.encoder = nn.Linear(input_dim, hidden_dim)
        
        # Decoder: reconstructs original representation
        self.decoder = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x):
        """Encode input to sparse latent representation"""
        return F.relu(self.encoder(x))
    
    def decode(self, z):
        """Decode latent representation back to input space"""
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        x_reconstructed = self.decode(z)
        return x_reconstructed, z


# ===== HELPER FUNCTIONS =====
def load_data(path):
    """Load and preprocess data from npz file"""
    if not os.path.exists(path):
        print(f"Warning: {path} not found. Please provide the correct path.")
        return None, None
    data = np.load(path)
    images = data['images'].astype('float32') / 255.0
    labels = data['labels']
    return images, labels


def extract_activations(model, data_loader, layer_name='fc1'):
    """Extract activations from a specific layer for all samples"""
    model.eval()
    all_activations = []
    all_labels = []
    
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.to(device)
            activations = model.get_intermediate_activations(data)
            
            # Flatten spatial activations for conv layers
            act = activations[layer_name]
            if len(act.shape) > 2:  # Conv layers have shape (B, C, H, W)
                act = act.reshape(act.size(0), -1)
            
            all_activations.append(act.cpu())
            all_labels.append(labels)
    
    return torch.cat(all_activations, dim=0), torch.cat(all_labels, dim=0)


def train_sae(sae, activations, epochs=SAE_EPOCHS, sparsity_weight=SAE_SPARSITY_WEIGHT):
    """Train the Sparse Autoencoder"""
    sae.train()
    optimizer = optim.Adam(sae.parameters(), lr=SAE_LEARNING_RATE)
    
    # Create data loader
    dataset = TensorDataset(activations)
    loader = DataLoader(dataset, batch_size=SAE_BATCH_SIZE, shuffle=True)
    
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_sparsity_loss = 0.0
        
        for batch in loader:
            x = batch[0].to(device)
            
            # Forward pass
            x_recon, z = sae(x)
            
            # Reconstruction loss
            recon_loss = F.mse_loss(x_recon, x)
            
            # Sparsity loss (L1 penalty on latent activations)
            sparsity_loss = torch.mean(torch.abs(z))
            
            # Total loss
            loss = recon_loss + sparsity_weight * sparsity_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_sparsity_loss += sparsity_loss.item()
        
        avg_loss = epoch_loss / len(loader)
        avg_recon = epoch_recon_loss / len(loader)
        avg_sparsity = epoch_sparsity_loss / len(loader)
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] - Total Loss: {avg_loss:.6f}, "
                  f"Recon: {avg_recon:.6f}, Sparsity: {avg_sparsity:.6f}")
    
    return losses


def analyze_features(sae, activations, labels, images, top_k=5):
    """Analyze learned features by finding samples that maximally activate each feature"""
    sae.eval()
    
    with torch.no_grad():
        z = sae.encode(activations.to(device))
        z = z.cpu().numpy()
    
    labels = labels.numpy()
    num_features = z.shape[1]
    
    # Find top activating samples for each feature
    feature_analysis = []
    
    for feat_idx in range(min(50, num_features)):  # Analyze first 50 features
        # Get activation values for this feature across all samples
        feat_activations = z[:, feat_idx]
        
        # Find top-k samples that maximally activate this feature
        top_indices = np.argsort(feat_activations)[-top_k:][::-1]
        
        # Get labels and mean activation for these samples
        top_labels = labels[top_indices]
        mean_activation = feat_activations[top_indices].mean()
        
        # Compute label distribution for top activating samples
        label_counts = np.bincount(top_labels, minlength=10)
        dominant_label = np.argmax(label_counts)
        
        feature_analysis.append({
            'feature_idx': feat_idx,
            'mean_activation': mean_activation,
            'top_labels': top_labels,
            'dominant_label': dominant_label,
            'label_counts': label_counts,
            'top_indices': top_indices
        })
    
    return feature_analysis


def visualize_feature(images, labels, feature_analysis, feature_idx, save_path=None):
    """Visualize samples that maximally activate a specific feature"""
    feat_info = None
    for fa in feature_analysis:
        if fa['feature_idx'] == feature_idx:
            feat_info = fa
            break
    
    if feat_info is None:
        print(f"Feature {feature_idx} not found in analysis")
        return
    
    top_indices = feat_info['top_indices']
    
    fig, axes = plt.subplots(1, len(top_indices), figsize=(15, 3))
    fig.suptitle(f"Feature {feature_idx} - Dominant Label: {feat_info['dominant_label']}", fontsize=14)
    
    for i, idx in enumerate(top_indices):
        ax = axes[i] if len(top_indices) > 1 else axes
        img = images[idx]
        label = labels[idx]
        ax.imshow(img)
        ax.set_title(f"Label: {label}")
        ax.axis('off')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def intervention_experiment(model, sae, images, labels, feature_idx, scale_factors=[0.0, 0.5, 2.0, 5.0]):
    """
    Experiment with interventions by modifying specific feature activations.
    This tests if dialing up/down features changes model predictions.
    """
    model.eval()
    sae.eval()
    
    # Get a batch of images
    batch_size = min(100, len(images))
    sample_indices = np.random.choice(len(images), batch_size, replace=False)
    sample_images = torch.FloatTensor(images[sample_indices]).permute(0, 3, 1, 2).to(device)
    sample_labels = labels[sample_indices]
    
    print(f"\n=== Intervention Experiment on Feature {feature_idx} ===")
    
    for scale in scale_factors:
        # Get original activations
        with torch.no_grad():
            activations = model.get_intermediate_activations(sample_images)
            fc1_act = activations['fc1']  # Shape: (B, 128)
            
            # Encode to SAE space
            z = sae.encode(fc1_act)  # Shape: (B, hidden_dim)
            
            # Modify the specific feature
            z_modified = z.clone()
            z_modified[:, feature_idx] = z_modified[:, feature_idx] * scale
            
            # Decode back
            fc1_modified = sae.decode(z_modified)
            
            # Continue forward pass with modified activations
            fc1_modified = F.relu(fc1_modified)  # Apply ReLU
            output = model.fc2(fc1_modified)
            
            # Get predictions
            predictions = torch.argmax(output, dim=1).cpu().numpy()
            
            # Calculate accuracy
            accuracy = (predictions == sample_labels).mean() * 100
            
            # Calculate prediction distribution
            pred_counts = np.bincount(predictions, minlength=10)
            
            print(f"\nScale Factor: {scale}")
            print(f"  Accuracy: {accuracy:.2f}%")
            print(f"  Prediction Distribution: {pred_counts}")


def analyze_color_features(model, sae, train_images, train_labels, test_images, test_labels):
    """
    Analyze whether SAE features are sensitive to color vs. shape.
    Uses training data (biased colors) and test data (reversed colors) to compare.
    """
    model.eval()
    sae.eval()
    
    # Create data loaders
    train_tensor = torch.FloatTensor(train_images).permute(0, 3, 1, 2)
    test_tensor = torch.FloatTensor(test_images).permute(0, 3, 1, 2)
    
    train_dataset = TensorDataset(train_tensor, torch.LongTensor(train_labels))
    test_dataset = TensorDataset(test_tensor, torch.LongTensor(test_labels))
    
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    # Extract activations
    print("Extracting training activations...")
    train_acts, train_labs = extract_activations(model, train_loader, 'fc1')
    print("Extracting test activations...")
    test_acts, test_labs = extract_activations(model, test_loader, 'fc1')
    
    # Get SAE encodings
    with torch.no_grad():
        train_z = sae.encode(train_acts.to(device)).cpu().numpy()
        test_z = sae.encode(test_acts.to(device)).cpu().numpy()
    
    # Compute mean activations per class for each dataset
    num_features = train_z.shape[1]
    train_class_means = np.zeros((10, num_features))
    test_class_means = np.zeros((10, num_features))
    
    for c in range(10):
        train_mask = train_labs.numpy() == c
        test_mask = test_labs.numpy() == c
        
        if train_mask.sum() > 0:
            train_class_means[c] = train_z[train_mask].mean(axis=0)
        if test_mask.sum() > 0:
            test_class_means[c] = test_z[test_mask].mean(axis=0)
    
    # Find features that differ significantly between train and test
    # These might be color-sensitive features
    feature_diffs = np.abs(train_class_means - test_class_means).mean(axis=0)
    
    # Get top color-sensitive features
    color_sensitive_features = np.argsort(feature_diffs)[-20:][::-1]
    
    # Get features that are consistent (shape-sensitive)
    shape_sensitive_features = np.argsort(feature_diffs)[:20]
    
    print("\n=== Feature Analysis Results ===")
    print(f"\nTop 20 Color-Sensitive Features (high train/test difference):")
    print(f"  Feature indices: {color_sensitive_features}")
    print(f"  Mean differences: {feature_diffs[color_sensitive_features]}")
    
    print(f"\nTop 20 Shape-Sensitive Features (low train/test difference):")
    print(f"  Feature indices: {shape_sensitive_features}")
    print(f"  Mean differences: {feature_diffs[shape_sensitive_features]}")
    
    return color_sensitive_features, shape_sensitive_features, feature_diffs


def main():
    """Main function to run the Sparse Autoencoder analysis"""
    
    print("=" * 60)
    print("Task 6: The Decomposition - Sparse Autoencoders")
    print("=" * 60)
    
    # ===== LOAD DATA =====
    print("\n[1/6] Loading data...")
    
    train_images, train_labels = load_data(TRAIN_DATA_PATH)
    test_images, test_labels = load_data(TEST_DATA_PATH)
    
    if train_images is None:
        print("Please update TRAIN_DATA_PATH and TEST_DATA_PATH to your data locations.")
        print("The script expects npz files with 'images' and 'labels' arrays.")
        return
    
    print(f"Training data: {train_images.shape}")
    if test_images is not None:
        print(f"Test data: {test_images.shape}")
    
    # ===== LOAD BIASED MODEL =====
    print("\n[2/6] Loading biased model...")
    
    model = CNN3Layer(num_classes=NUM_CLASSES).to(device)
    
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print(f"Loaded model from {MODEL_PATH}")
    else:
        print(f"Warning: {MODEL_PATH} not found. Using untrained model.")
        print("Please update MODEL_PATH to point to your trained biased model.")
    
    model.eval()
    
    # ===== EXTRACT ACTIVATIONS =====
    print("\n[3/6] Extracting intermediate activations...")
    
    # Convert to tensors
    train_tensor = torch.FloatTensor(train_images).permute(0, 3, 1, 2)
    train_labels_tensor = torch.LongTensor(train_labels)
    
    train_dataset = TensorDataset(train_tensor, train_labels_tensor)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False)
    
    # Extract FC1 activations (these capture both color and shape information)
    activations, labels_extracted = extract_activations(model, train_loader, 'fc1')
    print(f"Extracted activations shape: {activations.shape}")
    
    # ===== TRAIN SPARSE AUTOENCODER =====
    print("\n[4/6] Training Sparse Autoencoder...")
    
    input_dim = activations.shape[1]
    hidden_dim = input_dim * SAE_EXPANSION_FACTOR  # Overcomplete representation
    
    print(f"SAE Architecture: {input_dim} -> {hidden_dim} -> {input_dim}")
    
    sae = SparseAutoencoder(input_dim, hidden_dim).to(device)
    losses = train_sae(sae, activations)
    
    # Plot training loss
    plt.figure(figsize=(10, 4))
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('SAE Training Loss')
    plt.grid(True)
    plt.savefig('sae_training_loss.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # ===== ANALYZE FEATURES =====
    print("\n[5/6] Analyzing learned features...")
    
    feature_analysis = analyze_features(sae, activations, labels_extracted, train_images, top_k=10)
    
    # Print summary of interesting features
    print("\nFeature Analysis Summary (first 20 features):")
    print("-" * 50)
    for fa in feature_analysis[:20]:
        print(f"Feature {fa['feature_idx']:3d}: Dominant Label = {fa['dominant_label']}, "
              f"Mean Act = {fa['mean_activation']:.4f}, "
              f"Label Dist = {fa['label_counts']}")
    
    # Visualize some interesting features
    print("\nVisualizing top 5 features...")
    for i in range(min(5, len(feature_analysis))):
        visualize_feature(train_images, train_labels, feature_analysis, 
                         feature_analysis[i]['feature_idx'],
                         save_path=f'feature_{i}_samples.png')
    
    # ===== COLOR VS SHAPE ANALYSIS =====
    print("\n[6/6] Analyzing color vs. shape features...")
    
    if test_images is not None:
        color_features, shape_features, diffs = analyze_color_features(
            model, sae, train_images, train_labels, test_images, test_labels
        )
        
        # Visualize feature differences
        plt.figure(figsize=(14, 5))
        plt.bar(range(len(diffs)), diffs)
        plt.xlabel('Feature Index')
        plt.ylabel('Train/Test Activation Difference')
        plt.title('Feature Sensitivity: Higher = More Color-Sensitive')
        plt.axhline(y=np.median(diffs), color='r', linestyle='--', label='Median')
        plt.legend()
        plt.savefig('feature_color_sensitivity.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        # ===== INTERVENTION EXPERIMENTS =====
        print("\n=== Running Intervention Experiments ===")
        
        # Try intervening on a few features
        print("\nIntervening on most color-sensitive features:")
        for feat_idx in color_features[:3]:
            intervention_experiment(model, sae, train_images, train_labels, feat_idx)
        
        print("\nIntervening on most shape-sensitive features:")
        for feat_idx in shape_features[:3]:
            intervention_experiment(model, sae, train_images, train_labels, feat_idx)
    
    # ===== SAVE SAE MODEL =====
    sae_save_path = 'sae_model.pth'
    torch.save(sae.state_dict(), sae_save_path)
    print(f"\nSAE model saved to {sae_save_path}")
    
    print("\n" + "=" * 60)
    print("Task 6 Complete!")
    print("=" * 60)
    print("""
    Key Findings:
    1. The SAE decomposes FC1 activations into {hidden_dim} features
    2. Some features are highly class-specific (likely shape detectors)
    3. Features with high train/test difference are likely color-sensitive
    4. Interventions on color-sensitive features may change predictions
    
    Explore further by:
    - Adjusting SAE_SPARSITY_WEIGHT for more/less sparse representations
    - Analyzing different layers (conv1, conv2, conv3)
    - Visualizing decoder weights to understand feature meanings
    - Manual labeling of features to categorize color vs. shape
    """.format(hidden_dim=hidden_dim))


if __name__ == "__main__":
    main()
