Code below does a grid search for optimal HMM values:


In [282]:
%load_ext autoreload
%autoreload 2

%load_ext autoreload
%autoreload 2

import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from copy import deepcopy
import glob
import sys

[sys.path.append(f) for f in glob.glob('utils/*')]
from preprocess import DataStruct
from firingrate import raster2FR
from plotting_utils import figSize
from lineplots import plotsd
from session_utils import *
from recalibration_utils import *
from click_utils import *

from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.decomposition import FactorAnalysis, PCA


min_nblocks    = 3       # min number of blocks for a session to be include
max_ndays      = 30      # accept all pairs of sessions regardless of time between
min_R2         = 0.1     # subselect days with good decoder transfer performance 


f_dir          = glob.glob('D:/T5_ClosedLoop/*')
sessions_check = np.load('misc_data/sessions_check.npy', allow_pickle = True).item()
files          = get_Sessions(f_dir, min_nblocks)


init_pairs    = get_SessionPairs(files, max_ndays = max_ndays, manually_remove = sessions_check['bad_days'])
pairs, scores = get_StrongTransferPairs(init_pairs, min_R2 = min_R2, train_frac = 0.5, block_constraints = sessions_check)
n_pairs       = len(pairs)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [283]:
from joblib import Parallel, delayed

'''
def my_func(inflection, exp, vmKappa, rawDecTraj, stateTrans, targLocs, B_cursorPos, pStateStart):
    Code for parallelizing HMM sweeps. Inputs are:
    
        inflection, exp (floats) - parameters for adjusting kappa weighting
        vmKappa (float)          - base kappa value
        rawDecTraj (2D array)    - time x 2 of decoder outputs
        stateTrans (2D array)    - square transition matrix for markov states
        targLocs (2D array)      - k x 2 array of corresponding target positions for each state
        B_cursorPos (2D array)   - time x 2 array of cursor positions
        pStateStart (1D array)   - starting probabilities for each state
    
    def adjustKappa(dist):
        coef = 1 / (1 + np.exp(-1 * (dist - inflection) * exp))
        return coef 
    predTarg = hmmviterbi_vonmises(rawDecTraj, stateTrans, targLocs, B_cursorPos, pStateStart, vmKappa, adjustKappa)[0]
    
    return predTarg
'''


def HMMrecal_parallel(inflection, exp, vmKappa, probThresh, decoder, neural, stateTrans, targLocs, B_cursorPos, pStateStart):
    '''Code for parallelizing HMM sweeps. Inputs are:
    
        inflection, exp (floats) - parameters for adjusting kappa weighting
        vmKappa (float)          - base kappa value
        probThresh (float)       - subselect high probability time points; between 0 and 1 
        decoder (sklearn)        - sklearn LinearRegression() object 
        neural (2D array)        - time x channels of neural activity
        stateTrans (2D array)    - square transition matrix for markov states
        targLocs (2D array)      - k x 2 array of corresponding target positions for each state
        B_cursorPos (2D array)   - time x 2 array of cursor positions
        pStateStart (1D array)   - starting probabilities for each state'''
    
    def adjustKappa(dist):
        coef = 1 / (1 + np.exp(-1 * (dist - inflection) * exp))
        return coef 
    
    new_decoder = train_HMMRecalibrate(deepcopy(full_decoder), [Btrain_x], [B_cursorPos], stateTrans, pStateStart, targLocs, vmKappa, adjustKappa, probThresh)
    return new_decoder



In [None]:
from hmm import *
from hmm_utils import prep_HMMData, get_DiscreteTargetGrid, train_HMMRecalibrate
from sklearn.metrics import r2_score
import itertools


# general settings:
np.random.seed(42)
diffs           = list()
task            = None
train_frac      = 0.5
sigma           = 2

# HMM settings: 
gridSize         = 20  
stayProb         = 0.99
kappa_sweep      = [0.5, 1, 2, 4, 6, 8]
inflection_sweep = [0.1, 10, 30, 50, 70, 100, 200, 400]  
exp_sweep        = [0.0001, 0.001, 0.025, 0.05, 0.1, 0.5, 1, 2, 4]
thresh_sweep     = [0.1, 0.3, 0.5, 0.7]

