In [None]:
import numpy as np
from os import path
import matplotlib.pyplot as plt

### Load data

# imresps.npy is of shape (1573, 2, 15363), where 1573 is number of images, 2 repeats each, and 15363 neurons recorded
# stimids.npy has the image id (matching the image dataset ~selection1866~) for each stimulus number, 
# so of you want to see what image was presented on imresps[502] you would check stim_ids[502]

PATH_TO_DATA = '../../data/neural'

imresps = np.load(path.join(PATH_TO_DATA, 'imresps.npy'))
stimids = np.load(path.join(PATH_TO_DATA, 'stimids.npy'))

print(imresps.shape) # (1573, 2, 15363)
print(stimids.shape) # (1573,)

### Compute the null distribution of SRV values for all neurons

def compute_signal_related_variance(resp_a, resp_b, mean_center=True):
    """
    compute the fraction of signal-related variance for each neuron,
    as per Stringer et al Nature 2019. Cross-validated by splitting
    responses into two halves. Note, this only is "correct" if resp_a
    and resp_b are *not* averages of many trials.

    Args:
        resp_a (ndarray): n_stimuli, n_cells
        resp_b (ndarray): n_stimuli, n_cells

    Returns:
        fraction_of_stimulus_variance: 0-1, 0 is non-stimulus-caring, 1 is only-stimulus-caring neurons
        stim_to_noise_ratio: ratio of the stim-related variance to all other variance
    """
    if len(resp_a.shape) > 2:
        # if the stimulus is multi-dimensional, flatten across all stimuli
        resp_a = resp_a.reshape(-1, resp_a.shape[-1])
        resp_b = resp_b.reshape(-1, resp_b.shape[-1])
    ns, nc = resp_a.shape
    if mean_center:
        # mean-center the activity of each cell
        resp_a = resp_a - resp_a.mean(axis=0)
        resp_b = resp_b - resp_b.mean(axis=0)
    
    # compute the cross-trial stimulus covariance of each cell
    # dot-product each cell's (n_stim, ) vector from one half
    # with its own (n_stim, ) vector on the other half

    covariance = (resp_a * resp_b).sum(axis=0) / ns

    # compute the variance of each cell across both halves
    resp_a_variance = (resp_a**2).sum(axis=0) / ns
    resp_b_variance = (resp_b**2).sum(axis=0) / ns
    total_variance = (resp_a_variance + resp_b_variance) / 2

    if np.any(total_variance < 1e-12):
        print(f"Warning: Near-zero total variance for neurons: {np.where(total_variance < 1e-12)[0]}")

    # compute the fraction of the total variance that is
    # captured in the covariance
    fraction_of_stimulus_variance = covariance / total_variance

    # if you want, you can compute SNR as well:
    stim_to_noise_ratio = fraction_of_stimulus_variance / (
        1 - fraction_of_stimulus_variance
    )

    return fraction_of_stimulus_variance, stim_to_noise_ratio

# TODO: double check INDEXING (images, cells)

# imresps shape = (1573, 2, 15363)
# responses in imresps shape = (2, 15363)
num_stimuli = imresps.shape[0] # 1573
num_repeats = imresps.shape[1] # 2
num_neurons = imresps.shape[2] # 15363
n_shuffles = 100

null_srv_all_neurons = [] # shape (n_shuffles, num_neurons)

for _ in range(n_shuffles):
    # Shuffle stimulus indices *twice* to create two independent splits!
    shuffled_indices_A = np.random.permutation(num_stimuli)
    shuffled_indices_B = np.random.permutation(num_stimuli)

    # Now for the splits, we can just use fixed repeat indices, 
    # because for each split, at index N the responses correspond to different stimuli
    # e.g. split_A = [ stim_100_repeat_1, stim_2_repeat_1, stim_19_repeat_1, ... ]
    # e.g. split_B = [ stim_543_repeat_2, stim_345_repeat_2, stim_3_repeat_2, ... ]
    split_A = imresps[shuffled_indices_A, 0, :]
    split_B = imresps[shuffled_indices_B, 1, :]

    # Compute SRV for the shuffled data
    fraction_of_stimulus_variance, _ = compute_signal_related_variance(split_A, split_B)
    null_srv_all_neurons.append(fraction_of_stimulus_variance)

