In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.io as spio
from numpy.random import randn
from matplotlib import pyplot as plt

def convertToPosition(inputX,inputY):
    zeroArr = np.zeros((128,128))
    intInputX = inputX.astype(int)
    intInputY = inputY.astype(int)
    for p,value in enumerate(intInputX):
        if intInputX[p] < 128 and intInputY[p] < 128:
            zeroArr[intInputX[p],intInputY[p]] += 1
    zeroArr = np.reshape(zeroArr,(64,2,64,2)).sum(axis=1).sum(axis=2)
    return np.where(zeroArr>1,1,zeroArr)

def loadValuesFromFile(fileOfInterest):
    rmat = spio.loadmat(fileOfInterest)
    trueX = rmat.get('trueX')[0][2].flatten()
    trueY = rmat.get('trueY')[0][2].flatten()
    estX = rmat.get('estX')[0][2].flatten()
    estY = rmat.get('estY')[0][2].flatten()
    outArr = [convertToPosition(trueX[45:],trueY[45:]),trueX[45:],trueY[45:],
              convertToPosition(estX[45:],estY[45:]),estX[45:],estY[45:]]
    return outArr

def loadValuesFromFalseFile(inputFalseFile):
    fmat = spio.loadmat(inputFalseFile)
    fakedX = fmat.get('fakePosX')
    fakedY = fmat.get('fakePosY')
    outBlock = 0
    xOut = [None]
    yOut = [None]
    for a in range(len(fakedX)):
        if a == 0:
            currPosArr = np.reshape(convertToPosition(fakedX[a].flatten(),fakedY[a].flatten()),(1,64,64))
            outBlock = currPosArr
            xOut.append(fakedX[a].flatten())
            yOut.append(fakedY[a].flatten())
        else:
            currPosArr = np.reshape(convertToPosition(fakedX[a].flatten(),fakedY[a].flatten()),(1,64,64))
            outBlock = np.concatenate((outBlock,currPosArr),axis=0)
            xOut.append(fakedX[a].flatten())
            yOut.append(fakedY[a].flatten())
    return outBlock,xOut[1:],yOut[1:]

def generateTrialPlot(inputArray,xPos,yPos,inTitle,saveName):
    fig = plt.figure(figsize=(5,5))
    ax1 = fig.add_subplot(1,1,1)
    sns.heatmap(inputArray,cmap=plt.get_cmap('binary'),cbar=False,ax=ax1,zorder=0)
    ax1.set_xticklabels('')
    ax1.set_xlabel('x Position (x2 cm)',fontsize=16)
    ax1.set_yticklabels('')
    ax1.set_ylabel('y Position (x2 cm)',fontsize=16)
    ax1.set_title(inTitle,fontsize=18)
    for a in range(1,len(xPos),8):
        diffX = xPos[a] - xPos[int(a-1)]
        diffY = yPos[a] - yPos[int(a-1)]
        ax1.arrow(yPos[int(a-1)],xPos[int(a-1)],diffY,diffX,color='r',width=.4)
    plt.show()
    #fig.savefig(saveName,bbox_inches='tight')
    #plt.close()
    return

def generateSimmTrialPlot(inputArray,xPos,yPos,inTitle,saveName):
    fig = plt.figure(figsize=(5,5))
    ax1 = fig.add_subplot(1,1,1)
    sns.heatmap(inputArray,cmap=plt.get_cmap('binary'),cbar=False,ax=ax1,zorder=0)
    ax1.set_xticklabels('')
    ax1.set_xlabel('x Position (x2 cm)',fontsize=16)
    ax1.set_yticklabels('')
    ax1.set_ylabel('y Position (x2 cm)',fontsize=16)
    ax1.set_title(inTitle,fontsize=18)
    for a in range(1,len(xPos),8):
        diffX = xPos[a] - xPos[int(a-1)]
        diffY = yPos[a] - yPos[int(a-1)]
        ax1.arrow(yPos[int(a-1)],xPos[int(a-1)],diffY,diffX,color='r',width=.4)
    plt.show()
    #fig.savefig(saveName,bbox_inches='tight')
    #plt.close()
    return

#Repeat for permuted array
rawFile = ''
falseFile = ''
savName = []
savNamesFake = []
truePos,xTrue,yTrue,simmedPos,posX,posY = loadValuesFromFile(rawFile)
falsePos,falseX,falseY = loadValuesFromFalseFile(falseFile)
figVar01 = generateTrialPlot(truePos,np.flip(xTrue/2),np.flip(yTrue/2),'Actual Animal Position',savName[0])
figVar02 = generateSimmTrialPlot(simmedPos,posX/2,posY/2,'Estimated Animal Position',savName[1])
figVar03 = generateTrialPlot(falsePos[0],falseX[0]/2,falseY[0]/2,'Perturbed Animal Position',savNamesFake[0])
figVar04 = generateTrialPlot(falsePos[1],falseX[1]/2,falseY[1]/2,'Perturbed Animal Position',savNamesFake[1])
figVar05 = generateTrialPlot(falsePos[2],falseX[2]/2,falseY[2]/2,'Perturbed Animal Position',savNamesFake[2])
figVar06 = generateTrialPlot(falsePos[3],falseX[3]/2,falseY[3]/2,'Perturbed Animal Position',savNamesFake[3])
figVar07 = generateTrialPlot(falsePos[4],falseX[4]/2,falseY[4]/2,'Perturbed Animal Position',savNamesFake[4])
figVar08 = generateTrialPlot(falsePos[5],falseX[5]/2,falseY[5]/2,'Perturbed Animal Position',savNamesFake[5])
figVar09 = generateTrialPlot(falsePos[6],falseX[6]/2,falseY[6]/2,'Perturbed Animal Position',savNamesFake[6])
figVar10 = generateTrialPlot(falsePos[7],falseX[7]/2,falseY[7]/2,'Perturbed Animal Position',savNamesFake[7])
figVar11 = generateTrialPlot(falsePos[8],falseX[8]/2,falseY[8]/2,'Perturbed Animal Position',savNamesFake[8])
figVar12 = generateTrialPlot(falsePos[9],falseX[9]/2,falseY[9]/2,'Perturbed Animal Position',savNamesFake[9])
print('Done')
