## Figure 5: Comparing GLM vs. GLM-HMM Model Performance

This notebook recreates the figure panels included in Figure 5 of [Bolkan, Stone et al 2021](https://www.biorxiv.org/content/10.1101/2021.07.23.453573v1). It also serves as a useful tutorial notebook for users who are looking to compare GLM vs. GLM-HMM performance for models fit to their own experimental data.

The general premise of this notebook/figure, in the context of the paper, is that we fit a Bernoulli GLM to our data and realized that a GLM does not provide a great explanation of the data. Instead, we considered a model (called a GLM-HMM) in which the animals' decision-making process could be described by multiple GLMs, each one corresponding to a different internal state or task strategy. After testing how many states best describes the data (see <code>extdatafig7.ipynb</code> for details) we settled on a 3-state GLM-HMM. Below, we'll compare how the 3-state GLM-HMM performs relative to the standard GLM. 

We will conclude at the end of this notebook that the 3-state GLM-HMM performs better than the GLM, and so we'll stick with that model for all subsequent paper analyses. 

------------------
### Running Cross-Validation
For a few of the figure panels, we're going to train the models on training data and then compare model performance when fit to test data. This requires us to first run cross-validation for both models.
####  Import the required code packages and modules.

In [128]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '..')

import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
from glmhmm import glm, glm_hmm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load the data

In [154]:
# load the data for the indirect pathway cohort
x_d2 = np.load('data/indirect_x.npy') # z-scored design matrix
y_d2 = np.load('data/indirect_y.npy') # vector of right and left choices for each trial
sessions_d2 = np.load('data/indirect_sessions.npy') # vector of session start and stop indices
mouseIDs_d2 = np.load('data/indirect_mouseIDs.npy') # vector of mouse IDs for each trial

# load the data for the direct pathway cohort
x_d1 = np.load('data/direct_x.npy') # z-scored design matrix
y_d1 = np.load('data/direct_y.npy') # vector of right and left choices for each trial
sessions_d1 = np.load('data/direct_sessions.npy') # vector of session start and stop indices
mouseIDs_d1 = np.load('data/direct_mouseIDs.npy') # vector of mouse IDs for each trial

#### Split the data

Now, let's split our data into train and test sets. This can be a little tricky to do with real data, as we don't necessarily want to split the data randomly. Instead, we'll want to preserve session structure (as opposed to splitting trials within sessions). Because of individual differences in animals, we'll also want to try to balance the test sets so that they contain approximately the same number of sessions per mouse. 

In [186]:
from glmhmm.utils import splitData

## indirect pathway cohort --------------------------------------------------------

# initialize as lists since not every test/train set will be exactly the same size
x_train_d2, x_test_d2, y_train_d2, y_test_d2, sessions_train_d2, sessions_test_d2, testIx_d2 = [],[],[],[],[],[],[]

# specify seeds for splitting the data for reproducibility
seeds = [55,38,13,23,103]

# split the data
for seed in seeds:
    trainIx, sessionsTrain, testIx, sessionsTest = splitData(sessions_d2,mouseIDs_d2,testSize=0.2,seed=seed)
    x_train_d2.append(x_d2[trainIx,:])
    x_test_d2.append(x_d2[testIx,:])
    y_train_d2.append(y_d2[trainIx])
    y_test_d2.append(y_d2[testIx])
    sessions_train_d2.append(sessionsTrain)
    sessions_test_d2.append(sessionsTest)
    testIx_d2.append(testIx)
    
## direct pathway cohort --------------------------------------------------------
    
# initialize as lists since not every test/train set will be exactly the same size
x_train_d1, x_test_d1, y_train_d1, y_test_d1, sessions_train_d1, sessions_test_d1, testIx_d1 = [],[],[],[],[],[],[]

# specify seeds for splitting the data for reproducibility
seeds = [10,66,100,73,200]

# split the data
for seed in seeds:
    trainIx, sessionsTrain, testIx, sessionsTest = splitData(sessions_d1,mouseIDs_d1,testSize=0.2,seed=seed)
    x_train_d1.append(x_d1[trainIx,:])
    x_test_d1.append(x_d1[testIx,:])
    y_train_d1.append(y_d1[trainIx])
    y_test_d1.append(y_d1[testIx])
    sessions_train_d1.append(sessionsTrain)
    sessions_test_d1.append(sessionsTest)
    testIx_d1.append(testIx)

