## Ext. Data Figure 7: Model Selection and Control Data Analyses for the GLM-HMM

This notebook recreates the figure panels included in Extended Data Figure 7 of [Bolkan, Stone et al 2021](https://www.biorxiv.org/content/10.1101/2021.07.23.453573v1). 

The general premise of this notebook/figure, in the context of the paper, is to demonstrate how we selected certain model parameters and to show the results of some analyses conducted on the control group (a group of no-opsin mice for which the laser was on for a subset of trials, in the same manner as the experimental groups, but for which the laser should have no inhibiting effect on behavior). 

#### Import the required code packages and modules

In [4]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '..')

import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import pickle
from glmhmm import glm, glm_hmm
from glmhmm.utils import find_best_fit, crossval_split

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load data

In [10]:
# load the data for the indirect pathway cohort
x_d2 = np.load('data/indirect_x.npy') # z-scored design matrix
y_d2 = np.load('data/indirect_y.npy') # vector of right and left choices for each trial
sessions_d2 = np.load('data/indirect_sessions.npy') # vector of session start and stop indices
mouseIDs_d2 = np.load('data/indirect_mouseIDs.npy') # vector of mouse IDs for each trial

# load the data for the direct pathway cohort
x_d1 = np.load('data/direct_x.npy') # z-scored design matrix
y_d1 = np.load('data/direct_y.npy') # vector of right and left choices for each trial
sessions_d1 = np.load('data/direct_sessions.npy') # vector of session start and stop indices
mouseIDs_d1 = np.load('data/direct_mouseIDs.npy') # vector of mouse IDs for each trial

# load the data for the direct pathway cohort
x_ct = np.load('data/control_x.npy') # z-scored design matrix
y_ct = np.load('data/control_y.npy') # vector of right and left choices for each trial
sessions_ct = np.load('data/control_sessions.npy') # vector of session start and stop indices
mouseIDs_ct = np.load('data/control_mouseIDs.npy') # vector of mouse IDs for each trial

### Ext. Data Figure 7A: Cross Validation across States
#### Split the data

In [9]:
seeds = [55,38,13,23,103]
xtrain_d2,xtest_d2,ytrain_d2,ytest_d2,sesstrain_d2,sesstest_d2,testIx_d2,_ = crossval_split(x_d2, y_d2,
                                                                                            sessions_d2,
                                                                                            mouseIDs_d2,
                                                                                            test_size=0.2, 
                                                                                            seeds=seeds)
        
seeds = [10,66,100,73,200]
xtrain_d1,xtest_d1,ytrain_d1,ytest_d1,sesstrain_d1,sesstest_d1,testIx_d1,_ = crossval_split(x_d1, y_d1,
                                                                                            sessions_d1,
                                                                                            mouseIDs_d1,
                                                                                            test_size=0.2, 
                                                                                            seeds=seeds)

seeds = [0,7,237,411,219]
xtrain_ct,xtest_ct,ytrain_ct,ytest_ct,sesstrain_ct,sesstest_ct,testIx_ct,_ = crossval_split(x_ct, y_ct,
                                                                                            sessions_ct,
                                                                                            mouseIDs_ct,
                                                                                            test_size=0.2, 
                                                                                            seeds=seeds)

#### Set the hyperparmeters

In [11]:
C = 2 # number of observation classes
D = x_d2.shape[1] # number of GLM inputs (regressors)

#### Fit the training sets
This takes a foolish amount of time to run (like, almost a week -- average ~2 hours per fit x 5 sets of states x 5 training sets x 3 cohorts). The code is presented as it is here for clarity and instructional purposes but it is not actually advised to run this in a series of for loops. We strongly recommend taking this code out of the for loops and parallelizing/running as separate jobs on a remote server in order to cut down on the computation time.

In [None]:
inits = 20 # set the number of initializations
maxiter = 250 # maximum number of iterations of EM to allow for each fit
tol = 1e-3
folds = len(seeds)
states = 5

# store model objects for each simulated dataset
best_GLMHMMs_d2 = np.zeros((states,folds), dtype=object)
best_GLMHMMs_d1 = np.zeros((states,folds), dtype=object)
best_GLMHMMs_ct = np.zeros((states,folds), dtype=object)

