In [7]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.io
import sklearn

## sods model and standard nb-gml

In [8]:
#objective function, calculate the maximum likelihood for SOD model
def S0D_sum(x0, K, y, x):
    r=x0[0]
    b0=x0[1]
    loggam=x0[2]
    s=x0[3]
    b=x0[4:]
    b = np.append([b0],b)
    gam = np.exp(loggam)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    tmp = b.reshape(1,-1).dot(x_tmp)
    mu = (gam*np.exp(tmp)+1)**(-1/gam)
    
    import scipy.special as sc
    LL_sum =sc.gammaln(r+s*mu)+sc.gammaln(y+s-s*mu)+ \
          sc.gammaln(r+y)+sc.gammaln(s)-\
          sc.gammaln(r+y+s)-sc.gammaln(r)-\
          sc.gammaln(s*mu)-sc.gammaln(s-s*mu)-sc.gammaln(y+1)
    
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum


#objective function, calculate the maximum likelihood for NBGLM model
def NBGLM_sum(x0, K, y, x):
    r=x0[0]
    b0=x0[1]
    loggam=x0[2]
    b=x0[4:]
    gam = np.exp(loggam)
    b = np.append([b0],b)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    tmp = b.reshape(1,-1).dot(x_tmp)
    mu = (gam*np.exp(tmp)+1)**(-1/gam)
  
    
    import scipy.special as sc
    LL_sum=sc.gammaln(r+y)-sc.gammaln(y+1)-sc.gammaln(r)- \
        r/gam*np.log(gam*np.exp(tmp)+1)+y*np.log(1-mu)
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum



#objective function, calculate the maximum likelihood for poission model
def Poisson_sum(x0, K, y, x):
    b0=x0[1]
    b=x0[4:]
    b = np.append([b0],b)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    mu = b.reshape(1,-1).dot(x_tmp)
    lam = np.exp(mu)
  
    import scipy.special as sc
    LL_sum=y*mu-lam-sc.gammaln(y+1)
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum

## 5 fold cross validation, repeat 10 times

In [None]:
result = dict()
filename = '62814.mat'
spike_count = scipy.io.loadmat(filename)['spike_count']
spike_count = spike_count.T


