<a href="https://colab.research.google.com/github/desstaw/Shortcut_Learning/blob/main/Muting_Sparse_Neurons_SAE_background.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://colab.research.google.com/drive/1erGvTo3VIy1c3rAL9vxMkEHfyLQ0M_Qg#scrollTo=gmYi78ByHxs0

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import numpy as np
import gc
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
import seaborn as sns
import json
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd



# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set seed for reproducibility
def set_seed(seed=1):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1)


##Load pretrained SAE & Alexnet

In [None]:
'''
class LabeledDataLoaderWrapper:
    def __init__(self, dataloader, label):
        self.dataloader = dataloader
        self.label = label

    def __iter__(self):
        for images in self.dataloader:
            batch_size = images.size(0)
            labels = torch.full((batch_size,), self.label, dtype=torch.long)
            yield images, labels

    def __len__(self):
        return len(self.dataloader)
'''

'\nclass LabeledDataLoaderWrapper:\n    def __init__(self, dataloader, label):\n        self.dataloader = dataloader\n        self.label = label\n\n    def __iter__(self):\n        for images in self.dataloader:\n            batch_size = images.size(0)\n            labels = torch.full((batch_size,), self.label, dtype=torch.long)\n            yield images, labels\n\n    def __len__(self):\n        return len(self.dataloader)\n'

