In [61]:
import torch
import pickle
import pandas as pd
import numpy as np
import random 
import wandb

import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from tqdm import trange
from PIL import Image
from torch.nn import PoissonNLLLoss, MSELoss

from tqdm import tqdm, trange
from os.path import join
from torch.utils.data import DataLoader

from fix_models.feature_extractors import resnet3d18_reg, get_video_feature_extractor, resnet50_reg, get_image_feature_extractor, ImageFeatureExtractor, VideoFeatureExtractor, dorsalnet_reg
from fix_models.datasets import VideoDataset, ImageDataset, get_datasets_and_loaders, get_search_dataset_and_loader
from fix_models.transforms import BaseVideoTransform, BaseImageTransform
from fix_models.readouts import PoissonGaussianReadout
from fix_models.metrics import corr_to_avg
from fix_models.models import FullModel

from fix_models.feature_extractors import ImageFeatureExtractor, VideoFeatureExtractor
from fix_models.readouts import PoissonGaussianReadout

In [62]:
import torch
torch.cuda.is_available() 

True

In [63]:
"""
!pip3 install pandas
!pip3 install scipy
!pip3 install imagehash
!pip3 install matplotlib
!pip3 install tqdm
!pip3 install wandb
!pip install python_dict_wrapper
!pip install GitPython
!pip install tables
!apt-get update && apt-get install ffmpeg libsm6 libxext6  -y
!pip install opencv-python
"""

'\n!pip3 install pandas\n!pip3 install scipy\n!pip3 install imagehash\n!pip3 install matplotlib\n!pip3 install tqdm\n!pip3 install wandb\n!pip install python_dict_wrapper\n!pip install GitPython\n!pip install tables\n!apt-get update && apt-get install ffmpeg libsm6 libxext6  -y\n!pip install opencv-python\n'

In [64]:
torch.cuda.is_available() 

True

In [113]:
# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'./data/{config["modality"]}/'
stimulus_dir = f'./data/{config["modality"]}/stimuli/'
embedding_dir = f'./data/{config["modality"]}/embeddings/'
model_output_path = f'./data/{config["modality"]}/model_output/results'

# image defaults
if config["modality"] == "image":
    # model parameters
    config["layer"] = "layer3"
    config['use_sigma'] = True
    config['center_readout'] = False
    config["use_pool"] = True
    config["pool_size"] = 4
    config["pool_stride"] = 1
    config["use_pretrained"] = True

    config["flatten_time"] = True


    # stimulus parameters 
    config["stim_size"] = 32 #25 #[25, 50, 100]
    config["win_size"] = 240 #[50, 100, 180]
    if config["modality"] == "video":
        stim_dur_ms = 200
        stim_shape = (1, 3, 5, config["stim_size"], config["stim_size"])
    elif config["modality"] == "image":
        stim_dur_ms = 120
        stim_shape = (1, 3, config["stim_size"], config["stim_size"])
    
    # training parameters 
    config["exp_var_thresholds"] = [0.1, 0.1, 0.1, 0.1] #[0.15, 0.15, 0.3, 0.25]
    config["lr"] = 0.001 #1 #[0.001, 0.01, 0.1]
    config["batch_size"] = 16
    config["num_epochs"] = 20#1
    config["l2_weight"] = 0#1#[0.01, 0.1, 1, 1]

    config["feat_ext_type"] = 'resnet50'

# video defaults
if config["modality"] == "video":
    # model parameters
    config["layer"] = "layer1"
    config["use_sigma"] = True
    config["center_readout"] = False
    config["use_pool"] = True
    config["pool_size"] = 4
    config["pool_stride"] = 2
    config["use_pretrained"] = True

    config["flatten_time"] = True

    # stimulus parameters 
    config["win_size"] = 240 #240# #240 #180 #180 #[50, 100, 180]

    config["feat_ext_type"] = 'dorsalnet'
    config["stim_size"] = 32 #32 #50 #25 #[25, 50, 100]
    if config["feat_ext_type"] == 'hiera':
        config["stim_size"] = 224 #50 #25 #[25, 50, 100]

    if config["modality"] == "video":
        stim_dur_ms = 200
        stim_shape = (1, 3, 5, config["stim_size"], config["stim_size"])
    elif config["modality"] == "image":
        stim_dur_ms = 120
        stim_shape = (1, 3, config["stim_size"], config["stim_size"])
    
    # training parameters 
    config["exp_var_thresholds"] = [0.25, 0.25, 0.25] #[0.15, 0.15, 0.15] #[-1, -1, -1] #[0.15, 0.15, 0.15, 0.15]
    config["lr"] = 0.001 #1 #[0.001, 0.01, 0.1]
    config["batch_size"] = 16
    config["num_epochs"] = 20 #1
    config["l2_weight"] = 0 #1e-5 #0 #s1e-3# 0 #0 #0.01 #.001 #.001 #0.001 #0.001 # 0 #.001 #1#[0.01, 0.1, 1, 1]
    config["first_frame_only"] = False
    config["blur_sigma"] = 0

    config["mlp"] = False
    config["loss"] = "poisson"

config['pos'] = (400, 180)

# logging
config["wandb"] = True

config["ensemble"] = False

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
if config["modality"] == "video":
    session_ids = ["082824", "082924", "083024"] #"082824", 
elif config["modality"] == "image":
    session_ids = ["051724", "081624", "080124", "082324"]

#### Step 1 - Train encoding model

