In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.io
import sklearn

## sods model and standard nb-gml

In [None]:
#objective function, calculate the maximum likelihood for SOD model
def S0D_sum(x0, K, y, x):
    r=x0[0]
    b0=x0[1]
    loggam=x0[2]
    s=x0[3]
    b=x0[4:]
    b = np.append([b0],b)
    gam = np.exp(loggam)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    tmp = b.reshape(1,-1).dot(x_tmp)
    mu = (gam*np.exp(tmp)+1)**(-1/gam)
    
    import scipy.special as sc
    LL_sum =sc.gammaln(r+s*mu)+sc.gammaln(y+s-s*mu)+ \
          sc.gammaln(r+y)+sc.gammaln(s)-\
          sc.gammaln(r+y+s)-sc.gammaln(r)-\
          sc.gammaln(s*mu)-sc.gammaln(s-s*mu)-sc.gammaln(y+1)
    
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum


#objective function, calculate the maximum likelihood for NBGLM model
def NBGLM_sum(x0, K, y, x):
    r=x0[0]
    b0=x0[1]
    loggam=x0[2]
    b=x0[4:]
    gam = np.exp(loggam)
    b = np.append([b0],b)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    tmp = b.reshape(1,-1).dot(x_tmp)
    mu = (gam*np.exp(tmp)+1)**(-1/gam)
  
    
    import scipy.special as sc
    LL_sum=sc.gammaln(r+y)-sc.gammaln(y+1)-sc.gammaln(r)- \
        r/gam*np.log(gam*np.exp(tmp)+1)+y*np.log(1-mu)
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum



#objective function, calculate the maximum likelihood for poission model
def Poisson_sum(x0, K, y, x):
    b0=x0[1]
    b=x0[4:]
    b = np.append([b0],b)
    
    y = y.reshape(1,-1)
    x_tmp = np.transpose(np.hstack([np.ones((x.shape[0],1)), x]))
    mu = b.reshape(1,-1).dot(x_tmp)
    lam = np.exp(mu)
  
    import scipy.special as sc
    LL_sum=y*mu-lam-sc.gammaln(y+1)
    LL_sum = -1*np.sum(LL_sum)
    return LL_sum

## 5 fold cross validation, repeat 10 times

In [None]:
result = dict()
filename = '62814.mat'
spike_count = scipy.io.loadmat(filename)['spike_count']
spike_count = spike_count.T


for i in range(spike_count.shape[1]):
    #preparing training and test data for each neuron
    y = spike_count[:,i]
    X = np.delete(spike_count, i, axis=1)

    from sklearn.model_selection import RepeatedKFold
    n_splits = 5
    rkf = RepeatedKFold(n_splits=n_splits, n_repeats = 10, random_state = 42)
    
    llr_arr = np.zeros((n_splits,3))  
    result[str(i)] = dict()
    j = 0
    for train_index, test_index in rkf.split(X):
    #     print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        #value initiation
        r = 5
        b0 = -5
        loggam = 5
        s = 20
        M = spike_count.shape[1]
        K = spike_count.shape[0]
        rng42 = np.random.default_rng(seed=42)
        b = rng42.uniform(-1,1,M-1)
        x0 = [r,b0,loggam,s]
        x0.extend(b)

        bnds_sod = [[1,None],[None,None],[None,None],[1,None]]
        bnds_sod.extend([[-1,1]]*(M-1))

        # 3 models
        op_sod = scipy.optimize.minimize(S0D_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_sod.success
            
        except ValueError:
            print('not converge in sod model')       

        op_nbglm = scipy.optimize.minimize(NBGLM_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_nbglm.success
            
        except ValueError:
            print('not converge in nbglm model')
        
        op_poisson = scipy.optimize.minimize(Poisson_sum, x0, args=(K, y_train, X_train), \
                                     bounds = bnds_sod, method='SLSQP',options = {'maxiter':200})
        try:
            op_poisson.success      
        except ValueError:
            print('not converge in poisson model')
            
        #save model    
        result[str(i)]['op_sod_'+str(j)]=op_sod
        result[str(i)]['op_nbglm_'+str(j)]=op_nbglm
        result[str(i)]['op_poisson_'+str(j)]=op_poisson   

        #load weights
        w_sod = op_sod['x']
        w_nbglm = op_nbglm['x']
        w_poi = op_poisson['x']
        #save weights
        result[str(i)]['w_sod_'+str(j)]=w_sod
        result[str(i)]['w_nbglm_'+str(j)]=w_nbglm
        result[str(i)]['w_poi_'+str(j)]=w_poi  

        #log-likelihood
        ll_sod = -1*S0D_sum(w_sod, K, y_test, X_test)
        ll_ngblm = -1*NBGLM_sum(w_nbglm, K, y_test, X_test)
        ll_poi = -1*Poisson_sum(w_poi, K, y_test, X_test)
        
        result[str(i)]['ll_sod_'+str(j)]=ll_sod
        result[str(i)]['ll_ngblm_'+str(j)]=ll_ngblm
        result[str(i)]['ll_poi_'+str(j)]=ll_poi

        sod_nb = (ll_sod-ll_ngblm)/np.abs(ll_ngblm)
        sod_poi = (ll_sod-ll_poi)/np.abs(ll_poi)
        
        print(ll_sod,ll_ngblm,ll_poi)
        j = j+1
    

In [None]:
import pickle
import random
file_id = filename+'_experiment_data.pkl'

# Saving the objects:
with open(file_id, 'wb') as f:  # Python 3: open(..., 'wb')
    pickle.dump(result, f, protocol=-1)

In [None]:
# Getting back the objects:
with open(file_id,'rb') as f:  # Python 3: open(..., 'rb')
    data = pickle.load(f)

In [None]:
import pandas as pd
ll_sod=[]
ll_ngblm=[]
ll_poi=[]
len_neuron = len(data)
for neuron_id in range(len_neuron):
    for i in range(50): #5-fold cross-validation, repeat 10 times
        ll_sod.append(data[str(neuron_id)]['ll_sod_'+str(i)])
        ll_ngblm.append(data[str(neuron_id)]['ll_ngblm_'+str(i)])
        ll_poi.append(data[str(neuron_id)]['ll_poi_'+str(i)])

        
neuron_id = [[i]*50 for i in range(len_neuron)]
neuron_id = [i for j in neuron_id for i in j]

df = {'neuron_id':neuron_id, 'll_sod':ll_sod,'ll_ngblm':ll_ngblm,'ll_poi':ll_poi}
#data frame for all log-likelihood
df = pd.DataFrame(df)

df['sod_nbglm'] = (df['ll_sod']-df['ll_ngblm'])/np.abs(df['ll_ngblm'])
df['sod_poi'] = (df['ll_sod']-df['ll_poi'])/np.abs(df['ll_poi'])
df.to_csv(filename[:-4]+'.csv')


In [None]:
df