In [None]:
import os
import random
import scipy.stats
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.io as spio
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from mlxtend.evaluate import permutation_test

def loadValuesFromFile(inputFileOfInterest):
    imat = spio.loadmat(inputFileOfInterest)
    outDict = {'smithArray':imat.get('smithArray').flatten(),
               'trialCorrect':imat.get('trialCorrect').flatten(),
               'journey':imat.get('journey').flatten(),
               'hurstHPC':imat.get('hurstExponentHPC').flatten(),
               'hurstPFC':imat.get('hurstExponentPFC').flatten(),
               'spectHPC':imat.get('singularitySpectrumWidthHPC').flatten(),
               'spectPFC':imat.get('singularitySpectrumWidthPFC').flatten(),
              }
    return outDict

def createColorArray(inCorrect,inJourney):
    hexColors = []
    for tr in range(len(inCorrect)):
        if inCorrect[tr] == 1:
            if inJourney[tr] == 'NW':
                hexColors.append('#332288')
            if inJourney[tr] == 'NE':
                hexColors.append('#117733')
            if inJourney[tr] == 'SW':
                hexColors.append('#44AA99')
            if inJourney[tr] == 'SE':
                hexColors.append('#88CCEE')
        else:
            if inJourney[tr] == 'NW':
                hexColors.append('#DDCC77')
            if inJourney[tr] == 'NE':
                hexColors.append('#CC6677')
            if inJourney[tr] == 'SW':
                hexColors.append('#AA4499')
            if inJourney[tr] == 'SE':
                hexColors.append('#882255')
    return hexColors

def generatePlottingDict(inputPathDict):
    outputPlotDict = {'smithArray':inputPathDict['smithArray'].flatten(),
                      'hurstHPC':inputPathDict['hurstHPC'].flatten(),
                      'hurstPFC':inputPathDict['hurstPFC'].flatten(),
                      'colorArr':createColorArray(inputPathDict['trialCorrect'],inputPathDict['journey'])
                     }
    return outputPlotDict

def pathWalk(inputRawPath):
    pathDict = {}
    fileCount = 0
    for aRoot,aDirs,aFiles in os.walk(inputRawPath):
        for aFile in aFiles:
            if aFile.endswith('.mat'):
                currRawFile = os.path.join(aRoot,aFile)
                fileDict = loadValuesFromFile(currRawFile)
                if fileCount == 0:
                    pathDict = fileDict
                else:
                    for fkey in fileDict.keys():
                        pathDict[fkey] = np.concatenate((pathDict[fkey],fileDict[fkey]),axis=0)
                fileCount += 1
    return generatePlottingDict(pathDict)

def generateRegression(inputX,inputY):
    m,b,r,p,sem = scipy.stats.linregress(inputX,inputY)
    xRange = np.arange(0,1.01,.01)
    yRange = xRange*m+b
    return xRange,yRange,r**2,p

def generateFigure(rawInputPath,savInputPath):
    dictToPlot = pathWalk(rawInputPath)
    xPlot,yPlot,squaredR,pVal = generateRegression(dictToPlot['hurstHPC'],dictToPlot['hurstPFC'])
    print(f'R² = {squaredR}, p = {pVal}')
    fig = plt.figure(figsize=(10,10))
    ax1 = fig.add_subplot(1,1,1)
    for a,el in enumerate(dictToPlot['hurstHPC']):
        ax1.scatter(dictToPlot['hurstHPC'],dictToPlot['hurstPFC'],c=dictToPlot['colorArr'])
    ax1.plot(xPlot,yPlot,c='k',linewidth=3,linestyle='dashed')
    ax1.set_xlim([0,1.01])
    ax1.set_ylim([0,1.01])
    ax1.set_xlabel('Hurst CA1')
    ax1.set_ylabel('Hurst mPFC')
    plt.show()
    #fig.savefig(savInputPath,bbox_inches='tight')
    #plt.close()
    return

rawP = ''
savP = ''
figVar = generateFigure(rawP,savP)
#print('Done')
