In [1]:
from cortexlib.utils import file as futils
from cortexlib.mouse import CortexlabMouse
from cortexlib.images import CortexlabImages
from cortexlib.simclr import PreTrainedSimCLRModel
from cortexlib.utils.random import set_global_seed
from cortexlib.utils.logging import Logger
from sklearn.linear_model import RidgeCV
from sklearn.decomposition import PCA
import numpy as np
import torch

logger = Logger()
set_global_seed()

In [2]:
MOUSE_ID = futils.get_mouse_id()
logger.info(f"This notebook is running for mouse {MOUSE_ID}")

[1;37m22:21:26 | INFO     | ℹ️ This notebook is running for mouse m03_d4[0m


In [3]:
mouse = CortexlabMouse(mouse_id=MOUSE_ID)

logger.progress(f"Computing null distributions for all neurons in mouse {mouse.id}")
null_srv_all_neurons = mouse.compute_null_all_neurons(n_shuffles=100)
logger.success(f"Null distributions computed")

real_srv_all_neurons = mouse.compute_real_srv_all_neurons()
reliable_neuron_indices = mouse.get_reliable_neuron_indices(
            null_srv_all_neurons, real_srv_all_neurons, percentile_threshold=99)
neural_responses_mean, neural_responses, _ = mouse.get_responses_for_reliable_neurons(reliable_neuron_indices, real_srv_all_neurons, num_neurons=500)

logger.info(f"Neural responses shape: {neural_responses.shape}")

[1;37m22:21:26 | INFO     | ⏳ Computing null distributions for all neurons in mouse m03_d4...[0m
[1;32m22:21:51 | SUCCESS  | ✅ Null distributions computed![0m
[1;37m22:21:52 | INFO     | ℹ️ Neural responses shape: (1573, 2, 500)[0m


In [4]:
pca = PCA(100)
neural_data_pcs = pca.fit_transform(neural_responses_mean)
pc1_neural_data = neural_data_pcs[:, 0]

In [5]:
futils.save_filtered_neural_data(
    mouse_id=mouse.id,
    neural_responses=neural_responses,
    neural_responses_mean=neural_responses_mean)

[1;37m22:21:52 | INFO     | ℹ️ Skipping save, file already exists at /Users/callummessiter/workspace/msc-neuro/research-project/analysis/mouse_m03_d4/_neural_data/neural_data_mouse_m03_d4.pt[0m


In [6]:
simclr = PreTrainedSimCLRModel()

settings = simclr.get_image_settings()
images = CortexlabImages(
    size=settings['size'],
    channels=settings['channels'],
    normalise_mean=settings['mean'], 
    normalise_std=settings['std'],
    rescale_per_image=settings['rescale_per_image'],
)

logger.progress("Loading and preprocessing images shown to mouse")
image_dataset = images.load_images_shown_to_mouse(mouse.image_ids)
logger.success("Images processed")

[1;37m22:21:52 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
[1;37m22:21:52 | INFO     | ⏳ Loading and preprocessing images shown to mouse...[0m
[1;32m22:22:09 | SUCCESS  | ✅ Images processed![0m


In [7]:
logger.progress("SimCLR: extracting features from images shown to mouse")
simclr_feats, labels = simclr.extract_features(image_dataset)
logger.success("SimCLR features extracted")

for layer, feats in simclr_feats.items():
    logger.info(f"{layer} feats shape: {tuple(feats.shape)}")

[1;37m22:22:09 | INFO     | ⏳ SimCLR: extracting features from images shown to mouse...[0m


  0%|          | 0/25 [00:00<?, ?it/s]

[1;32m22:22:25 | SUCCESS  | ✅ SimCLR features extracted![0m
[1;37m22:22:25 | INFO     | ℹ️ layer1 feats shape: (1573, 64, 24, 24)[0m
[1;37m22:22:25 | INFO     | ℹ️ layer2 feats shape: (1573, 128, 12, 12)[0m
[1;37m22:22:25 | INFO     | ℹ️ layer3 feats shape: (1573, 256, 6, 6)[0m
[1;37m22:22:25 | INFO     | ℹ️ layer4 feats shape: (1573, 512, 3, 3)[0m
[1;37m22:22:25 | INFO     | ℹ️ fc feats shape: (1573, 512)[0m


In [8]:
futils.save_model_features(model=futils.Model.SIMCLR, mouse_id=mouse.id, features=simclr_feats, labels=labels)

[1;37m22:22:25 | INFO     | ℹ️ Skipping save, file already exists at /Users/callummessiter/workspace/msc-neuro/research-project/analysis/mouse_m03_d4/_model_features/simclr_features_mouse_m03_d4.pt[0m


In [9]:
from PIL import Image
from cortexlib.images import CortexlabRawImages
import os

img_loader = CortexlabRawImages(channels=1)
output_dir = "top_pc1_originals"
os.makedirs(output_dir, exist_ok=True)

top_k = 10
indices_top_images = np.argsort(pc1_neural_data)[-top_k:]

for i, idx in enumerate(indices_top_images):
    img, img_id = image_dataset[idx]
    img_np = img_loader.load_mat_image(img_id)[:, :, 0]

    img_norm = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    img_uint8 = (img_norm * 255).astype(np.uint8)
    img_pil = Image.fromarray(img_uint8)

    filename = f"rank_{i}_img_{img_id}.png"
    filepath = os.path.join(output_dir, filename)
    img_pil.save(filepath)

In [10]:
def regressor(X, Y):
    alphas = np.logspace(1, 7, 20)
    ridge = RidgeCV(alphas=alphas, store_cv_results=True)
    ridge.fit(X, Y)
    Y_pred = ridge.predict(X)
    return ridge

def l2_penalty(img, lam=0.0001):
    l2_penalty = lam * torch.sum(img ** 2)
    return l2_penalty

def generate_synthetic_img(layer_name, ridge, iterations=200, l2_lam=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sim_clr = PreTrainedSimCLRModel(intermediate_layers=[layer_name])

    intermediate_features = {}
    def hook_fn(module, input, output):
        intermediate_features[layer_name] = output

    layer = dict([*sim_clr.convnet.named_modules()])[layer_name]
    hook_handle = layer.register_forward_hook(hook_fn)

    # Convert ridge regressor weights to torch
    ridge_weights = torch.tensor(ridge.coef_, dtype=torch.float32, device=device).unsqueeze(0) # (1, D)

    synthetic_image = torch.randn(1, 1, 96, 96, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([synthetic_image], lr=0.05, weight_decay=1e-6)

    for _ in range(iterations):
        optimizer.zero_grad()
        input_img = synthetic_image.repeat(1, 3, 1, 1)
        _ = sim_clr.convnet(input_img)

        feats = intermediate_features[layer_name].view(1, -1)
        score = torch.matmul(feats, ridge_weights.t()).squeeze()
        loss = -score + (l2_penalty(synthetic_image, l2_lam) if l2_lam is not None else 0) 
        
        loss.backward()
        optimizer.step()

        synthetic_image.data.clamp_(-1, 1)

    img_np = synthetic_image.detach().cpu().squeeze().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    hook_handle.remove()

    return img_np

In [11]:
import os
from PIL import Image, ImageOps

output_dir="synthetic_pc1_images"
os.makedirs(output_dir, exist_ok=True)

img_size = simclr.get_image_settings()['size']

for layer_name, feats in simclr_feats.items():
    feats = feats if layer_name == 'fc' else feats.view(feats.size(0), -1)
    idx_layer_name = 'avgpool' if layer_name == 'fc' else layer_name
    
    ridge = regressor(feats, pc1_neural_data)
    synthetic_img = generate_synthetic_img(idx_layer_name, ridge, l2_lam=0.001)
    
    img_uint8 = (synthetic_img * 255).astype(np.uint8)
    img_pil = Image.fromarray(img_uint8, mode='L')
    img_resized = img_pil.resize((img_size[1]*2, img_size[0]*2), resample=Image.BICUBIC)

    filename = f"synthetic_img_{layer_name}.png"
    filepath = os.path.join(output_dir, filename)
    img_resized.save(filepath)

    if layer_name == 'layer3':
        filename = f"synthetic_img_{layer_name}_bordered.png"
        filepath = os.path.join(output_dir, filename)
        img_rgb = img_resized.convert("RGB")
        bordered_img = ImageOps.expand(img_rgb, border=5, fill="#41e6ff")
        bordered_img.save(filepath)

[1;37m22:22:28 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
[1;37m22:22:32 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
[1;37m22:22:37 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
[1;37m22:22:42 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
[1;37m22:22:48 | INFO     | ℹ️ Already downloaded pretrained SimCLR model[0m
