# Poisson HMM - model comparison

In [1]:
""" 
IMPORTS
"""
import os
import autograd.numpy as np
import pickle
from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
import seaborn as sns
from collections import defaultdict
import pandas as pd

from one.api import ONE
from jax import vmap
from pprint import pprint
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from dynamax.hidden_markov_model import PoissonHMM
import concurrent.futures

# Get my functions
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//2_fit_models/'
#functions_path = '/Users/ineslaranjeira/Documents/Repositories/representation_learning_variability//Models/Sub-trial//2_fit_models/'
os.chdir(functions_path)
from preprocessing_functions import idxs_from_files, prepro_design_matrix, concatenate_sessions
from fitting_functions import cross_validate_poismodel, compute_inputs
functions_path =  '/home/ines/repositories/representation_learning_variability/Models/Sub-trial//3_postprocess_results/'
#functions_path = '/Users/ineslaranjeira/Documents/Repositories/representation_learning_variability//Models/Sub-trial//2_fit_models/'
os.chdir(functions_path)
from plotting_functions import plot_transition_mat, plot_states_aligned, params_to_df, align_bin_design_matrix, states_per_trial_phase, plot_states_aligned_trial, traces_over_sates

one = ONE(mode="remote")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Parameters

In [2]:
bin_size = 0.1
bin_size = 0.02

# Plotting params
multiplier = 1/bin_size

event_type_list = ['goCueTrigger_times']  # , 'feedback_times', 'firstMovement_times'
event_type_name = ['Go cue']  # , 'Feedback time', 'First movement onset'


In [3]:
# Load preprocessed data
prepro_results_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/' + str(bin_size) + '/'
os.chdir(prepro_results_path)
idxs, mouse_names, matrix_all, matrix_all_unnorm, session_all = pickle.load(open(prepro_results_path + "preprocessed_data_v4_170724", "rb"))
collapsed_matrices, collapsed_unnorm, collapsed_trials = concatenate_sessions (mouse_names, matrix_all, matrix_all_unnorm, session_all)

## Parameters

In [4]:
num_iters = 100
num_train_batches = 5
method = 'kmeans'
threshold = 0.05


use_sets = [['avg_wheel_vel'], ['Lick count'], ['whisker_me'],
            ['left_X', 'left_Y', 'right_X', 'right_Y'], ['nose_X', 'nose_Y']]
var_interest_map = ['avg_wheel_vel', 'Lick count', 'whisker_me', 'left_X', 'nose_X']
idx_init_list = [0, 1, 2, 3, 7]
idx_end_list = [1, 2, 3, 7, 9]

In [5]:
def grid_search_kappa(mouse_name):

    var_interest = 'Lick count'
    concatenate = True
    num_states = 2

    kappas = [0, 5, 10, 50, 100, 500, 800, 1000, 1500]
    kappas = [0, 1, 5, 10, 100, 500, 1000, 2000, 5000, 7000, 10000]

    index_var = np.where(np.array(var_interest_map)==var_interest)[0][0]
    idx_init = idx_init_list[index_var]
    idx_end = idx_end_list[index_var]
    var_names = use_sets[index_var]
    
    # Initialize vars for saving results
    all_init_params = defaultdict(list)
    all_fit_params = defaultdict(list)
    all_lls = defaultdict(list)
    all_baseline_lls = defaultdict(list)
    
    print('Fitting mouse ' + mouse_name)
    
    # Get mouse data
    if concatenate == True:
        design_matrix = collapsed_matrices[mouse_name][:,idx_init:idx_end]
    else:
        print('Not ready for non-concatenated sessions')
        # design_matrix = matrix_all[mouse_name][session]
        
    if len(np.shape(design_matrix)) > 2:
        design_matrix = design_matrix[0]
        
    " Fit model with cross-validation"
    # Prepare data for cross-validation
    num_timesteps = np.shape(design_matrix)[0]
    emission_dim = np.shape(design_matrix)[1]
    shortened_array = np.array(design_matrix[:(num_timesteps // num_train_batches) * num_train_batches])
    train_emissions = jnp.stack(jnp.split(shortened_array, num_train_batches))
    
    " Fit model with cross-validation across kappas "
    for kappa in kappas:
        
        print(f"fitting model with {kappa} kappa")
        
        # Make a range of Poisson HMMs
        test_arhmm = PoissonHMM(num_states, emission_dim, transition_matrix_stickiness=kappa)
    
        all_val_lls, fit_params, init_params, baseline_lls = cross_validate_poismodel(test_arhmm, 
                                                                              jr.PRNGKey(0), shortened_array, 
                                                                              train_emissions, num_train_batches)
        # Save results
        all_lls[kappa] = all_val_lls
        all_baseline_lls[kappa] = baseline_lls
        all_init_params[kappa] = init_params
        all_fit_params[kappa] = fit_params
            
    mouse_results = all_lls, all_baseline_lls, all_init_params, all_fit_params
    
    # Save design matrix
    data_path =  '/home/ines/repositories/representation_learning_variability/DATA/Sub-trial/Results/' + str(bin_size) + '/grid_search/'
    os.chdir(data_path)
    pickle.dump(mouse_results, open("best_results_" + var_names[0] + '_' + mouse_name , "wb"))

                
def parallel_process_data(sessions, function_name):
    with concurrent.futures.ThreadPoolExecutor() as executor:

        # Process each chunk in parallel
        executor.map(function_name, sessions)
        

In [7]:
# Loop through animals
function_name = grid_search_kappa

for m, mouse in enumerate(mouse_names):

    parallel_process_data([mouse], function_name)

Fitting mouse CSHL045
fitting model with 0 kappa
fitting model with 1 kappa
fitting model with 5 kappa
fitting model with 10 kappa
fitting model with 100 kappa
fitting model with 500 kappa
fitting model with 1000 kappa
fitting model with 2000 kappa
fitting model with 5000 kappa
fitting model with 7000 kappa
fitting model with 10000 kappa
Fitting mouse ibl_witten_25
fitting model with 0 kappa
fitting model with 1 kappa
fitting model with 5 kappa
fitting model with 10 kappa
fitting model with 100 kappa
fitting model with 500 kappa
fitting model with 1000 kappa
fitting model with 2000 kappa
fitting model with 5000 kappa
fitting model with 7000 kappa
fitting model with 10000 kappa
Fitting mouse ibl_witten_29
fitting model with 0 kappa
fitting model with 1 kappa
fitting model with 5 kappa
fitting model with 10 kappa
fitting model with 100 kappa
fitting model with 500 kappa
fitting model with 1000 kappa
fitting model with 2000 kappa
fitting model with 5000 kappa
fitting model with 7000 kappa

KeyboardInterrupt: 

fitting model with 10000 kappa
