In [None]:
import numpy as np
from os import path

In [None]:
### Load data

# imresps.npy is of shape (1573, 2, 15363), where 1573 is number of images, 2 repeats each, and 15363 neurons recorded
# stimids.npy has the image id (matching the image dataset ~selection1866~) for each stimulus number, 
# so of you want to see what image was presented on imresps[502] you would check stim_ids[502]

PATH_TO_DATA = '../../data/neural'

imresps = np.load(path.join(PATH_TO_DATA, 'imresps.npy'))
stimids = np.load(path.join(PATH_TO_DATA, 'stimids.npy'))

print(imresps.shape) # (1573, 2, 15363)
print(stimids.shape) # (1573,)

In [None]:
manual_labels = [
    {'stim_id': 1, 'category': 'bird'},
    {'stim_id': 2, 'category': 'fungus'},
    {'stim_id': 4, 'category': 'fungus'},
    {'stim_id': 5, 'category': 'fungus'},
    {'stim_id': 7, 'category': 'fungus'},
    {'stim_id': 9, 'category': 'bird'},
    {'stim_id': 10, 'category': 'fungus'},
    {'stim_id': 11, 'category': 'fungus'},
    {'stim_id': 12, 'category': 'fungus'},
    {'stim_id': 13, 'category': 'fungus'},
    {'stim_id': 16, 'category': 'fungus'},
    {'stim_id': 17, 'category': 'fungus'},
    {'stim_id': 18, 'category': 'fungus'},
    {'stim_id': 25, 'category': 'fungus'},
    {'stim_id': 28, 'category': 'fungus'},
    {'stim_id': 33, 'category': 'fungus'},
    {'stim_id': 35, 'category': 'fungus'},
    {'stim_id': 40, 'category': 'fungus'},
    {'stim_id': 41, 'category': 'fungus'},
    {'stim_id': 44, 'category': 'fungus'},
    {'stim_id': 45, 'category': 'fungus'},
    {'stim_id': 45, 'category': 'fungus'},
    {'stim_id': 48, 'category': 'fungus'},
    {'stim_id': 49, 'category': 'fungus'},
    {'stim_id': 50, 'category': 'fungus'},
    {'stim_id': 51, 'category': 'fungus'},
    {'stim_id': 55, 'category': 'fungus'},
    {'stim_id': 56, 'category': 'fungus'},
    {'stim_id': 100, 'category': 'snake'},
    {'stim_id': 101, 'category': 'snake'},
    {'stim_id': 102, 'category': 'snake'},
    {'stim_id': 103, 'category': 'snake'},
    {'stim_id': 108, 'category': 'snake'},
    {'stim_id': 109, 'category': 'snake'},
    {'stim_id': 110, 'category': 'snake'},
    {'stim_id': 111, 'category': 'snake'},
    {'stim_id': 112, 'category': 'snake'},
    {'stim_id': 114, 'category': 'snake'},
    {'stim_id': 115, 'category': 'snake'},
    {'stim_id': 116, 'category': 'snake'},
    {'stim_id': 116, 'category': 'snake'},
    {'stim_id': 118, 'category': 'snake'},
    {'stim_id': 121, 'category': 'snake'},
    {'stim_id': 122, 'category': 'snake'},
    {'stim_id': 123, 'category': 'snake'},
    {'stim_id': 124, 'category': 'snake'},
    {'stim_id': 125, 'category': 'snake'},
    {'stim_id': 126, 'category': 'snake'},
    {'stim_id': 128, 'category': 'snake'},
    {'stim_id': 129, 'category': 'snake'},
    {'stim_id': 619, 'category': 'snake'},
    {'stim_id': 620, 'category': 'snake'},
    {'stim_id': 621, 'category': 'snake'},
    {'stim_id': 622, 'category': 'snake'},
    {'stim_id': 623, 'category': 'snake'},
    {'stim_id': 624, 'category': 'snake'},
    {'stim_id': 625, 'category': 'snake'},
    {'stim_id': 626, 'category': 'snake'},
    {'stim_id': 628, 'category': 'snake'},
    {'stim_id': 629, 'category': 'snake'},
    {'stim_id': 630, 'category': 'snake'},
    {'stim_id': 632, 'category': 'snake'},
    {'stim_id': 633, 'category': 'snake'},
    {'stim_id': 634, 'category': 'snake'},
    {'stim_id': 635, 'category': 'snake'},
    {'stim_id': 637, 'category': 'snake'},
    {'stim_id': 639, 'category': 'snake'},
    {'stim_id': 640, 'category': 'snake'},
    {'stim_id': 641, 'category': 'snake'},
    {'stim_id': 152, 'category': 'bird'},
    {'stim_id': 153, 'category': 'bird'},
    {'stim_id': 155, 'category': 'bird'},
    {'stim_id': 156, 'category': 'bird'},
    {'stim_id': 157, 'category': 'bird'},
    {'stim_id': 158, 'category': 'bird'},
    {'stim_id': 159, 'category': 'bird'},
    {'stim_id': 160, 'category': 'bird'},
    {'stim_id': 161, 'category': 'bird'},
    {'stim_id': 162, 'category': 'bird'},
    {'stim_id': 163, 'category': 'bird'},
    {'stim_id': 164, 'category': 'bird'},
    {'stim_id': 166, 'category': 'bird'},
    {'stim_id': 170, 'category': 'bird'},
    {'stim_id': 172, 'category': 'bird'},
    {'stim_id': 173, 'category': 'bird'},
    {'stim_id': 174, 'category': 'bird'},
    {'stim_id': 175, 'category': 'bird'},
    {'stim_id': 176, 'category': 'bird'},
    {'stim_id': 177, 'category': 'bird'},
    {'stim_id': 176, 'category': 'bird'},
    {'stim_id': 182, 'category': 'bird'},
    {'stim_id': 184, 'category': 'bird'},
    {'stim_id': 185, 'category': 'bird'},
    {'stim_id': 186, 'category': 'bird'},
    {'stim_id': 187, 'category': 'bird'},
    {'stim_id': 188, 'category': 'bird'},
    {'stim_id': 189, 'category': 'bird'},
    {'stim_id': 190, 'category': 'bird'},
    {'stim_id': 191, 'category': 'bird'},
    {'stim_id': 192, 'category': 'bird'},
    {'stim_id': 131, 'category': 'bird'},
    {'stim_id': 203, 'category': 'bird'},
    {'stim_id': 206, 'category': 'bird'},
    {'stim_id': 674, 'category': 'bird'},
    {'stim_id': 1200, 'category': 'bird'}, 
    {'stim_id': 1207, 'category': 'bird'}, 
    {'stim_id': 1240, 'category': 'bird'}, 
    {'stim_id': 1790, 'category': 'bird'},
    {'stim_id': 133, 'category': 'cat'}, 
    {'stim_id': 134, 'category': 'cat'}, 
    {'stim_id': 135, 'category': 'cat'}, 
    {'stim_id': 136, 'category': 'cat'}, 
    {'stim_id': 137, 'category': 'cat'}, 
    {'stim_id': 138, 'category': 'cat'}, 
    {'stim_id': 139, 'category': 'cat'}, 
    {'stim_id': 140, 'category': 'cat'}, 
    {'stim_id': 141, 'category': 'cat'}, 
    {'stim_id': 143, 'category': 'cat'}, 
    {'stim_id': 144, 'category': 'cat'}, 
    {'stim_id': 145, 'category': 'cat'}, 
    {'stim_id': 146, 'category': 'cat'}, 
    {'stim_id': 147, 'category': 'cat'}, 
    {'stim_id': 148, 'category': 'cat'}, 
    {'stim_id': 149, 'category': 'cat'}, 
    {'stim_id': 150, 'category': 'cat'}, 
    {'stim_id': 219, 'category': 'cat'}, 
    {'stim_id': 221, 'category': 'cat'}, 
    {'stim_id': 222, 'category': 'cat'}, 
    {'stim_id': 223, 'category': 'cat'}, 
    {'stim_id': 224, 'category': 'cat'}, 
    {'stim_id': 225, 'category': 'cat'}, 
    {'stim_id': 226, 'category': 'cat'}, 
    {'stim_id': 227, 'category': 'cat'}, 
    {'stim_id': 228, 'category': 'cat'}, 
    {'stim_id': 229, 'category': 'cat'}, 
    {'stim_id': 230, 'category': 'cat'}, 
    {'stim_id': 232, 'category': 'cat'}, 
    {'stim_id': 233, 'category': 'cat'}, 
    {'stim_id': 235, 'category': 'cat'}, 
    {'stim_id': 237, 'category': 'cat'}, 
    {'stim_id': 238, 'category': 'cat'}, 
    {'stim_id': 239, 'category': 'cat'}, 
    {'stim_id': 241, 'category': 'cat'}, 
    {'stim_id': 243, 'category': 'cat'}, 
    {'stim_id': 244, 'category': 'cat'}, 
    {'stim_id': 245, 'category': 'cat'}, 
    {'stim_id': 246, 'category': 'cat'}, 
    {'stim_id': 247, 'category': 'cat'}, 
    {'stim_id': 248, 'category': 'cat'}, 
    {'stim_id': 249, 'category': 'cat'}, 
    {'stim_id': 250, 'category': 'cat'}, 
    {'stim_id': 252, 'category': 'cat'}, 
    {'stim_id': 254, 'category': 'cat'}, 
    {'stim_id': 255, 'category': 'cat'}, 
    {'stim_id': 256, 'category': 'cat'}, 
    {'stim_id': 257, 'category': 'cat'}, 
    {'stim_id': 258, 'category': 'cat'},
    {'stim_id': 291, 'category': 'rodent'}, 
    {'stim_id': 293, 'category': 'rodent'}, 
    {'stim_id': 294, 'category': 'rodent'}, 
    {'stim_id': 295, 'category': 'rodent'}, 
    {'stim_id': 295, 'category': 'rodent'}, 
    {'stim_id': 296, 'category': 'rodent'}, 
    {'stim_id': 297, 'category': 'rodent'}, 
    {'stim_id': 299, 'category': 'rodent'}, 
    {'stim_id': 300, 'category': 'rodent'}, 
    {'stim_id': 302, 'category': 'rodent'}, 
    {'stim_id': 304, 'category': 'rodent'}, 
    {'stim_id': 306, 'category': 'rodent'}, 
    {'stim_id': 307, 'category': 'rodent'}, 
    {'stim_id': 309, 'category': 'rodent'}, 
    {'stim_id': 310, 'category': 'rodent'}, 
    {'stim_id': 311, 'category': 'rodent'}, 
    {'stim_id': 315, 'category': 'rodent'}, 
    {'stim_id': 318, 'category': 'rodent'}, 
    {'stim_id': 320, 'category': 'rodent'}, 
    {'stim_id': 323, 'category': 'rodent'}, 
    {'stim_id': 324, 'category': 'rodent'}, 
    {'stim_id': 325, 'category': 'rodent'}, 
    {'stim_id': 326, 'category': 'rodent'},
    {'stim_id': 330, 'category': 'rodent'},
    {'stim_id': 331, 'category': 'rodent'},
    {'stim_id': 332, 'category': 'rodent'},
    {'stim_id': 333, 'category': 'rodent'},
    {'stim_id': 334, 'category': 'rodent'},
    {'stim_id': 335, 'category': 'rodent'},
    {'stim_id': 336, 'category': 'rodent'},
    {'stim_id': 337, 'category': 'rodent'},
    {'stim_id': 338, 'category': 'rodent'},
    {'stim_id': 339, 'category': 'rodent'},
    {'stim_id': 343, 'category': 'rodent'},
    {'stim_id': 344, 'category': 'rodent'},
    {'stim_id': 810, 'category': 'rodent'},
    {'stim_id': 812, 'category': 'rodent'},
    {'stim_id': 813, 'category': 'rodent'},
    {'stim_id': 815, 'category': 'rodent'},
    {'stim_id': 817, 'category': 'rodent'},
    {'stim_id': 818, 'category': 'rodent'},
    {'stim_id': 822, 'category': 'rodent'},
    {'stim_id': 823, 'category': 'rodent'},
    {'stim_id': 824, 'category': 'rodent'},
    {'stim_id': 826, 'category': 'rodent'},
    {'stim_id': 828, 'category': 'rodent'},
    {'stim_id': 830, 'category': 'rodent'},
    {'stim_id': 832, 'category': 'rodent'},
    {'stim_id': 833, 'category': 'rodent'},
    {'stim_id': 834, 'category': 'rodent'},
    {'stim_id': 836, 'category': 'rodent'},
    {'stim_id': 838, 'category': 'rodent'},
    {'stim_id': 841, 'category': 'rodent'},
    {'stim_id': 843, 'category': 'rodent'},
    {'stim_id': 844, 'category': 'rodent'},
    {'stim_id': 845, 'category': 'rodent'},
    {'stim_id': 856, 'category': 'rodent'},
    {'stim_id': 858, 'category': 'rodent'},
    {'stim_id': 859, 'category': 'rodent'},
    {'stim_id': 860, 'category': 'rodent'},
    {'stim_id': 861, 'category': 'rodent'},
    {'stim_id': 862, 'category': 'rodent'},
    {'stim_id': 863, 'category': 'rodent'},
    {'stim_id': 1739, 'category': 'insect'},
    {'stim_id': 1740, 'category': 'insect'},
    {'stim_id': 1741, 'category': 'insect'},
    {'stim_id': 1746, 'category': 'insect'},
    {'stim_id': 1748, 'category': 'insect'},
    {'stim_id': 1749, 'category': 'insect'},
    {'stim_id': 1750, 'category': 'insect'},
    {'stim_id': 1751, 'category': 'insect'},
    {'stim_id': 1752, 'category': 'insect'},
    {'stim_id': 1755, 'category': 'insect'},
    {'stim_id': 1760, 'category': 'insect'},
    {'stim_id': 1761, 'category': 'insect'},
    {'stim_id': 1762, 'category': 'insect'},
    {'stim_id': 1763, 'category': 'insect'},
    {'stim_id': 1765, 'category': 'insect'},
    {'stim_id': 1773, 'category': 'insect'},
    {'stim_id': 1775, 'category': 'insect'},
    {'stim_id': 1776, 'category': 'insect'},
    {'stim_id': 1782, 'category': 'insect'},
    {'stim_id': 1783, 'category': 'insect'},
    {'stim_id': 1422, 'category': 'insect'},
    {'stim_id': 1424, 'category': 'insect'},
    {'stim_id': 1425, 'category': 'insect'},
    {'stim_id': 1429, 'category': 'insect'},
    {'stim_id': 1431, 'category': 'insect'},
    {'stim_id': 1435, 'category': 'insect'},
    {'stim_id': 1436, 'category': 'insect'},
    {'stim_id': 1439, 'category': 'insect'},
    {'stim_id': 1457, 'category': 'insect'},
    {'stim_id': 1458, 'category': 'insect'},
    {'stim_id': 1459, 'category': 'insect'},
    {'stim_id': 1460, 'category': 'insect'},
    {'stim_id': 1462, 'category': 'insect'},
    {'stim_id': 1465, 'category': 'insect'},
    {'stim_id': 1471, 'category': 'insect'},
    {'stim_id': 1472, 'category': 'insect'},
    {'stim_id': 3, 'category': 'texture'},
    {'stim_id': 8, 'category': 'texture'},
    {'stim_id': 14, 'category': 'texture'},
    {'stim_id': 15, 'category': 'texture'},
    {'stim_id': 79, 'category': 'texture'},
    {'stim_id': 82, 'category': 'texture'},
    {'stim_id': 88, 'category': 'texture'},
    {'stim_id': 98, 'category': 'texture'},
    {'stim_id': 270, 'category': 'texture'},
    {'stim_id': 285, 'category': 'texture'},
    {'stim_id': 366, 'category': 'texture'},
    {'stim_id': 379, 'category': 'texture'},
    {'stim_id': 388, 'category': 'texture'},
    {'stim_id': 406, 'category': 'texture'},
    {'stim_id': 517, 'category': 'texture'},
    {'stim_id': 518, 'category': 'texture'},
    {'stim_id': 529, 'category': 'texture'},
    {'stim_id': 537, 'category': 'texture'},
    {'stim_id': 555, 'category': 'texture'},
    {'stim_id': 566, 'category': 'texture'},
    {'stim_id': 578, 'category': 'texture'},
    {'stim_id': 569, 'category': 'texture'},
    {'stim_id': 580, 'category': 'texture'},
    {'stim_id': 581, 'category': 'texture'},
    {'stim_id': 584, 'category': 'texture'},
    {'stim_id': 586, 'category': 'texture'},
    {'stim_id': 600, 'category': 'texture'},
    {'stim_id': 601, 'category': 'texture'},
    {'stim_id': 607, 'category': 'texture'},
    {'stim_id': 612, 'category': 'texture'},
    {'stim_id': 672, 'category': 'texture'},
    {'stim_id': 680, 'category': 'texture'},
    {'stim_id': 676, 'category': 'texture'},
    {'stim_id': 673, 'category': 'texture'},
    {'stim_id': 686, 'category': 'texture'},
    {'stim_id': 729, 'category': 'texture'},
    {'stim_id': 761, 'category': 'texture'},
    {'stim_id': 765, 'category': 'texture'},
    {'stim_id': 774, 'category': 'texture'},
    {'stim_id': 787, 'category': 'texture'},
    {'stim_id': 788, 'category': 'texture'},
    {'stim_id': 784, 'category': 'texture'},
    {'stim_id': 789, 'category': 'texture'},
    {'stim_id': 803, 'category': 'texture'},
    {'stim_id': 801, 'category': 'texture'},
    {'stim_id': 799, 'category': 'texture'},
    {'stim_id': 809, 'category': 'texture'},
    {'stim_id': 829, 'category': 'texture'},
    {'stim_id': 842, 'category': 'texture'},
    {'stim_id': 847, 'category': 'texture'},
    {'stim_id': 850, 'category': 'texture'},
    {'stim_id': 866, 'category': 'texture'},
    {'stim_id': 873, 'category': 'texture'},
    {'stim_id': 871, 'category': 'texture'},
    {'stim_id': 870, 'category': 'texture'},
    {'stim_id': 869, 'category': 'texture'},
    {'stim_id': 868, 'category': 'texture'},
    {'stim_id': 878, 'category': 'texture'},
    {'stim_id': 880, 'category': 'texture'},
    {'stim_id': 886, 'category': 'texture'},
    {'stim_id': 894, 'category': 'texture'},
    {'stim_id': 897, 'category': 'texture'},
    {'stim_id': 1051, 'category': 'texture'}
]