#kappa_sweep      = [ 1, 2]
#inflection_sweep = [50, 70]  # inflection point sweep
#exp_sweep        = [2, 4]
#thresh_sweep     = [0.3]


#--------------------------

nStates       = gridSize**2
stateTrans    = np.eye(nStates)*stayProb #Define the state transition matrix, which assumes uniform transition probability of transitioning to new state

for x in range(nStates):
    idx                = np.setdiff1d(np.arange(nStates), x)
    stateTrans[x, idx] = (1-stayProb)/(nStates-1)
pStateStart = np.zeros((nStates,1)) + 1/nStates


params_grid = list(itertools.product(inflection_sweep, exp_sweep, kappa_sweep, thresh_sweep))
grid_inds   = list(itertools.product(range(len(inflection_sweep)), range(len(exp_sweep)), range(len(kappa_sweep)), range(len(thresh_sweep))))

scores   = np.zeros((n_pairs, len(inflection_sweep), len(exp_sweep), len(kappa_sweep), len(thresh_sweep) )) 
diffs    = np.zeros((n_pairs,)) # track the # of days between sessions in each pairing

for i, (A_file, B_file) in enumerate(pairs):  
    dayA, dayB              = DataStruct(A_file, alignScreens = True), DataStruct(B_file, alignScreens = True)
    diffs[i]                = daysBetween(dayA.date, dayB.date) # record number of days between sessions

    dayA_blocks             = [sessions_check[A_file] if A_file in sessions_check.keys() else None][0]
    dayB_blocks             = [sessions_check[B_file] if B_file in sessions_check.keys() else None][0] 
    dayA_task, dayB_task, _ = getPairTasks(dayA, dayB, task = task)

    # obtain features and cursorError targets:
    Atrain_x, Atest_x, Atrain_y, Atest_y                 = getTrainTest(dayA, train_frac = train_frac, sigma = sigma, blocks = dayA_blocks, task = dayA_task, return_flattened = True)
    Btrain_x, B_cursorPos, Btrain_y, Btest_x, _, Btest_y = prep_HMMData(dayB, train_frac = train_frac, sigma = sigma, blocks = dayB_blocks, task = task, return_flattened = True)
    targetPos                                            = Btrain_y + B_cursorPos
    
    full_score, full_decoder = traintest_DecoderSupervised([Atrain_x], [Atrain_x], [Atrain_y], [Atrain_y], meanRecal = False)    
    targLocs                 = get_DiscreteTargetGrid(dayB, gridSize = gridSize, task = dayB_task)

    decoders = Parallel(n_jobs=-2)(delayed(HMMrecal_parallel)(inflection, exp, vmKappa, probThresh, full_decoder, 
                                                           Btrain_x, stateTrans, targLocs, B_cursorPos, pStateStart) 
                                                           for j, (inflection, exp, vmKappa, probThresh) in enumerate(params_grid))
    
    for j in range(len(decoders)):
        score             = decoders[j].score(Btest_x - Btrain_x.mean(axis = 0), Btest_y)
        a, b, c, d        = grid_inds[j]
        scores[i,a,b,c,d] = score
    
    if (i + 1) % int(np.round(len(pairs) / 10)):
        print(np.round((i + 1) * 100 / len(pairs), 1), '% complete')


In [None]:
plt.subplot(1, 2, 1)
x = np.linspace(0, 400, 3000)
y = coef = 1 / (1 + np.exp(-1 * (x - inflection_sweep[args[0]]) * exp_sweep[args[1]]))

figSize(5, 10)
plt.plot(x, y)
plt.xlabel('Distance to target')
plt.ylabel('Kappa adjustment factor')
plt.title('Weighting function')

plt.subplot(1, 2, 2)
sns.swarmplot(scores[:, args[0], args[1], args[2], args[3]], orient = 'v')
plt.title('Session scores (best parameters)')
plt.ylabel('R^2 (new day)')