In [None]:


import os
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.io as spio
import matplotlib.lines
import scipy.stats as spst
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from mlxtend.evaluate import permutation_test

def readFromTextFile(inputFile):
    accList = []
    iFile = open(inputFile,'r')
    for line in iFile.readlines():
        if 'Evaluating' in line:
            currString = line.split('- ')[1][:-4]
        if '[' in line and ']' in line and currString in line:
            currArr = line.rstrip().split(': [')[1][:-1]
            lossVal = float(currArr.split(', ')[0])
            accVal = float(currArr.split(', ')[1])
            accList.append(accVal)
    iFile.close()
    return accList

def createChanceArray():
    return np.random.uniform(0.75,0.8,size=(1000,))

def generatePermArrP(in1,in2):
    allIn = np.concatenate((in1,in2),axis=0)
    outP = []
    for val in allIn:
        currP = permutation_test([val],createChanceArray(),method='approximate',num_rounds=1000,seed=0)
        if currP < 0.05:
            outP.append('b')
        else:
            outP.append('r')
    return outP

def generateArrDF(inArr1,inArr2):
    U, p = spst.mannwhitneyu(inArr1,inArr2)
    oneLabels = np.random.uniform(0.65,.75,size=(len(inArr1),))
    twoLabels = np.random.uniform(0.95,1.05,size=(len(inArr2),))
    oneLabs = [0] * len(inArr1)
    twoLabs = [1] * len(inArr2)
    labelArr = np.concatenate((oneLabs,twoLabs),axis=0)
    labelArr1 = np.concatenate((oneLabels,twoLabels),axis=0)
    valArr = np.concatenate((inArr1,inArr2),axis=0)
    colArr = generatePermArrP(inArr1,inArr2)
    labelRes = np.reshape(labelArr,(len(labelArr),1))
    valRes = np.reshape(valArr,(len(labelArr),1))
    labelValArr = np.concatenate((labelRes,valRes),axis=1)
    accCatDF = pd.DataFrame(labelValArr,columns=['Region','Acc'])
    return labelArr1,valArr,colArr,U,round(p*2,3),accCatDF

def createPlot(file1,file2,saveFig):
    varHPC = np.asarray(readFromTextFile(file1))
    mH = round(np.mean(varHPC),3)
    sH = round(np.std(varHPC)/np.sqrt(len(varHPC)),3)
    varPFC = np.asarray(readFromTextFile(file2))
    mP = round(np.mean(varPFC),3)
    sP = round(np.std(varPFC)/np.sqrt(len(varPFC)),3)
    arrLab,arrVal,arrCol,statU,statP,valDF = generateArrDF(varHPC,varPFC)
    lStr = f'U = {statU}, p = {statP}\nHPC Mean = {mH}, HPC Std = {sH}\nPFC Mean = {mP}, PFC Std = {sP}'
    legElem1 = [Line2D([0],[0],color='white',marker='o',markersize=12,markerfacecolor='b',label='p > 0.05'),
                Line2D([0],[0],color='white',marker='o',markersize=12,markerfacecolor='r',label='p < 0.05'),
               ]
    legElem2 = [Line2D([0],[0],color='white',linestyle='',
                       label=lStr),
               ]
    
    fig = plt.figure(figsize=(20,20))
    ax1 = fig.add_subplot(2,4,1)
    ax2 = fig.add_subplot(2,2,2)
    
    ax1.scatter(arrLab,arrVal,color=arrCol)
    ax1.set_xlim([.4,1.7])
    ax1.set_ylim([0,1])
    ax1.set_xticklabels(['','HPC','','mPFC','','',''],fontsize=16)
    ax1.set_ylabel('Start Decoding Accuracy',fontsize=16)
    ax1.legend(handles=legElem1,fontsize=12,loc='upper right')
    
    sns.barplot(x='Region',y='Acc',data=valDF,palette=['silver','dimgrey'],capsize=.1,ax=ax2)
    sns.swarmplot(x='Region',y='Acc',data=valDF, color="0", alpha=1,ax=ax2)
    ax2.set_title('Model Accuracies',fontsize=18)
    ax2.set_ylabel('Start Decoding Accuracy',fontsize=16)
    ax2.set_xlabel('')
    ax2.set_xticklabels(['CA1','mPFC'],fontsize=16)
    ax2.set_ylim([0,1.2])
    ax2.legend(handles=legElem2,fontsize=14,loc='upper right')
    
    #plt.show()
    fig.savefig(saveFig,bbox_inches='tight')
    plt.close()
    return

hpcF = ''
pfcF = ''
savF = ''
createPlot(hpcF,pfcF,savF)

