# Combination of fits of different models

In [1]:

""" 
IMPORTS
"""
import os
import autograd.numpy as np
import pickle
from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
import seaborn as sns
from collections import defaultdict
import pandas as pd

from one.api import ONE
from jax import vmap
from pprint import pprint
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# from dynamax.hidden_markov_model import GaussianHMM
from dynamax.hidden_markov_model import LinearAutoregressiveHMM
from dynamax.hidden_markov_model import PoissonHMM

from dynamax.utils.plotting import gradient_cmap
from dynamax.utils.utils import random_rotation

# Get my functions
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//2_fit_models/'
#functions_path = '/Users/ineslaranjeira/Documents/Repositories/representation_learning_variability//Models/Sub-trial//2_fit_models/'
os.chdir(functions_path)
from preprocessing_functions import idxs_from_files, prepro_design_matrix, concatenate_sessions, fix_discontinuities
from fitting_functions import cross_validate_armodel, compute_inputs
from plotting_functions import plot_transition_mat, plot_states_aligned, align_bin_design_matrix, states_per_trial_phase, plot_states_aligned_trial, traces_over_sates, traces_over_few_sates

one = ONE()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Downloading: /home/ines/Downloads/ONE/alyx.internationalbrainlab.org/tmphw4hchoo/cache.zip Bytes: 127561276


 10%|▉         | 12.0/121.65191268920898 [05:39<51:41, 28.28s/it] 


KeyboardInterrupt: 

## Parameters

In [None]:
bin_size = 0.1

# Plotting params
multiplier = 1/bin_size

event_type_list = ['goCueTrigger_times']  # , 'feedback_times', 'firstMovement_times'
event_type_name = ['Go cue']  # , 'Feedback time', 'First movement onset'


## Data path

In [None]:
data_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Design matrix/' + 'v4_5Jul2024/' + str(bin_size) + '/'

os.chdir(data_path)
design_matrices = os.listdir(data_path)

var_sets = [['avg_wheel_vel'], ['nose_X', 'nose_Y'], ['Lick count'], 
                    ['left_X', 'left_Y', 'right_X', 'right_Y'], ['whisker_me']]
var_sets = [['avg_wheel_vel'], ['nose_X', 'nose_Y'], 
                    ['left_X', 'left_Y', 'right_X', 'right_Y'], ['whisker_me']]
    
# Load frame rate
fr_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Design matrix/frame_rate/'
os.chdir(fr_path)
frame_rate = pickle.load(open(fr_path + "frame_rate", "rb"))


## Get mice list

In [None]:
idxs, mouse_names = idxs_from_files(one, design_matrices, frame_rate, data_path, bin_size)

local md5 mismatch on dataset: churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy
/home/ines/Downloads/ONE/alyx.internationalbrainlab.org/churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy: 100%|██████████| 4.60k/4.60k [00:00<00:00, 9.93kB/s]
local md5 mismatch on dataset: cortexlab/Subjects/KS014/2019-12-03/001/alf/_ibl_trials.goCueTrigger_times.npy
/home/ines/Downloads/ONE/alyx.internationalbrainlab.org/cortexlab/Subjects/KS014/2019-12-03/001/alf/_ibl_trials.goCueTrigger_times.npy: 100%|██████████| 4.38k/4.38k [00:00<00:00, 9.45kB/s]


## Parameters

In [None]:
num_iters = 100
num_train_batches = 5
method = 'kmeans'
threshold = 0.05

num_lags = 11  # First order ARHMM
num_states = 2
kappa = 1000

# Plotting params
bin_size = 0.1

# Values for grid search (should get this from results)
last_lag = 20
lag_step = 2
start_lag = 1
Lags = list(range(start_lag, last_lag, lag_step))
kappas = [0, 1, 5, 10, 100, 500, 1000, 2000, 5000, 7000, 10000]

## Fit sessions with params from best fit for all variable sets (models)

