#### Example: parallelizing PRI-T with chunking

Here we'll test PRI-T's inference speed with two strategies:
- a standard Viterbi search
- "chunking" data into contiguous stretches and processing in parallel

In [10]:
from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
import re

In [11]:
dat = loadmat('exampledat.mat')

result =  dat['description'][0].replace('\n', ' ')
result = "\n".join([re.sub("  +"," ",x.strip(" ")) for x in "".join(result).split("\n")])
print(result)

Example data from a closed-loop simulation run. Here we simulated 200 seconds (20 ms timebins) of closed-loop control after a nonstationarity occurs in the neural tuning matrix. The decoder is fixed, meaning there is now a mismatch between it and the neural tuning.


Let's measure the inference speed for our standard approach. We'll also measure the inferred and ground-truth target correlation here as a performance metric.

In [3]:
from PRIT.prit_utils import generateTargetGrid, generateTransitionMatrix
from PRIT.prit import HMMRecalibration
import time

gridSize  = 20      # number of rows/columns when discretizing screen
stayProb  = 0.999  # probability that target just stays where it is at any given timestep
vmKappa   = 2       # precision parameter for the von mises distribution.
adjustKappa  = lambda x: 1 / (1 + np.exp(-1 * (x - 0.) * 32.)) # our kappa weighting function


nStates                 = gridSize**2
targLocs                = generateTargetGrid(gridSize = gridSize, is_simulated=True)
stateTrans, pStateStart = generateTransitionMatrix(gridSize = gridSize, stayProb = stayProb)


# create a PRI-T HMM object
hmm = HMMRecalibration(stateTrans, targLocs, pStateStart, vmKappa, adjustKappa = adjustKappa)


# record inference speed 
start = time.time()
targStates, pTargState = hmm.predict([dat['cursorPos']], [dat['cursorVel']])
baseline_speed   = time.time() - start 
inferredTargLoc  = hmm.targLocs[targStates.astype('int').flatten(),:]
baseline_corr    = np.corrcoef(inferredTargLoc.flatten(), dat['targetPos'].flatten())[0,1]

print('Baseline target correlation:', baseline_corr, '\n',
     'Baseline speed (sec): ', baseline_speed)


Baseline target correlation: 0.8218927670060818 
 Baseline speed (sec):  6.469006299972534


On my machine, this takes 5-6 seconds. It may be faster/slower on your end. 

Now let's cut our data up into 30 second segments and run PRI-T in parallel on each. **Note**: the first time you run this block, numba has to compile the function. This will add a fixed overhead cost at the very start. After that compilation though, the result will be cached. Run the block twice to measure speed more accurately. 


In [9]:
chunklens = 30 # length in seconds
timestep  = 0.02 # sampling rate (bin size)
parallel  = True  # tell HMMRecalibration object to parallelize across data


# -----------------------------------
chunksize  = chunklens // timestep
num_chunks = dat['cursorVel'].shape[0] // chunksize 
chunked_cursorVel = np.array_split(dat['cursorVel'], num_chunks)
chunked_cursorPos = np.array_split(dat['cursorPos'], num_chunks)

start = time.time()
targStates, pTargState = hmm.predict(chunked_cursorPos, chunked_cursorVel, parallel = parallel)
chunk_speed      = time.time() - start 

inferredTargLoc  = hmm.targLocs[targStates.astype('int').flatten(),:]
chunk_corr       = np.corrcoef(inferredTargLoc.flatten(), dat['targetPos'].flatten())[0,1]

print('Chunked target correlation: ', chunk_corr, 
      '\n Chunked speed (sec):', chunk_speed)

Chunked target correlation:  0.8263122185175749 
 Chunked speed (sec): 2.292280912399292


After letting the function compile, we get the same target correlations but slightly faster speed (2.5 seconds versus 5 seconds for the baseline). 

Depending on how many threads you can spin out on your machine and your dataset size, you may get faster or slower speeds with different chunking sizes. I'd recommend using 10-30 second segments as a default and going from there.