In [None]:
import torch
import torch.nn as nn

class ModelEnsembler(nn.Module):
    def __init__(self, models):
        super(ModelEnsembler, self).__init__()
        self.models = nn.ModuleList(models)  # List of models
        self.weights = nn.Parameter(torch.ones((4, 1, 1), device=device)/4)
    def forward(self, x):
        # Get the outputs of each model
        outputs = [model(x) for model in self.models]
        # Stack the outputs along a new dimension and average them
        self.weights.data = torch.clamp(self.weights, 0, 1)
        avg_output = torch.sum(self.weights/torch.sum(self.weights)*torch.stack(outputs), 0)
        return avg_output
        
all_corrs = []
xs = []

#l2_weights = [0, 1e-5, 1e-4, 1e-3, 1e-2]
#lrs = [1e-3, 1e-5, 1e-4, 1e-2, 1e-1]

#config["win_size"] = 180 #180 #180 #[50, 100, 180]
"""
for modality in ["video"]:#, "image"]:
    # all parameters
    config = dict()
    config["modality"] = modality # or image
    
    # paths
    input_dir = f'./data/{config["modality"]}/'
    stimulus_dir = f'./data/{config["modality"]}/stimuli/'
    embedding_dir = f'./data/{config["modality"]}/embeddings/'
    model_output_path = f'./data/{config["modality"]}/model_output/results'
    
    # image defaults
    if config["modality"] == "image":
        # model parameters
        config["layer"] = "layer3"
        config['use_sigma'] = True
        config['center_readout'] = False
        config["use_pool"] = True
        config["pool_size"] = 2
        config["pool_stride"] = 1
        config["use_pretrained"] = True
    
        config["flatten_time"] = False
    
    
        # stimulus parameters 
        config["stim_size"] = 32 #25 #[25, 50, 100]
        config["win_size"] = 240 #[50, 100, 180]
        if config["modality"] == "video":
            stim_dur_ms = 200
            stim_shape = (1, 3, 5, config["stim_size"], config["stim_size"])
        elif config["modality"] == "image":
            stim_dur_ms = 120
            stim_shape = (1, 3, config["stim_size"], config["stim_size"])
        
        # training parameters 
        config["exp_var_thresholds"] = [0.1, 0.1, 0.1, 0.1] #[0.15, 0.15, 0.3, 0.25]
        config["lr"] = 0.001 #1 #[0.001, 0.01, 0.1]
        config["batch_size"] = 16
        config["num_epochs"] = 20#1
        config["l2_weight"] = 0#1#[0.01, 0.1, 1, 1]
    
        config["feat_ext_type"] = 'resnet50'
    
    # video defaults
    if config["modality"] == "video":
        # model parameters
        config["layer"] = "layer3"
        config["use_sigma"] = True
        config["center_readout"] = False
        config["use_pool"] = True
        config["pool_size"] = 2
        config["pool_stride"] = 1
        config["use_pretrained"] = True
    
        config["flatten_time"] = True
    
        # stimulus parameters 
        config["stim_size"] = 32 #50 #25 #[25, 50, 100]
        config["win_size"] = 240 #180 #180 #[50, 100, 180]
        if config["modality"] == "video":
            stim_dur_ms = 200
            stim_shape = (1, 3, 5, config["stim_size"], config["stim_size"])
        elif config["modality"] == "image":
            stim_dur_ms = 120
            stim_shape = (1, 3, config["stim_size"], config["stim_size"])
        
        # training parameters 
        config["exp_var_thresholds"] = [0.1, 0.1, 0.1] #[-1, -1, -1] #[0.15, 0.15, 0.15, 0.15]
        config["lr"] = 0.001 #1 #[0.001, 0.01, 0.1]
        config["batch_size"] = 16
        config["num_epochs"] = 20 #1
        config["l2_weight"] = 0 #.001 #0.001 #0.001 # 0 #.001 #1#[0.01, 0.1, 1, 1]
    
        config["feat_ext_type"] = 'resnet3d'
        
    # logging
    config["wandb"] = True
    
    # save model
    config["save"] = True
    
    # device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # session names
    if config["modality"] == "video":
        session_ids = ["082824", "082924", "083024"] #"082824", 
    elif config["modality"] == "image":
        session_ids = ["051724", "081624", "080124", "082324"]
    # image experiments
    if config["modality"] == "video":
        feat_ext_types = ["resnet3d"] #["dorsalnet", "resnet3d"]
        layer_pool_sizes = [8, 4, 2, 1]
    
    elif config["modality"] == "image":
        feat_ext_types = ["resnet50", "resnet50_fcn"]
        layer_pool_sizes = [4, 2, 1, 1]
    
    # video experiments
    layers = ['layer3'] #['layer1', 'layer2', 'layer3', 'layer4']
    
    for layer, layer_pool_size in zip(layers, layer_pool_sizes):
        config["layer"] = layer
        config["pool_size"] = layer_pool_size
        for feat_ext in feat_ext_types:
            config["feat_ext_type"] = feat_ext
            """
