## Model responses generation for context-detection task in silico

In [None]:
import os
import pickle
import gc
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np
import pandas as pd
from scipy.stats import pearsonr as lincorr
from functools import partial

from matplotlib.ticker import FixedLocator
%load_ext autoreload
%autoreload 2
from openretina.hoefling_2024.nnfabrik_model_loading import load_ensemble_retina_model_from_directory
from openretina.hoefling_2024.nnfabrik_model_loading import Center
from openretina.utils.misc import CustomPrettyPrinter
from rgc_natstim_model.constants.plot_settings import cmap_colors as rgc_colors
from rgc_natstim_model.utils.data_handling import unPickle,makePickle
from rgc_natstim_model.utils.inference import get_model_responses
from rgc_natstim_model.constants.identifiers import dh2eh, dh2eh_linear, example_nids
from rgc_natstim_model.utils.nm_vis import rotate90
from rgc_natstim_model.analyses.context_change_detection import bootstrap_ci,cohens_d,get_ind_roc_curve,get_roc_curve

In [202]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
pp = CustomPrettyPrinter(indent=4, max_lines=40)
# import sys
# sys.path.append('/gpfs01/euler/User/lhoefling/GitHub/nnfabrik_euler')

In [203]:
dataset_hashes = list(dh2eh.keys())
ensemble_hashes = list(dh2eh.values())

file_name_template = r'2024-06-15_neuron_data_stim_c285329_responses_{}_wri.h5'
data_file_names = [file_name_template.format(dh[:7]) for dh in dataset_hashes]

data_file_names.pop()
data_file_names.append('2024-06-15_neuron_data_stim_41fc277_responses_5e62060_wri.h5')

model_type = 'nonlinear'
base_folder = '/gpfs01/euler/data/SharedFiles/projects/Hoefling2024/'

movie_file_name = '2024-01-11_movies_dict_c285329.pkl'
flipped_movie_file_name = '2024-05-27_movies_dict_41fc277.pkl'

In [204]:
dataset_hashes = list(dh2eh.keys())
ensemble_hashes = list(dh2eh.values())

In [219]:
respGen_path = base_folder+'data/simulation/response_generation/'
stimGen_path = base_folder+'data/simulation/stimulus_generation/'

In [220]:
dh_2_session_ids = unPickle(respGen_path+'dh_2_session_ids.pkl')

### Load dataframes for later uses

In [212]:
## all 6984 neurons
df = pd.read_pickle(base_folder+'data/base/full_data_df.pkl')

### Load in-silico stimuli

In [206]:
speeds = [4,12,20,28]

In [207]:
## the final outcome of the stimulus generation notebook
stimuli_120f = {speed:np.load(stimGen_path+'sti1000_120f_s{}_18x16.npy'.format(speed))
               for speed in speeds}

## reshape the 120f stimuli before fitting into the models
model_input_dict = {speed:stimuli_120f[speed].reshape(4000,2,120,18,16) for speed in speeds}

## normalize the stimuli based on the movies, which were used to train models
stimulus_normalization = unPickle(respGen_path+'stimulus_normalization.pkl')
transform_mean = stimulus_normalization['mean']
transform_sd = stimulus_normalization['sd']

## the transformed model input
norm_model_input ={speed:(model_input_dict[speed]- transform_mean)/transform_sd
                   for speed in speeds }
## rotate the transfomred input for 90 degrees for models trained on rotated stimuli
R90_norm_model_input = {
    speed:rotate90(norm_model_input[speed]) for speed in speeds
}

### Generate model responses one speed by one speed

In [209]:
speed = 28 # 4, 12, 20, 28
R90_model_input_tor = torch.Tensor(R90_norm_model_input[speed].copy()).to('cuda')
model_input_tor = torch.Tensor(norm_model_input[speed].copy()).to('cuda')

