In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from copy import deepcopy
import glob
import sys

[sys.path.append(f) for f in glob.glob('utils/*')]
from preprocess import DataStruct
from firingrate import raster2FR
from plotting_utils import figSize
from session_utils import *
from recalibration_utils import *
from click_utils import *

from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline


min_nblocks    = 3       # min number of blocks for a session to be include
max_ndays      = 30      # accept all pairs of sessions regardless of time between
min_R2         = 0.1     # subselect days with good decoder transfer performance 


f_dir          = glob.glob('D:/T5_ClosedLoop/historical/*')
sessions_check = np.load('misc_data/OldSessions_check.npy', allow_pickle = True).item()
files          = get_Sessions(f_dir, min_nblocks)

In [6]:
from hmm import *
from hmm_utils import prep_HMMData, get_DiscreteTargetGrid, train_HMMRecalibrate
from sklearn.metrics import r2_score
import itertools


# general settings:
np.random.seed(42)
diffs           = list()
task            = None
train_size      = 0.5
sigma           = 2

# HMM settings: 
gridSize         = 20  
stayProb         = 0.99
inflection_sweep = 70 
exp_sweep        = 0.5
thresh_sweep     = 0.3


#--------------------------

nStates       = gridSize**2
stateTrans    = np.eye(nStates)*stayProb #Define the state transition matrix, which assumes uniform transition probability of transitioning to new state

for x in range(nStates):
    idx                = np.setdiff1d(np.arange(nStates), x)
    stateTrans[x, idx] = (1-stayProb)/(nStates-1)
pStateStart = np.zeros((nStates,1)) + 1/nStates


A_file = files[0]
B_file = A_file
dayA, dayB              = DataStruct(A_file, alignScreens = True), DataStruct(B_file, alignScreens = True)

dayA_blocks             = [sessions_check[A_file] if A_file in sessions_check.keys() else None][0]
dayB_blocks             = [sessions_check[B_file] if B_file in sessions_check.keys() else None][0] 
dayA_task, dayB_task, _ = getPairTasks(dayA, dayB, task = task)

# obtain features and cursorError targets:
Atrain_x, Atest_x, Atrain_y, Atest_y  = getTrainTest(dayA, train_size = train_size, sigma = sigma, blocks = dayA_blocks, task = dayA_task, returnFlattened = True)    
Atrain_x, Atest_x  = get_BlockwiseMeanSubtracted(Atrain_x, Atest_x, concatenate = True)
Atrain_y           = np.concatenate(Atrain_y)
Atest_y            = np.concatenate(Atest_y)

Btrain_x, Btest_x, Btrain_y, Btest_y, B_cursorPos, _  = getTrainTest(dayB, train_size = train_size, sigma = sigma, blocks = dayB_blocks, task = dayB_task, 
                                                                         returnFlattened = True, returnCursor = True)    
Btrain_x, Btest_x  = get_BlockwiseMeanSubtracted(Btrain_x, Btest_x, concatenate = True)
Btrain_y           = np.concatenate(Btrain_y)
Btest_y            = np.concatenate(Btest_y)
B_cursorPos        = np.concatenate(B_cursorPos)
targetPos          = Btrain_y + B_cursorPos

full_score, full_decoder = traintest_DecoderSupervised([Atrain_x], [Atrain_x], [Atrain_y], [Atrain_y], meanRecal = False)    
targLocs                 = get_DiscreteTargetGrid(dayB, gridSize = gridSize, task = dayB_task)

rawDecodeVec = full_decoder.predict(Btrain_x - Btrain_x.mean(axis = 0))

### optimizing HMM

In [None]:
imporCosineTuningcipy, cython
%load_ext Cython