corr_avgs = []
for ses_idx, session_id in enumerate(session_ids):
    sess_corr_avg = 0
    sess_corrs = []
    # logging training
    config["session_id"] = session_id

    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-basline',
            config=config,
        )
        wandb.define_metric("corr_to_avg", summary="max")
        wandb.define_metric("test_loss", summary="min")

    exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], blur_sigma = config["blur_sigma"], pos = config['pos'])

    if config["ensemble"]:
        full_model = ModelEnsembler([FullModel(config["modality"], "layer" + str(jjj), stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device, mlp=config["mlp"]) for jjj in range(1,5)])
    else:
        full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device, mlp=config["mlp"])

    
    params_with_l2 = []
    params_without_l2 = []
    for name, param in full_model.named_parameters():
        if 'mu' in name or 'sigma' in name:
            params_without_l2.append(param)
        else:
            params_with_l2.append(param)
            
    optimizer = torch.optim.Adam([
    {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
    {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
    ], lr=config["lr"], weight_decay=config['l2_weight'])

    if config['loss'] == 'poisson':
        loss_func = PoissonNLLLoss(log_input=False, full=True)
    elif config['loss'] == 'mse':
        loss_func = MSELoss()
        
    for epochs in range(config["num_epochs"]):
        epoch_loss = 0
        for i, (stimulus, targets) in (enumerate(train_loader)): 
            stimulus = stimulus.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            preds = full_model(stimulus)
            # on 10/4/24 - changed l2 weight decay to be a part of Adam optimizer 
            loss = loss_func(preds, targets) # + config["l2_weight"] * torch.mean((torch.sum(full_model.model[1].linear.weight ** 2, 1)))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # printing corr to avg and loss metrics 
        with torch.no_grad():
            corr_avg = corr_to_avg(full_model, test_loader, modality=config["modality"], device=device)
            test_loss = 0
            for i, (stimulus, targets) in enumerate(test_loader):
                stimulus = stimulus.to(device)
                targets = targets.to(device)
                preds = full_model(stimulus) 
                loss = loss_func(preds, targets)
                test_loss += loss.item()
        if config["wandb"]:
            wandb.log({"corr_to_avg": np.nanmean(corr_avg), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
        if np.nanmean(corr_avg) > sess_corr_avg:
            sess_corr_avg = np.nanmean(corr_avg)
            sess_corrs = corr_avg
        print('  epoch {} loss: {} corr: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(corr_avg)))
        print(f' num. neurons : {len(corr_avg)}')
    if config["save"]:
        torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
    corr_avgs.append(sess_corrs)
    if config["wandb"]:
        wandb.finish()

if config["wandb"]:
    wandb.init(
        project=f'{config["modality"]}-basline',
        config=config,
        #name=f"{feat_ext}_{layer}_{config["first_frame_only",
    )
    for sess_corr in corr_avgs:
        for corr in sess_corr:
            wandb.log({"corr": corr})
    wandb.finish()


readout input shape: torch.Size([1, 32, 5, 8, 8])


  checkpoint = torch.load(path, map_location=args.device)


  epoch 1 loss: 0.1254706863709438 corr: 0.07201377483141577
 num. neurons : 54
  epoch 2 loss: 0.11321003013186985 corr: 0.11567752897549415
 num. neurons : 54
  epoch 3 loss: 0.11136853471214389 corr: 0.1425432886578258
 num. neurons : 54
  epoch 4 loss: 0.11020525726271264 corr: 0.17389835658907163
 num. neurons : 54
  epoch 5 loss: 0.10950821393801842 corr: 0.20020052010362915
 num. neurons : 54
  epoch 6 loss: 0.10896263057802931 corr: 0.20969730408812776
 num. neurons : 54
  epoch 7 loss: 0.1085477418664061 corr: 0.22898195640833502
 num. neurons : 54
  epoch 8 loss: 0.10825363147405931 corr: 0.22834540440348441
 num. neurons : 54
  epoch 9 loss: 0.10800979655465962 corr: 0.25268914358508443
 num. neurons : 54
  epoch 10 loss: 0.10774150836614915 corr: 0.2630346954978146
 num. neurons : 54
  epoch 11 loss: 0.10742415251555266 corr: 0.26333880462889897
 num. neurons : 54
  epoch 12 loss: 0.10713574550769947 corr: 0.2699976575726546
 num. neurons : 54
  epoch 13 loss: 0.10699078854

In [None]:
import matplotlib.pyplot as plt
for jj, (stimulus, targets) in (enumerate(train_loader)): 
    for j in range(config['batch_size']):
        fig, axs = plt.subplots(1, 5)
        for i in range(5):
            plt.sca(axs[i])
            plt.imshow(np.mean(stimulus[j, :, i, :, :].cpu().detach().numpy(),0), cmap='Grays') #, cmap='Grays'
            plt.axis('off')
        plt.show()
    if jj > 10:
        break

In [None]:
full_model.weights

In [None]:
import matplotlib.pyplot as plt 
plt.hist(full_model.models[4].model[1].sigma.detach().cpu().numpy())

In [None]:
plt.hist(full_model.models[4].model[1].mu.detach().cpu().numpy())

In [None]:
plt.scatter(full_model.model[1].mu[:, 0].detach().cpu().numpy(), full_model.model[1].mu[:, 1].detach().cpu().numpy())

In [None]:
from fix_models.feature_extractors import hierat_reg
model = hierat_reg(device=device)

In [None]:
model(torch.ones((1,3,224,224), device=device)).reshaped_hidden_states[-2].shape

In [None]:
model = HieraModel.from_pretrained("facebook/hiera-tiny-224-hf")

In [None]:
from fix_models.feature_extractors import sam2t_reg
model = sam2t_reg(device=device)

In [None]:
model._forward_hooks

In [None]:
get_graph_node_names(predictor)

In [None]:
ls ./baselines/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml

In [None]:
predictor

In [None]:
#corr_avg
with open(join(input_dir, f"{session_id}.pickle"), "rb") as f:
    model_input = pickle.load(f)

In [None]:
import matplotlib.pyplot as plt
ev = model_input['expvar'][model_input['expvar'] > 0]
plt.hist(ev)

In [None]:
threshold = 0.1
print(np.mean(corr_avg[ev > threshold]))
print(len(corr_avg[ev > threshold]))
plt.hist(corr_avg[ev > threshold] ** 2)

In [None]:
plt.scatter(ev[ev > threshold], corr_avg[ev > threshold])
plt.axis('square')
plt.ylim([0, 0.9])
plt.xlim([0.1, 1])
plt.plot([0.1, 0.9], [0.1, 0.9])
plt.xlabel("explainable variance")
plt.ylabel("correlation to average")

#### Step 2 - Search for natural stimuli that highly activate the population of neurons or single neurons in the neuron_ids vector

In [None]:
import os
import shutil

# additional configuration variables for searching for highly activating stimuli
modality = config["modality"]
config["stim_input_dir"] = f"./data/{modality}/novel_{modality}_datasets/"
config["stim_output_dir"] = f"./data/{modality}/novel_{modality}_datasets_pred/"
config["n_stim"] = 25 # number of most activating and least activating stimuli to return 
config["n_log"] = 5
config["pop_act"] = False
config["neuron_ids"] = [0]

if config["pop_act"]:
    config["neuron_ids"] = [0]

for ses_idx, session_id in enumerate(session_ids):
    # load model from state dict + get targets etc. for variable sizing
    exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"])
    full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device)
    
    full_model.load_state_dict(torch.load(f"{model_output_path}_{session_id}.pickle", weights_only=True))
    
    for i, (stimulus, targets) in (enumerate(train_loader)): 
        break

    # get dataset and loader for novel videos to search over
    search_dataset, search_loader = get_search_dataset_and_loader(config["stim_input_dir"], config["modality"], config["stim_size"], config["win_size"], config["batch_size"])
    batch_size = config["batch_size"]
    all_preds = torch.zeros((len(search_dataset), targets.shape[1]))
    all_names = []
    for i, (stimulus, stim_names) in (enumerate(search_loader)): 
        stimulus = stimulus.to(device)
        with torch.no_grad():
            preds = full_model(stimulus)
            all_preds[i*batch_size:i*batch_size+batch_size, :] = preds
            all_names = all_names + list(stim_names)
    all_names = np.array(all_names)
    
    source_dir = config["stim_input_dir"]  
    target_dir = config["stim_output_dir"] 
    
    os.makedirs(target_dir, exist_ok=True)

    for neuron_idx in config["neuron_ids"]:
        if config["pop_act"]:
            act_vect = torch.mean((all_preds - torch.mean(all_preds, 0))/torch.std(all_preds, 0), 1)
        else:
            act_vect = all_preds[:, neuron_idx]
            
        least_activating_videos = all_names[torch.argsort(act_vect)[:config["n_stim"]].numpy().astype(int)]
        most_activating_videos = all_names[torch.argsort(act_vect)[-config["n_stim"]:].numpy().astype(int)]

        least_activating_log = least_activating_videos[:config["n_log"]]
        most_activating_log = most_activating_videos[-config["n_log"]:]

        # save a text file with the most and least activating videos + copy the video files to the target directory
        if config["pop_act"]:
            output_txt_path = os.path.join(target_dir, f'video_activation_names_{session_id}.txt')
        else:
            output_txt_path = os.path.join(target_dir, f'video_activation_names_{session_id}_{neuron_idx}.txt')

        with open(output_txt_path, 'w') as f:
            for video in least_activating_videos:
                f.write(f'Least activating: {video}\n')
                if not os.path.exists(os.path.join(target_dir, video)):
                    shutil.copy(os.path.join(source_dir, video), target_dir)
            
            for video in most_activating_videos:
                f.write(f'Most activating: {video}\n')
                if not os.path.exists(os.path.join(target_dir, video)):
                    shutil.copy(os.path.join(source_dir, video), target_dir)
                    
        print(f"video names saved to {output_txt_path} and videos copied to {target_dir} for session {session_id}.")

#### Step 3 - Use models to synthesize highly activating input stimuli for the population of neurons or single neurons

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AffineTransformationModel(nn.Module):
    def __init__(self, device, steps=5):
        super(AffineTransformationModel, self).__init__()
        self.steps = steps
        # Initialize the affine parameters close to identity
        # Theta is of shape [2, 3]
        # For Identity, theta is [[1, 0, 0], [0, 1, 0]]
        
        theta_identity = torch.tensor([[1, 0, 0],
                                       [0, 1, 0]], dtype=torch.float32, device=device)
        # Define theta_total as a trainable parameter
        self.theta_total = nn.Parameter(theta_identity.clone())

    
        
    def forward(self, img):
        N, C, H, W = img.size()
        device = img.device
        frames = []
        theta_identity = torch.tensor([[1, 0, 0],
                                       [0, 1, 0]], dtype=torch.float32, device=device)
        # For each step, compute theta_i
        for i in range(0, self.steps):
            t_i = i / self.steps
            theta_i = theta_identity + t_i * (self.theta_total - theta_identity)
            theta_i = theta_i.unsqueeze(0)  # [1, 2, 3]
            grid = F.affine_grid(theta_i, img.size(), align_corners=False)
            output = F.grid_sample(img, grid, align_corners=False)
            frames.append(output)
        return frames

import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
"""
# Load an example image
image_path = 'data/image/stimuli/n01440764_1145.JPEG'  # Replace with your image path
img = Image.open(image_path)
transform = transforms.Compose([
    transforms.ToTensor(),
])
img = transform(img).unsqueeze(0)  # Convert to [1, C, H, W] tensor

# print(train_dataset[0][0][:, 0, :, :].shape)
model = AffineTransformationModel(steps=5)
frames = model(img)

# Plot the frames
plt.figure(figsize=(15, 5))
for idx, frame in enumerate(frames):
    plt.subplot(1, 5, idx + 1)
    # Convert tensor to NumPy array and transpose dimensions for plotting
    frame_img = frame[0].permute(1, 2, 0).detach().cpu().numpy()
    # Clip values to valid range [0, 1]
    frame_img = frame_img.clip(0, 1)
    plt.imshow(frame_img)
    plt.title(f"Frame {idx + 1}")
    plt.axis('off')
plt.tight_layout()
plt.show()
"""

In [None]:
from copy import deepcopy
import torch.nn.functional as F

import math

# additional configuration variables for searching for highly activating stimuli
config["stim_output_dir"] = f"./data/{modality}/{modality}_datasets_meis/"
config["n_stim"] = 25 # number of most activating and least activating stimuli to return 
config["n_log"] = 5
config["pop_act"] = False
config["neuron_ids"] = [0, 1, 2, 3, 4, 5]

if config["pop_act"]:
    config["neuron_ids"] = [0]

config["mei_lr"] = 0.1
neuron_id = 15#27
num_iter = 1000
norm_weight = 0
# 0 #300 #250 # 0.3 #0.1
blur = True
blur_mag = 0.75#0.35 #0.5#0.5

temp_lambda = 1000
space_lambda = 0


def temporal_smoothness_loss(video):
    # video shape: (batch, channels, frames, height, width)
    loss = 0.0
    for t in range(video.shape[2] - 1):
        loss += F.mse_loss(video[:, :, t, :, :], video[:, :, t+1, :, :])
    #loss += F.mse_loss(video[:, :, 0, :, :], video[:, :, -1, :, :])
    return loss

def total_variation_loss(video):
    tv_loss = torch.sum(torch.abs(video[:, :, :, :, :-1] - video[:, :, :, :, 1:])) + \
              torch.sum(torch.abs(video[:, :, :, :-1, :] - video[:, :, :, 1:, :]))
    return tv_loss

start_img = train_dataset[1][0][:, 0, :, :]
start_img = torch.zeros(start_img.shape, device=device, dtype=start_img.dtype)

import torch
import matplotlib.pyplot as plt


def create_grid_image(image_size=32, square_size=4, spacing=4):
    """
    Creates a black and white image with a grid of white squares on a black background.

    Args:
        image_size (int): The height and width of the image.
        square_size (int): The size of each white square.
        spacing (int): The spacing between the squares.

    Returns:
        torch.Tensor: The generated image tensor of shape [1, 1, image_size, image_size].
    """
    # Initialize a black image
    img = torch.zeros((3, image_size, image_size), dtype=torch.float32)
    
    # Place white squares on the black background
    bord = int(square_size/2)
    for y in range(0, image_size - square_size + 1, square_size + spacing):
        for x in range(0, image_size - square_size + 1, square_size + spacing):
            img[:,  bord+ y:bord+y+square_size, bord+x:bord+x+square_size] = 1.0
    
    return img


start_img = create_grid_image(image_size=32, square_size=4, spacing=4)


siz = start_img.shape[-2]

for ses_idx, session_id in enumerate(session_ids):
    # load model from state dict + get targets etc. for variable sizing
    exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"])
    full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device)
    
    full_model.load_state_dict(torch.load(f"{model_output_path}_{session_id}.pickle", weights_only=True))
    
    for i, (stimulus, targets) in (enumerate(train_loader)): 
        break
        
    for param in full_model.parameters():
        param.requires_grad = False

    for neuron_id in config["neuron_ids"]:
    
        model = AffineTransformationModel(steps=5, device=device)
        #frames = model(img)
    
        #full_model.model[0].stim = nn.Parameter(start_img.to(device), requires_grad = True)
    
        optim = torch.optim.SGD(model.parameters(), lr=config["mei_lr"], momentum=0.1)
        
        losses = []
        for i in trange(num_iter):
            # start img 
            frames = model(start_img.unsqueeze(0).to(device))
            mod_input = torch.stack(frames, 0).permute(1, 2, 0, 3, 4)
            loss = -full_model(mod_input)[0, neuron_id]
            #loss = loss + norm_weight*torch.mean((full_model.model[0].stim - 0) ** 2) + \
            #temp_lambda * temporal_smoothness_loss(full_model.model[0].stim) + space_lambda * total_variation_loss(full_model.model[0].stim)
            loss.backward() # backward pass
            losses.append(loss.data.detach().cpu().numpy())
     
            optim.step() #gradient descent
            optim.zero_grad()
        
            # gaussian blur over space
            """
            with torch.no_grad():
                if blur:
                    if config["modality"] == "video":
                        for j in range(stim_dims[2]):
                            blur_trans = transforms.GaussianBlur(3, sigma=blur_mag)
                            full_model.model[0].stim.data[0, :, j] = blur_trans(full_model.model[0].stim.data[0, :, j])
                    else:
                        blur_trans = transforms.GaussianBlur(3, sigma=blur_mag)
                        full_model.model[0].stim.data[0] = blur_trans(full_model.model[0].stim.data[0].mean(0, keepdim=True)).repeat(3, 1, 1)
                        #full_model[0].stim.data[0] = blur_trans(full_model[0].stim.data[0])
            """
        stim = (full_model.model[0].stim.data.detach().cpu().numpy())

        fig, axs = plt.subplots(1, 1, figsize=(3, 3))
        plt.plot(np.squeeze(np.array(losses))[10:])
        plt.xlabel("iteration")
        plt.ylabel("synthetic neuron activity")
        plt.tight_layout()
        plt.show()

        fig, axs = plt.subplots(1, 5, figsize=(10, 2))
        for i in range(5):
            plt.sca(axs[i])
            plt.imshow(np.mean(mod_input.detach().cpu().numpy(), (0,1))[i], cmap='Greys_r')
        
            #plt.clim([0, 1])
            #plt.colorbar()
        plt.show()

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(10, 2))
for i in range(5):
    plt.sca(axs[i])
    plt.imshow(np.mean(mod_input.detach().cpu().numpy(), (0,1))[i], cmap='Greys_r')

    #plt.clim([0, 1])
    #plt.colorbar()
plt.show()

In [None]:
from copy import deepcopy
import torch.nn.functional as F

# additional configuration variables for searching for highly activating stimuli
config["stim_output_dir"] = f"./data/{modality}/{modality}_datasets_meis/"
config["n_stim"] = 25 # number of most activating and least activating stimuli to return 
config["n_log"] = 5
config["pop_act"] = False
config["neuron_ids"] = [0]

if config["pop_act"]:
    config["neuron_ids"] = [0]

config["mei_lr"] = 0.5 
neuron_id = 15#27
num_iter = 1000
norm_weight = 0
# 0 #300 #250 # 0.3 #0.1
blur = True
blur_mag = 0.75#0.35 #0.5#0.5

temp_lambda = 1000
space_lambda = 0

neuron_id = config["neuron_ids"][0]

def temporal_smoothness_loss(video):
    # video shape: (batch, channels, frames, height, width)
    loss = 0.0
    for t in range(video.shape[2] - 1):
        loss += F.mse_loss(video[:, :, t, :, :], video[:, :, t+1, :, :])
    #loss += F.mse_loss(video[:, :, 0, :, :], video[:, :, -1, :, :])
    return loss

def total_variation_loss(video):
    tv_loss = torch.sum(torch.abs(video[:, :, :, :, :-1] - video[:, :, :, :, 1:])) + \
              torch.sum(torch.abs(video[:, :, :, :-1, :] - video[:, :, :, 1:, :]))
    return tv_loss

start_img = train_dataset[0][0][:, 0, :, :]

for ses_idx, session_id in enumerate(session_ids):
    # load model from state dict + get targets etc. for variable sizing
    exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"])
    full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device)
    
    full_model.load_state_dict(torch.load(f"{model_output_path}_{session_id}.pickle", weights_only=True))
    
    for i, (stimulus, targets) in (enumerate(train_loader)): 
        break
        
    for param in full_model.parameters():
        param.requires_grad = False

    
    full_model.model[0].stim.data = full_model.model[0].stim.data.to(device) - 0.5
    init_stim = deepcopy(full_model.model[0].stim.data.detach().cpu().numpy())

    full_model.model[0].stim.requires_grad=True
    optim = torch.optim.SGD([full_model.model[0].stim], lr=config["mei_lr"], momentum=0.1)
    
    losses = []
    for i in trange(num_iter):
        loss = -full_model(full_model.model[0].stim)[0, neuron_id]
        loss = loss + norm_weight*torch.mean((full_model.model[0].stim - 0) ** 2) + \
        temp_lambda * temporal_smoothness_loss(full_model.model[0].stim) + space_lambda * total_variation_loss(full_model.model[0].stim)
        loss.backward() # backward pass
        losses.append(loss.data.detach().cpu().numpy())
 
        optim.step() #gradient descent
        optim.zero_grad()
    
        # gaussian blur over space
        with torch.no_grad():
            if blur:
                if config["modality"] == "video":
                    for j in range(stim_dims[2]):
                        blur_trans = transforms.GaussianBlur(3, sigma=blur_mag)
                        full_model.model[0].stim.data[0, :, j] = blur_trans(full_model.model[0].stim.data[0, :, j])
                else:
                    blur_trans = transforms.GaussianBlur(3, sigma=blur_mag)
                    full_model.model[0].stim.data[0] = blur_trans(full_model.model[0].stim.data[0].mean(0, keepdim=True)).repeat(3, 1, 1)
                    #full_model[0].stim.data[0] = blur_trans(full_model[0].stim.data[0])
    
    stim = (full_model.model[0].stim.data.detach().cpu().numpy())
    
    plt.plot(np.squeeze(np.array(losses))[10:])
    plt.xlabel("iteration")
    plt.ylabel("synthetic neuron activity")
    plt.tight_layout()
    plt.show()

In [None]:
config["modality"]

In [None]:
for i in range(5):
    plt.imshow(np.mean(stim, (0,1))[i], cmap='Greys_r')

    #plt.clim([0, 1])
    plt.colorbar()
    plt.show()

In [None]:
for i in range(4):
    plt.imshow(np.mean(stim, (0,1))[i] - np.mean(stim, (0,1))[i+1], cmap='Greys_r')

    #plt.clim([0.3, 0.7])
    plt.colorbar()
    plt.show()

#### Step 4 - in silico electrophysiology experiments/rf centers/tuning

In [None]:
# DRIFTING GRATINGS
def create_drifting_gratings():
    ntau = 5
    radius = 16
    ndirections = 8

    lx, lt = 16, 16

    # Create stimuli that contain all combos that are needed
    xi, yi = np.meshgrid(np.arange(-16, 16), np.arange(-16, 16))
    mask = xi**2 + yi**2 < radius**2
    oi = (np.arange(ndirections) / ndirections * 2 * np.pi).reshape((-1, 1, 1, 1))
    ti = np.arange(ntau)
    ti = ti - ti.mean()

    vals = []
    stims = []

    ri = (np.cos(oi) * xi.reshape((1, 1, xi.shape[0], xi.shape[1])) - np.sin(oi) * yi.reshape((1, 1, xi.shape[0], xi.shape[1])))
    X = mask.reshape((1, 1, xi.shape[0], xi.shape[1])) * np.cos((ri / lx) * 2 * np.pi - ti.reshape((1, -1, 1, 1)) / lt * 2 *np.pi)
    X = np.stack([X, X, X], axis=1) # Go from black and white to RGB
    return X

X_drift = torch.tensor(create_drifting_gratings()).to(device='cuda', dtype=torch.float)

In [None]:
# SPIRAL SQUARE
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from scipy.ndimage import gaussian_filter

# Function to generate a single frame
def create_square_frame(size, rotation, canvas_size=32):
    img = Image.new('L', (canvas_size, canvas_size), color=0)  # Create black canvas
    draw = ImageDraw.Draw(img)
    
    # Define the square's coordinates (centered)
    top_left = ((canvas_size - size) // 2, (canvas_size - size) // 2)
    bottom_right = ((canvas_size + size) // 2, (canvas_size + size) // 2)
    
    # Draw the square
    draw.rectangle([top_left, bottom_right], fill=255)
    
    # Rotate the image
    img = img.rotate(rotation, expand=False)
    
    return np.array(img)

# Function to generate frames for a specific transformation
def generate_stimuli_frames(expanding=True, rotating_clockwise=True, num_frames=5):
    frames = []
    
    if expanding is True:
        sizes = np.linspace(10, 20, num_frames)  # Square grows in size
    elif expanding is False:
        sizes = np.linspace(20, 10, num_frames)  # Square shrinks in size
    else:
        sizes = [15] * num_frames  # Keep square the same size if no expansion/contraction
    
    if rotating_clockwise is True:
        rotations = np.linspace(0, -45, num_frames)  # Clockwise rotation
    elif rotating_clockwise is False:
        rotations = np.linspace(0, 45, num_frames)  # Counterclockwise rotation
    else:
        rotations = [0] * num_frames  # No rotation
    
    for size, rotation in zip(sizes, rotations):
        frame = gaussian_filter(create_square_frame(int(size), rotation), sigma=2)
        frames.append(frame)
    
    return frames

# Function to visualize frames using matplotlib
def visualize_stimuli_frames(frames):
    fig, axes = plt.subplots(1, len(frames), figsize=(15, 5))
    
    # Plot each frame
    for i, frame in enumerate(frames):
        axes[i].imshow(frame, cmap='gray')
        axes[i].axis('off')  # Hide the axes
    
    plt.show()

# Generate and visualize 8 different stimuli for each condition
conditions = [
    {"expanding": True, "rotating_clockwise": None},  # Only expanding
    {"expanding": True, "rotating_clockwise": True},
    {"expanding": None, "rotating_clockwise": True},   # Only rotating clockwise
    {"expanding": False, "rotating_clockwise": True},
    {"expanding": False, "rotating_clockwise": None},  # Only contracting
    {"expanding": False, "rotating_clockwise": False},
    {"expanding": None, "rotating_clockwise": False},   # Only rotating counterclockwise
    {"expanding": True, "rotating_clockwise": False},
]
    
# Visualize 8 stimuli for the first condition as an example
all_stim = []
for j in range(8):  # 8 stimuli per condition
    condition = conditions[j]  # Change the index for other conditions
    expanding = condition["expanding"] if condition["expanding"] is not None else None
    rotating_clockwise = condition["rotating_clockwise"] if condition["rotating_clockwise"] is not None else None
    
    frames = generate_stimuli_frames(expanding=expanding, rotating_clockwise=rotating_clockwise)
    all_stim.append(frames)
    visualize_stimuli_frames(frames)

X_spiral = torch.tensor(np.array(all_stim), dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1, 1).to(device)/255/2+0.5

In [None]:
def get_all_predictions(X):
    config["neuron_ids"] = [0,1]
    all_preds = []
    
    for ses_idx, session_id in enumerate(session_ids):
        # load model from state dict + get targets etc. for variable sizing
        exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"])
        full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device)
    
        full_model.load_state_dict(torch.load(f"{model_output_path}_{session_id}.pickle", weights_only=True))
        
        for i, (stimulus, targets) in (enumerate(train_loader)): 
            break
            
        for param in full_model.parameters():
            param.requires_grad = False
    
        preds = full_model(X)
        all_preds.append(preds.detach().cpu().numpy())
    
    preds = np.column_stack(all_preds)
    return preds

def get_rf_preds():
    all_preds = []
    
    for ses_idx, session_id in enumerate(session_ids):
        # load model from state dict + get targets etc. for variable sizing
        exp_var_threshold = config["exp_var_thresholds"][ses_idx]
    
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], exp_var_threshold, stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"])
        full_model = FullModel(config["modality"], config["layer"], stim_shape, train_dataset, use_sigma = config['use_sigma'], center_readout=config['center_readout'], use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], use_pretrained = config["use_pretrained"], feat_ext_type = config["feat_ext_type"],flatten_time = config["flatten_time"], device=device)
    
        full_model.load_state_dict(torch.load(f"{model_output_path}_{session_id}.pickle", weights_only=True))
        
        for i, (stimulus, targets) in (enumerate(train_loader)): 
            break
            
        for param in full_model.parameters():
            param.requires_grad = False

        for i in range(full_model.model[1].mu.shape[0]):
            all_preds.append(full_model.model[1].mu[i, :].detach().cpu().numpy())
    
    preds = np.column_stack(all_preds)
    return preds
    