In [None]:
# Define the custom AlexNet model based from the older notebook
class AlexNet(nn.Module):
    def __init__(self, width_mult=1):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
        )
        self.fc1 = nn.Linear(256 * 1 * 1, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 1000)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(-1, 256 * 1 * 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# Load AlexNet Model
def load_model(model_path):
    print(f"Loading model from {model_path}")
    model = AlexNet()
    model.load_state_dict(torch.load(model_path))
    model.to(device)

    # Freeze all layers up to (and including) fc2
    for name, param in model.named_parameters():
        if "fc3" not in name:  # Freeze all layers except fc3
            param.requires_grad = False

    # Set the model to evaluation mode
    model.eval()
    print("Model loaded and layers up to fc2 are frozen")
    return model

'''
# Define Image Dataset and Preprocessing original
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image

# Preprocessing function
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
'''

class ImageDataset(Dataset):
    def __init__(self, path: str, is_two: int):
        self.resize_shape = (64, 64)
        self.transform = transforms.Compose([
            transforms.Resize(self.resize_shape),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.path = path
        self.data_files = os.listdir(self.path)
        self.labels = [is_two] * len(self.data_files)

    def __getitem__(self, i):
        img_path = os.path.join(self.path, self.data_files[i])
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label = self.labels[i]
        return img, label, self.data_files[i]  # Return the filename as a string

    def __len__(self):
        return len(self.data_files)

class MnistDataset(Dataset):
    def __init__(self, file_paths: list, is_two: int):
        self.resize_shape = (64, 64)
        self.transform = transforms.Compose([
            transforms.Resize(self.resize_shape),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.data_files = file_paths  # Accept a list of file paths
        self.labels = [is_two] * len(self.data_files)

    def __getitem__(self, i):
        img_path = self.data_files[i]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label = self.labels[i]
        return img, label, os.path.basename(img_path)  # Return the filename as well

    def __len__(self):
        return len(self.data_files)



In [None]:
# Define Sparse Autoencoder from older notebook
class SparseAutoencoder(nn.Module):
    def __init__(self, in_dims, h_dims):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(in_dims, h_dims), nn.ReLU())
        self.decoder = nn.Sequential(nn.Linear(h_dims, in_dims), nn.ReLU())

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


# Load the pre-trained autoencoder for layer 6 (fc2) (from snippet 4)
def load_autoencoder(device):
    save_sae_dir = '/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/background/Autoencoders/bg_autoencoder_layer_6.pth'
    input_dims = 4096
    encoding_dim = 8192

    # Initialize the autoencoder
    autoencoder = SparseAutoencoder(input_dims, encoding_dim).to(device)
    autoencoder.load_state_dict(torch.load(save_sae_dir))

    # Freeze all parameters of the autoencoder
    for param in autoencoder.parameters():
        param.requires_grad = False

    # Set the autoencoder to evaluation mode
    autoencoder.eval()
    print("Autoencoder loaded and frozen successfully")
    return autoencoder


##Main Pipeline
Load saved SAE and activations. Project fc2 activations into sparse space then decode one with muting in sparse space and once without muting to the worst group: two_with_patch

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import gc
import torch
from scipy.stats import ttest_ind
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score
import matplotlib.pyplot as plt

# Set up device for model computations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure base directory paths are created
base_dir = "/content/drive/MyDrive/Masterthesis/Datasets/mnist"
activation_dir = os.path.join(base_dir, "activations")
output_base_dir = os.path.join(base_dir, "outputs")
Path(output_base_dir).mkdir(parents=True, exist_ok=True)

# Define paths for pre-saved activations
def get_activation_path(folder_name, filename):
    return os.path.join(activation_dir, folder_name, f"{filename}.npy")

# Extract activations for fc2 (layer 6)
def extract_fc2_activations(model, dataloader):
    print("Extracting Alexnet activations for layer fc2...")
    activations = []
    with torch.no_grad():
        for data in dataloader:
            # Unpack the data returned by the DataLoader
            image_tensor, label, filename = data

            # Move the image tensor to the device
            image_tensor = image_tensor.to(device)

            # Extract activations up to the fc2 layer
            tensor = model.layer5(model.layer4(model.layer3(model.layer2(model.layer1(image_tensor)))))
            tensor = tensor.view(-1, 256 * 1 * 1)
            tensor = model.fc2(model.fc1(tensor))

            activations.append(tensor.cpu().numpy())
            print(f"Processed {len(activations)} images")
            torch.cuda.empty_cache()
            gc.collect()
    return np.vstack(activations)

'''
# Function to load activations if they exist or extract and save them if not
def load_or_extract_fc2_activations(model, dataloader, folder_name, filename):
    activation_path = get_activation_path(folder_name, filename)
    if os.path.exists(activation_path):
        print(f"Loading pre-saved Alexnet activations for {filename} from {activation_path}...")
        activations = np.load(activation_path, allow_pickle=True)
    else:
        print(f"No pre-saved Alexnet activations found for {filename}. Extracting and saving...")
        activations = extract_fc2_activations(model, dataloader)
        os.makedirs(os.path.dirname(activation_path), exist_ok=True)
        np.save(activation_path, activations)
        print(f"Activations for layer fc2 saved to {activation_path}")
    return activations
'''

# Function to load activations if they exist or extract and save them if not
def load_or_extract_fc2_activations(model, dataloader, folder_name, filename):
    activation_path = get_activation_path(folder_name, filename)
    print(f"No pre-saved Alexnet activations found for {filename}. Extracting and saving...")
    activations = extract_fc2_activations(model, dataloader)
    os.makedirs(os.path.dirname(activation_path), exist_ok=True)
    np.save(activation_path, activations)
    print(f"Activations for layer fc2 saved to {activation_path}")
    return activations



# Project activations into sparse space
def project_activations(autoencoder, activations, device):
    print("Projecting Alexnet activations into SAE sparse space...")
    with torch.no_grad():
        projected = autoencoder.encoder(torch.from_numpy(activations).to(device).float())
    return projected.cpu().numpy()

# Function 1: Calculate neuron activations per image and overall average for patched/unpatched sets
def calculate_neuron_activations(autoencoder, activations, folder_name, patch_status):
    print(f"Calculating neuron activations for {patch_status} images...")
    projected_activations = project_activations(autoencoder, activations, device)
    neuron_activations = pd.DataFrame(projected_activations)

    # Save individual activations per image
    individual_activation_path = os.path.join(folder_name, f"bg_{patch_status}_individual_neuron_activations.csv")
    neuron_activations.to_csv(individual_activation_path, index=False)

    # Calculate and save the average activations across all images for each neuron
    neuron_avg = neuron_activations.mean(axis=0)
    avg_activation_path = os.path.join(folder_name, f"bg_{patch_status}_average_neuron_activations.csv")
    neuron_avg.to_csv(avg_activation_path, header=["Average Activation"], index_label="Neuron")

    print(f"Saved {patch_status} individual activations to {individual_activation_path} and averages to {avg_activation_path}")
    return neuron_avg

# Function 2: Calculate the absolute difference in average activations between patched and unpatched
def calculate_neuron_differences(avg_activations_patch, avg_activations_no_patch, folder_name):
    print("Calculating absolute difference in activations...")
    abs_diff = np.abs(avg_activations_patch - avg_activations_no_patch)
    diff_path = os.path.join(folder_name, "neuron_absolute_differences.csv")
    abs_diff.to_csv(diff_path, header=["Absolute Difference"], index_label="Neuron")
    print(f"Saved neuron differences to {diff_path}")
    return abs_diff

# Function 3: Identify and save the top 10% neurons with the highest difference
def get_top_neurons(abs_diff, folder_name, top_percentage=0.1):
    top_neuron_count = int(len(abs_diff) * top_percentage)
    top_neurons = abs_diff.nlargest(top_neuron_count).index
    top_neuron_path = os.path.join(folder_name, "bg_top_10_percent_neurons.csv")
    pd.DataFrame(top_neurons, columns=["Neuron"]).to_csv(top_neuron_path, index=False)
    print(f"Saved top 10% neurons with highest differences to {top_neuron_path}")
    return top_neurons

# Function 4: Mute the top neurons in sparse space and classify patched images

def classify_with_muted_neurons(autoencoder, model, activations_patch, top_neurons):
    print("Muting top neurons and classifying patched images...")
    projected_patch = project_activations(autoencoder, activations_patch, device)
    print("Applying muting in sparse space...")
    projected_patch[:, top_neurons] = 0  # Mute selected neurons
    decoded_patch = autoencoder.decoder(torch.from_numpy(projected_patch).to(device).float()).cpu().detach().numpy()

    # Pass through AlexNet softmax for classification
    predictions = []
    for activation in decoded_patch:
        output = model.fc3(torch.from_numpy(activation).float().to(device))
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        predictions.append(prediction)
    return predictions

# Function 5: Classify unmuted sparse activations for non-patched images
def classify_without_muting(autoencoder, model, activations_no_patch):
    print("Classifying non-patched images without muting neurons...")
    projected_no_patch = project_activations(autoencoder, activations_no_patch, device)
    decoded_no_patch = autoencoder.decoder(torch.from_numpy(projected_no_patch).to(device).float()).cpu().detach().numpy()

    # Classify with AlexNet softmax
    predictions = []
    for activation in decoded_no_patch:
        output = model.fc3(torch.from_numpy(activation).float().to(device))
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        predictions.append(prediction)
    return predictions

# Function 6: Save decoded neuron activations per image
def save_decoded_activations(decoded_activations, patch_status, folder_name):
    decoded_path = os.path.join(folder_name, f"bg_decoded_activations_{patch_status}.csv")
    pd.DataFrame(decoded_activations).to_csv(decoded_path, index=False)
    print(f"Saved decoded neuron activations for {patch_status} images to {decoded_path}")

# Function 7: Save average neuron activations across all images
def save_average_activations(decoded_activations, patch_status, folder_name):
    avg_activations = np.mean(decoded_activations, axis=0)
    avg_path = os.path.join(folder_name, f"bg_average_decoded_activations_{patch_status}.csv")
    pd.DataFrame(avg_activations, columns=["Average Activation"]).to_csv(avg_path, index_label="Neuron")
    print(f"Saved average decoded activations for {patch_status} images to {avg_path}")

# Function 8: Evaluate the effect of muting neurons
def evaluate_muting_effect(predictions_with_muting, predictions_without_muting):
    agreement_count = sum(pw == pn for pw, pn in zip(predictions_with_muting, predictions_without_muting))
    accuracy = agreement_count / len(predictions_with_muting) * 100
    print(f"Accuracy of classifications with muted neurons matching non-muted classifications: {accuracy:.2f}%")
    # print("Percentage change in classification accuracy after muting")

# Function 9: Calculate and display classification metrics
def evaluate_classification_metrics(predictions_with_muting, predictions_without_muting, labels):
    target_class = 1
    labels_target = [1 if label == target_class else 0 for label in labels]
    preds_with_muting_target = [1 if pred == target_class else 0 for pred in predictions_with_muting]
    preds_without_muting_target = [1 if pred == target_class else 0 for pred in predictions_without_muting]

    accuracy_with_muting = accuracy_score(labels_target, preds_with_muting_target)
    precision_with_muting = precision_score(labels_target, preds_with_muting_target)
    recall_with_muting = recall_score(labels_target, preds_with_muting_target)

    accuracy_without_muting = accuracy_score(labels_target, preds_without_muting_target)
    precision_without_muting = precision_score(labels_target, preds_without_muting_target)
    recall_without_muting = recall_score(labels_target, preds_without_muting_target)

    print("Metrics for 'two with patch' class with muting:")
    print(f"  Accuracy: {accuracy_with_muting:.2f}")
    print("\nMetrics for 'two with patch' class without muting:")
    print(f"  Accuracy: {accuracy_without_muting:.2f}")

# Function 10: Visualize differences for top neurons with binning
def visualize_binned_neuron_differences(abs_diff, top_neurons, bin_width=0.05):
    # Get the differences for the top neurons and sort them
    top_neuron_diffs = abs_diff.loc[top_neurons].sort_values(ascending=False)

    # Bin the difference values
    max_diff = top_neuron_diffs.max()
    bins = np.arange(0, max_diff + bin_width, bin_width)
    binned_counts = pd.cut(top_neuron_diffs, bins=bins).value_counts(sort=False)

    # Plot the binned counts
    plt.figure(figsize=(10, 6))
    binned_counts.plot(kind='bar', color='skyblue')
    plt.xlabel("Difference Value Bins")
    plt.ylabel("Neuron Count")
    plt.title("Neuron Count in Each Difference Value Bin")

    # Rotate x-axis labels for readability
    plt.xticks(rotation=45, ha='right')

    # Display grid for easier comparison
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

# Function 11: Conduct t-tests in sparse and decoded space
def t_test_sparse_vs_non_sparse(sparse_with_patch, sparse_no_patch):
    t_stat_sparse, p_value_sparse = ttest_ind(sparse_with_patch, sparse_no_patch, equal_var=False)
    print(f"Sparse Activations T-Test:\n  T-statistic: {t_stat_sparse}, P-value: {p_value_sparse}")
    return t_stat_sparse, p_value_sparse

def t_test_decoded_muted_vs_non_muted(decoded_with_patch_muted, decoded_with_patch_non_muted):
    t_stat_decoded, p_value_decoded = ttest_ind(decoded_with_patch_muted, decoded_with_patch_non_muted, equal_var=False)
    print(f"Decoded Activations T-Test:\n  T-statistic: {t_stat_decoded}, P-value: {p_value_decoded}")
    return t_stat_decoded, p_value_decoded

# Function 12 to classify decoded activations
def classify_decoded_activations(model, decoded_activations):
    """Classify decoded activations using the softmax layer of the model."""
    predictions = []
    for activation in decoded_activations:
        output = model.fc3(torch.from_numpy(activation).float().to(device))
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        predictions.append(prediction)
    return predictions


# Function 13: Perform and print formatted t-tests for layer activations
def perform_and_format_t_tests(condition1_activations, condition2_activations, layer_name="Layer"):
    # Perform two-sample t-test
    t_stat, p_values = ttest_ind(condition1_activations, condition2_activations, axis=0, equal_var=False)

    # Bonferroni correction
    num_neurons = condition1_activations.shape[1]
    adjusted_p_values = np.minimum(p_values * num_neurons, 1.0)

    # Calculate mean activations for each neuron
    condition1_mean = np.mean(condition1_activations, axis=0)
    condition2_mean = np.mean(condition2_activations, axis=0)

    # Calculate and display the percentage of neurons with significant p-values
    raw_significant_0_05 = np.mean(p_values <= 0.05) * 100
    raw_significant_0_02 = np.mean(p_values <= 0.02) * 100
    corrected_significant_0_05 = np.mean(adjusted_p_values <= 0.05) * 100
    corrected_significant_0_02 = np.mean(adjusted_p_values <= 0.02) * 100

    print(f"{layer_name}:")
    print("  Condition 1 Mean Activation (Muted):")
    print(f"    Mean across neurons: {condition1_mean.mean():.4f}")
    print("  Condition 2 Mean Activation (Non-Muted):")
    print(f"    Mean across neurons: {condition2_mean.mean():.4f}")
    print("  T-Test (before Bonferroni correction):")
    print(f"    Percentage of neurons with raw p-value <= 0.05: {raw_significant_0_05:.2f}%")
    print(f"    Percentage of neurons with raw p-value <= 0.02: {raw_significant_0_02:.2f}%")
    print("  T-Test (after Bonferroni correction):")
    print(f"    Percentage of neurons with adjusted p-value <= 0.05: {corrected_significant_0_05:.2f}%")
    print(f"    Percentage of neurons with adjusted p-value <= 0.02: {corrected_significant_0_02:.2f}%")
    print("-" * 50)


'''
def classify_with_alexnet(model, dataloader):
    print("Classifying test images using the loaded AlexNet model...")
    predictions = []
    labels = []

    with torch.no_grad():
        for images, label in dataloader:
            images = images.to(device)
            label = label.to(device)

            # Forward pass through AlexNet
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            predictions.extend(predicted.cpu().numpy())
            labels.extend(label.cpu().numpy())

    # Calculate accuracy
    accuracy = accuracy_score(labels, predictions)
    print(f"Accuracy of the AlexNet model on the test set (with patch): {accuracy:.2f}")
    return accuracy
'''



'\ndef classify_with_alexnet(model, dataloader):\n    print("Classifying test images using the loaded AlexNet model...")\n    predictions = []\n    labels = []\n\n    with torch.no_grad():\n        for images, label in dataloader:\n            images = images.to(device)\n            label = label.to(device)\n\n            # Forward pass through AlexNet\n            outputs = model(images)\n            _, predicted = torch.max(outputs, 1)\n\n            predictions.extend(predicted.cpu().numpy())\n            labels.extend(label.cpu().numpy())\n\n    # Calculate accuracy\n    accuracy = accuracy_score(labels, predictions)\n    print(f"Accuracy of the AlexNet model on the test set (with patch): {accuracy:.2f}")\n    return accuracy\n'

In [None]:
# Main function
def main():
    # Paths and initialization
    model_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt"
    patch_folder = '/content/drive/MyDrive/Masterthesis/Datasets/mnist/dataset_splits/background/test/class_2'
    no_patch_folder = '/content/drive/MyDrive/Masterthesis/Datasets/mnist/dataset_splits/original/test/class_2'

    # Load model and autoencoder
    model = load_model(model_path)
    autoencoder = load_autoencoder(device)

    # Prepare dataloaders
    patch_image_paths = [os.path.join(root, file) for root, dirs, files in os.walk(patch_folder) for file in files if file.endswith(('.jpg', '.png'))]
    no_patch_image_paths = [os.path.join(root, file) for root, dirs, files in os.walk(no_patch_folder) for file in files if file.endswith(('.jpg', '.png'))]

    patch_dataset = MnistDataset(patch_image_paths, is_two=1)
    no_patch_dataset = MnistDataset(no_patch_image_paths, is_two=1)

    patch_loader = DataLoader(patch_dataset, batch_size=1, shuffle=False)
    no_patch_loader = DataLoader(no_patch_dataset, batch_size=1, shuffle=False)



    # Load or extract fc2 activations
    activations_patch = load_or_extract_fc2_activations(model, patch_loader, 'test_patch', 'bg_fc2_activations_patch')
    activations_no_patch = load_or_extract_fc2_activations(model, no_patch_loader, 'test_no_patch', 'bg_fc2_activations_no_patch')

    # Directory for saving results
    sparse_output_dir = os.path.join(output_base_dir, "bg_fc2_sparse_outputs")
    Path(sparse_output_dir).mkdir(parents=True, exist_ok=True)

    # Calculate activations and save outputs
    avg_activations_patch = calculate_neuron_activations(autoencoder, activations_patch, sparse_output_dir, "patch")
    avg_activations_no_patch = calculate_neuron_activations(autoencoder, activations_no_patch, sparse_output_dir, "no_patch")

    # Calculate differences and get top neurons
    abs_diff = calculate_neuron_differences(avg_activations_patch, avg_activations_no_patch, sparse_output_dir)
    max_neurons = 4096
    top_neurons = get_top_neurons(abs_diff, sparse_output_dir)

    # Visualize top neuron differences
    #visualize_binned_neuron_differences(abs_diff, top_neurons, bin_width=0.05)


    # Classify with and without muting, then evaluate
    # predictions with muting result of activations_patch (two patch)
    predictions_patch_with_muting = classify_with_muted_neurons(autoencoder, model, activations_patch, top_neurons)
    predictions_patch_without_muting = classify_without_muting(autoencoder, model, activations_patch)
    evaluate_muting_effect(predictions_patch_with_muting, predictions_patch_without_muting)

    # Evaluate classification metrics
    labels = [1] * len(predictions_patch_with_muting)  # All images are class 2
    evaluate_classification_metrics(predictions_patch_with_muting, predictions_patch_without_muting, labels)
    print('Elmafrood tala3 7aga')

    # Project into sparse space
    projected_patch = project_activations(autoencoder, activations_patch, device)
    projected_no_patch = project_activations(autoencoder, activations_no_patch, device)

    projected_patch = project_activations(autoencoder, activations_patch, device)
    projected_patch[:, top_neurons] = 0
    # Ensure we have decoded versions if needed for t-tests on decoded activations
    decoded_patch_muted = autoencoder.decoder(torch.from_numpy(projected_patch).to(device).float()).cpu().detach().numpy()
    decoded_patch_non_muted = autoencoder.decoder(torch.from_numpy(projected_no_patch).to(device).float()).cpu().detach().numpy()
    decoded_no_patch_muted = autoencoder.decoder(torch.from_numpy(projected_patch).to(device).float()).cpu().detach().numpy()
    decoded_no_patch_non_muted = autoencoder.decoder(torch.from_numpy(projected_no_patch).to(device).float()).cpu().detach().numpy()

    # Perform t-tests on sparse activations
    #print("Performing T-Tests on Sparse Activations:")
    #perform_and_format_t_tests(projected_patch, projected_no_patch, layer_name="Sparse Activations")

    # Perform t-tests on decoded activations
    #print("Performing T-Tests on Decoded Activations (Muted vs Non-Muted):")
    #perform_and_format_t_tests(decoded_with_patch_muted, decoded_with_patch_non_muted, layer_name="Decoded Activations")

    # Classify decoded activations for images with no patch in sparse space without muting
    print("\nClassifying decoded activations for 'two with no patch' after projecting into sparse space and decoding without muting...")
    predictions_no_patch_non_muted_decoded = classify_decoded_activations(model, decoded_no_patch_non_muted)

    # Classify decoded activations for images with patch in sparse space with muting
    print("\nClassifying decoded activations for 'two with patch' after projecting into sparse space and decoding with muting...")
    predictions_patch_decoded_muting = classify_decoded_activations(model, decoded_patch_muted)

    # Classify decoded activations for images with patch in sparse space without muting
    print("\nClassifying decoded activations for 'two with patch' after projecting into sparse space and decoding without muting...")
    predictions_patch_decoded_non_muting = classify_decoded_activations(model, decoded_patch_non_muted)

    # Calculate and print accuracy for 'two with no patch' decoded activations
    labels_no_patch = [1] * len(predictions_no_patch_non_muted_decoded)  # All images are labeled as class 2
    accuracy_no_patch_non_muting_decoded = accuracy_score(labels_no_patch, predictions_no_patch_non_muted_decoded)
    print(f"Accuracy of 'two with no patch' decoded activations after sparse projection and no muting: {accuracy_no_patch_non_muting_decoded:.5f}")

    # Calculate and print accuracy for 'two with patch' decoded activations without muting
    labels_patch = [1] * len(predictions_patch_decoded_non_muting)  # All images are labeled as class 2
    accuracy_patch_decoded_non_muting = accuracy_score(labels_patch, predictions_patch_decoded_non_muting)
    print(f"Accuracy of 'two with patch' decoded activations after sparse projection and no muting: {accuracy_patch_decoded_non_muting:.5f}")


    # Calculate and print accuracy for 'two with patch' decoded activations with muting
    labels_no_patch = [1] * len(predictions_patch_decoded_muting)  # All images are labeled as class 2
    accuracy_patch_decoded_muting = accuracy_score(labels_no_patch, predictions_patch_decoded_muting)
    print(f"Accuracy of 'two with patch' decoded activations after sparse projection and muting: {accuracy_patch_decoded_muting:.5f}")

    print("Top neurons and their differences:", abs_diff.loc[top_neurons])



'''

# Prepare the dataloader for patched images
    patch_image_paths = [os.path.join(root, file)
                        for root, dirs, files in os.walk(patch_folder)
                        for file in files if file.endswith(('.jpg', '.png'))]
    patch_dataset = ImageDataset(patch_image_paths, transform=preprocess)
    patch_loader = DataLoader(patch_dataset, batch_size=1, shuffle=False)

    # Wrap the dataloader with a fixed label of 1
    labeled_patch_loader = LabeledDataLoaderWrapper(patch_loader, label=0)

    # Classify using the original model
    classify_with_alexnet(model, labeled_patch_loader)
'''




main()

Loading model from /content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt


  model.load_state_dict(torch.load(model_path))


Model loaded and layers up to fc2 are frozen


  autoencoder.load_state_dict(torch.load(save_sae_dir))


Autoencoder loaded and frozen successfully
No pre-saved Alexnet activations found for bg_fc2_activations_patch. Extracting and saving...
Extracting Alexnet activations for layer fc2...
Processed 1 images
Processed 2 images
Processed 3 images
Processed 4 images
Processed 5 images
Processed 6 images
Processed 7 images
Processed 8 images
Processed 9 images
Processed 10 images
Processed 11 images
Processed 12 images
Processed 13 images
Processed 14 images
Processed 15 images
Processed 16 images
Processed 17 images
Processed 18 images
Processed 19 images
Processed 20 images
Processed 21 images
Processed 22 images
Processed 23 images
Processed 24 images
Processed 25 images
Processed 26 images
Processed 27 images
Processed 28 images
Processed 29 images
Processed 30 images
Processed 31 images
Processed 32 images
Processed 33 images
Processed 34 images
Processed 35 images
Processed 36 images
Processed 37 images
Processed 38 images
Processed 39 images
Processed 40 images
Processed 41 images
Proc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Classifying decoded activations for 'two with no patch' after projecting into sparse space and decoding without muting...

Classifying decoded activations for 'two with patch' after projecting into sparse space and decoding with muting...

Classifying decoded activations for 'two with patch' after projecting into sparse space and decoding without muting...
Accuracy of 'two with no patch' decoded activations after sparse projection and no muting: 0.99419
Accuracy of 'two with patch' decoded activations after sparse projection and no muting: 0.99419
Accuracy of 'two with patch' decoded activations after sparse projection and muting: 1.00000
Top neurons and their differences: 7216    1.563951
7527    1.414505
7167    1.397315
6756    1.287517
8134    1.228547
          ...   
5465    0.358660
4933    0.358117
3538    0.357912
3323    0.357901
3243    0.357396
Length: 819, dtype: float32


In [None]:

def identify_patch_specific_neurons(avg_activations_patch, avg_activations_no_patch, patch_threshold=0.2, no_patch_threshold=0.05):
    """
    Identify neurons that are highly activated for patched images but not for non-patched ones.
    """
    print("Identifying neurons selectively activated by patches...")

    # Calculate activation differences without normalization
    high_patch_activation = avg_activations_patch > patch_threshold
    low_no_patch_activation = avg_activations_no_patch < no_patch_threshold

    # Select neurons that meet both criteria
    patch_specific_neurons = np.where(high_patch_activation & low_no_patch_activation)[0]

    # Debugging: Print some stats to understand what's happening
    print(f"Average activation for patch: {avg_activations_patch.mean():.4f}, No patch: {avg_activations_no_patch.mean():.4f}")
    print(f"Number of neurons with high activation for patch: {(high_patch_activation).sum()}")
    print(f"Number of neurons with low activation for no patch: {(low_no_patch_activation).sum()}")
    print(f"Found {len(patch_specific_neurons)} patch-specific neurons.")

    return patch_specific_neurons

import pandas as pd

def save_top_neurons_to_csv(abs_diff, top_neurons, folder_name, filename="bg_top_neurons.csv"):
    """
    Save the top neurons with their difference values to a CSV file.
    """
    print(f"Saving top neurons to CSV file: {filename}")

    # Create a DataFrame with neuron indices and their absolute differences
    neuron_data = pd.DataFrame({
        "Neuron_Index": range(len(abs_diff)),
        "Activation_Difference": abs_diff
    })

    # Mark whether each neuron is in the top 10%
    neuron_data["Selected_for_Muting"] = neuron_data["Neuron_Index"].isin(top_neurons)

    # Sort by absolute difference in descending order
    neuron_data.sort_values(by="Activation_Difference", ascending=False, inplace=True)

    # Save the DataFrame to a CSV file
    csv_path = os.path.join(folder_name, filename)
    neuron_data.to_csv(csv_path, index=False)
    print(f"CSV saved at: {csv_path}")

def classify_with_alexnet(model, activations):
    """
    Classify images using the original AlexNet classifier on the fc2 activations.
    """
    predictions = []
    for activation in activations:
        # Convert numpy activation to tensor
        output = model.fc3(torch.from_numpy(activation).float().to(device))
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        predictions.append(prediction)
    return predictions




def main():
    # Load the pre-trained models
    model_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt"
    model = load_model(model_path)
    autoencoder = load_autoencoder(device)

    # Define paths to pre-saved activations
    activation_patch_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/test_patch/bg_fc2_activations_patch.npy"
    activation_no_patch_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/test_no_patch/bg_fc2_activations_no_patch.npy"

    # Ensure the output directory exists
    folder_name = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/difference_analysis"
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    # Load pre-saved activations
    print(f"Loading pre-saved AlexNet activations for fc2_activations_patch...")
    activations_patch = np.load(activation_patch_path, allow_pickle=True)
    print(f"Loading pre-saved AlexNet activations for fc2_activations_no_patch...")
    activations_no_patch = np.load(activation_no_patch_path, allow_pickle=True)

    # Direct classification using AlexNet
    predictions_patch_alexnet = classify_with_alexnet(model, activations_patch)
    accuracy_patch_alexnet = accuracy_score([1] * len(predictions_patch_alexnet), predictions_patch_alexnet)

    predictions_no_patch_alexnet = classify_with_alexnet(model, activations_no_patch)
    accuracy_no_patch_alexnet = accuracy_score([1] * len(predictions_no_patch_alexnet), predictions_no_patch_alexnet)

    # Project activations into sparse space
    projected_patch = project_activations(autoencoder, activations_patch, device)
    projected_no_patch = project_activations(autoencoder, activations_no_patch, device)

    # Decode the projected activations back to the original space
    decoded_patch = autoencoder.decoder(torch.from_numpy(projected_patch).to(device).float()).cpu().detach().numpy()
    decoded_no_patch = autoencoder.decoder(torch.from_numpy(projected_no_patch).to(device).float()).cpu().detach().numpy()

    # Calculate the absolute differences between patch and no patch
    avg_activations_patch = np.mean(projected_patch, axis=0)
    avg_activations_no_patch = np.mean(projected_no_patch, axis=0)
    abs_diff = np.abs(avg_activations_patch - avg_activations_no_patch)

    # Identify the top 10% neurons with the highest differences
    top_neuron_count = int(len(abs_diff) * 0.1)
    top_neurons = np.argsort(abs_diff)[-top_neuron_count:]

    # Classify 'two_with_patch' without muting
    print("Classifying 'two_with_patch' without muting neurons...")
    predictions_patch_without_muting = classify_decoded_activations(model, decoded_patch)
    accuracy_patch_without_muting = accuracy_score([1] * len(predictions_patch_without_muting), predictions_patch_without_muting)

    # Mute the top neurons for 'two_with_patch' and classify
    projected_patch[:, top_neurons] = 0
    decoded_patch_muted = autoencoder.decoder(torch.from_numpy(projected_patch).to(device).float()).cpu().detach().numpy()
    predictions_patch_with_muting = classify_decoded_activations(model, decoded_patch_muted)
    accuracy_patch_with_muting = accuracy_score([1] * len(predictions_patch_with_muting), predictions_patch_with_muting)

    # Classify 'two_no_patch' without muting
    print("Classifying 'two_no_patch' without muting neurons...")
    predictions_no_patch_without_muting = classify_decoded_activations(model, decoded_no_patch)
    accuracy_no_patch_without_muting = accuracy_score([1] * len(predictions_no_patch_without_muting), predictions_no_patch_without_muting)

    # Mute the top neurons for 'two_no_patch' and classify
    projected_no_patch[:, top_neurons] = 0
    decoded_no_patch_muted = autoencoder.decoder(torch.from_numpy(projected_no_patch).to(device).float()).cpu().detach().numpy()
    predictions_no_patch_with_muting = classify_decoded_activations(model, decoded_no_patch_muted)
    accuracy_no_patch_with_muting = accuracy_score([1] * len(predictions_no_patch_with_muting), predictions_no_patch_with_muting)

    # Print the results
    print("\nClassification Accuracy Results:")
    print(f"1. Accuracy (two_with_patch using AlexNet directly): {accuracy_patch_alexnet:.4f}")
    print(f"2. Accuracy (two_no_patch using AlexNet directly): {accuracy_no_patch_alexnet:.4f}")
    print(f"3. Accuracy (two_with_patch without muting): {accuracy_patch_without_muting:.4f}")
    print(f"4. Accuracy (two_with_patch with muting): {accuracy_patch_with_muting:.4f}")
    print(f"5. Accuracy (two_no_patch without muting): {accuracy_no_patch_without_muting:.4f}")
    print(f"6. Accuracy (two_no_patch with muting): {accuracy_no_patch_with_muting:.4f}")

if __name__ == "__main__":
    main()



Loading model from /content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt


  model.load_state_dict(torch.load(model_path))


Model loaded and layers up to fc2 are frozen


  autoencoder.load_state_dict(torch.load(save_sae_dir))


Autoencoder loaded and frozen successfully
Loading pre-saved AlexNet activations for fc2_activations_patch...
Loading pre-saved AlexNet activations for fc2_activations_no_patch...
Projecting Alexnet activations into SAE sparse space...
Projecting Alexnet activations into SAE sparse space...
Classifying 'two_with_patch' without muting neurons...
Classifying 'two_no_patch' without muting neurons...

Classification Accuracy Results:
1. Accuracy (two_with_patch using AlexNet directly): 0.0000
2. Accuracy (two_no_patch using AlexNet directly): 0.9971
3. Accuracy (two_with_patch without muting): 0.0000
4. Accuracy (two_with_patch with muting): 1.0000
5. Accuracy (two_no_patch without muting): 0.9942
6. Accuracy (two_no_patch with muting): 1.0000



**Something to think about when using laster a larget sparse space...now I only have 8k neurons in the sparse space:**

In the sparse space, not all neurons are consistently activated across all images. For example, a neuron might remain inactive (close to zero) in most images, but activate strongly for a few specific images, such as those containing spurious features like patches. When we take the average activation of that neuron across all images, the low values from the inactive images will dominate, resulting in a low overall average. This averaging process can therefore obscure the true impact of that neuron in encoding the patch feature, leading to a misleadingly low indication of its importance. The concern is that by using the average activation values in this way, we might be overlooking neurons that are actually sensitive to the spurious features but appear unimportant due to their sparsity. This could affect the accuracy of our results, particularly in identifying which neurons are encoding spurious features.



---



Double checking the test accuracies without projecting into sparse space



---



## Evaluate Model on activations

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class NpyActivationsDataset(Dataset):
    def __init__(self, activation_path: str, is_two: int):
        # Load the entire numpy file into memory
        self.activations = np.load(activation_path)
        self.labels = [is_two] * len(self.activations)

    def __getitem__(self, idx):
        activation = torch.tensor(self.activations[idx], dtype=torch.float32)
        label = self.labels[idx]
        return activation, label, idx

    def __len__(self):
        return len(self.activations)

# Specify the path to the .npy activations file
activation_patch_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/test_patch/bg_fc2_activations_patch.npy"


In [None]:
import pandas as pd
from tqdm import tqdm

def evaluate_model_on_activations(model, dataloader, device, output_csv='predictions.csv'):
    model.eval()
    predictions = []
    running_corrects = 0
    total_samples = 0

    with torch.no_grad():
        for activations, labels, indices in tqdm(dataloader):
            activations, labels = activations.to(device), labels.to(device)

            # Pass the pre-saved activations through fc3 only
            outputs = model.fc3(activations)
            _, preds = torch.max(outputs, 1)

            # Store predictions with index as reference
            for idx, pred in zip(indices, preds):
                predictions.append((idx.item(), pred.item()))

            # Calculate accuracy
            running_corrects += (preds == labels).sum().item()
            total_samples += labels.size(0)

    # Calculate accuracy
    accuracy = running_corrects / total_samples
    print(f"Test Accuracy: {accuracy:.4f}")

    # Save predictions to CSV
    df = pd.DataFrame(predictions, columns=['index', 'predicted_class'])
    df.to_csv(output_csv, index=False)
    print(f"Predictions saved to {output_csv}")


In [None]:
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset using the pre-saved .npy activations
test_dataset = NpyActivationsDataset(activation_patch_path, is_two=1)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

# Load the trained model
model = AlexNet()
model.load_state_dict(torch.load('/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt'))
model.to(device)

# Evaluate using the pre-saved activations
evaluate_model_on_activations(model, test_loader, device)


  model.load_state_dict(torch.load('/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt'))
100%|██████████| 33/33 [00:00<00:00, 85.36it/s] 

Test Accuracy: 0.0000
Predictions saved to predictions.csv





## Evaluate model on raw images

In [None]:

class MnistDataset(Dataset):
    def __init__(self, path: str, is_two: int):
        self.resize_shape = (64, 64)
        self.transform = transforms.Compose([
            transforms.Resize(self.resize_shape),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.path = path
        self.data_files = os.listdir(self.path)
        self.labels = [is_two] * len(self.data_files)

    def __getitem__(self, i):
        img_path = os.path.join(self.path, self.data_files[i])
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label = self.labels[i]
        return img, label, self.data_files[i]  # Return the filename as a string

    def __len__(self):
        return len(self.data_files)



class AlexNet(nn.Module):
    def __init__(self, width_mult=1):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # 96*55*55 (for 224x224 input)
            nn.MaxPool2d(kernel_size=3, stride=2),  # 96*27*27
            nn.ReLU(inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),  # 256*27*27
            nn.MaxPool2d(kernel_size=3, stride=2),  # 256*13*13
            nn.ReLU(inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),  # 384*13*13
            nn.ReLU(inplace=True),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1),  # 384*13*13
            nn.ReLU(inplace=True),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 256*13*13
            nn.MaxPool2d(kernel_size=3, stride=2),  # 256*6*6
            nn.ReLU(inplace=True),
        )
        self.fc1 = nn.Linear(256 * 1 * 1, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 1000)  # 1000 output

    def forward(self, x):
        x = self.layer1(x)
        #print("After layer1:", x.mean().item(), x.std().item())
        x = self.layer2(x)
        #print("After layer2:", x.mean().item(), x.std().item())
        x = self.layer3(x)
        #print("After layer3:", x.mean().item(), x.std().item())
        x = self.layer4(x)
        #print("After layer4:", x.mean().item(), x.std().item())
        x = self.layer5(x)
        #print("After layer5:", x.mean().item(), x.std().item())
        x = x.view(-1, 256 * 1 * 1)
        x = self.fc1(x)
        #print("After fc1:", x.mean().item(), x.std().item())
        x = self.fc2(x)
        #print("After fc2:", x.mean().item(), x.std().item())
        x = self.fc3(x)
        #print("After fc3 (output):", x.mean().item(), x.std().item())
        return x


In [None]:
import torch
from tqdm import tqdm
import pandas as pd

def evaluate_model(model, dataloader, device, output_csv='predictions.csv'):

    model.eval()
    predictions = []
    running_corrects = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels, filenames in tqdm(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Store image names and predictions correctly
            for img_name, pred in zip(filenames, preds):
                predictions.append((img_name, pred.item()))

            # Calculate accuracy
            running_corrects += (preds == labels).sum().item()
            total_samples += labels.size(0)

    # Calculate overall accuracy
    accuracy = running_corrects / total_samples
    print(f"Test Accuracy: {accuracy:.4f}")

    # Convert to a DataFrame and save to CSV
    df = pd.DataFrame(predictions, columns=['image_name', 'predicted_class'])
    df.to_csv(output_csv, index=False)
    print(f"Predictions saved to {output_csv}")



In [None]:
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_dataset = MnistDataset(path='/content/drive/MyDrive/Masterthesis/Datasets/mnist/dataset_splits/background/test/class_2', is_two=1)  # Assuming is_two is 1 for class 2
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2)
# Load the trained model
model = AlexNet()
model.load_state_dict(torch.load('/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt'))
model.to(device)
model.eval()  # Set the model to evaluation mode

# Run the evaluation and save predictions to CSV
evaluate_model(model, test_loader, device)

  model.load_state_dict(torch.load('/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt'))
100%|██████████| 516/516 [00:07<00:00, 68.67it/s] 

Test Accuracy: 0.0000
Predictions saved to predictions.csv





In [None]:
stophere

# Archives

In [None]:
# Define paths to Google Drive locations for pre-saved activations
def get_activation_path(folder_name, filename):
    return f'/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/{folder_name}/{filename}.npy'


# Extract activations for fc2 (layer 6)
def extract_fc2_activations(model, dataloader):
    print("Extracting Alexnet activations for layer fc2...")
    activations = []
    with torch.no_grad():
        for image_tensor in dataloader:
            image_tensor = image_tensor.to(device)
            tensor = model.layer5(model.layer4(model.layer3(model.layer2(model.layer1(image_tensor)))))
            tensor = tensor.view(-1, 256 * 1 * 1)
            tensor = model.fc2(model.fc1(tensor))
            activations.append(tensor.cpu().numpy())
            print(f"Processed {len(activations)} images")
            torch.cuda.empty_cache()
            gc.collect()
    return np.vstack(activations)


# Function to load activations if they exist or extract and save them if not
def load_or_extract_fc2_activations(model, dataloader, folder_name, filename):
    activation_path = get_activation_path(folder_name, filename)

    # Check if the activation file already exists
    if os.path.exists(activation_path):
        print(f"Loading pre-saved Alexnet activations for {filename} from {activation_path}...")
        activations = np.load(activation_path, allow_pickle=True)
    else:
        print(f"No pre-saved Alexnet activations found for {filename}. Extracting and saving...")
        activations = extract_fc2_activations(model, dataloader)
        os.makedirs(os.path.dirname(activation_path), exist_ok=True)
        np.save(activation_path, activations)
        print(f"Activations for layer fc2 saved to {activation_path}")

    return activations


# Extract and save activations for fc2 (layer 6) to Google Drive
def extract_and_save_fc2_activations(model, dataloader, folder_name, filename):
    print("Extracting and saving Alexnet activations for layer fc2...")
    # Extract activations using the existing function
    activations = extract_fc2_activations(model, dataloader)

    # Define the save path in Google Drive
    drive_path = f'/content/drive/MyDrive/Masterthesis/Datasets/mnist/activations/{folder_name}/{filename}.npy'
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)

    # Save the activations as a .npy file
    np.save(drive_path, activations)
    print(f"Alexnet Activations for layer fc2 saved to {drive_path}")

# Project activations into sparse space
def project_activations(autoencoder, activations, device):
    print("Projecting Alexnet activations into SAE sparse space...")
    with torch.no_grad():
        projected = autoencoder.encoder(torch.from_numpy(activations).to(device).float())
    return projected.cpu().numpy()

# Mean Activation Difference
def mean_activation_difference(projected_patch, projected_no_patch, top_k=10):
    print("Calculating mean sparse activations difference...")
    mean_diff = np.abs(projected_patch.mean(axis=0) - projected_no_patch.mean(axis=0))
    top_neurons = np.argsort(mean_diff)[-top_k:]
    return top_neurons

# Statistical Significance Testing
def statistical_testing_neurons(projected_patch, projected_no_patch, threshold=0.05):
    print("Performing statistical testing...")
    significant_neurons = []
    for i in range(projected_patch.shape[1]):
        _, p_value = ttest_ind(projected_patch[:, i], projected_no_patch[:, i], equal_var=False)
        if p_value < threshold:
            significant_neurons.append(i)
    return significant_neurons

# Use AlexNet's own FC weights to identify patch-relevant neurons
def patch_classifier_importance(model, projected_patch, projected_no_patch, top_k=10):
    print("Using AlexNet's FC layer weights to identify important neurons...")
    # Extract weights from the final fully connected layer (fc3 as output)
    importance = np.abs(model.fc3.weight.cpu().detach().numpy()[0])  # Take absolute values of weights

    # Sort by importance and get the top K neurons
    top_neurons = np.argsort(importance)[-top_k:]
    print("Top neurons identified based on AlexNet's weights:", top_neurons)
    return top_neurons


# Correlation Analysis
def correlation_analysis(projected_patch, projected_no_patch, top_k=10):
    combined = np.vstack([projected_patch, projected_no_patch])
    patch_condition = np.hstack([np.ones(len(projected_patch)), np.zeros(len(projected_no_patch))])
    correlations = [pearsonr(combined[:, i], patch_condition)[0] for i in range(combined.shape[1])]
    top_neurons = np.argsort(np.abs(correlations))[-top_k:]
    return top_neurons

# Visualize Top Neurons
def visualize_neurons(neuron_indexes_dict):
    # Convert neuron indexes to a DataFrame, filling missing values with NaN
    neuron_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in neuron_indexes_dict.items()]))

    # Plot each method's top neuron indexes using a heatmap
    plt.figure(figsize=(10, 6))
    sns.heatmap(neuron_df, annot=True, fmt=".0f", cmap="viridis",  # Use .0f to handle NaN values
                cbar=True, yticklabels=False)

    plt.xlabel("Methods")
    plt.ylabel("Top Neurons")
    plt.title("Comparison of Silenced Neurons Across Methods")
    plt.show()