In [153]:
prediction_dicts ={dh:{0:{},1:{},2:{},3:{}} for dh in dh2eh.keys()}
for dataset_hash in dataset_hashes:
    print(dataset_hash)
    ensemble_hash = dh2eh[dataset_hash]
    model_path = os.path.join(base_folder, "models", 
                              model_type, ensemble_hash)
    _, ensemble_model = load_ensemble_retina_model_from_directory(model_path)
    
    Center
    model_transform = Center(target_mean=[0., 0.])
    model_transform(ensemble_model)
    
    for tt_idx in range(4):
        prediction_dict = {}
        for session_id in dh_2_session_ids[dataset_hash]:
            # get model predictions for test stimulus
            print('   ',session_id)
            if dataset_hash =='5e620609fc7b491aa5edb4a5d4cd7276':
                predictions = get_model_responses(
                    ensemble_model, 
                    R90_model_input_tor[tt_idx*1000:tt_idx*1000+1000],
                 
                '_'.join(session_id.split('_')[1:])
               ).squeeze()
            else:
                predictions = get_model_responses(
                    ensemble_model,
                    model_input_tor[tt_idx*1000:tt_idx*1000+1000],
                '_'.join(session_id.split('_')[1:])
               ).squeeze()
    
            current_neuron_ids = df[(df['dataset_hash']==dataset_hash)
            & (df['session_id']==session_id)]['neuron_id'].values
            
            model_readout_idxs = df[(df['dataset_hash']==dataset_hash)
            & (df['session_id']==session_id)]['model_readout_idx'].values
            
            [prediction_dict.update({_neuron_id:predictions[:,:, _model_readout_idx]})
             for _neuron_id, _model_readout_idx in zip(current_neuron_ids, model_readout_idxs)]
        prediction_dicts[dataset_hash][tt_idx] = prediction_dict
    del ensemble_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Run garbage collection
    gc.collect()

069836032a542cb20fe4c678dde07615
Random seed 0 has been set.
Ignored Missing keys:
core.features.layer1.conv._log_speed_default
core.features.layer0.conv._log_speed_default
Random seed 1000 has been set.
Ignored Missing keys:
core.features.layer1.conv._log_speed_default
core.features.layer0.conv._log_speed_default
Random seed 2000 has been set.
Ignored Missing keys:
core.features.layer1.conv._log_speed_default
core.features.layer0.conv._log_speed_default
Random seed 3000 has been set.
Ignored Missing keys:
core.features.layer1.conv._log_speed_default
core.features.layer0.conv._log_speed_default
Random seed 4000 has been set.
Ignored Missing keys:
core.features.layer1.conv._log_speed_default
core.features.layer0.conv._log_speed_default
    session_1_ventral1_20210929
    session_1_ventral1_20210930
    session_1_ventral2_20210929
    session_1_ventral2_20210930
    session_2_ventral1_20210929
    session_2_ventral2_20210929
    session_2_ventral2_20210930
    session_3_ventral2_20210929

### Process the model responses

Each model neuron received an input tensor (4000, 2, 120, 18, 16) and produced an responses tensor (4000,90).

#### Organize model responses (4000,90) to (4,1000,90), reshaped by transition types

In [154]:
full_movie_resp_dict ={nid:np.zeros((4,1000,90)) for nid in df.index}
for dh in dh2eh.keys():
    for tt_idx in range(4):
        for nid in prediction_dicts[dh][tt_idx].keys():
            full_movie_resp_dict[nid][tt_idx] = prediction_dicts[dh][tt_idx][nid]

In [224]:
## save the full response dictionary
# makePickle(respGen_path+'s{}_full_movie_resp_dict.pkl'.format(speed),full_movie_resp_dict)

In [223]:
## a quick access to the pre-generated full response dictionary
full_movie_resp_dict = unPickle(respGen_path+'s{}_full_movie_resp_dict.pkl'.format(speed))

#### Bin the responses

Reduce the 90-dpt responses to a 120-frame transition snippet to one scalar number by taking the mean of the first 30 response datapoints. In this case, transition type doesn't make a difference.

4,1000,90 -> 4000,1

In [156]:
for each in full_movie_resp_dict:
    full_movie_resp_dict[each] = full_movie_resp_dict[each].flatten()
start_indices = np.arange(0, 4*1000*90, 90)

In [158]:
binned_movie_resp_dict = {nid: np.zeros(len(start_indices)) for nid in full_movie_resp_dict.keys()}
for nid, resp in full_movie_resp_dict.items():
    ## average over the first 30 responses (consistent with the previous analysis)
    temp = np.asarray([resp[i+30:i+60].mean() - resp[i+30] for i in start_indices])
    binned_movie_resp_dict[nid] = (temp-temp.mean())/temp.std()

In [159]:
# makePickle(respGen_path+'s{}_binned_movie_resp_dict.pkl'.format(speed),binned_movie_resp_dict)