In [8]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# from J_broja_PID import pid, BROJA_2PID_Exception
from BROJA_2PID import pid, BROJA_2PID_Exception

In [9]:
## EqualState assign states with equal possibility for input array x
def EqualState(x, num_state):
    xs=np.sort(x)
    binlen=int(len(x)/num_state-0.5) #round
    edges = xs[np.arange(num_state+1)*binlen]
    xstate=np.zeros(len(x))
    for i in range(num_state):
        xstate[x>=edges[i]] = i
    xstate = xstate.astype(int)
    return xstate, edges

In [10]:
def Histo3D2Dict(P):
    PDF = dict()
    for i in range(np.size(P, 0)):
        for j in range(np.size(P, 1)):
            for k in range(np.size(P, 2)):
                PDF[(k,i,j)] = float(P[i,j,k])
    return PDF

In [11]:
def PIfunc(r, x, v, dt, window, method = 'Beer', perSpike = False):
# PIfunc is a code performing PID on I(r,{x,v}; \detla t) for \detla t in 'window'.
# It returns timeshift and Information,
    # timeshift is time-axis of TLPI;
    # Information is a dictionary, which cotains 4 kinds of PI with one or two proposal.

# r, x, v are three random variables, they are integers.
# dt is the length of the time bin.
# window is a 2-element array, indicating the time-region of interest for TLPI.
# The keyword 'method' decide which proposal for PID would be carried out and exported.
    # The 'method' 'Beer' proposed by Beer et al. (2010). https://arxiv.org/abs/1004.2515.
    # The 'method' 'BROJA_2PID' proposed by Bertschinger et al. (2014). https://dx.doi.org/10.3390/e16042161.
    # There is a function 'BROJA_2PID' for PID part written by Makkeh et al. (2018). https://dx.doi.org/10.3390/e20040271
