In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

#This code generate the average Kull-Back Leibler of Centrality metrics (KL_CM) across 5 folds cross-validation for
# multi-view fusion methods (SCA, netNorm, cMGINet, MVCF-Net and DGN)

import pickle
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset1 = 'LH_GSP'
args_dataset2 = 'RH_GSP'

# test male LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from cMGINet
with open('KL_male_tot_' + args_dataset1 + '_' + 'cMGINet', 'rb') as f:
    KL_male_cMGINet_tot_LH = pickle.load(f)
# test female LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from cMGINet
with open('KL_female_tot_' + args_dataset1 + '_' + 'cMGINet', 'rb') as f:
    KL_female_cMGINet_tot_LH = pickle.load(f)  
# test male RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from cMGINet
with open('KL_male_tot_' + args_dataset2 + '_' + 'cMGINet', 'rb') as f:
    KL_male_cMGINet_tot_RH = pickle.load(f)
# test female RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from cMGINet
with open('KL_female_tot_' + args_dataset2 + '_' + 'cMGINet', 'rb') as f:
    KL_female_cMGINet_tot_RH = pickle.load(f)    
    
# test male LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from SCA
with open('KL_male_tot_' + args_dataset1 + '_' + 'SCA', 'rb') as f:
    KL_male_SCA_tot_LH = pickle.load(f)
# test female LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from SCA
with open('KL_female_tot_' + args_dataset1 + '_' + 'SCA', 'rb') as f:
    KL_female_SCA_tot_LH = pickle.load(f)
# test male RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from SCA
with open('KL_male_tot_' + args_dataset2 + '_' + 'SCA', 'rb') as f:
    KL_male_SCA_tot_RH = pickle.load(f)
# test female RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from SCA
with open('KL_female_tot_' + args_dataset2 + '_' + 'SCA', 'rb') as f:
    KL_female_SCA_tot_RH = pickle.load(f)
    
# test male LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from DGN
with open('KL_male_tot_' + args_dataset1 + '_' + 'DGN', 'rb') as f:
    KL_male_DGN_tot_LH = pickle.load(f)
# test female LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from DGN
with open('KL_female_tot_' + args_dataset1 + '_' + 'DGN', 'rb') as f:
    KL_female_DGN_tot_LH = pickle.load(f)
# test male RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from DGN
with open('KL_male_tot_' + args_dataset2 + '_' + 'DGN', 'rb') as f:
    KL_male_DGN_tot_RH = pickle.load(f)
# test female RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from DGN
with open('KL_female_tot_' + args_dataset2 + '_' + 'DGN', 'rb') as f:
    KL_female_DGN_tot_RH = pickle.load(f)
    
# test male LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from MVCF-Net    
with open('KL_male_tot_' + args_dataset1 + '_' + 'MVCF-Net', 'rb') as f:
    KL_male_MVCFNet_tot_LH = pickle.load(f)
# test female LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from MVCF-Net
with open('KL_female_tot_' + args_dataset1 + '_' + 'MVCF-Net', 'rb') as f:
    KL_female_MVCFNet_tot_LH = pickle.load(f)
# test male RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from MVCF-Net
with open('KL_male_tot_' + args_dataset2 + '_' + 'MVCF-Net', 'rb') as f:
    KL_male_MVCFNet_tot_RH = pickle.load(f)
# test female RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from MVCF-Net
with open('KL_female_tot_' + args_dataset2 + '_' + 'MVCF-Net', 'rb') as f:
    KL_female_MVCFNet_tot_RH = pickle.load(f)

# test male LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from netNorm
with open('KL_male_tot_' + args_dataset1 + '_' + 'netNorm', 'rb') as f:
    KL_male_netNorm_tot_LH = pickle.load(f)
# test female LH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from netNorm
with open('KL_female_tot_' + args_dataset1 + '_' + 'netNorm', 'rb') as f:
    KL_female_netNorm_tot_LH = pickle.load(f)
# test male RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from netNorm
with open('KL_male_tot_' + args_dataset2 + '_' + 'netNorm', 'rb') as f:
    KL_male_netNorm_tot_RH = pickle.load(f)  
# test female RH population Kull-Back Leibler divergence of 8 centrality metrics using cbt generated from netNorm
with open('KL_female_tot_' + args_dataset2 + '_' + 'netNorm', 'rb') as f:
    KL_female_netNorm_tot_RH = pickle.load(f)
    
# Average Kull-Back Leibler of Centrality metrics (KL_CM) across 5 folds cross-validation for SCA, netNorm, cMGINet, MVCF-Net and DGN 

