In [None]:
import os
import numpy as np
import scipy.io as spio
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
from mpl_toolkits import mplot3d
import seaborn as sns

def generateNeuralEventData():
    np.random.seed(2)
    neuralData = np.random.uniform(0,1,size=(4,50))
    randomSimmed = np.random.uniform(0,1,size=(4,50))
    neuralDataUpd = np.where(neuralData > 0.5,1,0)
    randomSimmUpd = np.where(randomSimmed > 0.5,1,0)
    fakeNeuralData = neuralDataUpd
    xInd,yInd = np.where(neuralDataUpd == 1)
    for i in range(5):
        fakeNeuralData[xInd[i],yInd[i]] = 0
    return neuralDataUpd,fakeNeuralData,randomSimmUpd

def generateSpikeRasterPlot(inputNeuralData,saveName):
    inputTimes = np.arange(0,np.shape(inputNeuralData)[1],1)
    fig = plt.figure()
    ax1 = fig.add_subplot(1,1,1)
    for i in range(np.shape(inputNeuralData)[0]):
        currSpikes = np.where(inputNeuralData[i] == 1)[0]
        currUnit = [i]*len(currSpikes)
        ax1.scatter(currSpikes,currUnit,color='black',marker='|',s=200)
    ax1.set_xlabel('Time',fontsize=18)
    ax1.set_xticklabels('')
    ax1.set_ylabel('Neuron',fontsize=18)
    ax1.set_yticklabels('')
    #plt.show()
    fig.savefig(saveName,bbox_inches='tight')
    plt.close()
    return None

def generateFlatMatrix():
    truePosMatrix = np.zeros((128,128))
    truePosMatrix[:65,64] += 1
    truePosMatrix[64,:65] += 1
    return truePosMatrix

def generateVariantFlatMatrix():
    truePosMatrix = np.zeros((128,128))
    truePosMatrix[:40,64] += 1
    truePosMatrix[64,:40] += 1
    return truePosMatrix

def generatePositionMatrices():
    flat1 = generateFlatMatrix()
    flat2 = generateFlatMatrix()
    flat3 = generateVariantFlatMatrix()
    otherMatrix1 = np.zeros((128,128))
    otherMatrix2 = np.zeros((128,128))
    otherMatrix3 = np.zeros((128,128))
    otherMatrix1[:40,62:64] += 1
    otherMatrix2[36:65,64:67] += 1
    otherMatrix3[64:67,30:67] += 1
    flat1 += otherMatrix1 + otherMatrix2 + otherMatrix3
    truePosMatrixUpd = np.where(flat1>1,1,flat1)
    reconstructMatrixUpd = np.where(flat2>1,1,flat2)
    return truePosMatrixUpd,reconstructMatrixUpd,flat3

def generateSurfacePlot(inputMatrix,saveName):
    fig = plt.figure(figsize=(5,5))
    ax1 = fig.add_subplot(1,1,1)
    sns.heatmap(inputMatrix,cmap=plt.get_cmap('binary'),vmin=0,vmax=1,cbar=False, ax=ax1)
    ax1.set_xticklabels('')
    ax1.set_yticklabels('')
    ax1.set_xlabel('x Position (cm)')
    ax1.set_ylabel('y Position (cm)')
    #plt.show()
    fig.savefig(saveName,bbox_inches='tight')
    plt.close()
    return None

#realTrain,fakeTrain,reconstructTrain = generateNeuralEventData()
#raster01 = generateSpikeRasterPlot(realTrain,'')
#raster02 = generateSpikeRasterPlot(fakeTrain,'')
#raster03 = generateSpikeRasterPlot(reconstructTrain,'')
var1,var2,var3 = generatePositionMatrices()
pos01 = generateSurfacePlot(var1,'')
pos02 = generateSurfacePlot(var2,'')
pos03 = generateSurfacePlot(var3,'')
print('Done')