#### Fit GLMs to the training sets

In [192]:
M = 10 # number of input features
C = 2 # number of observation classes
folds = 5

## indirect pathway cohort --------------------------------------------------------
fit_glms_d2 = np.zeros((folds),dtype=object)
for i in range(folds):
    N = x_train_d2[i].shape[0] 
    fit_glms_d2[i] = glm.GLM(N,M,C,observations="bernoulli")
    w_init = fit_glms_d2[i].init_weights()
    results = fit_glms_d2[i].fit(x_train_d2[i],w_init,y_train_d2[i],compHess=False)
    
## direct pathway cohort --------------------------------------------------------
fit_glms_d1 = np.zeros((folds),dtype=object)
for i in range(folds):
    N = x_train_d1[i].shape[0] 
    fit_glms_d1[i] = glm.GLM(N,M,C,observations="bernoulli")
    w_init = fit_glms_d1[i].init_weights()
    results = fit_glms_d1[i].fit(x_train_d1[i],w_init,y_train_d1[i],compHess=False)

#### Fit GLM-HMMs to the training sets
The cell below will take about 20 hours to run (~2 hours per fold x 5 folds x 2 datasets) but of course you can speed this up by putting the code below into a python script and running it for each fold (and initialization) in parallel.  

In [None]:
from glmhmm.utils import find_best_fit

K = 3
inits = 20 # set the number of initializations

## indirect pathway cohort --------------------------------------------------------
best_fit_GLMHMMs_d2 = np.zeros((folds),dtype=object)
for j in range(folds):
    # store values for each initialization
    lls_all = np.zeros((inits,250))
    real_GLMHMMs = np.zeros((inits),dtype=object)

    # fit the model for each initialization
    N = x_train_d2[j].shape[0]
    for i in range(inits): 
        real_GLMHMMs[i] = glm_hmm.GLMHMM(N,M,C,K,observations="bernoulli")
        A_init,w_init,_ = real_GLMHMMs[i].generate_params(weights=['GLM',-0.2,1.2,x_train_d2[j],y_train_d2[j],1])                   
        lls_all[i],_,_,_ = real_GLMHMMs[i].fit(y_train_d2[j],x_train_d2[j],A_init,w_init,sess=sessions_train_d2[j]) 
        
    # store results from best fit
    bestix = find_best_fit(lls_all)
    best_fit_GLMHMMs_d2[j] = real_GLMHMMs[bestix]
    
## direct pathway cohort --------------------------------------------------------
best_fit_GLMHMMs_d1 = np.zeros((folds),dtype=object)
for j in range(folds):
    # store values for each initialization
    lls_all = np.zeros((inits,250))
    real_GLMHMMs = np.zeros((inits),dtype=object)

    # fit the model for each initialization
    N = x_train_d1[j].shape[0]
    for i in range(inits): 
        real_GLMHMMs[i] = glm_hmm.GLMHMM(N,M,C,K,observations="bernoulli")
        A_init,w_init,_ = real_GLMHMMs[i].generate_params(weights=['GLM',-0.2,1.2,x_train_d1[j],y_train_d1[j],1])                   
        lls_all[i],_,_,_ = real_GLMHMMs[i].fit(y_train_d1[j],x_train_d1[j],A_init,w_init,sess=sessions_train_d1[j]) 
        
    # store results from best fit
    bestix = find_best_fit(lls_all)
    best_fit_GLMHMMs_d1[j] = real_GLMHMMs[bestix]

Now we'll compute the test loglikelihoods for the fit GLM and GLM-HMM models for each mouse and plot the results to see how much improvement in performance we see for the GLM-HMM over the GLM.

In [None]:
from glmhmm.analysis import compare_LL_GLMvsGLMHMM

unique_mouseIDs = np.unique(mouseIDs_d2)
numMice = len(uniqueMouseIDs)
test_lls = np.zeros((folds,numMice,2))
for j in range(folds):
    test_mouseIDs = mouseIDs_d2[testIx_d2[j]]
    for mouseID in unique_mouseIDs:
        test_mouse_ix = np.where(testmouseIDs == mouseID)
        test_lls[i,j] = compare_LL_GLMvsGLMHMM(fit_glm_d2[j],best_fit_GLMHMMs_d2[j],x[test_mouse_ix,:],
                                               y[test_mouse_ix])