Code below does a grid search for optimal HMM values:


In [1]:
%load_ext autoreload
%autoreload 2

%load_ext autoreload
%autoreload 2

import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from copy import deepcopy
import glob
import sys

[sys.path.append(f) for f in glob.glob('../utils/*')]
from preprocess import DataStruct
from plotting_utils import figSize
from lineplots import plotsd
from session_utils import *
from recalibration_utils import *
from click_utils import *

from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.decomposition import FactorAnalysis, PCA


min_nblocks    = 3       # min number of blocks for a session to be include
max_ndays      = 30      # accept all pairs of sessions regardless of time between
#min_R2         = 0.1    # subselect days with good decoder transfer performance 


f_dir          = glob.glob('D:/T5_ClosedLoop/historical/*')
sessions_check = np.load('../utils/misc_data/NewSessions_check.npy', allow_pickle = True).item()
files          = get_Sessions(f_dir, min_nblocks, manually_remove = sessions_check['bad_days'])

init_pairs    = get_SessionPairs(files, max_ndays = max_ndays)
pairs         = init_pairs
#pairs, scores = get_StrongTransferPairs(init_pairs, min_R2 = min_R2, train_frac = 0.5, block_constraints = sessions_check)
n_pairs       = len(pairs)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [99]:
from joblib import Parallel, delayed
from sklearn.model_selection import ParameterGrid

def HMMrecal_parallel(params_dict, baseOpts, train_x, cursorPos, test_x, test_y ):
    '''Inputs are:
    
        params_dict (dictionary) - entries are values to sweep for associated parameters
        baseOpts (dictionary)    - contains unchanging HMM parameters; key-value pairs are:
        
            'stateTrans'  : 2D float array - n x n transition matrix model
            'targLocs'    : 2D float array - n x 2 of target location for each state
            'pStateStart' : 1D float array - n x 1 of start probabilities
    
    '''
    
    grid = ParameterGrid(params_dict)
    
    HMMs   = [HMMRecalibration(baseOpts['stateTrans'], baseOpts['targLocs'], baseOpts['pStateStart'], params['kappa'], 
                               adjustKappa = lambda dist : 1 / (1 + np.exp(-1 * (dist - params['inflection']) * params['exp']))) for params in grid]
    
    args   = zip(HMMs, [full_decoder] * len(HMMs), [train_x] * len(HMMs), [cursorPos] * len(HMMs), [test_x] * len(HMMs), [test_y] * len(HMMs))                   
    scores = Parallel(n_jobs=-1)(delayed(test_HMMrecal)(*arg) for arg in args)
    
    # append R^2 to each parameter set: 
    scores_dict = list(grid)
    for score, params in zip(scores, scores_dict):
        params['score'] = score
    
    return scores_dict


def test_HMMrecal(HMM, decoder, train_x, cursorPos, test_x, test_y):
    '''Code for training/testing HMM recalibrated decoder. Inputs are:
    
        HMM (HMMRecalibration object) - hmm to use 
        decoder (sklearn-like)        - decoder following sklearn conventions
        train_x (list)                - entries are (time x n_channels) float arrays of neural activity
        cursorPos (list)              - entries are (time x 2) float arrays of cursor positions
        test_x (list)                 - test set version of <neural>
        test_y (list)                 - entries are (time x 2) float arrays of cursor error signals'''
    
    new_decoder = HMM.recalibrate(deepcopy(decoder), train_x, cursorPos)
    score       = new_decoder.score(np.vstack(test_x), np.vstack(test_y))
    
    return score


In [None]:
from hmm import *
from hmm_utils import prep_HMMData, get_DiscreteTargetGrid, train_HMMRecalibrate
from sklearn.metrics import r2_score
import itertools


# general settings:
np.random.seed(42)
diffs           = list()
task            = None
train_size      = 0.67
probWeighted    = 'probWeighted'
gridSize         = 20  
stayProb         = 0.999

# hyperparams to sweep:  
sigma_sweep = [None, 1, 2, 3]

params_dict = dict()
params_dict['kappa']      = [0.5, 1, 2, 4, 6, 8]
params_dict['inflection'] = [0.1, 10, 30, 50, 70, 100, 200, 400]  
params_dict['exp']        = [0.0001, 0.001, 0.025, 0.05, 0.1, 0.5, 1, 2, 4]

#params_dict['kappa']      = [0.5, 1,]
#params_dict['inflection'] = [0.1, 10]  
#params_dict['exp']        = [0.0001, 0.001, 0.025]

#--------------------------

nStates       = gridSize**2
stateTrans    = np.eye(nStates)*stayProb #Define the state transition matrix, which assumes uniform transition probability of transitioning to new state

