In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pickle
import os
from sklearn.linear_model import Ridge
import torch.nn.functional as F
from sklearn.metrics import r2_score
import torch.nn as nn

DATA_PATH = '../../data/cadena_plosCB19/'
FILE = 'cadena_ploscb_data.pkl'

# ===================================
# Load Cadena's daata
# ===================================
def load_neural_data():
    file_path = os.path.join(DATA_PATH, FILE)
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data

data_dict = load_neural_data()

# ===================================
# Clean neural data https://github.com/sacadena/Cadena2019PlosCB/blob/master/cnn_sys_ident/data.py
# ===================================
responses = data_dict['responses'].copy() 
responses[np.isnan(responses)] = 0
data_dict['responses'] = responses

# ===================================
# Get sample images and responses
# ===================================
# np.random.seed(42)
# indices = np.random.choice(7250, 1000, replace=False)
sample_images = data_dict["images"][:1000]
sample_responses = data_dict["responses"][:, :1000, :]

# ===================================
# Extract features from VGG-19
# ===================================

# Load VGG-19 and extract features only up to conv1_1
vgg = models.vgg19(pretrained=True).features[:10].eval()

# Move model to CPU (or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.transform = transforms.Compose([
            transforms.CenterCrop(80),
            transforms.Resize((224, 224)),  # Resize for VGG-19
            transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
            transforms.ToTensor(),  # Convert to PyTorch tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        # Ensure img is in float32 format before converting to PIL
        if isinstance(img, np.ndarray):
            img = img.astype(np.float32)  # Explicitly cast to float32
            img = Image.fromarray(img)  # Convert NumPy to PIL

        return self.transform(img)  # Apply transforms

# Initialize dataset and DataLoader
dataset = ImageDataset(sample_images)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

# Extract features in batches (to avoid RAM overload)
all_features = []
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        features = vgg(batch)  # Extract Conv1_1 features

        all_features.append(features.cpu())  # Keep shape (batch, C, H, W)

        # Print progress every 100 batches
        if i % 100 == 0:
            print(f"Processed {i * len(batch)} images...")

# Concatenate all extracted features
vgg_features = torch.cat(all_features, dim=0)
print("Final extracted features shape:", vgg_features.shape)  # Expected: (n_images, feature_dim)

# ===================================
# Apply spatial pooling to features, then flatten
# NOTE: Cadena et al. use learned spatial pooling, not fixed pooling like adaptive_avg_pool2d.
# NOTE: Their pooling is encouraged through trainable pooling weights and regularization​.
# TODO: to replicate, might want to use a trainable depthwise convolution, instead of adaptive_avg_pool2d
# ===================================

# Apply batch normalisation before pooling
bn_layer = nn.BatchNorm2d(num_features=vgg_features.shape[1], momentum=0.9).to(device)
bn_layer.eval()  
vgg_features_bn = bn_layer(vgg_features)

# Apply spatial pooling before flattening
if vgg_features_bn.dim() == 2:  # Likely (n_images, flattened_dim)
    raise ValueError("vgg_features is already flattened! Ensure it retains spatial dimensions.")
pooled_features = F.adaptive_avg_pool2d(vgg_features_bn, (14, 14)) # Reduce spatial size
X_pooled = pooled_features.view(pooled_features.size(0), -1).detach().cpu().numpy() # Flatten pooled maps

In [None]:
# ===================================
# Regression
# ===================================

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# NOTE: we use only first 1000 images and responses for regression, so we adjust test-train-split, and variance calculations accordingly

X = X_pooled
Y = sample_responses.mean(axis=0)

print("Features shape:", X_pooled.shape)
print("Responses shape:", Y.shape)

ridge = Ridge(alpha=10000) # Adjust alpha for regularization strength
ridge.fit(X, Y) # Train on all neurons at once
Y_pred = ridge.predict(X)
r2 = r2_score(Y, Y_pred, multioutput='raw_values')

print("Mean R-squared score:", r2.mean())

# Generate indices for the entire dataset
num_samples = X.shape[0]
all_indices = np.arange(num_samples)

# Split into train/test while keeping the indices
train_indices, test_indices = train_test_split(
    all_indices, test_size=0.2, random_state=42
)

print(train_indices[:10])
print(test_indices[:10])

X_train, X_test = X[train_indices], X[test_indices]
Y_train, Y_test = Y[train_indices], Y[test_indices]

ridge.fit(X_train, Y_train)  # Train on 80%
Y_pred_test = ridge.predict(X_test)  # Predict on 20%

r2_test = r2_score(Y_test, Y_pred_test, multioutput='raw_values')
print(f"Mean R² score on test set: {np.mean(r2_test):.4f}")  # Should be lower than training R²

# (FEV)
total_var_test = np.var(Y_test, axis=0)
trial_var_test = np.var(sample_responses[:, test_indices, :], axis=0)  # Variance across trials
noise_var_test = np.mean(trial_var_test, axis=0)  # Average across images
mse_test = mean_squared_error(Y_test, Y_pred_test, multioutput='raw_values')
explainable_var_test = total_var_test - noise_var_test
fev = 1 - (mse_test - noise_var_test) / explainable_var_test
fev = np.clip(fev, 0, 1)
print(f"Mean FEV across neurons (subset of 1000 images): {np.mean(fev):.4f}")

In [None]:
### Feature visualisation - conv layers