i = 7 # i in [0,1,3,4,5,6,7,9,10] which represent the centrality metric that will be plotted 
    
    # i==0 -- centrality metric is betweeness centrality
    # i==1 -- centrality metric is degree centrality
    # i==3 -- centrality metric is eigenvector centrality
    # i==4 -- centrality metric is PageRank
    # i==6 -- centrality metric is information centrality
    # i==7 -- centrality metric is random-walk centrality
    # i==9 -- centrality metric is Katz centrality
    # i==10 -- centrality metric is Laplacian centrality
    
KL_CM_male_cMGINet_tot_LH = KL_male_cMGINet_tot_LH[i][0] + KL_male_cMGINet_tot_LH[i][1] + \
KL_male_cMGINet_tot_LH[i][2] + KL_male_cMGINet_tot_LH[i][3] + KL_male_cMGINet_tot_LH[i][4]

KL_CM_female_cMGINet_tot_LH = KL_female_cMGINet_tot_LH[i][0] + KL_female_cMGINet_tot_LH[i][1] + \
KL_female_cMGINet_tot_LH[i][2] + KL_female_cMGINet_tot_LH[i][3] + KL_female_cMGINet_tot_LH[i][4]

KL_CM_male_cMGINet_tot_RH = KL_male_cMGINet_tot_RH[i][0] + KL_male_cMGINet_tot_RH[i][1] + \
KL_male_cMGINet_tot_RH[i][2] + KL_male_cMGINet_tot_RH[i][3] + KL_male_cMGINet_tot_RH[i][4]

KL_CM_female_cMGINet_tot_RH = KL_female_cMGINet_tot_RH[i][0] + KL_female_cMGINet_tot_RH[i][1] + \
KL_female_cMGINet_tot_RH[i][2] + KL_female_cMGINet_tot_RH[i][3] + KL_female_cMGINet_tot_RH[i][4]


KL_CM_male_SCA_tot_LH = KL_male_SCA_tot_LH[i][0] + KL_male_SCA_tot_LH[i][1] + \
KL_male_SCA_tot_LH[i][2] + KL_male_SCA_tot_LH[i][3] + KL_male_SCA_tot_LH[i][4]

KL_CM_female_SCA_tot_LH = KL_female_SCA_tot_LH[i][0] + KL_female_SCA_tot_LH[i][1] + \
KL_female_SCA_tot_LH[i][2] + KL_female_SCA_tot_LH[i][3] + KL_female_SCA_tot_LH[i][4]

KL_CM_male_SCA_tot_RH = KL_male_SCA_tot_RH[i][0] + KL_male_SCA_tot_RH[i][1] + \
KL_male_SCA_tot_RH[i][2] + KL_male_SCA_tot_RH[i][3] + KL_male_SCA_tot_RH[i][4]

KL_CM_female_SCA_tot_RH = KL_female_SCA_tot_RH[i][0] + KL_female_SCA_tot_RH[i][1] + \
KL_female_SCA_tot_RH[i][2] + KL_female_SCA_tot_RH[i][3] + KL_female_SCA_tot_RH[i][4]
  
    
KL_CM_male_DGN_tot_LH = KL_male_DGN_tot_LH[i][0] + KL_male_DGN_tot_LH[i][1] + \
KL_male_DGN_tot_LH[i][2] + KL_male_DGN_tot_LH[i][3] + KL_male_DGN_tot_LH[i][4]

KL_CM_female_DGN_tot_LH = KL_female_DGN_tot_LH[i][0] + KL_female_DGN_tot_LH[i][1] + \
KL_female_DGN_tot_LH[i][2] + KL_female_DGN_tot_LH[i][3] + KL_female_DGN_tot_LH[i][4]

KL_CM_male_DGN_tot_RH = KL_male_DGN_tot_RH[i][0] + KL_male_DGN_tot_RH[i][1] + \
KL_male_DGN_tot_RH[i][2] + KL_male_DGN_tot_RH[i][3] + KL_male_DGN_tot_RH[i][4]


KL_CM_male_MVCFNet_tot_LH = KL_male_MVCFNet_tot_LH[i][0] + KL_male_MVCFNet_tot_LH[i][1] + \
KL_male_MVCFNet_tot_LH[i][2] + KL_male_MVCFNet_tot_LH[i][3] + KL_male_MVCFNet_tot_LH[i][4]

