In [1]:
import os
import random
import numpy as np
import scipy.io as spio
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
from mpl_toolkits import mplot3d

def phi(inVolt,threshold):
    currArr = inVolt/threshold
    return np.where(currArr>1,1,currArr)

def galvesLocherbachSimulator(unitCount,connGraph,remain,thresh,timeSteps):
    #Unit Count, Connection Graph, and timeSteps are all known
    random.seed(2022) #Pick a RNG seed, does not matter which one, just make sure it consistent across all runs
    outputSimmedSpikeTrain = 0
    sampV = np.random.randint(0,thresh,size=(unitCount,))
    for t in range(timeSteps):
        spikeProb = phi(sampV,thresh)
        randArr = np.random.random(size=(unitCount,))
        trueSpike = np.reshape(np.where(spikeProb>randArr,1,0),(len(spikeProb),1)).astype(int)
        if np.shape(outputSimmedSpikeTrain) == ():
            outputSimmedSpikeTrain = trueSpike
        else:
            outputSimmedSpikeTrain = np.concatenate((outputSimmedSpikeTrain,trueSpike),axis=1)
        voltSpike = np.where(spikeProb>randArr,0,spikeProb)
        voltSpike = voltSpike*remain
        sampV = voltSpike+np.sum(voltSpike*connGraph)
    return outputSimmedSpikeTrain.astype(int)

def loadTrueValues():
    rawFile = ''
    rmat = spio.loadmat(rawFile)
    return rmat.get('connectionAll')

def getRealSpikeTrain():
    realFile = ''
    remat = spio.loadmat(realFile)
    return remat.get('spaSpike')[:,:5000]

def buildSimulatedSpikeTrain():
    graphConn = loadTrueValues()
    unCount = np.shape(graphConn)[0]
    fakeSpikeTrain = galvesLocherbachSimulator(unCount,graphConn,1,18.0475,5000)
    return fakeSpikeTrain

def generateSpikeRasterPlot(neuralData,saveFigure):
    fig = plt.figure()
    ax1 = fig.add_subplot(1,1,1)
    ax1.eventplot(neuralData, color = 'k',linelengths=[.3]*np.shape(neuralData)[0])
    ax1.set_xlabel('Time',fontsize=18)
    ax1.set_xticklabels('')
    ax1.set_ylabel('Neuron',fontsize=18)
    #plt.show()
    fig.savefig(saveFigure,bbox_inches='tight')
    plt.close()
    return

sav01 = ''
sav02 = ''
trueSpikeTrain = getRealSpikeTrain()
simmSpikeTrain = buildSimulatedSpikeTrain()
fig01 = generateSpikeRasterPlot(trueSpikeTrain[:25,:400],sav01)
fig02 = generateSpikeRasterPlot(simmSpikeTrain[:25,:400],sav02)
print('Done')


Done
