In [None]:
# Results

# - Ridge regression: 1000 first images, layer1, mean r2 = ~0.1, fev = ~0.25
# - Ridge regression: 1000 first images, layer2, mean r2 = ~0.1, fev = ~0.25
# - Ridge regression: 1000 first images, layer3, mean r2 = ~0.1, fev = ~0.25
# - Ridge regression: 1000 first images, layer4, mean r2 = ~0.1, fev = ~0.25

In [None]:
### Run images through a pretrained SimCLR model and extract features

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from tqdm.notebook import tqdm
from typing import Dict
from torch.utils.data import Dataset
import urllib.request
from urllib.error import HTTPError
import torch.nn.functional as F

class SimCLR(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()

        # Base ResNet18 backbone (pretrained=False, because we load custom weights later, from the SimCLR checkpoint file)
        self.convnet = torchvision.models.resnet18(pretrained=False)
        
        # This is the projection head, only needed during training. For downstream tasks it is disposed of
        # and the final linear layer output is used (Chen et al., 2020) 
        self.convnet.fc = nn.Sequential(
            nn.Linear(self.convnet.fc.in_features, 4 * hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

        self.intermediate_layers_to_capture =[]
        self.intermediate_layer_features = {}
        self.num_workers = os.cpu_count()
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    def load_pretrained(self):
        """
        Load pretrained SimCLR weights
        """
        base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/"
        models_dir = "../../models"
        pretrained_simclr_filename = "SimCLR.ckpt"
        pretrained_simclr_path = os.path.join(models_dir, pretrained_simclr_filename)
        os.makedirs(models_dir, exist_ok=True)

        # Check whether the pretrained model file already exists locally. If not, try downloading it
        file_url = base_url + pretrained_simclr_filename
        if not os.path.isfile(pretrained_simclr_path):
            print(f"Downloading pretrained SimCLR model {file_url}...")
            try:
                urllib.request.urlretrieve(file_url, pretrained_simclr_path)
            except HTTPError as e:
                print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

        print(f"Already downloaded pretrained model: {file_url}")

        # Load pretrained model
        checkpoint = torch.load(pretrained_simclr_path, map_location=self.device)
        self.load_state_dict(checkpoint['state_dict'])
        self.to(self.device)
        self.eval()
    
    def set_intermediate_layers_to_capture(self, layers):
        """
        Register hooks to capture features from intermediate layers
        """
        # Just check the layers specified are actually in the convnet
        top_level_block_layers = [name for name, _ in self.convnet.named_children()]
        if not all(layer in top_level_block_layers for layer in layers):
            print('You have specified convnet layers that are not top-level blocks - make sure your layer names are valid')
        
        self.intermediate_layers_to_capture = layers
        intermediate_layer_features = {}

        def get_hook(layer_name):
            def hook(module, input, output):
                intermediate_layer_features[layer_name] = output.detach()
            return hook

        for layer_name in layers:
            layer = dict([*self.convnet.named_modules()])[layer_name]
            layer.register_forward_hook(get_hook(layer_name))

        self.intermediate_layer_features = intermediate_layer_features

    @torch.no_grad()
    def extract_features(self, dataset: Dataset) -> Dict[str, torch.Tensor]:
        """
        Run the pretrained SimCLR model on the image data, and capture features from final layer and intermediate layers.

        Args:
            dataset (Dataset): A PyTorch Dataset containing input images and labels. The image data should have shape (N, C, H, W)

        Returns:
            Dict[str, torch.Tensor]: A dictionary containing:
                - Intermediate layer features as tensors.
                - Final layer features under 'final_layer'.
                - Labels under 'labels'.
            Features from a given layer has shape (N, F) where N is num images, F is number of features - flattened version of (C, H, W).
        """
        self.convnet.fc = nn.Identity()  # Removing projection head g(.)
        self.eval()
        self.to(self.device)
        
        # Encode all images
        data_loader = DataLoader(dataset, batch_size=64, num_workers=0, shuffle=False, drop_last=False)
        feats, intermediate_features = [], {layer: [] for layer in self.intermediate_layers_to_capture}

        for batch_idx, batch_imgs in enumerate(tqdm(data_loader)):
            batch_imgs = batch_imgs.to(self.device)
            batch_feats = self.convnet(batch_imgs)
            
            feats.append(batch_feats.detach().cpu())

            # Collect intermediate layer outputs
            for layer in self.intermediate_layers_to_capture:
                # Final linear layer outputs a 2d tensor; but intermediate layers don't, so we flatten them (ready for PCA etc.)
                # layer_output_flattened = self.intermediate_layer_features[layer].view(self.intermediate_layer_features[layer].size(0), -1)
                # # Apply spatial pooling before flattening
                pooled_features = F.adaptive_avg_pool2d(self.intermediate_layer_features[layer], (14, 14)) # Reduce spatial size
                layer_output_flattened = pooled_features.view(pooled_features.size(0), -1) # Flatten after pooling 
                intermediate_features[layer].append(layer_output_flattened.cpu())
        
        # Concatenate results for each layer
        feats = torch.cat(feats, dim=0)
        intermediate_features = {layer: torch.cat(intermediate_features[layer], dim=0) for layer in self.intermediate_layers_to_capture}

        # Debugging log after concatenation
        print("✅ Feature extraction complete. Final feature shapes:")
        print(f"Final layer: {feats.shape}")
        for layer, feature in intermediate_features.items():
            print(f"{layer}: {feature.shape}")  # Check final stored shape

        return {**intermediate_features, 'final_layer': feats}

intermediate_layers = ['layer1', 'layer2', 'layer3', 'layer4']

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pickle
import os

DATA_PATH = '../../data/cadena_plosCB19/'
FILE = 'cadena_ploscb_data.pkl'

# ===================================
# Load Cadena's daata
# ===================================
def load_neural_data():
    file_path = os.path.join(DATA_PATH, FILE)
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data

data_dict = load_neural_data()

# ===================================
# Clean neural data https://github.com/sacadena/Cadena2019PlosCB/blob/master/cnn_sys_ident/data.py
# ===================================
responses = data_dict['responses'].copy() 
responses[np.isnan(responses)] = 0
data_dict['responses'] = responses

# ===================================
# Get sample images and responses
# ===================================
# np.random.seed(42)
# indices = np.random.choice(7250, 1000, replace=False)
sample_images = data_dict["images"][:1000]
sample_responses = data_dict["responses"][:, :1000, :]

# ===================================
# Extract features from VGG-19
# ===================================

# Load VGG-19 and extract features only up to conv1_1
vgg = models.vgg19(pretrained=True).features[:10].eval()

# Move model to CPU (or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.transform = transforms.Compose([
            # transforms.CenterCrop(80),
            transforms.Resize((224, 224)),  # Resize for VGG-19
            transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
            transforms.ToTensor(),  # Convert to PyTorch tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        # Ensure img is in float32 format before converting to PIL
        if isinstance(img, np.ndarray):
            img = img.astype(np.float32)  # Explicitly cast to float32
            img = Image.fromarray(img)  # Convert NumPy to PIL

        return self.transform(img)  # Apply transforms

# Initialize dataset and DataLoader
dataset = ImageDataset(sample_images)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

# ===================================
# Extract SIMCLR features
# ===================================
sim_clr = SimCLR()
sim_clr.load_pretrained()
sim_clr.set_intermediate_layers_to_capture(intermediate_layers)
feats = sim_clr.extract_features(dataset)

for layer in ["layer1", "layer2", "layer3", "layer4"]:
    if layer in feats:
        variance = np.var(feats[layer].numpy())
        print(f"{layer} variance: {variance:.6f}")

layer1_feats = feats['layer1'] # Shape: torch.Size([1573, 200704]) (n_images, n_features)
layer2_feats = feats['layer2']
layer3_feats = feats['layer3']
layer4_feats = feats['layer4']
final_layer_feats = feats['final_layer'] # Shape: torch.Size([1573, 512])

print('layer1 shape', layer1_feats.shape)
print('final layer shape', final_layer_feats.shape)

In [None]:
# ===================================
# Regression
# ===================================
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

Y = sample_responses.mean(axis=0)

# Data for regression
for x in [layer1_feats, layer2_feats, layer3_feats, layer4_feats, final_layer_feats]:
    X = x

    print("Features shape:", X.shape)
    print("Responses shape:", Y.shape)

    ridge = Ridge(alpha=1000) # Adjust alpha for regularization strength
    ridge.fit(X, Y) # Train on all neurons at once
    Y_pred = ridge.predict(X)
    r2 = r2_score(Y, Y_pred, multioutput='raw_values')

    print("Mean R-squared score:", r2.mean())

    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

    ridge.fit(X_train, Y_train)  # Train on 80%
    Y_pred_test = ridge.predict(X_test)  # Predict on 20%

    r2_test = r2_score(Y_test, Y_pred_test, multioutput='raw_values')
    print(f"Mean R² score on test set: {np.mean(r2_test):.4f}")  # Should be lower than training R²

    # (FEV) NOTE: we use only first 1000 images and responses for regression, so we adjust variance calculations accordingly
    Y_subset = Y[:1000]
    total_var = np.var(Y_subset, axis=0)
    trial_var = np.var(data_dict["responses"][:, :1000, :], axis=0)  # Variance across trials
    noise_var = np.mean(trial_var, axis=0)  # Average across images
    mse = mean_squared_error(Y_test, Y_pred_test, multioutput='raw_values')
    explainable_var = total_var - noise_var
    fev = 1 - (mse - noise_var) / explainable_var
    fev = np.clip(fev, 0, 1)
    print(f"Mean FEV across neurons (subset of 1000 images): {np.mean(fev):.4f}")