for x in range(nStates):
    idx                = np.setdiff1d(np.arange(nStates), x)
    stateTrans[x, idx] = (1-stayProb)/(nStates-1)
pStateStart = np.zeros((nStates,1)) + 1/nStates

baseOpts = dict()
baseOpts['stateTrans']  = stateTrans
baseOpts['pStateStart'] = pStateStart

scores = list()
#for i, (A_file, B_file) in enumerate([pairs[j] for j in range(22, len(pairs))]):
for i, (A_file, B_file) in enumerate(pairs):
    dayA, dayB              = DataStruct(A_file, alignScreens = True), DataStruct(B_file, alignScreens = True)

    #dayA_blocks             = [sessions_check[A_file] if A_file in sessions_check.keys() else None][0]
    #dayB_blocks             = [sessions_check[B_file] if B_file in sessions_check.keys() else None][0] 
    #dayA_task, dayB_task, _ = getPairTasks(dayA, dayB, task = task)
    dayA_blocks, dayB_blocks = None, None
    dayA_task, dayB_task    = None, None

    # obtain features and cursorError targets:
    for sigma in sigma_sweep:
    
        Atrain_x, Atest_x, Atrain_y, Atest_y  = getTrainTest(dayA, train_size = train_size, sigma = sigma, blocks = dayA_blocks, task = dayA_task, returnFlattened = True)    
        Atrain_x, Atest_x  = get_BlockwiseMeanSubtracted(Atrain_x, Atest_x, concatenate = True)
        Atrain_y           = np.concatenate(Atrain_y)
        Atest_y            = np.concatenate(Atest_y)

        Btrain_x, Btest_x, Btrain_y, Btest_y, B_cursorPos, _  = getTrainTest(dayB, train_size = train_size, sigma = sigma, blocks = dayB_blocks, task = dayB_task, 
                                                                             returnFlattened = True, returnCursor = True)    

        Btrain_x, Btest_x  = get_BlockwiseMeanSubtracted(Btrain_x, Btest_x, concatenate = True)
        Btrain_y           = np.concatenate(Btrain_y)
        Btest_y            = np.concatenate(Btest_y)
        B_cursorPos        = np.concatenate(B_cursorPos)
        targetPos          = Btrain_y + B_cursorPos

        full_score, full_decoder = traintest_DecoderSupervised([Atrain_x], [Atrain_x], [Atrain_y], [Atrain_y], meanRecal = False)    
        targLocs                 = get_DiscreteTargetGrid(dayB, gridSize = gridSize, task = dayB_task)
        baseOpts['targLocs']     = targLocs
        
        # parallelized parameter eval for this session-pair:
        scores_dict = HMMrecal_parallel(params_dict, baseOpts, [Btrain_x], [B_cursorPos], [Btest_x], [Btest_y])
        
        # add smoothing and session-specific information
        for param in scores_dict:
            param['idx']       = i
            param['diffs']     = daysBetween(dayA.date, dayB.date) 
            param['smoothing'] = [sigma if sigma is not None else 0][0]
        scores.extend(scores_dict)
        
    if (i + 1) % int(np.round(len(pairs) / 10)):
        print(np.round((i + 1) * 100 / len(pairs), 1), '% complete')


In [103]:
scores_df = pd.DataFrame(scores)

In [None]:
import seaborn as sns
sns.swarmplot(y = scores_df['score'], x = scores_df['smoothing'], orient = 'v')

In [None]:
import seaborn as sns 

medscores = np.median(scores, axis = 0)
args      = np.unravel_index(medscores.argmax(), medscores.shape)

print('Best weighting function: logistic with inflection = ', inflection_sweep[args[0]], ' exponent = ', exp_sweep[args[1]])
print('Best kappa: ', kappa_sweep[args[2]])
print('Best threshold: ', thresh_sweep[args[3]])

plt.subplot(1, 2, 1)
x = np.linspace(0, 400, 3000)
y = coef = 1 / (1 + np.exp(-1 * (x - inflection_sweep[args[0]]) * exp_sweep[args[1]]))

figSize(5, 10)
plt.plot(x, y)
plt.xlabel('Distance to target')
plt.ylabel('Kappa adjustment factor')
plt.title('Weighting function')

plt.subplot(1, 2, 2)
sns.swarmplot(data = scores[:, args[0], args[1], args[2], args[3]], orient = 'v')
plt.title('Session scores (best parameters)')
plt.ylabel('R^2 (new day)')

In [None]:
kappa_sweep[args[0]]
kappa_sweep[args[0]]