In [None]:
def best_lag_kappa(all_lls, all_baseline_lls, design_matrix, num_train_batches, kappas, Lags):
    
    # Find best params per mouse
    best_lag = {}
    best_kappa = {}
    mean_bits_LL = {}
                  
    # Get size of folds
    num_timesteps = np.shape(design_matrix)[0]
    shortened_array = np.array(design_matrix[:(num_timesteps // num_train_batches) * num_train_batches])
    fold_len =  len(shortened_array)/num_train_batches
    
    mean_bits_LL = np.ones((len(Lags), len(kappas))) * np.nan
    best_fold = np.ones((len(Lags), len(kappas))) * np.nan
    
    for l, lag in enumerate(Lags):
        
        # Reshape
        lag_lls = []
        b_lag_lls = []
        b_fold = []
        
        for k in kappas:
            lag_lls.append(all_lls[lag][k])
            b_lag_lls.append(all_baseline_lls[lag][k])
            # Best fold
            b_f = np.where(all_lls[lag][k]==np.nanmax(all_lls[lag][k]))[0][0]
            b_fold.append(b_f)
            
        avg_val_lls = np.array(lag_lls)
        baseline_lls = np.array(b_lag_lls)
        bits_LL = (np.array(avg_val_lls) - np.array(baseline_lls)) / fold_len * np.log(2)
        
        mean_bits_LL[l,:] = np.nanmean(bits_LL, axis=1)        
        best_fold[l, :] = b_fold
        
    # Save best params for the mouse
    best_lag = Lags[np.where(mean_bits_LL==np.nanmax(mean_bits_LL))[0][0]]
    best_kappa = kappas[np.where(mean_bits_LL==np.nanmax(mean_bits_LL))[1][0]]
    

    return best_lag, best_kappa, mean_bits_LL, best_fold

In [20]:

for s, set in enumerate(var_sets):
    
    var_names = set
    
    # Get data for all mice for the model of interest
    os.chdir(data_path)
    matrix_all, matrix_all_unnorm, session_all = prepro_design_matrix(one, idxs, mouse_names, 
                                                                      bin_size, var_names, data_path, first_90=True)
    collapsed_matrices, collapsed_unnorm, collapsed_trials = concatenate_sessions(mouse_names, matrix_all, 
                                                                                   matrix_all_unnorm, session_all)
    
    # Loop through animals
    for m, mat in enumerate(idxs[0:1]):
        if len(mat) > 35: 
            mat = '46794e05-3f6a-4d35-afb3-9165091a5a74_CSHL045'
            
            mouse_name = mat[37:]
            session = mat[0:36]
            
            # Get design_matrix
            design_matrix = collapsed_matrices[mouse_name]
            if len(np.shape(design_matrix)) > 2:
                design_matrix = design_matrix[0]
            
            # Get results from grid search
            results_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/'
            os.chdir(results_path)
            all_lls, all_baseline_lls, all_init_params, all_fit_params = pickle.load(open("best_results_" + var_names[0] + '_' + mouse_name, "rb"))
            
            # Retrieve best fits
            best_lag, best_kappa, mean_bits_LL, best_fold = best_lag_kappa(all_lls, all_baseline_lls, design_matrix, num_train_batches, kappas, Lags)
            index_lag = np.where(np.array(Lags)==best_lag)[0][0]
            index_kappa = np.where(np.array(kappas)==best_kappa)[0][0]
            use_fold = int(best_fold[index_lag, index_kappa])
            
            " Fit model with best params"
            # Prepare data 
            num_timesteps = np.shape(design_matrix)[0]
            emission_dim = np.shape(design_matrix)[1]
            shortened_array = np.array(design_matrix[:(num_timesteps // num_train_batches) * num_train_batches])
            train_emissions = jnp.stack(jnp.split(shortened_array, num_train_batches))
            
            # Compute inputs for required timelags
            my_inputs = compute_inputs(shortened_array, best_lag, emission_dim)
            train_inputs = jnp.stack(jnp.split(my_inputs, num_train_batches))
            
            best_params = all_fit_params[best_lag][best_kappa]
            
            # Find parameters for best fold
            initial_probs = best_params[0].probs[use_fold]
            transition_matrix = best_params[1].transition_matrix[use_fold]
            emission_weights = best_params[2].weights[use_fold]
            emission_biases = best_params[2].biases[use_fold]
            emission_covariances = best_params[2].covs[use_fold]        
            
            # Initialize new hmm
            new_arhmm = LinearAutoregressiveHMM(num_states, emission_dim, num_lags=best_lag, transition_matrix_stickiness=best_kappa)
            best_fold_params, props = new_arhmm.initialize(key=jr.PRNGKey(0), method=method,
                                            initial_probs=initial_probs,
                                            transition_matrix=transition_matrix,               
                                            emission_weights=emission_weights,
                                            emission_biases=emission_biases, 
                                            emission_covariances=emission_covariances,
                                            emissions=shortened_array)  # not sure if I need to include  
            
            # Get state estimates for validation data
            most_likely_states = new_arhmm.most_likely_states(best_fold_params, shortened_array, my_inputs)

            # Save most_likely_states
            os.chdir(results_path)
            pickle.dump(most_likely_states, open("most_likely_states" + var_names[0] + '_' + mouse_name , "wb"))   
        

local md5 mismatch on dataset: churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy
/home/ines/Downloads/ONE/alyx.internationalbrainlab.org/churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy: 100%|██████████| 4.60k/4.60k [00:00<00:00, 10.0kB/s]
local md5 mismatch on dataset: cortexlab/Subjects/KS014/2019-12-03/001/alf/_ibl_trials.goCueTrigger_times.npy
/home/ines/Downloads/ONE/alyx.internationalbrainlab.org/cortexlab/Subjects/KS014/2019-12-03/001/alf/_ibl_trials.goCueTrigger_times.npy: 100%|██████████| 4.38k/4.38k [00:00<00:00, 6.69kB/s]
local md5 mismatch on dataset: churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy
/home/ines/Downloads/ONE/alyx.internationalbrainlab.org/churchlandlab/Subjects/CSHL049/2020-01-11/001/alf/_ibl_trials.stimOff_times.npy: 100%|██████████| 4.60k/4.60k [00:00<00:00, 9.16kB/s]
local md5 mismatch on dataset: cortexlab/Subjects/KS014/2019-12-03/001/alf/_ibl_trials.goCueTrigger

In [22]:
# Get results from grid search
results_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/'
os.chdir(results_path)
most_likely_states = pickle.load(open("most_likely_states" + var_names[0] + '_' + mouse_name, "rb"))