num_labeled = len(manual_labels)
num_insect = len([label for label in manual_labels if label['category'] == 'insect'])
num_rodent = len([label for label in manual_labels if label['category'] == 'rodent'])
num_snake = len([label for label in manual_labels if label['category'] == 'snake'])
num_bird = len([label for label in manual_labels if label['category'] == 'bird'])
num_fungus = len([label for label in manual_labels if label['category'] == 'fungus'])
num_cat = len([label for label in manual_labels if label['category'] == 'cat'])
num_texture = len([label for label in manual_labels if label['category'] == 'texture'])

print(f"Number of labeled images: {num_labeled}")
print(f"Number of insect images: {num_insect}")
print(f"Number of rodent images: {num_rodent}")
print(f"Number of snake images: {num_snake}")
print(f"Number of bird images: {num_bird}")
print(f"Number of fungus images: {num_fungus}")
print(f"Number of cat images: {num_cat}")
print(f"Number of textre images: {num_texture}")

In [None]:
### Load and preprocess images

import os
from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import Normalize, Compose, Resize, CenterCrop
import torch
from torch.utils.data import TensorDataset
from torchvision import utils as torch_utils
 
PATH_TO_DATA = '../../data/selection1866'

file_list = sorted(f for f in os.listdir(PATH_TO_DATA) if f.endswith('.mat'))
stim_ids = stimids.astype(int)