for i in range(spike_count.shape[1]):
    #preparing training and test data for each neuron
    y = spike_count[:,i]
    X = np.delete(spike_count, i, axis=1)

    from sklearn.model_selection import RepeatedKFold
    n_splits = 5
    rkf = RepeatedKFold(n_splits=n_splits, n_repeats = 10, random_state = 42)
    
    llr_arr = np.zeros((n_splits,3))  
    result[str(i)] = dict()
    j = 0
    for train_index, test_index in rkf.split(X):
    #     print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        #value initiation
        r = 5
        b0 = -5
        loggam = 5
        s = 20
        M = spike_count.shape[1]
        K = spike_count.shape[0]
        rng42 = np.random.default_rng(seed=42)
        b = rng42.uniform(-1,1,M-1)
        x0 = [r,b0,loggam,s]
        x0.extend(b)

        bnds_sod = [[1,None],[None,None],[None,None],[1,None]]
        bnds_sod.extend([[-1,1]]*(M-1))

        # 3 models
        op_sod = scipy.optimize.minimize(S0D_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_sod.success
            
        except ValueError:
            print('not converge in sod model')       

        op_nbglm = scipy.optimize.minimize(NBGLM_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_nbglm.success
            
        except ValueError:
            print('not converge in nbglm model')
        
        op_poisson = scipy.optimize.minimize(Poisson_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_poisson.success      
        except ValueError:
            print('not converge in poisson model')
            
        #save model    
        result[str(i)]['op_sod_'+str(j)]=op_sod
        result[str(i)]['op_nbglm_'+str(j)]=op_nbglm
        result[str(i)]['op_poisson_'+str(j)]=op_poisson   

        #load weights
        w_sod = op_sod['x']
        w_nbglm = op_nbglm['x']
        w_poi = op_poisson['x']
        #save weights
        result[str(i)]['w_sod_'+str(j)]=w_sod
        result[str(i)]['w_nbglm_'+str(j)]=w_nbglm
        result[str(i)]['w_poi_'+str(j)]=w_poi  

        #log-likelihood
        ll_sod = -1*S0D_sum(w_sod, K, y_test, X_test)
        ll_ngblm = -1*NBGLM_sum(w_nbglm, K, y_test, X_test)
        ll_poi = -1*Poisson_sum(w_poi, K, y_test, X_test)
        
        result[str(i)]['ll_sod_'+str(j)]=ll_sod
        result[str(i)]['ll_ngblm_'+str(j)]=ll_ngblm
        result[str(i)]['ll_poi_'+str(j)]=ll_poi

        sod_nb = (ll_sod-ll_ngblm)/np.abs(ll_ngblm)
        sod_poi = (ll_sod-ll_poi)/np.abs(ll_poi)
        
        print(ll_sod,ll_ngblm,ll_poi)
        j = j+1
    

-41.02858476472456 -42.198940829125306 -42.78362769577117
-40.03390524553603 -39.528730410355216 -39.613227881474316




-22.625732061262624 -22.485827367966493 -22.47790191938005
-30.482447450651122 -31.240352720081425 -31.606079051603356
-39.004480277934114 -38.74914620333369 -40.226078632606175


  


-50.560198161653844 -50.52741620380243 -51.7448899412709
-25.295759181167504 -26.401608520463743 -26.39719960779295
-32.111324232505 -32.207638199515905 -31.414824486142706
-36.238639173627256 -35.44877562375478 -35.72290626243766
-37.74943883839737 -38.56124791019712 -38.71159526002425
-36.924790533560135 -37.5751382999786 -37.949753841266855
-39.780473903597056 -39.230950695773586 -39.077139016092985
-23.863403514959977 -24.575044847288318 -25.460520998341824
-38.99717591589957 -38.40977554660476 -37.51107802682662
-31.250543235257563 -32.35238324712871 -31.804383990471734
-32.68362604716579 -32.20749424906162 -32.25559283383093
-36.96852910996792 -36.40159160168667 -36.69935815273332
-43.37836515721167 -44.72920078018332 -45.110315869473276
-44.50567476050649 -43.08397379990408 -45.96066360371096
-20.30374452660887 -20.578347229626946 -20.095234287928626
-34.749795860405044 -36.1474379248085 -37.811140319046864
-28.435899564164323 -28.251211080949503 -28.541268831660556
-39.09312162



-45.94267370024956 -47.944014659353684 -48.13400360953356
-42.0953126558761 -41.288117283827155 -41.17047422599401
-45.02874744949745 -45.19072186017018 -44.65992520932443
-28.318392995191324 -29.17594331646347 -29.126948077880737
-36.57323890998119 -36.133485411010334 -36.71503089165917
-23.12106096489237 -22.9967568913859 -22.951029830938964




-52.90224651951758 -54.268786192748955 -54.330376488928714
-28.66549393867791 -28.45172779866885 -28.816712372613974
-32.76648259327041 -32.30697998930244 -32.70685780878077
-38.92926211039948 -37.6696016212592 -37.221754749829195
-16.89530209821531 -16.996748582382153 -18.004829743030722
-33.98242201916701 -34.467164556719126 -35.777450566495396
-47.680716239699926 -48.16515765214661 -49.03408045845612
-41.94790957015012 -40.72102956931072 -41.30292171176863
-40.28203552403167 -39.368739023162696 -39.55003700849955
-31.322908971522267 -32.00159448883076 -33.18803536190383
-22.056084031563692 -21.3226022027912 -21.41474415473064
-36.47238881155707 -36.335707365408695 -36.945944008932074
-26.68701560854661 -26.750118952143456 -27.327550955614402
-36.58449890397904 -36.12942161967241 -36.48772064338682
-32.77108664332574 -31.45243748139582 -31.287471370437228
-38.2873704484799 -38.66735522345466 -39.73391775214451
-36.38674046129746 -36.26619314015389 -37.928169862356754
-51.768389238018



-78.91239283976313 -78.58844500391045 -78.57311248454884
-78.43452834422308 -77.58374808792684 -77.54866876596421
-86.01226303853846 -86.5527571862053 -86.50331918046614
-82.79635709619247 -82.16742779250269 -82.09515307603657
-70.77107437339978 -70.12393204069261 -70.09159876293059
-71.95871043177405 -71.93786437577886 -71.94258193001613
-70.61892122291965 -70.04211496703776 -70.03800768145676
-73.45359806668219 -72.78180399835342 -72.75864481067681
-72.76546598833532 -72.4987177570047 -72.48677953474382
-78.22637634266167 -78.05400210062416 -78.06269812220943
-79.00962210608981 -78.92571675298896 -78.90501729557359
-79.49331394661417 -79.17277098756732 -79.2052964936748
-7.992574990565071 -8.63631807967046 -8.769820982959699
-12.694985835570245 -18.204928082689193 -18.804908850405244


  


-11.885541141481173 -11.192897885027154 -11.262968252981132
-11.23965761411612 -19.925738932283036 -21.9191464523111
-19.591397936215156 -23.4934141562634 -24.40841587158867
-29.53921589529933 -45.25382294595123 -47.61294537246324
-7.15429692898541 -7.190871709136277 -7.2862222717783265
-7.814294911463172 -8.28635573848368 -8.437922596088974
-20.134764018276755 -24.036219136264997 -24.722375353155684
-2.1037396954198804 -2.9355339577879893 -3.0636362361251557
-28.30066854446308 -30.248248164718948 -30.89774817607399
-16.390528226172997 -21.835064238394395 -22.413968844646806
-6.698843662564393 -7.102563340023312 -7.363229940008953
-2.012564761003232 -2.6684242323412475 -2.757999457544148
-12.181483893766952 -20.703452267357594 -22.893378342770813
-2.586729828359185 -3.853473959767597 -4.089442792393933
-21.74427970873372 -25.848792498973058 -26.752122328731677
-1.9356752475347976 -2.526992800267611 -2.5974588834215333
-28.2960110519738 -44.68638476874342 -47.04917357030311
-11.27866218

-31.138624964878602 -30.34079332664655 -30.930949267119004
-17.07679737107702 -21.778958982541962 -23.671118933730277
-22.263421745311053 -25.76631040253374 -27.574995510897256
-27.16106953460028 -29.556676397787573 -30.135191715661463
-18.42628089397295 -18.42837908975449 -19.207020375911398
-21.320859495655327 -28.72893476283039 -32.31509688795362
-36.545779128870535 -37.43677010457452 -38.828936978195365
-26.346949587917525 -25.15365631785728 -25.114429300500426
-37.34608054101946 -39.45359529823935 -41.6880007809367
-28.232925204491654 -30.704807902296523 -33.639966073581355
-41.72824192345962 -43.17856181445443 -46.3190163993028
-43.12001702945605 -46.51495540829258 -51.257070898316364
-33.64084643919676 -35.04988068088908 -37.50539692537425
-33.19978109434511 -33.175803317873765 -34.58398934672369
-36.16977383631249 -35.85509014343704 -36.22997738523739
-23.91612585447767 -25.81730525742732 -27.15699562232907
-29.065518129824532 -27.982870742255813 -27.872248543976067
-25.4448489

  if __name__ == '__main__':


-30.87084970339604 -32.03777838213487 -33.504989539848886
-42.292391212322855 -43.37786724592682 -45.75278938979627
-28.004842502247964 -29.09554444511044 -30.867233454014823
-29.04654858404195 -31.19040151158149 -33.406349843776134
-53.66611166622827 -55.995209967349076 -59.477192178837484


In [None]:
import pickle
import random
file_id = filename+'_experiment_data.pkl'

# Saving the objects:
with open(file_id, 'wb') as f:  # Python 3: open(..., 'wb')
    pickle.dump(result, f, protocol=-1)

In [None]:
# Getting back the objects:
with open(file_id,'rb') as f:  # Python 3: open(..., 'rb')
    data = pickle.load(f)

In [None]:
import pandas as pd
ll_sod=[]
ll_ngblm=[]
ll_poi=[]
len_neuron = len(data)
for neuron_id in range(len_neuron):
    for i in range(50): #5-fold cross-validation, repeat 10 times
        ll_sod.append(data[str(neuron_id)]['ll_sod_'+str(i)])
        ll_ngblm.append(data[str(neuron_id)]['ll_ngblm_'+str(i)])
        ll_poi.append(data[str(neuron_id)]['ll_poi_'+str(i)])

        
neuron_id = [[i]*50 for i in range(len_neuron)]
neuron_id = [i for j in neuron_id for i in j]

df = {'neuron_id':neuron_id, 'll_sod':ll_sod,'ll_ngblm':ll_ngblm,'ll_poi':ll_poi}
#data frame for all log-likelihood
df = pd.DataFrame(df)

df['sod_nbglm'] = (df['ll_sod']-df['ll_ngblm'])/np.abs(df['ll_ngblm'])
df['sod_poi'] = (df['ll_sod']-df['ll_poi'])/np.abs(df['ll_poi'])
df.to_csv(filename[:-4]+'.csv')


In [51]:
df

Unnamed: 0,neuron_id,ll_sod,ll_ngblm,ll_poi,sod_nbglm,sod_poi
0,0,-29.556499,-30.526470,-31.347217,0.031775,0.057125
1,0,-21.712441,-19.227504,-20.815993,-0.129239,-0.043065
2,0,-30.231481,-32.303924,-33.962565,0.064155,0.109859
3,0,-25.916129,-27.509982,-28.703596,0.057937,0.097112
4,0,-25.195634,-24.570371,-26.456237,-0.025448,0.047649
...,...,...,...,...,...,...
695,13,-26.624173,-29.691546,-30.920505,0.103308,0.138948
696,13,-16.564053,-15.615662,-15.592336,-0.060733,-0.062320
697,13,-7.524627,-8.044661,-8.114728,0.064643,0.072720
698,13,-10.412726,-13.293654,-15.054642,0.216715,0.308338