#     for k, v in kwargs.items():
#         vars()[k] = v
#     if ~('method' in locals()):      method = 'Beer'

    negshift = window[0] # second
    posshift = window[1] # second
    shiftlen = (posshift-negshift)/dt+1
    timeshift = np.linspace(negshift,posshift,int(shiftlen))
    bitshift = np.linspace(negshift/dt,posshift/dt,int(shiftlen),dtype = 'int16') #time-bin-axis of TLPI;
    Information = dict()
    if method == 'both' or method == 'BROJA':
        Information[('BROJA_2PID','SI')]=np.zeros(len(bitshift))
        Information[('BROJA_2PID','UIx')]=np.zeros(len(bitshift))
        Information[('BROJA_2PID','UIv')]=np.zeros(len(bitshift))
        Information[('BROJA_2PID','CI')]=np.zeros(len(bitshift))
    if method == 'both' or method == 'Beer':
        Information[('Beer','Red')]=np.zeros(len(bitshift)) 
        Information[('Beer','UIx')]=np.zeros(len(bitshift))
        Information[('Beer','UIv')]=np.zeros(len(bitshift))
        Information[('Beer','Syn')]=np.zeros(len(bitshift))

    parms = dict() # necessary for 'BROJA_2PID'
    parms['max_iters'] = 20
    for i in range(len(bitshift)):
        # shift '\detla t' between inputs
        # shift>0 => r shifted to positive side
        xx=[]
        vv=[]
        rr=[]
        shift=bitshift[i] 
        if shift>0:
            xx=x[shift:]
            vv=v[shift:]
            rr=r[:(-1*shift)]
        elif shift==0:
            xx=x
            vv=v
            rr=r
        elif shift<0:
            xx=x[:shift]
            vv=v[:shift]
            rr=r[(-1*shift):]
        #find weight of each states by 3D histogram 
        xedges = np.append(np.unique(xx),(max(xx)+1))
        vedges = np.append(np.unique(vv),(max(vv)+1))
        redges = np.append(np.unique(rr),(max(rr)+1))
        dat = np.concatenate((xx[:,np.newaxis], vv[:,np.newaxis],rr[:,np.newaxis]), axis=1)
        N, _ = np.histogramdd(dat, bins=(xedges, vedges, redges)) #3-D Mass Matrix
        #Calculate probability of each set (Density Matrices)
        #All matrices are 3D, x on 0th axis, v on 1st axis, r on 2nd axis
        px=(np.sum(N,axis=(1,2))/np.sum(N))[:, np.newaxis, np.newaxis]
        pv=(np.sum(N,axis=(0,2))/np.sum(N))[np.newaxis, :, np.newaxis]
        pr=(np.sum(N,axis=(0,1))/np.sum(N))[np.newaxis ,np.newaxis, :]
        pxv=(np.sum(N,axis=2)/np.sum(N))[:, :, np.newaxis]
        pxr=(np.sum(N,axis=1)/np.sum(N))[:, np.newaxis, :]
        pvr=(np.sum(N,axis=0)/np.sum(N))[np.newaxis, :, :]
        pxvr=(N/np.sum(N))
        
        if method == 'both' or method == 'BROJA':
            PDF=Histo3D2Dict(pxvr) # A transfer of form of Density Matrices
            BROJA_2PID = pid(PDF, cone_solver="ECOS", output=0, **parms) #Makkeh et al. (2018).
            Information[('BROJA_2PID','SI')][i]=BROJA_2PID['SI']/dt
            Information[('BROJA_2PID','UIx')][i]=BROJA_2PID['UIY']/dt
            Information[('BROJA_2PID','UIv')][i]=BROJA_2PID['UIZ']/dt
            Information[('BROJA_2PID','CI')][i]=BROJA_2PID['CI']/dt
        if method == 'both' or method == 'Beer':
            MIxr=np.nansum(pxr*np.log2(pxr/px/pr))/dt #I(r,x;\detla t)
            MIvr=np.nansum(pvr*np.log2(pvr/pv/pr))/dt #I(r,v;\detla t)
            MIxvR=np.nansum(pxvr*np.log2(pxvr/pxv/pr))/dt #I(r,{x,v};\detla t)
            # specific surprise 
            PI_xR = np.nansum(pxr*np.log2(pxr/px/pr), axis = (0,1)) #i_s(R=r,X) for all r
            PI_vR = np.nansum(pvr*np.log2(pvr/pv/pr), axis = (0,1)) #i_s(R=r,V) for all r
            R = sum(np.minimum(PI_xR, PI_vR))/dt # Redunancy of I(r,{x,v};\detla t)
            Information[('Beer','Red')][i] = R 
            Information[('Beer','UIx')][i] = MIxr - R
            Information[('Beer','UIv')][i] = MIvr - R
            Information[('Beer','Syn')][i] = MIxvR - MIxr - MIvr + R

        if perSpike:
            meanFiringRate = np.sum(rr)/(len(rr)*dt)
            for k in Information.keys():
                Information[k][i] = Information[k][i]/meanFiringRate
    return timeshift, Information

In [12]:
def MIfunc(r, x, dt, window, perSpike = False):
    
    negshift = window[0] # second
    posshift = window[1] # second
    shiftlen = (posshift-negshift)/dt+1
    timeshift = np.linspace(negshift,posshift,int(shiftlen))
    bitshift = np.linspace(negshift/dt,posshift/dt,int(shiftlen),dtype = 'int16') #time-bin-axis of TLPI;
    MIxr = np.zeros(len(bitshift))

    for i in range(len(bitshift)):
        xx=[]
        rr=[]
        shift=bitshift[i] 
        if shift>0:
            xx=x[shift:]
            rr=r[:(-1*shift)]
        elif shift==0:
            xx=x
            rr=r
        elif shift<0:
            xx=x[:shift]
            rr=r[(-1*shift):]
        xedges = np.append(np.unique(xx),(max(xx)+1))
        redges = np.append(np.unique(rr),(max(rr)+1))
        N, _, _ = np.histogram2d(xx, rr, bins=(xedges, redges)) 
        px=(np.sum(N,axis=1)/np.sum(N))[:, np.newaxis]
        pr=(np.sum(N,axis=0)/np.sum(N))[np.newaxis, :]
        pxr=N/np.sum(N)
        
        # the unit is bit/s or bit/spike
        MIxr[i]=np.nansum(pxr*np.log2(pxr/px/pr))/dt
        if perSpike:
            MIxr[i]=MIxr[i]/ ( np.sum(rr)/(len(rr)*dt) ) #I(r,x;\detla t)
    return timeshift, MIxr