KL_CM_female_MVCFNet_tot_LH = KL_female_MVCFNet_tot_LH[i][0] + KL_female_MVCFNet_tot_LH[i][1] + \
KL_female_MVCFNet_tot_LH[i][2] + KL_female_MVCFNet_tot_LH[i][3] + KL_female_MVCFNet_tot_LH[i][4]

KL_CM_male_MVCFNet_tot_RH = KL_male_MVCFNet_tot_RH[i][0] + KL_male_MVCFNet_tot_RH[i][1] + \
KL_male_MVCFNet_tot_RH[i][2] + KL_male_MVCFNet_tot_RH[i][3] + KL_male_MVCFNet_tot_RH[i][4]

KL_CM_female_MVCFNet_tot_RH = KL_female_MVCFNet_tot_RH[i][0] + KL_female_MVCFNet_tot_RH[i][1] + \
KL_female_MVCFNet_tot_RH[i][2] + KL_female_MVCFNet_tot_RH[i][3] + KL_female_MVCFNet_tot_RH[i][4]


KL_bc_male_netNorm_tot_RH = KL_male_netNorm_tot_RH[i][0] + KL_male_netNorm_tot_RH[i][1] + \
KL_male_netNorm_tot_RH[i][2] + KL_male_netNorm_tot_RH[i][3] + KL_male_netNorm_tot_RH[i][4]

KL_bc_female_netNorm_tot_RH = KL_female_netNorm_tot_RH[i][0] + KL_female_netNorm_tot_RH[i][1] + \
KL_female_netNorm_tot_RH[i][2] + KL_female_netNorm_tot_RH[i][3] + KL_female_netNorm_tot_RH[i][4]

KL_CM_male_netNorm_tot_LH = KL_male_netNorm_tot_LH[i][0] + KL_male_netNorm_tot_LH[i][1] + \
KL_male_netNorm_tot_LH[i][2] +  KL_male_netNorm_tot_LH[i][3] + KL_male_netNorm_tot_LH[i][4]

KL_CM_female_netNorm_tot_LH = KL_female_netNorm_tot_LH[i][0] + KL_female_netNorm_tot_LH[i][1] + \
KL_female_netNorm_tot_LH[i][2] + KL_female_netNorm_tot_LH[i][3] + KL_female_netNorm_tot_LH[i][4]

# Box plot of te average Kull-Back Leibler of Centrality metrics (KL_CM) for multi-view fusion methods

fig = plt.figure(figsize=(10, 5))
data = {'fold': [],
	'model': [],
	'value': []}
list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
#create dataframe
df2 = pd.DataFrame(data)
#append row to the dataframe
for j in range (len(KL_CM_female_DGN_tot_LH)):
    new_row1 = {'fold': 'LH-female', 'model':'netNorm', 'value':KL_CM_female_netNorm_tot_LH[j]}
    new_row2 = {'fold': 'LH-female', 'model':'SCA', 'value':KL_CM_female_SCA_tot_LH[j]}
    new_row3 = {'fold': 'LH-female', 'model':'MVCF-Net', 'value':KL_CM_female_MVCFNet_tot_LH[j]}
    new_row4 = {'fold': 'LH-female', 'model':'cMGI-Net', 'value':KL_CM_female_cMGINet_tot_LH[j]}
    new_row5 = {'fold': 'LH-female', 'model':'DGN', 'value':KL_CM_female_DGN_tot_LH[j]}
    df2 = df2.append(new_row1, ignore_index=True)
    df2 = df2.append(new_row2, ignore_index=True)
    df2 = df2.append(new_row3, ignore_index=True)
    df2 = df2.append(new_row4, ignore_index=True)
    df2 = df2.append(new_row5, ignore_index=True)
    
for j in range (len(KL_CM_male_DGN_tot_LH)):
    new_row6 = {'fold': 'LH-male', 'model':'netNorm', 'value':KL_CM_male_netNorm_tot_LH[j]}
    new_row7 = {'fold': 'LH-male', 'model':'SCA', 'value':KL_CM_male_SCA_tot_LH[j]}
    new_row8 = {'fold': 'LH-male', 'model':'MVCF-Net', 'value':KL_CM_male_MVCFNet_tot_LH[j]}
    new_row9 = {'fold': 'LH-male', 'model':'cMGI-Net', 'value':KL_CM_male_cMGINet_tot_LH[j]}
    new_row10 = {'fold': 'LH-male', 'model':'DGN', 'value':KL_CM_male_DGN_tot_LH[j]}
    df2 = df2.append(new_row6, ignore_index=True)
    df2 = df2.append(new_row7, ignore_index=True)
    df2 = df2.append(new_row8, ignore_index=True)
    df2 = df2.append(new_row9, ignore_index=True)
    df2 = df2.append(new_row10, ignore_index=True) 