def silence_and_classify(autoencoder, model, activations_patch, selected_neurons):
    projected_patch = project_activations(autoencoder, activations_patch, device)
    silenced_patch = np.copy(projected_patch)
    silenced_patch[:, selected_neurons] = 0
    decoded_patch = autoencoder.decoder(torch.from_numpy(silenced_patch).to(device).float()).detach().cpu().numpy()

    # Decode and classify with AlexNet softmax
    predictions = []
    for decoded_activation in decoded_patch:
        decoded_tensor = torch.from_numpy(decoded_activation).float().to(device)
        output = model.fc3(decoded_tensor)  # After fc2, apply fc3 for classification
        prediction = torch.argmax(F.softmax(output, dim=0)).item()
        predictions.append(prediction)
    return predictions


# Helper function to calculate accuracy
def calculate_accuracy(predictions, labels):
    correct = sum([1 if pred == label else 0 for pred, label in zip(predictions, labels)])
    return correct / len(labels) * 100  # Returns accuracy percentage

# Check overlap in silenced neurons across methods
def check_neuron_overlap(silenced_neurons_dict):
    methods = list(silenced_neurons_dict.keys())
    overlap_counts = {}

    for i, method1 in enumerate(methods):
        for method2 in methods[i + 1:]:
            overlap = set(silenced_neurons_dict[method1]).intersection(silenced_neurons_dict[method2])
            overlap_counts[f"{method1} & {method2}"] = len(overlap)

    print("Neuron Overlap Across Methods:")
    for pair, count in overlap_counts.items():
        print(f"{pair}: {count} neurons")