print(stim_ids)
print(stimids)

transform = Compose([
    Resize(96), # Resize shortest edge to 96 (cut off the rightmost part of the image)
    CenterCrop((96, 96)), # Crop to (96, 96)
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # !! Normalize expects input is already in the range [0, 1]
])

img_tensors, labels, category_labels = [], [], []

print('List:', file_list)

# we have 1866 images here, but the neural response data only uses 1573 of them
# because some ~300 images didn't have two repeats, so were disposed
# therefore we filter the full set here so that we only use the relevant 1573
for stim_id in stim_ids:
    filename = 'img' + str(stim_id) + '.mat'
    data = loadmat(os.path.join(PATH_TO_DATA, filename))

    img = data['img'][:, :500] # Take leftmost part of the image
    rgb_img = np.stack([img] * 3, axis=-1) # Convert grayscale to RGB for SimCLR
    tensor = torch.tensor(rgb_img, dtype=torch.float32).permute(2, 0, 1) # Shape (C, H, W)
    
    # Min-max scale the tensor to [0, 1]
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    tensor = (tensor - tensor_min) / (tensor_max - tensor_min)

    # Clamp to [0, 1] to ensure no outliers due to numerical precision
    tensor = torch.clamp(tensor, 0.0, 1.0)

    transformed_tensor = transform(tensor) # Normalize and resize for SimCLR
    img_tensors.append(transformed_tensor)

    # Append category label for sample that we have manually assigned categories to
    labels.append(stim_id)
    entry = next((item for item in manual_labels if item['stim_id'] == stim_id), None)
    if entry is not None:
        category_labels.append(entry['category'])
    else:
        category_labels.append('undetermined')
    