spiral_preds = get_all_predictions(X_spiral)
grating_preds = get_all_predictions(X_drift)
rf_preds = get_rf_preds()



In [None]:
# get true values of spiral tuning, drift tuning, and RF location
spiral_true = []
grating_true = []
rf_true = []
expvars = []
for session_id in session_ids: 
    spiral_tuning = np.load(f"./data/spiral_selectivity/{session_id}.npy").T
    dir_tuning = np.load(f"./data/dir_selectivity/{session_id}.npy").T
    rf_maps = np.load(f"./data/rf_maps/{session_id}.npy")

    input_dir = f'./data/{config["modality"]}/'
    thresh = config["exp_var_thresholds"][ses_idx]
    with open(join(input_dir, f"{session_id}.pickle"), "rb") as f:
        model_input = pickle.load(f)
    
    spiral_true.append(spiral_tuning[:, model_input['expvar'] > thresh])
    grating_true.append(dir_tuning[:, model_input['expvar'] > thresh])
    rf_maps = rf_maps[model_input['expvar'] > thresh]
    rf_pos = np.array([np.unravel_index(np.argmax(rf_maps[i, :, :]), rf_maps[i, :, :].shape) for i in range(rf_maps.shape[0])])
    plt.scatter(rf_pos[:, 0], rf_pos[:, 1], alpha=0.1)
    plt.show()
    rf_true.extend([np.unravel_index(np.argmax(rf_maps[i, :, :]), rf_maps[i, :, :].shape) for i in range(rf_maps.shape[0])])
    expvars.append(model_input['expvar'][model_input['expvar'] > thresh][None, :])