In [None]:
#### %%cython 
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def hmmviterbi_vonmises_parallel(np.ndarray[np.float_t, ndim=2] rawDecodeVec, np.ndarray[np.float_t, ndim=2] stateTransitions, np.ndarray[np.float_t, ndim=2] targLocs, 
                                 np.ndarray[np.float_t, ndim=2] cursorPos,  np.ndarray[np.float_t, ndim=2] pStateStart, float vmKappa, adjustKappa = None, bint verbose = False):
    '''Run viterbi algorithm to find most likely sequence of target states given the cursor position and decoder outputs. Inputs are:

        rawDecodeVec (2D array)     - time x 2 array containing decoder outputs at each timepoint
        stateTransitions (2D array) - transition probabilities; n_states x n_states
        targLocs (2D array)         - n_states x 2 array containing corresponding target locations for each state
        cursorPos (2D array)        - time x 2 array of cursor positions
        pStateStart (vector)        - starting probabilities for each state 
        vmKappa (float)             - precision parameter for Von Mises observation model
        adjustKappa (method)        - fxn for weighting kappa values; defaults to None
        
    NOTE:
        - we could parallelize a lot of the below code (in the for loop) but it doesn't seem to result in substantial speed gains!
            - around ~40 sec run vs 38 sec run during testing
            - would need to fully cythonize to see substantial compute time decreases
    '''
    cdef int numStates           = len(stateTransitions)
    cdef int L                   = rawDecodeVec.shape[0]
    cdef int[::1] currentState   = np.zeros((L, ), dtype= int)
    
    cdef int[:, ::1] pTR        = np.zeros((numStates, L), dtype= int)
    cdef double[:, ::1] logTR   = np.log(stateTransitions)
    cdef np.ndarray v           = np.log(pStateStart)
    cdef np.ndarray vOld        = v
 

    # declare variables used later:
    cdef np.ndarray[np.float_t, ndim=2] tmpV
    cdef np.intp_t count
    cdef int[::1] maxIdx
    cdef double[::1] maxVal
    
    # default kappa adjustment - none 
    if adjustKappa is None:
        def adjustKappa(dist):
            cdef np.ndarray adjusted = np.ones(dist.shape, dtype = np.float)
            return adjusted

    # Precompute some values for speedup: 
    cdef double[::1] observedAngle                       = np.arctan2(rawDecodeVec[:, 1], rawDecodeVec[:, 0])
    cdef np.ndarray[np.float_t, ndim=2] tDists           = np.linalg.norm(targLocs - cursorPos[:, np.newaxis, :], axis = 2)
    cdef np.ndarray[np.float_t, ndim=3] normPosErr       = (targLocs[:, np.newaxis] - cursorPos) / tDists.T[:, :, np.newaxis]
    cdef np.ndarray[np.float_t, ndim=2] expectedAngle    = np.arctan2(normPosErr[:, :, 1], normPosErr[:, :, 0])
    cdef np.ndarray[np.float_t, ndim=2] vmKappa_adjusted = vmKappa * adjustKappa(tDists)
    cdef np.ndarray[np.float_t, ndim=2] vmProbLog        = (vmKappa_adjusted * np.cos(observedAngle - expectedAngle).T) - np.log(2 * np.pi * np.i0(vmKappa_adjusted))

    # loop through the model;  von mises emissions probabilities
    for count in range(L):
        tmpV          = vOld + logTR
        maxIdx        = np.argmax(tmpV, axis = 1).astype('int')
        maxVal        = np.take_along_axis(tmpV, np.expand_dims(maxIdx, axis=-1), axis=-1).squeeze(axis=-1)
        
        pTR[:,count] = maxIdx
        v            = vmProbLog[count, :] + maxVal
        vOld         = v
    
    # decide which of the final states is most probable
    cdef int finalState = np.argmax(v)
    cdef float logP     = v[finalState]

    # Now back trace through the model
    currentState[L - 1] = finalState

    for count in reversed(range(0, L - 1)):
        currentState[count] = pTR[int(currentState[count + 1]), count + 1]
        if currentState[count] == 0 & verbose == True:
            print('stats:hmmviterbi:ZeroTransitionProbability', currentState[ count + 1 ])

    return currentState, logP

In [572]:


x= hmmviterbi_vonmises(rawDecodeVec[:200, :], stateTrans, targLocs, B_cursorPos[:200, :], pStateStart,  4, adjust, False)[0]
y= hmmviterbi_vonmises_parallel(rawDecodeVec[:200, :], stateTrans, targLocs, B_cursorPos[:200, :], pStateStart,  4, adjust, False)[0]
diff = np.linalg.norm(x - y)

print("Difference: ", diff)
assert diff == 0, "Error, outputs are different."
%timeit hmmviterbi_vonmises(rawDecodeVec[:200, :], stateTrans, targLocs, B_cursorPos[:200, :], pStateStart, 4, adjust, verbose = False)
%timeit hmmviterbi_vonmises_parallel(rawDecodeVec[:200, :], stateTrans, targLocs, B_cursorPos[:200, :], pStateStart, 4, adjust, verbose = False)

0.0


#################

In [21]:
import time
from numba import jit


@jit(nopython=True) 
def numba_sum(a, b):
    x = a.shape[0]
    D = np.empty((x, x), dtype=np.float64)
    for i in range(x):
        for j in range(x):
            c = a[i, 0] + b[i, j]
            D[i, j] = c[0]
    return D
    