for j in range (len(KL_CM_female_DGN_tot_RH)):    
    new_row11 = {'fold': 'RH-female', 'model':'netNorm', 'value':KL_CM_female_netNorm_tot_RH[j]}
    new_row12 = {'fold': 'RH-female', 'model':'SCA', 'value':KL_CM_female_SCA_tot_RH[j]}
    new_row13 = {'fold': 'RH-female', 'model':'MVCF-Net', 'value':KL_CM_female_MVCFNet_tot_RH[j]}
    new_row14 = {'fold': 'RH-female', 'model':'cMGI-Net', 'value':KL_CM_female_cMGINet_tot_RH[j]}
    new_row15 = {'fold': 'RH-female', 'model':'DGN', 'value':KL_CM_female_DGN_tot_RH[j]}
    df2 = df2.append(new_row11, ignore_index=True)
    df2 = df2.append(new_row12, ignore_index=True)
    df2 = df2.append(new_row13, ignore_index=True)
    df2 = df2.append(new_row14, ignore_index=True)
    df2 = df2.append(new_row15, ignore_index=True)
    
for j in range (len(KL_CM_male_DGN_tot_RH)):        
    new_row16 = {'fold': 'RH-male', 'model':'netNorm', 'value':KL_CM_male_netNorm_tot_RH[j]}
    new_row17 = {'fold': 'RH-male', 'model':'SCA', 'value':KL_CM_male_SCA_tot_RH[j]}
    new_row18 = {'fold': 'RH-male', 'model':'MVCF-Net', 'value':KL_CM_male_MVCFNet_tot_RH[j]}
    new_row19 = {'fold': 'RH-male', 'model':'cMGI-Net', 'value':KL_CM_male_cMGINet_tot_RH[j]}
    new_row20 = {'fold': 'RH-male', 'model':'DGN', 'value':KL_CM_male_DGN_tot_RH[j]}
    df2 = df2.append(new_row16, ignore_index=True)
    df2 = df2.append(new_row17, ignore_index=True)
    df2 = df2.append(new_row18, ignore_index=True)
    df2 = df2.append(new_row19, ignore_index=True)
    df2 = df2.append(new_row20, ignore_index=True) 
    
x1 = "fold"
y1 = "value"
hue1 = "model"
hue_order=['netNorm', 'SCA', 'MVCF-Net', 'cMGI-Net', 'DGN']
box_pairs1=[
    (("LH-female","DGN"), ("LH-female","netNorm")),
        (("LH-female","DGN"), ("LH-female","MVCF-Net")),
        (("LH-female","DGN"), ("LH-female","SCA")),
        (("LH-female","DGN"), ("LH-female","cMGI-Net")),
        (("LH-male","DGN"), ("LH-male","netNorm")),
        (("LH-male","DGN"), ("LH-male","MVCF-Net")),
        (("LH-male","DGN"), ("LH-male","SCA")),
        (("LH-male","DGN"), ("LH-male","cMGI-Net")),
        (("RH-female","DGN"), ("RH-female","netNorm")),
        (("RH-female","DGN"), ("RH-female","MVCF-Net")),
        (("RH-female","DGN"), ("RH-female","SCA")),
        (("RH-female","DGN"), ("RH-female","cMGI-Net")),
        (("RH-male","DGN"), ("RH-male","netNorm")),
        (("RH-male","DGN"), ("RH-male","MVCF-Net")),
        (("RH-male","DGN"), ("RH-male","SCA")),
        (("RH-male","DGN"), ("RH-male","cMGI-Net"))
    ]
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

ax = sns.boxplot(data=df2, x=x1, y=y1, hue=hue1)
colors = ['lightseagreen','mediumorchid','lightgreen','tomato','magenta']
sns.set_palette(sns.color_palette(colors))
sns.set_style("darkgrid")
ax.set(xlabel=None)  # remove the x axis label
ax.set(ylabel=None)  # remove the y axis label
#ax.set(ylim=(42,43.2))
#plt.yticks(np.arange(42,43.2, 0.25))
add_stat_annotation(ax, data=df2, x=x1, y=y1, hue=hue1, box_pairs=box_pairs1,
                test='t-test_paired', loc='inside', verbose=2)
ax.legend( frameon=False, prop={'size': 6})