spiral_true = np.column_stack(spiral_true)
grating_true = np.column_stack(grating_true)
rf_true = np.array(rf_true).T

In [None]:
from scipy.stats import pearsonr

def get_pearsonr(true, preds, jj = 0):
    rs = []
    ps = []
    for i in range(preds.shape[1]):
        r, p = pearsonr(true[:, i], np.roll(preds[:, i], jj))
        rs.append(r)
        ps.append(p)
    return np.array(rs), np.array(ps)

def get_pearsonr_rf(true, preds):
    r1 = pearsonr(-true[0, :], preds[0, :]).statistic
    r2 = pearsonr(-true[1, :], preds[1, :]).statistic
    return r2

spiral_rs, spiral_ps = get_pearsonr(spiral_true, spiral_preds)
grating_rs, grating_ps = get_pearsonr(grating_true, grating_preds, 4)
rf_rs = get_pearsonr_rf(rf_true, rf_preds)

print(f"spiral correlation {np.nanmean(spiral_rs)}")
print(f"grating correlation {np.nanmean(grating_rs)}")
print(f"rf correlation {np.nanmean(rf_rs)}")

In [None]:
torch.nn.functional.grid_sample(torch.tensor([[[[0,0,0,1],[0,0,0,0],[0,0,0,0]]]], dtype=torch.float32), torch.tensor([[[[1, 1]]]], dtype=torch.float32))

In [None]:
torch.tensor([[[[0,0,0,1],[0,0,0,0],[0,0,0,0]]]], dtype=torch.float32).shape

In [None]:
plt.scatter(rf_preds[0, :], rf_preds[1, :])

In [None]:
plt.scatter(rf_true[0, :], rf_true[1, :], alpha=0.1)

In [None]:
np.sum(grating_ps < 0.05)/len(grating_ps)

In [None]:
np.sum(spiral_ps < 0.05)/len(grating_ps)

In [None]:
def rescale(x):
    x = x - np.min(x)
    x = x/np.max(x)
    return x

for i in range(grating_ps.shape[0]):
    plt.plot(rescale(grating_true[:, i]))
    plt.plot(rescale(grating_preds[:, i]))
    plt.show()

In [None]:
spiral_preds

In [None]:
rs = np.array(rs)
rs[np.isnan(rs)] = 0

In [None]:
print(np.argmax(rs))
np.max(rs)

In [None]:
len(rs)

In [None]:
expvars[:, 546]

In [None]:
i = np.argmax(rs)
XX = all_spiral[:, i]
plt.plot(XX)
plt.show()

XX = preds[:, i].detach().cpu().numpy()
plt.plot(XX)
plt.show()