In [None]:
from cortexlib.mouse import CortexlabMouse

mouse = CortexlabMouse()

null_srv_all_neurons = mouse.compute_null_all_neurons(n_shuffles=100)
real_srv_all_neurons = mouse.compute_real_srv_all_neurons()
reliable_neuron_indices = mouse.get_reliable_neuron_indices(
            null_srv_all_neurons, real_srv_all_neurons, percentile_threshold=99)
neural_responses_mean, _ = mouse.get_responses_for_reliable_neurons(reliable_neuron_indices, real_srv_all_neurons, num_neurons=500)

mouse.plot_null_distribution_for_neuron(null_srv_all_neurons, neuron_index=0)
mouse.plot_real_srv_distribution(real_srv_all_neurons, reliable_neuron_indices)

In [None]:
from cortexlib.images import CortexlabImages

images = CortexlabImages()
images.plot_raw_image(int(mouse.stimulus_ids[0]))
image_dataset = images.load_images_shown_to_mouse(mouse.stimulus_ids)
images.show_sample(image_dataset, n=5)

In [None]:
from cortexlib.simclr import PreTrainedSimCLRModel

simclr = PreTrainedSimCLRModel()
simclr_features = simclr.extract_features(image_dataset)

In [None]:
from cortexlib.vgg19 import PreTrainedVGG19Model

vgg19 = PreTrainedVGG19Model(layers_to_capture = {
    # "conv1_1": 0,
    # "relu1_1": 1,
    # "conv1_2": 2,
    # "relu1_2": 3,
    # "pool1": 4,
    "conv2_1": 5,
    "relu2_1": 6,
    "conv2_2": 7,
    "relu2_2": 8,
    "pool2": 9,
    "conv3_1": 10,
    "relu3_1": 11,
    "conv3_2": 12,
    "relu3_2": 13,
    "conv3_3": 14,
    "relu3_3": 15,
    "conv3_4": 16,
    "relu3_4": 17,
    "pool3": 18,
    # "conv4_1": 19,
    # "relu4_1": 20,
    # "conv4_2": 21,
    # "relu4_2": 22,
    # "conv4_3": 23,
    # "relu4_3": 24,
    # "conv4_4": 25,
    # "relu4_4": 26,
    # "pool4": 27,
    # "conv5_1": 28,
    # "relu5_1": 29,
    # "conv5_2": 30,
    # "relu5_2": 31,
    # "conv5_3": 32,
    # "relu5_3": 33,
    # "conv5_4": 34,
    # "relu5_4": 35,
    # "pool5": 36
})

vgg19_features = vgg19.extract_features(image_dataset)

# e.g. conv3_1: torch.Size([1573, 256, 24, 24])
for key, value in vgg19_features.items():
    print(f"{key}: {value.shape}")

In [None]:
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr
import numpy as np
from sklearn.decomposition import PCA

def reduce_dim(feats, n_components=100):
    feats_np = feats.view(feats.size(0), -1).numpy()
    pca = PCA(n_components=n_components)
    return pca.fit_transform(feats_np)

def compute_rdm(X, metric='correlation'):
    return squareform(pdist(X, metric=metric))

def vectorize_rdm(rdm):
    triu_idx = np.triu_indices(rdm.shape[0], k=1)
    return rdm[triu_idx]

for layer, feats in vgg19_features.items():
    print(f"Layer {layer} features shape: {feats.shape}")

    for num_pcs in [50]:
        reduced_features = reduce_dim(feats, n_components=num_pcs)
        print(f"Reduced VGG19 {layer} features with {num_pcs} PCs shape: {reduced_features.shape}")

        rdm_vgg = compute_rdm(reduced_features)
        rdm_neural = compute_rdm(neural_responses_mean)
        sim = spearmanr(vectorize_rdm(rdm_neural), vectorize_rdm(rdm_vgg)).correlation

        print(f"Spearman correlation between neural RDM and VGG19 {layer} RDM: {sim:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Hardcoded RSA scores for SimCLR and VGG-19 at different PCA levels
simclr_featureslayers = ["layer1", "layer2", "layer3", "layer4", "fc"]
vgg_layers = ["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"] # Match length/order
layer_labels = ["early", "low-mid", "mid", "high", "top"]

rsa_scores = {
    # "SimCLR PCA-10": [0.1805, 0.1847, 0.1864, 0.1414, 0.1445],
    "SimCLR": [0.1816, 0.1962, 0.1905, 0.1427, 0.1506],
    # "SimCLR PCA-100": [0.1740, 0.1920, 0.1901, 0.1430, 0.1519],
    # "VGG19 PCA-10": [0.0564, 0.1738, 0.1852, 0.1490, 0.1158],
    "VGG19": [0.0531, 0.1689, 0.1787, 0.1437, 0.1149],
    # "VGG19 PCA-100": [0.0510, 0.1627, 0.1720, 0.1391, 0.1141],
}

plt.figure(figsize=(10, 6))
for label, scores in rsa_scores.items():
    plt.plot(layer_labels, scores, marker='o', label=label, linewidth=2.5)

plt.title("RSA (Spearman) Between Model Features and Neural Data\nAcross Layers and PCA Settings", fontsize=14)
plt.xlabel("Layer", fontsize=12)
plt.ylabel("RSA (Spearman correlation)", fontsize=12)
plt.ylim(0.04, 0.22)
plt.grid(True, linestyle='-', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()