null_srv_all_neurons = np.array(null_srv_all_neurons)
null_srv_all_neurons.shape # (100, 15363)

print(null_srv_all_neurons[0])
print(null_srv_all_neurons[33])

# e.g. if neuron_index = 0, it will plot the SRV value for neuron 0 across all shuffles
neuron_index = 0
plt.hist([srv[neuron_index] for srv in null_srv_all_neurons], bins=100, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Shuffles")
plt.title(f"Null Distribution of SRV for Neuron {neuron_index}")
plt.show()

### Compute the real SRV for each neuron

# TODO: Question for Ali: why can't we just split like this?
# split_A_real = imresps[:, 0, :] # First repeat for each stimulus
# split_B_real = imresps[:, 1, :] # Second repeat for each stimulus

split_A, split_B = [], []
for responses in imresps: # responses shape: (2, n_neurons)
    indices = np.random.permutation(2) # Randomly shuffle [0, 1]
    split_A.append(responses[indices[0]]) # Assign one repeat to split_A
    split_B.append(responses[indices[1]]) # Assign the other to split_B

split_A = np.array(split_A)  # Shape: (n_stimuli, n_neurons)
split_B = np.array(split_B)  # Shape: (n_stimuli, n_neurons)

# Compute SRV for real data
real_srv_all_neurons, stim_to_noise_ratio = compute_signal_related_variance(split_A, split_B)

print(real_srv_all_neurons)
print(stim_to_noise_ratio)

print("Real SRV shape:", real_srv_all_neurons.shape) # Should be (15363,)

plt.hist(real_srv_all_neurons, bins=100, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Shuffles")
plt.title(f"Null Distribution of SRV for Neuron {neuron_index}")
plt.show()

### Filter neurons whose real SRV is in the top 90th percentile of its null distribution

# This gives the 90th-percentile SRV value of the null distribution for each neuron
# In other words the threshold for each neuron to be considered reliable
# e.g. if neuron 0 has a null distribution of SRVs across 10 shuffles 
# [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], the threshold would be 0.9
top_99th_percentile_null = np.percentile(null_srv_all_neurons, 99, axis=0)
print(top_99th_percentile_null) # [0.03651716 0.03126347 0.03325775 ... 0.02738261 0.03546677 0.0333109 ]

# Get indices of reliable neurons
reliable_neuron_indices = np.where(real_srv_all_neurons >= top_99th_percentile_null)[0]

# Print results
print(f"Number of reliable neurons: {len(reliable_neuron_indices)}") # 5654
print(f"Indices of reliable neurons: {reliable_neuron_indices}") # [   14    29    48 ... 15357 15358 15360]

plt.hist(real_srv_all_neurons, bins=100, color='red', alpha=0.7)
plt.hist(real_srv_all_neurons[reliable_neuron_indices], bins=100, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Shuffles")
plt.title("All Neurons: SRV all vs. SRV reliable")
plt.show()

# Gather the neural responses for the reliable neurons
# we take the average across repeats for each neuron
neural_responses = imresps[:, :, reliable_neuron_indices] # Shape: (1573, 2, 5654)
neural_responses_mean = neural_responses.mean(axis=1) # Shape: (1573, 5654) -> 1573 images, 5654 neurons

In [None]:
### Load and preprocess images

import os
from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import Normalize, Compose, Resize, CenterCrop, ToTensor
import torch
from torch.utils.data import TensorDataset
from torchvision import utils as torch_utils
 
PATH_TO_DATA = '../../data/selection1866'

file_list = sorted(f for f in os.listdir(PATH_TO_DATA) if f.endswith('.mat'))
stim_ids = stimids.astype(int)

print(stim_ids)
print(stimids)

transform = Compose([
    Resize(244), # Resize shortest edge to 224 (cut off the rightmost part of the image)
    CenterCrop((224, 224)),
    ToTensor(), # Convert to torch.Tensor with range [0,1]
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
])

img_tensors, labels = [], []

print('List:', file_list)

# we have 1866 images here, but the neural response data only uses 1573 of them
# because some ~300 images didn't have two repeats, so were disposed
# therefore we filter the full set here so that we only use the relevant 1573
for stim_id in stim_ids:
    filename = 'img' + str(stim_id) + '.mat'
    data = loadmat(os.path.join(PATH_TO_DATA, filename))

    img = data['img'][:, :500] # Take leftmost part of the image
    rgb_img = np.stack([img] * 3, axis=-1) # Convert grayscale to RGB for SimCLR
    tensor = torch.tensor(rgb_img, dtype=torch.float32).permute(2, 0, 1) # Shape (C, H, W)
    
    # # Min-max scale the tensor to [0, 1]
    # tensor_min = tensor.min()
    # tensor_max = tensor.max()
    # tensor = (tensor - tensor_min) / (tensor_max - tensor_min)

    # # Clamp to [0, 1] to ensure no outliers due to numerical precision
    # tensor = torch.clamp(tensor, 0.0, 1.0)

    tensor = tensor / 255.0 # Scale from [0,255] to [0,1] **(instead of per-image min-max scaling)**

    img_tensors.append(tensor)
    labels.append(stim_id)

image_dataset = TensorDataset(torch.stack(img_tensors), torch.tensor(labels))

images, labels = image_dataset.tensors
print("Labels:", labels[:10])
print("Processed dataset shape:", images.shape) # (N, C, 96, 96)
print(f"Min pixel value (processed): {torch.min(images)}")
print(f"Max pixel value (processed): {torch.max(images)}")

# Show a sample of processed images
img_grid = torch_utils.make_grid(images[:12], nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0).numpy()
plt.figure(figsize=(10, 5))
plt.title('Processed images: sample')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

In [None]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

# Load the pretrained VGG-19 model
vgg19 = torchvision.models.vgg19(pretrained=True).features.eval()

layers_to_capture = {
    # "conv1_1": 0, "conv1_2": 2,
    # "conv2_1": 5, "conv2_2": 7,
    "conv3_1": 10,
    # "conv3_2": 12, "conv3_3": 14, "conv3_4": 16,
    # "conv4_1": 19, "conv4_2": 21, "conv4_3": 23, "conv4_4": 25,
    # "conv5_1": 28, "conv5_2": 30, "conv5_3": 32, "conv5_4": 34
}

activations = {}

def hook_fn(layer_name):
    def hook(module, input, output):
        activations[layer_name] = output.detach()
    return hook

# Register hooks
for layer_name, layer_idx in layers_to_capture.items():
    vgg19[layer_idx].register_forward_hook(hook_fn(layer_name))

@torch.no_grad()
def extract_vgg_features(dataset, batch_size=16):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    feature_maps = {layer: [] for layer in layers_to_capture}
    labels = []
    
    for batch_imgs, batch_labels in tqdm(dataloader):
        _ = vgg19(batch_imgs) # Forward pass to trigger hooks
        
        for layer in layers_to_capture:
            feature_maps[layer].append(activations[layer].cpu())
        
        labels.append(batch_labels)

    # Concatenate features across all batches
    feature_maps = {layer: torch.cat(feature_maps[layer], dim=0) for layer in feature_maps}
    labels = torch.cat(labels, dim=0)

    return feature_maps, labels

# Run feature extraction
vgg_features, vgg_labels = extract_vgg_features(image_dataset)

# Flatten each feature map into shape (N, F), where F = C * H * W
vgg_features_flat = {layer: vgg_features[layer].view(vgg_features[layer].size(0), -1) for layer in vgg_features}

# print("conv4_1 before flatten:", vgg_features["conv4_1"].shape)  # Example: [1573, 512, 14, 14]
# print("conv4_1 after flatten:", vgg_features_flat["conv4_1"].shape)  # Example: [1573, 100352]

In [None]:
num_components = 100  # Try 50, 100, 200

# Example using conv4_1
pca = PCA(n_components=num_components)
X_pca = pca.fit_transform(vgg_features_flat["conv3_1"])

# Split into train-test
X_train, X_test, Y_train, Y_test = train_test_split(X_pca, neural_responses_mean, test_size=0.2, random_state=42, shuffle=False)

# Train regression
reg = LinearRegression()
reg.fit(X_train, Y_train)
Y_pred = reg.predict(X_test)

# Compute R² scores
r2_scores = r2_score(Y_test, Y_pred, multioutput='raw_values')
print(f"R² using conv4_1 with {num_components} PCs:", r2_scores)

In [None]:
import numpy as np

# Compute correlation between each principal component and each neuron's response
corr_matrix = np.corrcoef(X_train.T, Y_train.T)[:X_train.shape[1], X_train.shape[1]:]

mean_corr = np.abs(corr_matrix).mean()
print(f"Mean absolute correlation: {mean_corr:.6f}")

In [None]:
from sklearn.linear_model import Ridge

for layer in ["conv2_1", "conv3_1", "conv4_1", "conv5_1"]:
    layer_activations = vgg_features[layer]
    N, C, H, W = layer_activations.shape
    layer_activations_flat = layer_activations.view(N, -1).numpy()
    
    pca = PCA(n_components=50)
    X_pca = pca.fit_transform(layer_activations_flat)
    
    X_train, X_test, Y_train, Y_test = train_test_split(X_pca, neural_responses_mean, test_size=0.2, random_state=42)
    reg = Ridge(alpha=10)
    reg.fit(X_train, Y_train)
    
    r2_scores = r2_score(Y_test, reg.predict(X_test), multioutput='raw_values')
    print(f"Layer {layer}: Mean R² = {r2_scores.mean():.4f}")

In [None]:
### Why are the regression performing poorly even for VGG-19 (Cadena et al. + Kenneth found it yielded 30%-50% explained variance)
print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

In [None]:
### Check image indexing and ordering

# 1. Check the Order of stimids and Image Files
print("First 10 `stimids` entries:", stimids[:10])
print("First 10 image filenames used:", [f'img{stim_id}.mat' for stim_id in stimids[:10]])

# First 10 `stimids` entries: [ 1.  2.  3.  4.  5.  7.  8.  9. 10. 11.]
# First 10 image filenames used: ['img1.0.mat', 'img2.0.mat', 'img3.0.mat', 'img4.0.mat', 'img5.0.mat', 'img7.0.mat', 'img8.0.mat', 'img9.0.mat', 'img10.0.mat', 'img11.0.mat']

# 2. Check Feature Map Ordering Before PCA
print("First 15 image labels in dataset:", [labels[i].item() for i in range(15)])
print("First 15 images used for neural data:", stimids[:15])

# First 5 image labels in dataset: [1, 2, 3, 4, 5]
# First 5 images used for neural data: [1. 2. 3. 4. 5.]

# 3. Check Neural Response Alignment
print("Shape of neural_responses_mean:", neural_responses_mean.shape)  # (1573, 5654)
print("First 15 entries of `stimids`:", stimids[:15])
print("First 15 images passed to VGG:", labels[:15].numpy())

# First 15 entries of `stimids`: [ 1.  2.  3.  4.  5.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
# First 15 images passed to VGG: [ 1  2  3  4  5  7  8  9 10 11 12 13 14 15 16]