In [None]:
# 1. Data Collection and Preprocessing

# Neural responses were recorded from 166 neurons in macaque V1 using a 32-channel array.
# Stimuli were natural images and synthesized textures, presented at 60 ms per image with no blank screens.
# Spike counts were extracted from 40-100 ms after stimulus onset.
# [TODO] Only neurons with at least 15% of their variance attributable to the stimulus were included.

# 2. Feature Extraction from VGG-19

# VGG-19, pre-trained on ImageNet, was used to extract features.
# Conv3_1 was chosen as the best feature layer for V1 response prediction.
# [TODO] Feature maps from VGG-19 were passed through batch normalization before being used for fitting.

# 3. Generalized Linear Model (GLM) Readout

# A GLM with a Poisson loss function was used to map VGG-19 features to neural responses.
# Three regularization terms were applied:
# L1 sparsity: Encourages feature selection.
# Spatial smoothness: Ensures receptive field locality.
# Group sparsity: Encourages pooling from a subset of feature maps.

# 4. Performance Evaluation Using FEV

# Fraction of Explainable Variance Explained (FEV) was computed as:

# 5. Results 
# VGG-19 (Conv3_1) achieved ~51.6% FEV, outperforming LNP (16.3%) and GFB (45.6%).
# The data-driven CNN performed similarly to VGG-19 but required more training data.
# VGG-19's advantage: Achieved high performance with only 20% of the dataset, while the CNN needed the full dataset.

In [None]:
# Questions

# My analysis gives very poor (- or close to zero) r-squared scores, so:

# 1. Would using 1400 images instead of 7250 explain this poor score?
# 2. Would using a pretrained VGG-19 mode, without Cadena's weights, explain this poor score?
# 3. Would not copying Cadena's image preprocessing (downsampling and cropping) explain this poor score?
# 4. Would not using their regularisation techniques explain this poor score (I don't think so, because they show that it lifts score from ~30% to ~50%)?

# - What size images do they pass into the model? [224x224]
# - They mention that only with at least 15% of their variance attributable to the stimulus were included - is this filtering already applied? [Yes, 166 neurons in data["responses"]]
# - They mention feature maps from VGG-19 were passed through batch normalization before being used for fitting - do I need to do this manually? [I think I need to do this manually]
# - They use feature pooling - do I need to do this?
# - They use feature maps directly, apply spatial pooling on all, then flatten. I was flattening features in batches and not applying spatial pooling. This may turn out to be important for the SimCLR analysis.

In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pickle
import os
from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge
import torch.nn.functional as F
from sklearn.metrics import r2_score

DATA_PATH = '../../data/cadena_plosCB19/'
FILE = 'cadena_ploscb_data.pkl'

# ===================================
# Load Cadena's daata
# ===================================
def load_neural_data():
    file_path = os.path.join(DATA_PATH, FILE)
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data

data_dict = load_neural_data()

# ===================================
# Clean neural data https://github.com/sacadena/Cadena2019PlosCB/blob/master/cnn_sys_ident/data.py
# ===================================
responses = data_dict['responses'].copy() 
responses[np.isnan(responses)] = 0
data_dict['responses'] = responses

# ===================================
# Extract features from VGG-19
# ===================================

# Load VGG-19 and extract features only up to conv1_1
vgg = models.vgg19(pretrained=True).features[:2].eval()

# Move model to CPU (or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize for VGG-19
            transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
            transforms.ToTensor(),  # Convert to PyTorch tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        # Ensure img is in float32 format before converting to PIL
        if isinstance(img, np.ndarray):
            img = img.astype(np.float32)  # Explicitly cast to float32
            img = Image.fromarray(img)  # Convert NumPy to PIL

        return self.transform(img)  # Apply transforms

# Initialize dataset and DataLoader
dataset = ImageDataset(data_dict["images"][:1000])
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)

# Extract features in batches (to avoid RAM overload)
all_features = []
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        features = vgg(batch)  # Extract Conv1_1 features

        all_features.append(features.cpu())  # Keep shape (batch, C, H, W)

        # Print progress every 100 batches
        if i % 100 == 0:
            print(f"Processed {i * len(batch)} images...")

# Concatenate all extracted features
vgg_features = torch.cat(all_features, dim=0)
print("Final extracted features shape:", vgg_features.shape)  # Expected: (n_images, feature_dim)

# ===================================
# Apply spatial pooling to features, then flatten
# NOTE: Cadena et al. use learned spatial pooling, not fixed pooling like adaptive_avg_pool2d.
# NOTE: Their pooling is encouraged through trainable pooling weights and regularization​.
# TODO: to replicate, might want to use a trainable depthwise convolution, instead of adaptive_avg_pool2d
# TODO: they also apply batch normalisation BEFORE pooling https://github.com/sacadena/Cadena2019PlosCB/blob/master/cnn_sys_ident/vggsysid.py 
# vgg_feats_bn = tf.layers.batch_normalization(vgg_features, training = self.is_training, momentum = 0.9, epsilon = 1e-4, name='vgg_bn', fused =True)
# ===================================

# Apply spatial pooling before flattening
if vgg_features.dim() == 2:  # Likely (n_images, flattened_dim)
    raise ValueError("vgg_features is already flattened! Ensure it retains spatial dimensions.")
pooled_features = F.adaptive_avg_pool2d(vgg_features, (14, 14)) # Reduce spatial size
X_pooled = pooled_features.view(pooled_features.size(0), -1).cpu().numpy() # Flatten pooled maps

X = X_pooled
Y = data_dict["responses"].mean(axis=0)[:1000]

# Features shape: (1000, 12544)
# Responses shape: (1000, 166)
print("Features shape:", X_pooled.shape)
print("Responses shape:", Y.shape)

# ===================================
# Regression
# ===================================

ridge = Ridge(alpha=1.0) # Adjust alpha for regularization strength
ridge.fit(X, Y) # Train on all neurons at once
Y_pred = ridge.predict(X)
r2 = r2_score(Y, Y_pred, multioutput='raw_values')

print("R-squared scores:", r2)
print("Mean R-squared score:", r2.mean())



Processed 0 images...
Processed 400 images...
Processed 800 images...
Final extracted features shape: torch.Size([1000, 64, 224, 224])
Features shape: (1000, 12544)
Responses shape: (1000, 166)
R-squared scores: [0.89060277 0.88540191 0.8963989  0.90245949 0.89787389 0.89113379
 0.89793539 0.90631399 0.90327997 0.89530422 0.91103059 0.89458142
 0.93634465 0.89253405 0.91119525 0.87338582 0.89048103 0.86483947
 0.90539144 0.90791706 0.9125686  0.91423498 0.89813786 0.90189862
 0.90640894 0.89071036 0.88935033 0.90857295 0.91338812 0.8940498
 0.89509193 0.900233   0.91826891 0.90020132 0.89508512 0.91037615
 0.90831182 0.90916178 0.90680847 0.88749172 0.89667413 0.90326387
 0.90183524 0.90063514 0.88742066 0.87221846 0.92751308 0.91913752
 0.90384837 0.91061083 0.92732379 0.86964886 0.91588777 0.90747767
 0.89022524 0.90182161 0.91420553 0.90633035 0.89273091 0.91211604
 0.9246586  0.9255471  0.89255537 0.89299036 0.89984535 0.91049699
 0.85459994 0.910977   0.89242959 0.91369392 0.90713