In [None]:
# Directory path for saving outputs
output_dir = Path("/content/drive/MyDrive/Masterthesis/Datasets/mnist/muted_sparse_sae/outputs")
output_dir.mkdir(parents=True, exist_ok=True)

# Main function for all methods
def main():
    # Paths and initialization
    model_path = "/content/drive/MyDrive/Masterthesis/Datasets/mnist/models/initial_classifier/alexnet_mnist_bg_cl0_cl2_1train.pt"
    patch_folder = '/content/drive/MyDrive/Masterthesis/Datasets/mnist/dataset_splits/background/test/class_2'
    no_patch_folder = '/content/drive/MyDrive/Masterthesis/Datasets/mnist/dataset_splits/original/test/class_2'

    # Load model and autoencoder
    model = load_model(model_path)
    autoencoder = load_autoencoder(device)

    # Prepare dataloaders for patched and unpatched datasets
    patch_image_paths = [os.path.join(root, file) for root, dirs, files in os.walk(patch_folder) for file in files if file.endswith(('.jpg', '.png'))]
    no_patch_image_paths = [os.path.join(root, file) for root, dirs, files in os.walk(no_patch_folder) for file in files if file.endswith(('.jpg', '.png'))]

    patch_dataset = ImageDataset(patch_image_paths, transform=preprocess)
    no_patch_dataset = ImageDataset(no_patch_image_paths, transform=preprocess)

    patch_loader = DataLoader(patch_dataset, batch_size=1, shuffle=False)
    no_patch_loader = DataLoader(no_patch_dataset, batch_size=1, shuffle=False)

    # Conditionally load or extract activations for patched images
    activations_patch = load_or_extract_fc2_activations(model, patch_loader, 'test_patch', 'fc2_activations_patch')

    # Conditionally load or extract activations for non-patched images
    activations_no_patch = load_or_extract_fc2_activations(model, no_patch_loader, 'test_no_patch', 'fc2_activations_no_patch')

    # Project into sparse space
    projected_patch = project_activations(autoencoder, activations_patch, device)
    projected_no_patch = project_activations(autoencoder, activations_no_patch, device)

    # Dictionary to store neuron indexes for each method
    silenced_neurons_dict = {}

    # Method 1: Mean Activation Difference
    top_neurons_mean_diff = mean_activation_difference(projected_patch, projected_no_patch, top_k=10)
    silenced_neurons_dict["Mean Activation Diff"] = top_neurons_mean_diff
    predictions_mean_diff = silence_and_classify(autoencoder, model, activations_patch, top_neurons_mean_diff)
    print("Classification Results (Mean Activation Diff):")
    print(predictions_mean_diff)

    # Method 2: Statistical Testing
    top_neurons_stat_test = statistical_testing_neurons(projected_patch, projected_no_patch, threshold=0.05)
    silenced_neurons_dict["Statistical Test"] = top_neurons_stat_test
    predictions_stat_test = silence_and_classify(autoencoder, model, activations_patch, top_neurons_stat_test)
    print("Classification Results (Statistical Test):")
    print(predictions_stat_test)

    # Method 3: Patch Classifier (Using AlexNet's FC Layer Weights)
    top_neurons_classifier = patch_classifier_importance(model, projected_patch, projected_no_patch, top_k=10)
    silenced_neurons_dict["Patch Classifier"] = top_neurons_classifier
    predictions_classifier = silence_and_classify(autoencoder, model, activations_patch, top_neurons_classifier)
    print("Classification Results (Patch Classifier):")
    print(predictions_classifier)

    # Method 4: Correlation Analysis
    top_neurons_correlation = correlation_analysis(projected_patch, projected_no_patch, top_k=10)
    silenced_neurons_dict["Correlation Analysis"] = top_neurons_correlation
    predictions_correlation = silence_and_classify(autoencoder, model, activations_patch, top_neurons_correlation)
    print("Classification Results (Correlation Analysis):")
    print(predictions_correlation)

    # Print silenced neurons by each method
    print("Silenced neurons by each method:", silenced_neurons_dict)

    # Check neuron overlap across methods
    check_neuron_overlap(silenced_neurons_dict)

    # Visualize silenced neuron indexes
    visualize_neurons(silenced_neurons_dict)

# Execute main function
main()