print(labels)

image_dataset = TensorDataset(torch.stack(img_tensors), torch.tensor(labels))

images, labels = image_dataset.tensors
print("Processed image labels (stim id):", labels[:30])
print("Stim IDs from neural data:", stim_ids[:30])
print("Processed dataset shape:", images.shape) # (N, C, 96, 96)
print(f"Min pixel value (processed): {torch.min(images)}")
print(f"Max pixel value (processed): {torch.max(images)}")

# Show a sample of processed images
img_grid = torch_utils.make_grid(images[:12], nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0).numpy()
plt.figure(figsize=(10, 5))
plt.title('Processed images: sample')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

filename = 'img20.mat'
data = loadmat(os.path.join(PATH_TO_DATA, filename))
img = data['img'][:, :500]

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
plt.imshow(img, cmap='gray')  # Adjust cmap as needed ('viridis', 'jet', etc.)
plt.colorbar(label="Pixel Intensity")
plt.title("Rendered Image")
plt.axis("off")  # Hide axis for better visualization
plt.show()

In [None]:
# ## Show images with stimulus IDs, and manually assign semantic labels
# import matplotlib.pyplot as plt
# import numpy as np

# BATCH_SIZE = 64
# ncols = 8
# manual_labels = {}

# for i in range(0, len(images), BATCH_SIZE):
#     imgs = images[i:i+BATCH_SIZE]
#     ids = labels[i:i+BATCH_SIZE]