def EXP_hmmviterbi_vonmises(rawDecodeVec, stateTransitions, targLocs, cursorPos, pStateStart, vmKappa, adjustKappa = None, verbose = False):
    '''Run viterbi algorithm to find most likely sequence of target states given the cursor position and decoder outputs. Inputs are:

        rawDecodeVec (2D array)     - time x 2 array containing decoder outputs at each timepoint
        stateTransitions (2D array) - transition probabilities; n_states x n_states
        targLocs (2D array)         - n_states x 2 array containing corresponding target locations for each state
        cursorPos (2D array)        - time x 2 array of cursor positions
        pStateStart (vector)        - starting probabilities for each state 
        vmKappa (float)             - precision parameter for Von Mises observation model
        adjustKappa (method)        - fxn for weighting kappa values; defaults to None
        
    NOTE:
        - we could parallelize a lot of the below code (in the for loop) but it doesn't seem to result in substantial speed gains!
            - around ~40 sec run vs 38 sec run during testing
            - would need to fully cythonize to see substantial compute time decreases
    '''
    
    start = time.time()
    if adjustKappa is None:
        def adjustKappa(dist):
            return np.ones(dist.shape)

    numStates    = len(stateTransitions)
    L            = rawDecodeVec.shape[0]
    currentState = np.zeros((L, ))
    pTR          = np.zeros((numStates, L))

    # work in log space to avoid numerical issues
    logTR = np.log(stateTransitions)
    tmpV  = np.zeros((numStates, numStates))
    v     = np.log(pStateStart)
    vOld  = np.copy(v)
    print('Setup: ', time.time() - start)

    # Precompute some values for speedup: 
    start             = time.time()
    observedAngle_all = np.arctan2(rawDecodeVec[:, 1], rawDecodeVec[:, 0])
    print('Angle precompute:', time.time() - start)

    # loop through the model;  von mises emissions probabilities
    T_expangle   = 0
    T_kappadjust = 0
    T_logprobs   = 0
    T_updates    = 0
    
    start_forward = time.time()
    for count in range(L):
        # 1. compute distance from the cursor to each target, and expected angle for that target
        start             = time.time()
        tDists        = np.linalg.norm(targLocs - cursorPos[count, :], axis = 1)
        normPosErr    = (targLocs - cursorPos[count, :]) / tDists[:, np.newaxis]
        expectedAngle = np.arctan2(normPosErr[:, 1], normPosErr[:,0])
        T_expangle += time.time() - start

        # 2. compute expected precision based on the base kappa and distance to
        # target (very close distances -> very large dispersion in expected angles)
        start             = time.time()
        vmKappa_adjusted = vmKappa * adjustKappa(tDists)
        T_kappadjust += time.time() - start

        # 3. compute VM probability densities
        start             = time.time()
        observedAngle = observedAngle_all[count]
        vmProbLog     = (vmKappa_adjusted * np.cos(observedAngle - expectedAngle)) - np.log(2*np.pi* scipy.special.i0(vmKappa_adjusted))
        T_logprobs += time.time() - start
        
        start         = time.time()
        #tmpV          = vOld + logTR
        tmpV          = numba_sum(vOld, logTR)
        T_updates += time.time() - start
        
        maxIdx        = np.argmax(tmpV, axis = 1)
        maxVal        = np.take_along_axis(tmpV, np.expand_dims(maxIdx, axis=-1), axis=-1).squeeze(axis=-1)
        pTR[:,count]  = maxIdx
        v             = vmProbLog + maxVal
        vOld          = v
        
        
    print(tmpV.shape)
    print('Expected angle compute:', T_expangle)
    print('Kappa adjust:', T_kappadjust)
    print('posterior probs:', T_logprobs)
    print('updates and storing: ', T_updates)

        

    
    print('Forward loop: ', time.time() - start_forward)
    # decide which of the final states is most probable
    finalState = np.argmax(v)
    logP       = v[finalState]

    # Now back trace through the model
    start_backward = time.time()
    currentState[L - 1] = finalState
    
    for count in reversed(range(0, L - 1)):
        currentState[count] = pTR[int(currentState[count + 1]), count + 1]
        if currentState[count] == 0 & verbose == True:
            print('stats:hmmviterbi:ZeroTransitionProbability', currentState[ count + 1 ])
    print('Reverse loop: ', time.time() - start_backward)
   # return currentState, logP
    return tmpV, logTR

def adjust(dist):
    coef = 1 / (1 + np.exp(-1 * (dist - 70) * 0.5))
    return coef 


In [22]:
tmpV, logTR = EXP_hmmviterbi_vonmises(rawDecodeVec[:10000, :], stateTrans, targLocs, B_cursorPos[:10000, :], pStateStart,  4, adjust, False)

Setup:  0.0
Angle precompute: 0.0


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1mNo implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(float64, Literal[int](0))
 
