### Example for Fitting a Generalized Linear Model (GLM) - Hidden Markov Model (HMM) to Real Data

This notebook is designed to help users get started using the glmhmm package to fit GLM-HMMs to their data. In this notebook, we fit the GLM-HMM to real experimental data (as in [Bolkan, Stone et al 2021](https://www.biorxiv.org/content/10.1101/2021.07.23.453573v1)) and recreate the figures from that paper. 

For an example of how the GLM-HMM can be applied to simulated data, check out the <code>fit-glm-hmm.ipynb</code> notebook.

In [446]:
%load_ext autoreload
%autoreload 2

import sys
import matplotlib.pyplot as plt
import numpy as np
import time
from sklearn.model_selection import KFold
from glmhmm import glm_hmm
from glmhmm.utils import permute_states, find_best_fit, compare_top_weights
from glmhmm.visualize import plot_model_params, plot_loglikelihoods, plot_weights

import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = '1'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

# set plot design features
font = {'family'   : 'sans-serif',
        'sans-serif' : 'Helvetica',
        'weight'   : 'regular',
        'size'     : 18}

mpl.rc('font', **font)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Description of the Data

The real data that we will use in this case comes from the experiments described in [Bolkan, Stone et al 2021](https://www.biorxiv.org/content/10.1101/2021.07.23.453573v1). This is behavioral data from mice performing a two alternative forced choice (2AFC) task in which the animals run down a virtual maze while multi-sensory "cues" appear to their left and right. The mice must "accumulate evidence" as these cues appear and ultimately make a decision to turn left or right based on which side of the maze had more cues.

The dataset in the paper includes three cohorts of mice: a group that was inhibited in the direct pathway of the striatum, a group that was inhibited in the indirect pathway, and a control (no opsin) group. Below we will fit a GLM-HMM to the "indirect pathway group", but this repository includes data for the other cohorts as well, should you want to take a look.

The provided design matrix (loaded below) includes z-scored values for the following external covariates (in this order): bias, difference in cues, laser, previous choices 1-6, and previous rewarded choice. See the methods section of the paper for more information on how we coded these covariates. 

#### 1. Import the required code packages and modules.

#### 1. Load the data

In [None]:
x = np.load('../data/indirect_x.npy') # z-scored design matrix
y = np.load('../data/indirect_y.npy') # vector of right and left choices for each trial
sessions = np.load('../data/indirect_sessions.npy') # vector of session start and stop indices

#### 2. Set the hyper-parameters of the GLM-HMM

In [None]:
N = x.shape[0] # number of data/time points
K = 3 # number of latent states
C = 2 # number of observation classes
D = x.shape[1] # number of GLM inputs (regressors)

#### 3. Instantiate the model

In [None]:
real_GLMHMM = glm_hmm.GLMHMM(N,D,C,K,observations="bernoulli",gaussianPrior=1)

#### 4. Fit the model

In this case, we're going to initialize the weights by fitting a GLM to the data. Once we have the fitted GLM weights (you can think of this as a 1-state GLM-HMM), we'll duplicate those weights $K$ times and add a small amount noise to each set. These weights will then form our $w_{init}$ and serve as a smarter initialization that will help us find the best solution more quickly and more often than if we initialized the weights randomly.

Another change from the simulated case is that now the structure of our data includes different sessions, consisting of different mice performing the task on different days. So that the fitted parameters aren't affected by this structure (i.e. the last trial on one day doesn't affect the first trial on the next day) we will fit separately to each session. Fortunately, this is easy to do -- we simply include a vector with the first and last indices of each session length as an extra input to the fitting code.

<b>Note:</b> To illustrate the process of confirming we've found the global optimum (see step 4), the code below fits the model 20 times. This can take 1-2.5 hours to run, depending on your machine (the times below are from running on a 2016 MacBook Pro, so relatively slow). For a quicker assessment, reduce the number of inits. 

In [None]:
inits = 20 # set the number of initializations
maxiter = 250 # maximum number of iterations of EM to allow for each fit
tol = 1e-3
# store values for each initialization
lls_all = np.zeros((inits,250))
A_all = np.zeros((inits,K,K))
w_all = np.zeros((inits,K,D,C))

# fit the model for each initialization
for i in range(inits):
    t0 = time.time()
    # initialize the weights
    A_init,w_init,pi_init = real_GLMHMM.generate_params(weights=['GLM',-0.2,1.2,x,y,1])
    # fit the model                     
    lls_all[i,:],A_all[i,:,:],w_all[i,:,:],pi0 = real_GLMHMM.fit(y,x,A_init,w_init,maxiter=maxiter,tol=tol,sess=sessions) 
    minutes = (time.time() - t0)/60
    print('initialization %s complete in %.2f minutes' %(i+1, minutes))

#### 5. Check to see that multiple fits achieve the same log-likelihood.

Since we're now fitting real data, we can't simply compare our inferred parameters to the true ones to make sure we're recovering the right values. But we can get a reasonable assessment of whether or not we're finding the global optimum of the log-likelihood by comparing the log-likelihoods for each fit. If multiple log-likelihoods for the best fits converge to the same (or very similar) values, this is a good indication that we've found the global optimum.

In [None]:
topixs = plot_loglikelihoods(lls_all,0.5,startix=5) # set the x-axis startix > 0 to see better view of final lls
print('Number of top matching lls within threshold: ', len(topixs))

#### 6. We can also check that the weights for the best fits are the same, within some tolerance.

In [None]:
# first, permute the weights according to the value of a particular regressor (here we pick cues) so that the states
# will be the same for each fit 
w_permuted = np.zeros_like(w_all[:,:,:,1])
order = np.zeros((inits,K))
for i in range(inits):
    w_permuted[i],order[i] = permute_states(w_all[i,:,:,1],method='weight value',param='weights',ix=1)

np.set_printoptions(precision=2,suppress=True)
# now let's check if the weights for the top fits match up
compare_top_weights(w_permuted,topixs,tol=0.5)

#### 7. Now let's look at the inferred parameters (as shown in our paper, [Bolkan, Stone et al 2021](https://www.biorxiv.org/content/10.1101/2021.07.23.453573v1)). 

In [None]:
In the plot of the weights below, we include error bars on each of the weights. We computed the variance for this purpose by taking the inverse Hessian of the optimized log-likelihood. This is a process that relies on autograd and is somewhat slow, so if you're running this notebook yourself and looking for a quick visualization, we recommend skipping that step (and changing the inputs in the <code>plot_weights</code> function to <code>error=None</code>). 

variance = real_GLMHMM.computeVariance(x,y,A_permuted,w_permuted[bestix,:,:,np.newaxis],gaussPrior=1)

bestix = find_best_fit(lls_all) # find the initialization that led to the best fit
A_permuted, _ = permute_states(A_all[bestix],method='order',order=order[bestix].astype(int))

# plot the inferred transition probabilities
fig, ax = plt.subplots(1,1)
plot_model_params(A_permuted,ax,precision='%.3f')

# plot the inferred weights probabilities
fig, ax = plt.subplots(1,1)
colors = np.array([[39,110,167],[237,177,32],[233,0,111],[176,100,245]])/255
xlabels = ['$\Delta$ cues', 'laser', 'bias', '1', '2', '3', '4', '5', '6', 'prev rew \n choice']
legend = ['state 1', 'state 2', 'state 3']
plot_weights(w_permuted[bestix],ax,xlabels=xlabels,switch=True,style='.-',color=colors,error=None,label=legend)
ax.text(0.43,-0.25,'prev choice',transform=ax.transAxes)
ax.legend()

Check back soon for more examples of convenient post-fitting analysis/plotting code, including many additional recreated figures from the paper.