import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Load pre-trained VGG-19 model (first convolutional layer)
vgg = models.vgg19(pretrained=True).features[:33].eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

# Function to generate a synthetic image that maximizes a given filter
def generate_maximally_activating_image(filter_idx, steps=100, lr=0.1):
    # Start with random noise image
    img = torch.randn(1, 3, 224, 224, device=device, requires_grad=True)
    
    optimizer = torch.optim.Adam([img], lr=lr)
    
    for step in range(steps):
        optimizer.zero_grad()
        
        # Forward pass through VGG-19
        features = vgg(img)
        
        # Target specific filter's activation
        activation = features[0, filter_idx].mean()  # Mean activation of filter
        
        # Gradient ascent
        loss = -activation  # Negate to maximize
        loss.backward()
        optimizer.step()
        
        # Normalize image to keep it visually stable
        with torch.no_grad():
            img.clamp_(-1.5, 1.5)  # Prevent extreme values
    
    return img.detach()

# Function to visualize the optimized image
def visualize_feature_map(img):
    """
    Displays the feature visualization as a grayscale image.
    """
    img = img.squeeze().cpu().detach().numpy()  # Remove singleton dimensions

    if img.ndim == 3:  # (1, H, W) -> (H, W)
        img = img[0]

    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1]

    plt.imshow(img, cmap='gray')
    plt.axis("off")
    plt.show()

# Generate an image that maximally activates filter 10
optimized_img = generate_maximally_activating_image(filter_idx=10)

# Convert to grayscale and visualize
grayscale_img = optimized_img.mean(dim=1, keepdim=True)  # Average RGB channels
visualize_feature_map(grayscale_img)

In [None]:
### Feature visualisation - fully-connected layers

import torch
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt

# Load VGG-19
vgg = models.vgg19(pretrained=True).eval().to("cuda" if torch.cuda.is_available() else "cpu")

# Function to generate an image that maximally activates an `fc6` neuron
def generate_fc6_maximization(neuron_idx, steps=200, lr=0.1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Start with random noise image
    img = torch.randn(1, 3, 224, 224, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([img], lr=lr)

    for step in range(steps):
        optimizer.zero_grad()
        
        # Forward pass through convolutional layers and pool5
        features = vgg.features(img)  # Get conv5_4 output
        features = vgg.avgpool(features)  # Pooling before FC layers
        features = torch.flatten(features, start_dim=1)  # Flatten for FC layer

        # Pass through FC layers up to fc6
        fc6_output = vgg.classifier[:1](features)  # Only extracting fc6

        # Maximize activation of selected neuron in `fc6`
        activation = fc6_output[0, neuron_idx]  # Select specific neuron
        loss = -activation  # Gradient ascent

        loss.backward()
        optimizer.step()

        # Regularization: Normalize image to avoid extreme values
        with torch.no_grad():
            img.clamp_(-1.5, 1.5)

    return img.detach()

# Function to preprocess and visualize the image
def visualize_optimized_image(img):
    img = img.squeeze().cpu().detach().numpy()
    img = np.transpose(img, (1, 2, 0))  # Convert to (H, W, C)
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1]
    
    plt.imshow(img)
    plt.axis("off")
    plt.show()

# Choose a neuron in `fc6` to visualize (e.g., 1000th neuron out of 4096)
neuron_idx = 500
optimized_img = generate_fc6_maximization(neuron_idx)

# Show the maximally activating image for `fc6`
visualize_optimized_image(optimized_img)

In [None]:
### Feature visualisation - class-conditioned

import torch
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt

# Load pre-trained VGG-19
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = models.vgg19(pretrained=True).eval().to(device)

def total_variation_loss(img):
    """
    Computes Total Variation (TV) loss to encourage smoothness in the generated image.
    """
    loss = torch.sum(torch.abs(img[:, :, :-1] - img[:, :, 1:])) + torch.sum(torch.abs(img[:, :-1, :] - img[:, 1:, :]))
    return loss

def generate_class_visualization(target_class, steps=300, lr=0.1, tv_weight=1e-6):
    """
    Generates an image that maximally activates a specific class in VGG-19 while applying
    Total Variation (TV) regularization for smoother results.
    """
    img = torch.randn(1, 3, 224, 224, device=device, requires_grad=True)
    
    optimizer = torch.optim.Adam([img], lr=lr)

    for step in range(steps):
        optimizer.zero_grad()
        
        # Forward pass
        features = vgg.features(img)
        features = vgg.avgpool(features)
        features = torch.flatten(features, start_dim=1)
        logits = vgg.classifier(features)  # Get class scores

        # Maximize target class activation
        loss = -logits[0, target_class]  

        # Add Total Variation (TV) loss for smoothness
        loss += tv_weight * total_variation_loss(img)

        loss.backward()
        optimizer.step()

        # Regularization: Clip pixel values
        with torch.no_grad():
            img.clamp_(-1.5, 1.5)

    return img.detach()

def preprocess_and_visualize(img):
    """
    Preprocesses and displays the optimized image.
    """
    img = img.squeeze().cpu().detach().numpy()
    img = np.transpose(img, (1, 2, 0))  # Convert to HWC
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1]

    plt.imshow(img)
    plt.axis("off")
    plt.show()

# Target Class: Golden Retriever (ImageNet ID: 207)
target_class = 207  
optimized_img = generate_class_visualization(target_class, tv_weight=1e-5)

# Show improved visualization
preprocess_and_visualize(optimized_img)