There are 22 candidate implementations:
[1m      - Of which 22 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(float64, int64)':[0m
[1m       No match.[0m
[0m
[0m[1mDuring: typing of intrinsic-call at <ipython-input-21-d759067d843b> (12)[0m
[0m[1mDuring: typing of static-get-item at <ipython-input-21-d759067d843b> (12)[0m
[1m
File "<ipython-input-21-d759067d843b>", line 12:[0m
[1mdef numba_sum(a, b):
    <source elided>
            c = a[i, 0] + b[i, j]
[1m            D[i, j] = c[0]
[0m            [1m^[0m[0m


In [91]:
print('tmpV: ', tmpV.shape)
print('logTR: ', tmpV.shape)

tmpV:  (400, 400)
logTR:  (400, 400)


160000

In [None]:
# working -- dont touch

%%cython 
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(True)#
@cython.wraparound(False)
def hmmviterbi_vonmises_parallel(np.ndarray[np.float_t, ndim=2] rawDecodeVec, np.ndarray[np.float_t, ndim=2] stateTransitions, np.ndarray[np.float_t, ndim=2] targLocs, 
                                 np.ndarray[np.float_t, ndim=2] cursorPos,  np.ndarray[np.float_t, ndim=2] pStateStart, float vmKappa, adjustKappa = None, bint verbose = False):
    '''Run viterbi algorithm to find most likely sequence of target states given the cursor position and decoder outputs. Inputs are:

        rawDecodeVec (2D array)     - time x 2 array containing decoder outputs at each timepoint
        stateTransitions (2D array) - transition probabilities; n_states x n_states
        targLocs (2D array)         - n_states x 2 array containing corresponding target locations for each state
        cursorPos (2D array)        - time x 2 array of cursor positions
        pStateStart (vector)        - starting probabilities for each state 
        vmKappa (float)             - precision parameter for Von Mises observation model
        adjustKappa (method)        - fxn for weighting kappa values; defaults to None
        
    NOTE:
        - we could parallelize a lot of the below code (in the for loop) but it doesn't seem to result in substantial speed gains!
            - around ~40 sec run vs 38 sec run during testing
            - would need to fully cythonize to see substantial compute time decreases
    '''
    cdef int numStates           = len(stateTransitions)
    cdef int L                   = rawDecodeVec.shape[0]
    cdef int[::1] currentState   = np.zeros((L, ), dtype= int)
    
    cdef int[:, ::1] pTR        = np.zeros((numStates, L), dtype= int)
    cdef double[:, ::1] logTR   = np.log(stateTransitions)
    cdef np.ndarray v           = np.log(pStateStart)
    cdef np.ndarray vOld        = v
 

    # declare variables used later:
    cdef np.ndarray[np.float_t, ndim=2] tmpV
    cdef np.intp_t count
    cdef int[::1] maxIdx
    cdef double[::1] maxVal
    
    # default kappa adjustment - none 
    if adjustKappa is None:
        def adjustKappa(dist):
            cdef np.ndarray adjusted = np.ones(dist.shape, dtype = np.float)
            return adjusted

    # Precompute some values for speedup: 
    cdef double[::1] observedAngle                       = np.arctan2(rawDecodeVec[:, 1], rawDecodeVec[:, 0])
    cdef np.ndarray[np.float_t, ndim=2] tDists           = np.linalg.norm(targLocs - cursorPos[:, np.newaxis, :], axis = 2)
    cdef np.ndarray[np.float_t, ndim=3] normPosErr       = (targLocs[:, np.newaxis] - cursorPos) / tDists.T[:, :, np.newaxis]
    cdef np.ndarray[np.float_t, ndim=2] expectedAngle    = np.arctan2(normPosErr[:, :, 1], normPosErr[:, :, 0])
    cdef np.ndarray[np.float_t, ndim=2] vmKappa_adjusted = vmKappa * adjustKappa(tDists)
    cdef np.ndarray[np.float_t, ndim=2] vmProbLog        = (vmKappa_adjusted * np.cos(observedAngle - expectedAngle).T) - np.log(2 * np.pi * np.i0(vmKappa_adjusted))
    #cdef np.ndarray[np.float_t, ndim=2] vmProbLog        = (vmKappa_adjusted * np.cos(observedAngle - expectedAngle).T) - np.log(2 * np.pi * scipy.special.i0(vmKappa_adjusted))

    # loop through the model;  von mises emissions probabilities
    for count in range(L):
        tmpV          = vOld + logTR
        maxIdx        = np.argmax(tmpV, axis = 1).astype('int')
        maxVal        = np.take_along_axis(tmpV, np.expand_dims(maxIdx, axis=-1), axis=-1).squeeze(axis=-1)
        
        pTR[:,count] = maxIdx
        v            = vmProbLog[count, :] + maxVal
        vOld         = v
    
    # decide which of the final states is most probable
    cdef int finalState = np.argmax(v)
    cdef float logP     = v[finalState]

    # Now back trace through the model
    currentState[L - 1] = finalState

    for count in reversed(range(0, L - 1)):
        currentState[count] = pTR[int(currentState[count + 1]), count + 1]
        if currentState[count] == 0 & verbose == True:
            print('stats:hmmviterbi:ZeroTransitionProbability', currentState[ count + 1 ])

    return currentState, logP