In [None]:
import os
import numpy as np
import pandas as pd
import scipy.io as spio
import seaborn as sns
import scipy.stats
import random
import matplotlib.pyplot as plt
from mlxtend.evaluate import permutation_test

def getValuesFromFile(fileOfInterest,lkBk):
    mat = spio.loadmat(fileOfInterest)
    valGroup1 = (mat.get('regOnePredTwo').flatten())[int(lkBk-1)]
    valGroup2 = mat.get('nullRegOnePredTwo')[:,int(lkBk-1)]
    valGroup3 = (mat.get('regTwoPredOne').flatten())[int(lkBk-1)]
    valGroup4 = mat.get('nullRegTwoPredOne')[:,int(lkBk-1)]
    return valGroup1,valGroup2,valGroup3,valGroup4

def buildPlottingDataset(v1,v2,v3,v4,nv1,nv2,nv3,nv4):
    valArr = np.asarray([[1,v1],[2,v2],[3,v3],[4,v4]])
    valDF = pd.DataFrame(valArr,columns=['Label','Granger'])
    nullValArr = np.asarray([nv1,nv2,nv3,nv4])
    nullVal95CI = (np.mean(nullValArr.T,axis=0) + 2*np.std(nullValArr.T,axis=0))
    return [v1,v2,v3,v4],valDF, nullVal95CI

def createLegendString(reg1,reg2,trueValArr,nullValArr,arrP):
    legendStr = ''
    if trueValArr[0] > nullValArr[0]:
        legendStr = legendStr + f'{reg1} Pro --> {reg2} Pro, p < 0.05'
    elif trueValArr[0] < nullValArr[0]:
        legendStr = legendStr + f'{reg1} Pro --> {reg2} Pro, p = {round(arrP[0],3)}'
    if trueValArr[1] > nullValArr[1]:
        legendStr = legendStr + f'\n{reg1} Pro --> {reg2} Ret, p < 0.05'
    elif trueValArr[1] < nullValArr[1]:
        legendStr = legendStr + f'\n{reg1} Pro --> {reg2} Ret, {round(arrP[1],3)}'
    if trueValArr[2] > nullValArr[2]:
        legendStr = legendStr + f'\n{reg1} Ret --> {reg2} Pro, p < 0.05'
    elif trueValArr[2] < nullValArr[2]:
        legendStr = legendStr + f'\n{reg1} Ret --> {reg2} Pro, {round(arrP[2],3)}'
    if trueValArr[3] > nullValArr[3]:
        legendStr = legendStr + f'\n{reg1} Ret --> {reg2} Ret, p < 0.05'
    elif trueValArr[3] < nullValArr[3]:
        legendStr = legendStr + f'\n{reg1} Ret --> {reg2} Ret, {round(arrP[3],3)}'
    return legendStr

def buildFigure(rawPs,lookBack,xTickLabels,savFile):
    hmPro,hmProNull,mhPro,mhProNull = getValuesFromFile(rawPs[0],lookBack)
    hmRet,hmRetNull,mhRet,mhRetNull = getValuesFromFile(rawPs[1],lookBack)
    hmProRet,hmProRetNull,mhProRet,mhProRetNull = getValuesFromFile(rawPs[2],lookBack)
    hmRetPro,hmRetProNull,mhRetPro,mhRetProNull = getValuesFromFile(rawPs[3],lookBack)
    hmLegArr,hmGranger,hmNull = buildPlottingDataset(hmPro,hmProRet,hmRetPro,hmRet,
                                            hmProNull,hmProRetNull,hmRetProNull,hmRetNull)
    mhLegArr,mhGranger,mhNull = buildPlottingDataset(mhPro,mhProRet,mhRetPro,mhRet,
                                            mhProNull,mhProRetNull,mhRetProNull,mhRetNull)
    
    nullP1 = permutation_test([hmPro],hmProNull,method='approximate',num_rounds=1000,seed=0)
    nullP2 = permutation_test([hmProRet],hmProRetNull,method='approximate',num_rounds=1000,seed=0)
    nullP3 = permutation_test([hmRetPro],hmRetProNull,method='approximate',num_rounds=1000,seed=0)
    nullP4 = permutation_test([hmRet],hmRetNull,method='approximate',num_rounds=1000,seed=0)
    nullP5 = permutation_test([mhPro],mhProNull,method='approximate',num_rounds=1000,seed=0)
    nullP6 = permutation_test([mhProRet],mhProRetNull,method='approximate',num_rounds=1000,seed=0)
    nullP7 = permutation_test([mhRetPro],mhRetProNull,method='approximate',num_rounds=1000,seed=0)
    nullP8 = permutation_test([mhRet],mhRetNull,method='approximate',num_rounds=1000,seed=0)
    
    print(hmPro,hmProRet,hmRetPro,hmRet)
    print(mhPro,mhProRet,mhRetPro,mhRet)
    hmLegend = createLegendString('CA1','mPFC',hmLegArr,hmNull,[nullP1,nullP2,nullP3,nullP4])
    mhLegend = createLegendString('mPFC','CA1',mhLegArr,mhNull,[nullP5,nullP6,nullP7,nullP8])
    
    fig = plt.figure(figsize=(20,20))
    ax1 = fig.add_subplot(2,2,1)
    ax2 = fig.add_subplot(2,2,2)
    sns.barplot(x='Label',y='Granger',data=hmGranger,capsize=.1,ax=ax1)
    ax1.hlines(hmNull,[-0.45,.55,1.55,2.55],[0.45,1.45,2.45,3.45])
    ax1.set_ylabel('Granger Value',fontsize=14)
    ax1.set_xlabel('')
    ax1.set_xticklabels(xTickLabels,fontsize=14)
    ax1.set_title('CA1 --> mPFC',fontsize=16)
    ax1.set_ylim([0,.12])
    ax1.legend([hmLegend],loc=2,fontsize=12)
    
    sns.barplot(x='Label',y='Granger',data=mhGranger,capsize=.1,ax=ax2)
    ax2.hlines(mhNull,[-0.45,.55,1.55,2.55],[0.45,1.45,2.45,3.45])
    ax2.set_ylabel('Granger Value',fontsize=14)
    ax2.set_xlabel('')
    ax2.set_xticklabels(xTickLabels,fontsize=14)
    ax2.set_title('mPFC --> CA1',fontsize=16)
    ax2.set_ylim([0,.12])
    ax2.legend([mhLegend],loc=2,fontsize=12)
    plt.show()
    #fig.savefig(savFile,bbox_inches='tight')
    #plt.close()
    return 'Done'

rP = ['',
     ]

sP = ''

buildFigure(rP,3,['Pro --> Pro','Pro --> Ret','Ret --> Pro','Ret --> Ret'],sP)