#     n = len(imgs)
#     nrows = (n + ncols - 1) // ncols

#     fig, axs = plt.subplots(nrows, ncols, figsize=(2 * ncols, 2 * nrows))
#     axs = axs.flatten()

#     for j, (img, stim_id) in enumerate(zip(imgs, ids)):
#         img_np = img.permute(1, 2, 0).numpy()
#         img_np = (img_np + 1) / 2  # Normalize from [-1, 1] to [0, 1]
#         img_np = np.clip(img_np, 0, 1)

#         axs[j].imshow(img_np)
#         axs[j].set_title(f'{stim_id}', fontsize=9)
#         axs[j].axis('off')

#     # Turn off any extra axes
#     for k in range(j + 1, len(axs)):
#         axs[k].axis('off')

#     plt.tight_layout()
#     plt.show()

In [None]:
### Run images through a pretrained SimCLR model and extract features

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from tqdm.notebook import tqdm
from typing import Dict
from torch.utils.data import Dataset
import urllib.request
from urllib.error import HTTPError

class SimCLR(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()

        # Base ResNet18 backbone (pretrained=False, because we load custom weights later, from the SimCLR checkpoint file)
        self.convnet = torchvision.models.resnet18(pretrained=False)
        
        # This is the projection head, only needed during training. For downstream tasks it is disposed of
        # and the final linear layer output is used (Chen et al., 2020) 
        self.convnet.fc = nn.Sequential(
            nn.Linear(self.convnet.fc.in_features, 4 * hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

        self.intermediate_layers_to_capture =[]
        self.intermediate_layer_features = {}
        self.num_workers = os.cpu_count()
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    def load_pretrained(self):
        """
        Load pretrained SimCLR weights
        """
        base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/"
        models_dir = "../../models"
        pretrained_simclr_filename = "SimCLR.ckpt"
        pretrained_simclr_path = os.path.join(models_dir, pretrained_simclr_filename)
        os.makedirs(models_dir, exist_ok=True)

        # Check whether the pretrained model file already exists locally. If not, try downloading it
        file_url = base_url + pretrained_simclr_filename
        if not os.path.isfile(pretrained_simclr_path):
            print(f"Downloading pretrained SimCLR model {file_url}...")
            try:
                urllib.request.urlretrieve(file_url, pretrained_simclr_path)
            except HTTPError as e:
                print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

        print(f"Already downloaded pretrained model: {file_url}")

        # Load pretrained model
        checkpoint = torch.load(pretrained_simclr_path, map_location=self.device)
        self.load_state_dict(checkpoint['state_dict'])
        self.to(self.device)
        self.eval()
    
    def set_intermediate_layers_to_capture(self, layers):
        """
        Register hooks to capture features from intermediate layers
        """
        # Just check the layers specified are actually in the convnet
        top_level_block_layers = [name for name, _ in self.convnet.named_children()]
        if not all(layer in top_level_block_layers for layer in layers):
            print('You have specified convnet layers that are not top-level blocks - make sure your layer names are valid')
        
        self.intermediate_layers_to_capture = layers
        intermediate_layer_features = {}

        def get_hook(layer_name):
            def hook(module, input, output):
                intermediate_layer_features[layer_name] = output.detach()
            return hook

        for layer_name in layers:
            layer = dict([*self.convnet.named_modules()])[layer_name]
            layer.register_forward_hook(get_hook(layer_name))

        self.intermediate_layer_features = intermediate_layer_features

    @torch.no_grad()
    def extract_features(self, dataset: Dataset) -> Dict[str, torch.Tensor]:
        """
        Run the pretrained SimCLR model on the image data, and capture features from final layer and intermediate layers.

        Args:
            dataset (Dataset): A PyTorch Dataset containing input images and labels. The image data should have shape (N, C, H, W)

        Returns:
            Dict[str, torch.Tensor]: A dictionary containing:
                - Intermediate layer features as tensors.
                - Final layer features under 'final_layer'.
                - Labels under 'labels'.
            Features from a given layer has shape (N, F) where N is num images, F is number of features - flattened version of (C, H, W).
        """
        self.convnet.fc = nn.Identity()  # Removing projection head g(.)
        self.eval()
        self.to(self.device)
        
        # Encode all images
        data_loader = DataLoader(dataset, batch_size=64, num_workers=self.num_workers, shuffle=False, drop_last=False)
        feats, labels, intermediate_features = [], [], {layer: [] for layer in self.intermediate_layers_to_capture}

        for batch_idx, (batch_imgs, batch_labels) in enumerate(tqdm(data_loader)):
            batch_imgs = batch_imgs.to(self.device)
            batch_feats = self.convnet(batch_imgs)
            
            feats.append(batch_feats.detach().cpu())
            labels.append(batch_labels)

            # Collect intermediate layer outputs
            for layer in self.intermediate_layers_to_capture:
                # Final linear layer outputs a 2d tensor; but intermediate layers don't, so we flatten them (ready for PCA etc.)
                # layer_output_flattened = self.intermediate_layer_features[layer].view(self.intermediate_layer_features[layer].size(0), -1) 
                # intermediate_features[layer].append(layer_output_flattened.cpu())

                # DON'T FLATTEN - IT CAUSES PROBLEMS WHEN VISUALISING FEATURES LATER
                intermediate_features[layer].append(self.intermediate_layer_features[layer].cpu())

        
        # Concatenate results for each layer
        feats = torch.cat(feats, dim=0)
        labels = torch.cat(labels, dim=0)
        intermediate_features = {layer: torch.cat(intermediate_features[layer], dim=0) for layer in self.intermediate_layers_to_capture}

        # Debugging log after concatenation
        print("✅ Feature extraction complete. Final feature shapes:")
        print(f"Final layer: {feats.shape}")
        for layer, feature in intermediate_features.items():
            print(f"{layer}: {feature.shape}")  # Check final stored shape

        return {**intermediate_features, 'final_layer': feats, 'labels': labels}

intermediate_layers = ['layer1', 'layer2', 'layer3', 'layer4']

sim_clr = SimCLR()
sim_clr.load_pretrained()
sim_clr.set_intermediate_layers_to_capture(intermediate_layers)
feats = sim_clr.extract_features(image_dataset)

for layer in ["layer1", "layer2", "layer3", "layer4"]:
    if layer in feats:
        variance = np.var(feats[layer].numpy())
        print(f"{layer} variance: {variance:.6f}")

# Our original images are grayscale, but SimCLR expects 3-channel RGB input.
# To meet this requirement, we duplicated the grayscale values across all three RGB channels.
# However, for PCA, we only need a single channel, so we extract just the first channel (Red).
flattened_images = images[:, 0, :, :].view(images.shape[0], -1) # shape: [1573, 50176] (1573 images, 224x224 pixels)

layer1_feats = feats['layer1'] # Shape: torch.Size([1573, 200704]) (n_images, n_features)
layer2_feats = feats['layer2']
layer3_feats = feats['layer3']
layer4_feats = feats['layer4']
final_layer_feats = feats['final_layer'] # Shape: torch.Size([1573, 512])

print('flattened_images shape', flattened_images.shape)
print('layer1 shape', layer1_feats.shape)
print('final layer shape', final_layer_feats.shape)

In [None]:
### For each layer, reduce features with PCA (optional) then t-SNE
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA

feats = layer2_feats.view(layer2_feats.size(0), -1) 

# Make mask for all determined labels
feats_np = feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR layer2 features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

from sklearn.metrics import silhouette_score
score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### For each layer, reduce features with PCA (optional) then t-SNE
feats = final_layer_feats

# Make mask for all determined labels
feats_np = feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR fc features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

from sklearn.metrics import silhouette_score
score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### For each layer, reduce features with PCA (optional) then t-SNE
feats = layer4_feats.view(layer1_feats.size(0), -1)

# Make mask for all determined labels
feats_np = feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR layer4 features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

from sklearn.metrics import silhouette_score
score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### For each layer, reduce features with PCA (optional) then t-SNE
feats = layer3_feats.view(layer3_feats.size(0), -1)

# Make mask for all determined labels
feats_np = feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR layer3 features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

from sklearn.metrics import silhouette_score
score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### For each layer, reduce features with PCA (optional) then t-SNE
feats = layer1_feats.view(layer1_feats.size(0), -1)

# Make mask for all determined labels
feats_np = feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(feats_pca)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR layer1 features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

from sklearn.metrics import silhouette_score
score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### Recompute silhouette score for raw features
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

raw_feats = layer2_feats.view(layer2_feats.size(0), -1)

feats_np = raw_feats.cpu().numpy()
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_feats = feats_np[valid_mask]
filtered_labels = [label for label in category_labels if label != 'undetermined']

tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(filtered_feats)
plt.figure(figsize=(8, 6))
sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
plt.title("t-SNE on SimCLR final layer features (filtered)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

score = silhouette_score(tsne_feats, filtered_labels)
print(f"Silhouette score: {score:.3f}")

In [None]:
### t-SNE silhouette scores
from sklearn.metrics import silhouette_score

layer_feats = {
    'layer1': layer1_feats.view(layer1_feats.size(0), -1),
    'layer2': layer2_feats.view(layer2_feats.size(0), -1),
    'layer3': layer3_feats.view(layer3_feats.size(0), -1),
    'layer4': layer4_feats.view(layer4_feats.size(0), -1),
    'fc': final_layer_feats,
}

filtered_labels = [label for label in category_labels if label != 'undetermined']

for layer_name, feats in layer_feats.items():
    feats_np = feats.cpu().numpy()
    valid_mask = np.array([label != 'undetermined' for label in category_labels])
    filtered_feats = feats_np[valid_mask]

    # feats_pca = PCA(n_components=100).fit_transform(filtered_feats)
    tsne_feats = TSNE(n_components=2, perplexity=30).fit_transform(filtered_feats)
    score = silhouette_score(tsne_feats, filtered_labels)
    print(f"Silhouette score for {layer_name}: {score:.3f}")

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=tsne_feats[:, 0], y=tsne_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
    plt.title(f"t-SNE on SimCLR {layer_name} features (filtered)")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Silhouette score for layer1: -0.077
# Silhouette score for layer2: -0.046
# Silhouette score for layer3: -0.023
# Silhouette score for layer4: -0.037
# Silhouette score for fc: 0.012

In [None]:
### Try K-means and purity
from sklearn.preprocessing import LabelEncoder
from sklearn.cluster import KMeans
import numpy as np
from collections import Counter

def compute_purity(y_true, y_pred):
    # For each cluster, find the most frequent true label
    total = 0
    for cluster in np.unique(y_pred):
        indices = np.where(y_pred == cluster)[0]
        true_labels = y_true[indices]
        most_common = Counter(true_labels).most_common(1)[0][1]
        total += most_common
    return total / len(y_true)

layer_feats = {
    'layer1': layer1_feats.view(layer1_feats.size(0), -1),
    'layer2': layer2_feats.view(layer2_feats.size(0), -1),
    'layer3': layer3_feats.view(layer3_feats.size(0), -1),
    'layer4': layer4_feats.view(layer4_feats.size(0), -1),
    'fc': final_layer_feats,
}

def compute_random_purity_baseline(y_true, n_clusters=None, n_trials=20, seed=None):
    # Compute the average cluster purity of random cluster assignments
    y_true = np.array(y_true)
    if n_clusters is None:
        n_clusters = len(np.unique(y_true))
    
    rng = np.random.default_rng(seed)
    purities = []
    for _ in range(n_trials):
        y_random = rng.integers(0, n_clusters, size=len(y_true))
        purity = compute_purity(y_true, y_random)
        purities.append(purity)
    
    return np.mean(purities)

filtered_labels = [label for label in category_labels if label != 'undetermined']
le = LabelEncoder()
encoded_labels = le.fit_transform(filtered_labels)  # e.g., 'cat' → 0, 'bird' → 1, etc.
n_clusters = len(np.unique(encoded_labels))
random_baseline = compute_random_purity_baseline(encoded_labels, n_clusters=6, n_trials=50, seed=42)
print(f"Random cluster purity baseline: {random_baseline:.3f}")

for layer_name, layer_feats in layer_feats.items():
    raw_feats = layer_feats
    feats_np = raw_feats.cpu().numpy()
    valid_mask = np.array([label != 'undetermined' for label in category_labels])
    filtered_feats = feats_np[valid_mask]

    # Use only labeled samples
    kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0)
    cluster_assignments = kmeans.fit_predict(filtered_feats)

    purity = compute_purity(encoded_labels, cluster_assignments)
    print(f"{layer_name} Cluster Purity: {purity:.3f}")

# without texture category
# Random cluster purity baseline: 0.272
# layer1 Cluster Purity: 0.333
# layer2 Cluster Purity: 0.357
# layer3 Cluster Purity: 0.409
# layer4 Cluster Purity: 0.456
# fc Cluster Purity: 0.468

# with texture category
# Random cluster purity baseline: 0.241
# layer1 Cluster Purity: 0.327
# layer2 Cluster Purity: 0.295
# layer3 Cluster Purity: 0.400
# layer4 Cluster Purity: 0.438
# fc Cluster Purity: 0.438

In [None]:
### Compute NMI (Normalized Mutual Information)
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics import normalized_mutual_info_score

def compute_random_nmi_baseline(y_true, n_clusters=None, n_trials=20, seed=None):
    # Compute the expected NMI from random cluster assignments
    y_true = np.array(y_true)
    if n_clusters is None:
        n_clusters = len(np.unique(y_true))
    
    rng = np.random.default_rng(seed)
    nmi_scores = []
    
    for _ in range(n_trials):
        random_labels = rng.integers(0, n_clusters, size=len(y_true))
        nmi = normalized_mutual_info_score(y_true, random_labels)
        nmi_scores.append(nmi)
    
    return np.mean(nmi_scores)

layer_feats = {
    'layer1': layer1_feats.view(layer1_feats.size(0), -1),
    'layer2': layer2_feats.view(layer2_feats.size(0), -1),
    'layer3': layer3_feats.view(layer3_feats.size(0), -1),
    'layer4': layer4_feats.view(layer4_feats.size(0), -1),
    'fc': final_layer_feats,
}

# Run KMeans clustering
filtered_labels = [label for label in category_labels if label != 'undetermined']
le = LabelEncoder()
encoded_labels = le.fit_transform(filtered_labels)  # e.g., 'cat' → 0, 'bird' → 1, etc.
n_clusters = len(np.unique(encoded_labels))

random_nmi = compute_random_nmi_baseline(encoded_labels, n_clusters=6, n_trials=50, seed=42)
print(f"Random NMI baseline: {random_nmi:.3f}")

for layer_name, layer_feats in layer_feats.items():
    raw_feats = layer_feats
    feats_np = raw_feats.cpu().numpy()
    valid_mask = np.array([label != 'undetermined' for label in category_labels])
    filtered_feats = feats_np[valid_mask]
    
    kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0)
    cluster_assignments = kmeans.fit_predict(filtered_feats)

    # Compute NMI
    nmi = normalized_mutual_info_score(encoded_labels, cluster_assignments)
    print(f"NMI {layer_name}: {nmi:.3f}")

# without texture category
# Random NMI baseline: 0.031
# NMI: 0.104
# NMI: 0.100
# NMI: 0.193
# NMI: 0.260
# NMI: 0.281

# with texture category
# Random NMI baseline: 0.030
# NMI layer1: 0.103
# NMI layer2: 0.089
# NMI layer3: 0.176
# NMI layer4: 0.233
# NMI fc: 0.244

In [None]:
## UMAP

from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.cluster import KMeans

def compute_cluster_purity(pred_labels, true_labels):
    cluster_to_labels = {}
    for pred, true in zip(pred_labels, true_labels):
        cluster_to_labels.setdefault(pred, []).append(true)

    correct = 0
    for members in cluster_to_labels.values():
        most_common = Counter(members).most_common(1)[0][1]
        correct += most_common

    return correct / len(true_labels)

layer_feats = {
    'layer1': layer1_feats.view(layer1_feats.size(0), -1),
    'layer2': layer2_feats.view(layer2_feats.size(0), -1),
    'layer3': layer3_feats.view(layer3_feats.size(0), -1),
    'layer4': layer4_feats.view(layer4_feats.size(0), -1),
    'fc': final_layer_feats  # already flat
}

# Filter out undetermined labels and get mask
valid_mask = np.array([label != 'undetermined' for label in category_labels])
filtered_labels = [label for label in category_labels if label != 'undetermined']
le = LabelEncoder()
encoded_labels = le.fit_transform(filtered_labels)
n_clusters = len(np.unique(encoded_labels))

# UMAP + plot for each layer
for layer_name, feats_tensor in layer_feats.items():
    feats_np = feats_tensor.cpu().numpy()
    filtered_feats = feats_np[valid_mask]  # same filtering for all layers

    reducer = UMAP(n_components=2, random_state=42)
    umap_feats = reducer.fit_transform(filtered_feats)

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=umap_feats[:, 0], y=umap_feats[:, 1], hue=filtered_labels, palette='tab10', s=30)
    plt.title(f"UMAP on SimCLR {layer_name} features (filtered)")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

    # Cluster metrics
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    pred_labels = kmeans.fit_predict(umap_feats)

    purity = compute_cluster_purity(pred_labels, encoded_labels)
    nmi = normalized_mutual_info_score(encoded_labels, pred_labels)
    sil = silhouette_score(umap_feats, encoded_labels)

    print(f"Layer {layer_name}:")
    print(f"  ➤ Cluster Purity:     {purity:.3f}")
    print(f"  ➤ Normalized MI:      {nmi:.3f}")
    print(f"  ➤ Silhouette Score:   {sil:.3f}")
    print("-" * 40)

In [None]:
### Nearest neighbors - layer 2
from sklearn.metrics.pairwise import cosine_distances
import matplotlib.pyplot as plt

stim_ids_ambiguous_imgs = [1719, 1470, 1375, 1237, 1153, 928]
for stim_id in stim_ids_ambiguous_imgs:
    # Find the index of the image with stim_id
    query_idx = (labels == stim_id).nonzero(as_tuple=True)[0].item()
    print(f"Stim ID {stim_id} found at index {idx}")

    layer2_feats = layer2_feats.view(layer2_feats.size(0), -1)
    layer2_feats_np = layer2_feats.cpu().numpy()

    dists = cosine_distances([layer2_feats_np[query_idx]], layer2_feats_np)[0]
    nearest_idxs = np.argsort(dists)[1:6]  # skip self

    plt.figure(figsize=(12, 3))
    for i, idx in enumerate([query_idx] + list(nearest_idxs)):
        img_np = images[idx].permute(1, 2, 0).numpy()
        img_np = (img_np + 1) / 2
        plt.subplot(1, 6, i + 1)
        plt.imshow(np.clip(img_np, 0, 1))
        plt.title("Query Layer2" if i == 0 else f"NN {i}")
        plt.axis("off")
    plt.tight_layout()
    plt.show()

    fc_features = final_layer_feats
    fc_feats_np = final_layer_feats.cpu().numpy()

    dists = cosine_distances([fc_feats_np[query_idx]], fc_feats_np)[0]
    nearest_idxs = np.argsort(dists)[1:6]  # skip self

    plt.figure(figsize=(12, 3))
    for i, idx in enumerate([query_idx] + list(nearest_idxs)):
        img_np = images[idx].permute(1, 2, 0).numpy()
        img_np = (img_np + 1) / 2
        plt.subplot(1, 6, i + 1)
        plt.imshow(np.clip(img_np, 0, 1))
        plt.title("Query LayerFC" if i == 0 else f"NN {i}")
        plt.axis("off")
    plt.tight_layout()
    plt.show()