In [1]:
import numpy as np
import math
import matplotlib.pyplot as plt

In [2]:
'''
#matlab engine
import matlab.engine
eng = matlab.engine.start_matlab()
'''

In [2]:
from mpmath import nsum, exp, inf,fac
import math

In [5]:
import copy

In [3]:
#Poisson Surprise Function
def poissonSurprise(num,meanFiringRate,interval):
    P = exp(-meanFiringRate*interval)*nsum(lambda i:(meanFiringRate*interval)**i/fac(i),[num,inf])
    S = -math.log(P,10)
    return S


In [4]:
def getSurprise(end,start,meanFire,timestamps):
    return poissonSurprise(end-start,meanFire,timestamps[end]-timestamps[start])


getSurprises = np.vectorize(getSurprise)


def getBursts(end,start,meanFire,timestamps,crop = False):
    #arr numpy array
    if crop:
        #crop bursts
        surprises = getSurprises(end,np.arange(start,end-1),meanFire,timestamps)
        optimalStart = np.argmax(surprises) + start
        return optimalStart,end
    else:
        #get potential bursts
        surprises = getSurprises(np.arange(start+2,end+1),start,meanFire,timestamps)
        optimalEnd = np.argmax(surprises) + start + 2
        return start,optimalEnd

    

In [None]:
def potentialBurstIndex(intervals,meanISI):
    #return a list of start/end pairs implying a consecutive series of intervals less than meanISI
    #also return a list of indices in each series where intervals are less than 0.5*meanISI
    processList = []
    process05List = []
    length = len(intervals)
    for i in range(1,length-1):
        if intervals[i] < 0.5*meanISI and intervals[i+1] < meanISI:
            process05Temp = [i]
            for j in range(i+2,length):
                if intervals[j] >= meanISI:
                    break
                elif intervals[j] < 0.5*meanISI:
                    process05Temp.append(j)
                    
            #return timestamp indices
            processList.appen((i,j))
            process05List.append(process05Temp)

    return processList,process05List

In [34]:
def potentialBursts(timeStamps,maxNumBurstSpikes):
    #numSpikes considered as an np multidimensional array
    numSpikes = timeStamps.shape[0] - 1
    #return last row value of timeStamps
    totalTime = timeStamps[numSpikes]

    #interval values by discrete difference function
    intervals = np.diff(timeStamps,n=1,axis = 0)
    #mean firing rate
    meanFreq = numSpikes/totalTime
    #mean interspike intervals
    meanISI = 1/meanFreq

    burstRanges,burstSurprises = [],[]
    #Find potential bursts:
    # 1. Check for two consecutive ISI < 0.5*mean_ISI
    # 2. Include consecutive spikes until ISI > mean_ISI
    # 3. Compute surprise for each new inclusion
    # 4. Retain the spike train that has the maximum surprise

    i = 1
    while i < numSpikes - 1:
        burstEndIndex = i+2
        if intervals[i] < 0.5*meanISI and intervals[i+1] < meanISI:

            period = intervals[i] + intervals[i+1]

            number = 2
            surprise = poissonSurprise(number,meanFreq,period)
            
            maxSurprise = surprise
            '''
            if surprise >= maxSurprise:
                burstEndIndex = i + 1
                maxSurprise = surprise
            '''
            
            j = i + 2

            while j <= numSpikes-1 and intervals[j] <= meanISI and number <= maxNumBurstSpikes:
                period += intervals[j]
                number += 1
                surprise = poissonSurprise(number,meanFreq,period)

                if surprise >= maxSurprise:
                    burstEndIndex = j 
                    maxSurprise = surprise

                j += 1
            
            burstRanges.append([i,burstEndIndex])
            burstSurprises.append(maxSurprise)
        
        i = burstEndIndex + 1
    
    return burstRanges,burstSurprises,intervals,totalTime,meanFreq,numSpikes

                

In [35]:
#Maximize surprise within the detected bursts by cropping spikes at the beginning
def cropBursts(burstRanges,burstSurprises,intervals,meanFreq):
    cropBurstRanges,cropBurstSurprises = [],[]
    numBursts = len(burstRanges)
    for i in range(numBursts):
        surprise = burstSurprises[i]
        maxSurprise = surprise
        cropStartIndex = burstRanges[i][0]

        startIndex = burstRanges[i][0] + 1
        endIndex = burstRanges[i][1]

        while startIndex < endIndex:
            period = sum(intervals[startIndex+1:endIndex+1])
            number = endIndex - startIndex
            surprise = poissonSurprise(number,meanFreq,period)

            if surprise >= maxSurprise:
                cropStartIndex = startIndex
                maxSurprise = surprise

            startIndex += 1
        
        cropBurstRanges.append([cropStartIndex,endIndex])
        cropBurstSurprises.append(maxSurprise)
    
    cropNumBursts = len(cropBurstRanges)

    return cropBurstRanges,cropBurstSurprises,cropNumBursts

In [36]:
#Retain bursts with at least 3 spikes in them.
def finalBursts(cropBurstRanges,cropBurstSurprises,cropNumBursts,minSurprise,numSpikes):   
    #create vector of zeros with same row number as timeStamps (neglecting first row as it's always zero as reference)
    burstIndicator = np.zeros(numSpikes,1)
    finalBurstRanges,finalBurstSurprises = [],[]
    for i in range(cropNumBursts):
        numBurstSpikes = cropBurstRanges[i][1] - cropBurstRanges[i][0] + 1
        surprise = cropBurstSurprises[i]
        if numBurstSpikes >= 3 and surprise >= minSurprise:
            finalBurstRanges.append(cropBurstRanges[i])
            finalBurstSurprises.append(cropBurstSurprises[i])

            burstIndicator[cropBurstRanges[i][0]:cropBurstRanges[i][1]] = 1
    
    finalNumBursts = len(finalBurstRanges)

    return burstIndicator,finalBurstRanges,finalBurstSurprises,finalNumBursts

In [37]:
def detectBursts(timeStamps,minSurprise,maxNumBurstSpikes):
    burstRanges,burstSurprises,intervals,totalTime,meanFreq,numSpikes = potentialBursts(timeStamps,maxNumBurstSpikes)

    cropBurstRanges,cropBurstSurprises,cropNumBursts = cropBursts(burstRanges,burstSurprises,intervals,meanFreq)
    
    burstIndicator,finalBurstRanges,finalBurstSurprises,finalNumBursts =\
          finalBursts(cropBurstRanges,cropBurstSurprises,cropNumBursts,minSurprise,numSpikes)
    
    return burstIndicator,finalNumBursts,finalBurstRanges,finalBurstSurprises,totalTime

In [38]:
'''
#quit Matlab engine
eng.quit()
'''

'\n#quit Matlab engine\neng.quit()\n'