for k in range(states):
    K = k+1
    for j in range(folds):
        # store values for each initialization
        lls_all_d2 = np.zeros((inits,maxiter))
        GLMHMMs_d2 = np.zeros((inits),dtype=object)
        lls_all_d1 = np.zeros((inits,maxiter))
        GLMHMMs_d1 = np.zeros((inits),dtype=object)
        lls_all_ct = np.zeros((inits,maxiter))
        GLMHMMs_ct = np.zeros((inits),dtype=object)

        # fit the models for each initialization
        for i in range(inits):
            
            if K == 1:
                ## indirect pathway cohort --------------------------------------------------------
                GLM_d2 = glm.GLM(N_d2,M,C,observations="bernoulli")
                w_init_d1 = GLM_d2.init_weights()
                w_d2, phi_d2 = GLM_d2.fit(xtrain_d2[j],w_init_d2,ytrain_d2[j])
                
                ## direct pathway cohort --------------------------------------------------------
                GLM_d1 = glm.GLM(N_d1,M,C,observations="bernoulli")
                w_init_d1 = GLM_d1.init_weights()
                w_d1, phi_d1 = GLM_d1.fit(xtrain_d1[j],w_init_d1,ytrain_d1[j])
                
                ## control cohort --------------------------------------------------------
                GLM_ct = glm.GLM(N_d1,M,C,observations="bernoulli")
                w_init_d1 = GLM_ct.init_weights()
                w_ct, phi_ct = GLM_ct.fit(xtrain_ct[j],w_init_ct,ytrain_ct[j])
            
            else:
                ## indirect pathway cohort --------------------------------------------------------
                GLMHMMs_d2[i] = glm_hmm.GLMHMM(N_d2,D,C,K,observations="bernoulli",gaussianPrior=1)
                A_init,w_init,_ = GLMHMMs_d2[i].generate_params(weights=['GLM',-0.2,1.2,xtrain_d2[j],ytrain_d2[j],1])
                lls_all_d2[i,:],_,_,_ = GLMHMMs_d2[i].fit(ytrain_d2[j],xtrain_d2[j],A_init,w_init,
                                                          maxiter=maxiter,tol=tol,sess=sesstrain_d2[j])

                ## direct pathway cohort ----------------------------------------------------------
                GLMHMMs_d1[i] = glm_hmm.GLMHMM(N_d1,D,C,K,observations="bernoulli",gaussianPrior=1)
                A_init,w_init,_ = GLMHMMs_d1[i].generate_params(weights=['GLM',-0.2,1.2,xtrain_d1[j],ytrain_d1[j],1])
                lls_all_d1[i,:],_,_,_ = GLMHMMs_d1[i].fit(ytrain_d1[j],xtrain_d1[j],A_init,w_init,
                                                          maxiter=maxiter,tol=tol,sess=sesstrain_d1[j])

                ## control cohort ----------------------------------------------------------
                GLMHMMs_ct[i] = glm_hmm.GLMHMM(N_d1,D,C,K,observations="bernoulli",gaussianPrior=1)
                A_init,w_init,_ = GLMHMMs_ct[i].generate_params(weights=['GLM',-0.2,1.2,xtrain_ct[j],ytrain_ct[j],1])
                lls_all_ct[i,:],_,_,_ = GLMHMMs_ct[i].fit(ytrain_ct[j],xtrain_ct[j],A_init,w_init,
                                                          maxiter=maxiter,tol=tol,sess=sesstrain_ct[j])

        if K  == 1:
            best_GLMHMMs_d2[k,j] = GLM_d2
            best_GLMHMMs_d1[k,j] = GLM_d1
            best_GLMHMMs_ct[k,j] = GLM_ct
            
        else:
            # find the initialization that led to the best fit
            bestix_d2 = find_best_fit(lls_all_d2)
            best_GLMHMMs_d2[k,j] = GLMHMMs_d2[bestix_d2]
            bestix_d1 = find_best_fit(lls_all_d1)
            best_GLMHMMs_d1[k,j] = GLMHMMs_d1[bestix_d1]    
            bestix_ct = find_best_fit(lls_all_ct)
            best_GLMHMMs_ct[k,j] = GLMHMMs_ct[bestix_ct]  
    
# save results in case we want to use them again later
pickle.dump(best_GLMHMMs_d2, open('fit models/training_states_GLMHMMs_d2.pickle', 'wb'))
pickle.dump(best_GLMHMMs_d1, open('fit models/training_states_GLMHMMs_d1.pickle', 'wb'))
pickle.dump(best_GLMHMMs_ct, open('fit models/training_states_GLMHMMs_ct.pickle', 'wb'))

### Ext. Data Figure 7a: Selecting the Number of States for the GLM-HMM
From the plots below, we see that the cross-validated log-likelihood starts to plateau around 3-4 states. From this, we decided to use the 3-state GLM-HMM for all analyses in the paper. For more on the 4-state GLM-HMM, see <b>Extended Data Figure 7d/e</b> further down in this notebook.