In [13]:
def MIfunc4ISI(r, x, Spike, dt, window, PorP, cut_state_num=6):
    negshift=window[0] # second
    posshift=window[1] # second
    shiftlen=(posshift-negshift)/dt+1
    timeshift=np.linspace(negshift,posshift,int(shiftlen))
    bitshift=np.linspace(negshift/dt,posshift/dt,int(shiftlen),dtype = 'int16')
    MIxr = np.zeros(len(bitshift))
    # Hx = np.zeros(len(bitshift))

    for i in range(len(bitshift)):
        xx=[]
        rr=[]
        shift=bitshift[i]
        if shift>0:
            xx=x[shift:]
            rr=r[:(-1*shift)]
            SS=np.hstack((Spike[:sum(rr)+1],Spike[-1]-timeshift[i]))
        elif shift==0:
            xx=x
            rr=r
            SS=Spike
        elif shift<0:
            xx=x[:shift]
            rr=r[(-1*shift):]
            SS=np.hstack((-timeshift[i],Spike[-sum(rr)-1:]))
        if len(SS)-2 != sum(rr):
            print(timeshift[i], SS[0:2], i)
            print(len(SS), sum(rr))
        new_xx = []
        new_vv = []
        post_ISI,_ = EqualState(np.diff(SS)[1:],cut_state_num)
        pre_ISI,_ = EqualState(np.diff(SS)[:-1],cut_state_num)
        for j in np.squeeze(np.where(rr != 0)):
            for k in range(rr[j]):
                new_xx.append(xx[j])
        new_xx = np.array(new_xx)
        if PorP == 'post':
            new_rr = post_ISI
        elif PorP == 'pre':
            new_rr = pre_ISI
        elif PorP == 'mid':
            new_rr = (pre_ISI+post_ISI)/2
        #find weight of each states by 3D histogram 
        xedges = np.append(np.unique(new_xx),(max(new_xx)+1))
        redges = np.append(np.unique(new_rr),(max(new_rr)+1))
        N, _, _ = np.histogram2d(new_xx, new_rr, bins=(xedges, redges)) 
        px=(np.sum(N,axis=1)/np.sum(N))[:, np.newaxis]
        pr=(np.sum(N,axis=0)/np.sum(N))[np.newaxis, :]
        pxr=N/np.sum(N)

        MIxr[i]=np.nansum(pxr*np.log2(pxr/px/pr))/dt

        # Hx[i] = -np.nansum(px*np.log2(px))/dt
        # if shift==0:
        #     Hr = -np.nansum(pr*np.log2(pr))/dt
    return timeshift, MIxr

In [14]:
# def STA(r, x, dt, window):
#     negshift = window[0] # second
#     posshift = window[1] # second
#     shiftlen = (posshift-negshift)/dt+1
#     timeshift = np.linspace(negshift,posshift,int(shiftlen))
#     bitshift = np.linspace(negshift/dt,posshift/dt,int(shiftlen),dtype = 'int16') #time-bin-axis of TLPI;
#     MIxr = np.zeros(len(bitshift))

#     for i in range(len(bitshift)):
#         xx=[]
#         rr=[]
#         shift=bitshift[i] 
#         if shift>0:
#             xx=x[shift:]
#             rr=r[:(-1*shift)]
#         elif shift==0:
#             xx=x
#             rr=r
#         elif shift<0:
#             xx=x[:shift]
#             rr=r[(-1*shift):]
#         xedges = np.append(np.unique(xx),(max(xx)+1))
#         redges = np.append(np.unique(rr),(max(rr)+1))
#         N, _, _ = np.histogram2d(xx, rr, bins=(xedges, redges)) 
#         px=(np.sum(N,axis=1)/np.sum(N))[:, np.newaxis]
#         pr=(np.sum(N,axis=0)/np.sum(N))[np.newaxis, :]
#         pxr=N/np.sum(N)
        
        
#         MIxr[i]=np.nansum(pxr*np.log2(pxr/px/pr))/dt #I(r,x;\detla t)
#     return timeshift, MIxr