In [None]:
# Semantic clustering helps you characterize the representational content of each layer. When this aligns with 
# where neural similarity peaks, it supports the interpretation that mouse visual cortex and SimCLR mid-layers 
# represent information at a similar level of abstraction.
# The analysis here compares semantic clustering of SimCLR layers and VGG-19 layers.
# In both cases, mid layers show an intermediate level of abstraction between early and late layers.
# The fact that SimCLR shows higher silhouette scores than VGG-19 across the board is less relevant to our hypothesis (alignment with mouse visual cortex()

In [None]:
### Get STL-10 images with labels; collect 100 images per class (10 classes)

import torchvision.transforms as transforms
from torchvision.datasets import STL10
from torch.utils.data import ConcatDataset
from torch.utils.data import TensorDataset 
import torch
import random

# Transform to resize to 96x96 and convert to tensor
transform = transforms.Compose([
    transforms.Resize(96),
    transforms.ToTensor()
])

# Download labeled training and test sets
data_root = '../../data'
train_set = STL10(root=data_root, split='train', download=True, transform=transform)
test_set = STL10(root=data_root, split='test', download=True, transform=transform)
labeled = ConcatDataset([train_set, test_set])

# Collect 100 images per class - 10 classes
random.seed(42)
label_to_indices = {i: [] for i in range(10)}
for idx, (img, label) in enumerate(labeled):
    if len(label_to_indices[label]) < 100:
        label_to_indices[label].append(idx)
    if all(len(v) == 100 for v in label_to_indices.values()):
        break

selected_indices = [i for indices in label_to_indices.values() for i in indices]
images = torch.stack([labeled[i][0] for i in selected_indices]) # (1000, 3, 96, 96)
labels = torch.tensor([labeled[i][1] for i in selected_indices]) # (1000,)
image_dataset = TensorDataset(images, labels)

print("Shape of images:", images.shape)
print("Shape of labels:", labels.shape)
print("Unique labels:", labels.unique())
print("Number of images:", len(image_dataset))

# Get class names from the STL-10 dataset
with open(f'{data_root}/stl10_binary/{STL10.class_names_file}', 'r') as f:
    class_names = [line.strip() for line in f.readlines()]
print("Class names:", class_names)

In [None]:
from cortexlib.simclr import PreTrainedSimCLRModel

simclr = PreTrainedSimCLRModel()
simclr_feats = simclr.extract_features(image_dataset)

In [None]:
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA

# Map each label to its corresponding class name
class_labels = [class_names[label] for label in labels]

# Keep track of silhouette scores for each layer
simclr_silhouette_scores = {}

for layer, feats in simclr_feats.items():
    print(f"Layer: {layer}, Shape: {feats.shape}")
    if layer == 'labels':
        continue

    # Flatten the features if they are 4D (all SimCLR layer outputs except fc are 4D)
    if feats.dim() == 4:
        feats = feats.view(feats.size(0), -1)

    feats_pca = PCA(n_components=100).fit_transform(feats)
    tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=class_labels, palette='tab10', s=30)
    plt.title(f"t-SNE on SimCLR {layer} features")

    # Reorder handles/labels based on your preferred class order
    ordered_class_names = ['car', 'truck', 'ship', 'airplane', 'bird', 'cat',
                        'dog', 'deer', 'horse', 'monkey']
    handles, labels_ = plt.gca().get_legend_handles_labels()
    label_to_handle = dict(zip(labels_, handles))
    ordered_handles = [label_to_handle[cls] for cls in ordered_class_names]
    plt.legend(ordered_handles, ordered_class_names, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()

    score = silhouette_score(tsne_feats, class_labels)
    print(f"Silhouette score: {score:.3f}")

    simclr_silhouette_scores[layer] = score

plt.plot(list(simclr_silhouette_scores.keys()), list(simclr_silhouette_scores.values()), marker='o')
plt.ylabel("Silhouette score")
plt.title("Semantic cluster separability by SimCLR layer")

In [None]:
from cortexlib.vgg19 import PreTrainedVGG19Model

vgg19 = PreTrainedVGG19Model(layers_to_capture = {
    "conv1_1": 0,
    "conv2_1": 5,
    "conv3_1": 10,
    "conv4_1": 19,
    "conv5_1": 28,
})

vgg19_feats = vgg19.extract_features(image_dataset)

In [None]:
vgg19_silhouette_scores = {}

for layer, feats in vgg19_feats.items():
    print(f"Layer: {layer}, Shape: {feats.shape}")

    # Flatten the features if they are 4D
    if feats.dim() == 4:
        feats = feats.view(feats.size(0), -1)

    feats_pca = PCA(n_components=100).fit_transform(feats)
    tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=class_labels, palette='tab10', s=30)
    plt.title(f"t-SNE on VGG19 {layer} features")

    ordered_class_names = ['car', 'truck', 'ship', 'airplane', 'bird', 'cat',
                        'dog', 'deer', 'horse', 'monkey']
    handles, labels_ = plt.gca().get_legend_handles_labels()
    label_to_handle = dict(zip(labels_, handles))
    ordered_handles = [label_to_handle[cls] for cls in ordered_class_names]
    plt.legend(ordered_handles, ordered_class_names, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()

    score = silhouette_score(tsne_feats, class_labels)
    print(f"Silhouette score: {score:.3f}")

    vgg19_silhouette_scores[layer] = score

plt.plot(list(vgg19_silhouette_scores.keys()), list(vgg19_silhouette_scores.values()), marker='o')
plt.ylabel("Silhouette score")
plt.title("Semantic cluster separability by VGG-19 layer")

In [None]:

simclr_layers, simclr_scores = zip(*simclr_silhouette_scores.items())
vgg_layers, vgg19_scores = zip(*vgg19_silhouette_scores.items())
layer_labels  = ['early', 'mid1', 'mid2', 'late', 'final']
layer_ids = [1, 2, 3, 4, 5]

plt.figure(figsize=(8, 5))
plt.plot(layer_ids, simclr_scores, marker='o', label='SimCLR')
plt.plot(layer_ids, vgg19_scores, marker='o', label='VGG19')

plt.xticks(layer_ids, layer_labels)
plt.ylabel("Silhouette Score")
plt.title("Semantic Clustering Across Layers")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()