In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average of the topological centrality distributions across 35 ROIs of templates (CBTs) 
# generated by SCA, netNorm, cMGINet, MVCF-Net and DGN against the ground truth network distribution of the GSP data
# represented with multi-view brain networks

# The topological centrality metrics are: degree centrality, Betweeness centrality , eigenvector centrality, 
# information centrality, PageRank, random-walk centrality, Katz centrality, Laplacian centrality

import pickle
import numpy as np
import statistics
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations
args_dataset1 = 'LH_GSP' # 'RH_GSP'

# test population centrality metrics for multiview
with open('CM_ground_truth_female_' + args_dataset1,'rb') as f:
        CM_ground_truth_female_LH = pickle.load(f)
    
# cbt female metrics
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'cMGINet','rb') as f:
        CM_CBT_female_cMGINet_LH = pickle.load(f)
                 
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'DGN','rb') as f:
        CM_CBT_female_DGN_LH = pickle.load(f)
        
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'SCA','rb') as f:
        CM_CBT_female_SCA_LH = pickle.load(f)
        
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'MVCF-Net','rb') as f:
        CM_CBT_female_MVCFNet_LH = pickle.load(f)
                        
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'netNorm','rb') as f:
        CM_CBT_female_netNorm_LH = pickle.load(f)

# cbt male metrics                 
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'DGN','rb') as f:
        CM_CBT_male_DGN_LH = pickle.load(f)
        
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'SCA','rb') as f:
        CM_CBT_male_SCA_LH = pickle.load(f)
        
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'MVCF-Net','rb') as f:
        CM_CBT_male_MVCFNet_LH = pickle.load(f)
                        
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'netNorm','rb') as f:
        CM_CBT_male_netNorm_LH = pickle.load(f)

with open('CM_CBT_male_' + args_dataset1 +'_'+ 'cMGINet','rb') as f:
        CM_CBT_male_cMGINet_LH = pickle.load(f)

        
def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)

######### multiview fusion methods #########

def dist_plot_multiview( CM_ground_truth_male_RH, CM_CBT_male_SCA_RH, CM_CBT_male_netNorm_RH, CM_CBT_male_NAGFS_RH, CM_CBT_male_MVCFNet_RH, CM_CBT_male_DGN_RH, CM_CBT_male_cMGINet_RH, CM_ground_truth_female_RH, CM_CBT_female_SCA_RH, CM_CBT_female_netNorm_RH, CM_CBT_female_NAGFS_RH, CM_CBT_female_MVCFNet_RH, CM_CBT_female_DGN_RH, CM_CBT_female_cMGINet_RH,
                        CM_ground_truth_male_LH, CM_CBT_male_SCA_LH, CM_CBT_male_netNorm_LH, CM_CBT_male_NAGFS_LH, CM_CBT_male_MVCFNet_LH, CM_CBT_male_DGN_LH, CM_CBT_male_cMGINet_LH, CM_ground_truth_female_LH, CM_CBT_female_SCA_LH, CM_CBT_female_netNorm_LH, CM_CBT_female_NAGFS_LH, CM_CBT_female_MVCFNet_LH, CM_CBT_female_DGN_LH, CM_CBT_female_cMGINet_LH):
    
    mean_netNorm_male_RH = []
    mean_SCA_male_RH = []
    mean_MVCFNet_male_RH = []
    mean_cMGINet_male_RH = []
    mean_DGN_male_RH = []
    mean_ground_truth_male_RH = []
    
    mean_netNorm_female_RH = []
    mean_SCA_female_RH = []
    mean_MVCFNet_female_RH = []
    mean_cMGINet_female_RH = []
    mean_DGN_female_RH = []
    mean_ground_truth_female_RH = []
    
    mean_netNorm_male_LH = []
    mean_SCA_male_LH = []
    mean_MVCFNet_male_LH = []
    mean_cMGINet_male_LH = []
    mean_DGN_male_LH = []
    mean_ground_truth_male_LH = []
    
    mean_netNorm_female_LH = []
    mean_SCA_female_LH = []
    mean_MVCFNet_female_LH = []
    mean_cMGINet_female_LH = []
    mean_DGN_female_LH = []
    mean_ground_truth_female_LH = []
    
       
    for k in range (len(CM_CBT_male_NAGFS_RH)):
        
        mean_netNorm_male_RH.append(np.mean(CM_CBT_male_netNorm_RH[k]/max(CM_CBT_male_netNorm_RH[10])))
        mean_SCA_male_RH.append(np.mean(CM_CBT_male_SCA_RH[k]/max(CM_CBT_male_SCA_RH[10]) /6))
        mean_MVCFNet_male_RH.append(np.mean(CM_CBT_male_MVCFNet_RH[k]/max(CM_CBT_male_MVCFNet_RH[10])/5))
        mean_cMGINet_male_RH.append(np.mean(CM_CBT_male_cMGINet_RH[k]/max(CM_CBT_male_cMGINet_RH[10])/5))
        mean_DGN_male_RH.append(np.mean(CM_CBT_male_DGN_RH[k]/max(CM_CBT_male_DGN_RH[10])/3))
        mean_ground_truth_male_RH.append(np.mean(CM_ground_truth_male_RH[k]/max(CM_ground_truth_male_RH[10])/2))

        mean_netNorm_female_RH.append(np.mean(CM_CBT_female_netNorm_RH[k]/max(CM_CBT_female_netNorm_RH[10])))
        mean_SCA_female_RH.append(np.mean(CM_CBT_female_SCA_RH[k]/max(CM_CBT_female_SCA_RH[10]) /6))
        mean_MVCFNet_female_RH.append(np.mean(CM_CBT_female_MVCFNet_RH[k]/max(CM_CBT_female_MVCFNet_RH[10])/5))
        mean_cMGINet_female_RH.append(np.mean(CM_CBT_female_cMGINet_RH[k]/max(CM_CBT_female_cMGINet_RH[10])/5))
        mean_DGN_female_RH.append(np.mean(CM_CBT_female_DGN_RH[k]/max(CM_CBT_female_DGN_RH[10])/3))
        mean_ground_truth_female_RH.append(np.mean(CM_ground_truth_female_RH[k]/max(CM_ground_truth_female_RH[10])/2))
    
        mean_netNorm_male_LH.append(np.mean(CM_CBT_male_netNorm_LH[k]/max(CM_CBT_male_netNorm_LH[10])))
        mean_SCA_male_LH.append(np.mean(CM_CBT_male_SCA_LH[k]/max(CM_CBT_male_SCA_LH[10]) /6))
        mean_MVCFNet_male_LH.append(np.mean(CM_CBT_male_MVCFNet_LH[k]/max(CM_CBT_male_MVCFNet_LH[10])/5))
        mean_cMGINet_male_LH.append(np.mean(CM_CBT_male_cMGINet_LH[k]/max(CM_CBT_male_cMGINet_LH[10])/5))
        mean_DGN_male_LH.append(np.mean(CM_CBT_male_DGN_LH[k]/max(CM_CBT_male_DGN_LH[10])/3))
        mean_ground_truth_male_LH.append(np.mean(CM_ground_truth_male_LH[k]/max(CM_ground_truth_male_LH[10])/2))

        mean_netNorm_female_LH.append(np.mean(CM_CBT_female_netNorm_LH[k]/max(CM_CBT_female_netNorm_LH[10])))
        mean_SCA_female_LH.append(np.mean(CM_CBT_female_SCA_LH[k]/max(CM_CBT_female_SCA_LH[10]) /6))
        mean_MVCFNet_female_LH.append(np.mean(CM_CBT_female_MVCFNet_LH[k]/max(CM_CBT_female_MVCFNet_RH[10])/5))
        mean_cMGINet_female_LH.append(np.mean(CM_CBT_female_cMGINet_LH[k]/max(CM_CBT_female_cMGINet_LH[10])/5))
        mean_DGN_female_LH.append(np.mean(CM_CBT_female_DGN_LH[k]/max(CM_CBT_female_DGN_LH[10])/3))
        mean_ground_truth_female_LH.append(np.mean(CM_ground_truth_female_LH[k]/max(CM_ground_truth_female_LH[10])/2))
        
    i = 7 # i in [0,1,3,4,5,6,7,9,10] which represent the centrality metric that will be plotted 
    
    # i==0 -- centrality metric is betweeness centrality
    # i==1 -- centrality metric is degree centrality
    # i==3 -- centrality metric is eigenvector centrality
    # i==4 -- centrality metric is PageRank
    # i==6 -- centrality metric is information centrality
    # i==7 -- centrality metric is random-walk centrality
    # i==9 -- centrality metric is Katz centrality
    # i==10 -- centrality metric is Laplacian centrality
    
    CM_CBT_mean_netNorm= [mean_netNorm_female_LH[i], mean_netNorm_male_LH[i], mean_netNorm_female_RH[i], mean_netNorm_male_RH[i]]
    CM_CBT_mean_SCA= [mean_SCA_female_LH[i], mean_SCA_male_LH[i], mean_SCA_female_RH[i], mean_SCA_male_RH[i]]
    CM_CBT_mean_MVCFNet= [mean_MVCFNet_female_LH[i], mean_MVCFNet_male_LH[i], mean_MVCFNet_female_RH[i], mean_MVCFNet_male_RH[i]]
    CM_CBT_mean_cMGINet= [mean_cMGINet_female_LH[i], mean_cMGINet_male_LH[i], mean_cMGINet_female_RH[i], mean_cMGINet_male_RH[i]]
    CM_CBT_mean_DGN= [mean_DGN_female_LH[i], mean_DGN_male_LH[i], mean_DGN_female_RH[i], mean_DGN_male_RH[i]]
    CM_CBT_ground_truth= [mean_ground_truth_female_LH[i],mean_ground_truth_male_LH[i], mean_ground_truth_female_RH[i], mean_ground_truth_male_RH[i]]
    
    errorlist = np.zeros((4,6))
    errorlist[0,4] = statistics.stdev(CM_CBT_female_netNorm_LH[i]/max(CM_CBT_female_netNorm_LH[10]))
    errorlist[1,4] = statistics.stdev(CM_CBT_male_netNorm_LH[i]/max(CM_CBT_male_netNorm_LH[10]))
    errorlist[2,4] = statistics.stdev(CM_CBT_female_netNorm_RH[i]/max(CM_CBT_female_netNorm_RH[10]))
    errorlist[3,4] = statistics.stdev(CM_CBT_male_netNorm_RH[i]/max(CM_CBT_male_netNorm_RH[10]))
    
    errorlist[0,0] = statistics.stdev(CM_CBT_female_SCA_LH[i]/max(CM_CBT_female_SCA_LH[10])/6)
    errorlist[1,0] = statistics.stdev(CM_CBT_male_SCA_LH[i]/max(CM_CBT_male_SCA_LH[10]) /6)
    errorlist[2,0] = statistics.stdev(CM_CBT_female_SCA_RH[i]/max(CM_CBT_female_SCA_RH[10]) /6)
    errorlist[3,0] = statistics.stdev(CM_CBT_male_SCA_RH[i]/max(CM_CBT_male_SCA_RH[10])/6)
    
    errorlist[0,1] = statistics.stdev(CM_CBT_female_MVCFNet_LH[i]/max(CM_CBT_female_MVCFNet_RH[10])/5)
    errorlist[1,1] = statistics.stdev(CM_CBT_male_MVCFNet_LH[i]/max(CM_CBT_male_MVCFNet_LH[10])/5)
    errorlist[2,1] = statistics.stdev(CM_CBT_female_MVCFNet_RH[i]/max(CM_CBT_female_MVCFNet_RH[10])/5)
    errorlist[3,1] = statistics.stdev(CM_CBT_male_MVCFNet_RH[i]/max(CM_CBT_male_MVCFNet_RH[10])/5)
    
    errorlist[0,2] = statistics.stdev(CM_CBT_female_cMGINet_LH[i]/max(CM_CBT_female_cMGINet_LH[10])/5)
    errorlist[1,2] = statistics.stdev(CM_CBT_male_cMGINet_LH[i]/max(CM_CBT_male_cMGINet_LH[10])/5)
    errorlist[2,2] = statistics.stdev(CM_CBT_female_cMGINet_RH[i]/max(CM_CBT_female_cMGINet_RH[10])/5)
    errorlist[3,2] = statistics.stdev(CM_CBT_male_cMGINet_RH[i]/max(CM_CBT_male_cMGINet_RH[10])/5)
    
    errorlist[0,3] = statistics.stdev(CM_CBT_female_DGN_LH[i]/max(CM_CBT_female_DGN_LH[10])/3)
    errorlist[1,3] = statistics.stdev(CM_CBT_male_DGN_LH[i]/max(CM_CBT_male_DGN_LH[10])/3)
    errorlist[2,3] = statistics.stdev(CM_CBT_female_DGN_RH[i]/max(CM_CBT_female_DGN_RH[10])/3)
    errorlist[3,3] = statistics.stdev(CM_CBT_male_DGN_RH[i]/max(CM_CBT_male_DGN_RH[10])/3)
    
    errorlist[0,5] = statistics.stdev(CM_ground_truth_female_view1_LH[i]/max(CM_ground_truth_female_LH[10])/2)
    errorlist[1,5] = statistics.stdev(CM_ground_truth_male_view1_LH[i]/max(CM_ground_truth_male_LH[10])/2)
    errorlist[2,5] = statistics.stdev(CM_ground_truth_female_view1_RH[i]/max(CM_ground_truth_female_RH[10])/2)
    errorlist[3,5] = statistics.stdev(CM_ground_truth_male_view1_RH[i]/max(CM_ground_truth_male_RH[10])/2)
      
    
    plotdata1 = pd.DataFrame({
    
    "SCA":CM_CBT_mean_SCA,
    "MVCF-Net":CM_CBT_mean_MVCFNet,
    "cMGI-Net":CM_CBT_mean_cMGINet,
    "DGN":CM_CBT_mean_DGN,
    
    "netNorm":CM_CBT_mean_netNorm,
    "Average":CM_CBT_ground_truth
    }, 
    index=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    )
    data = {'fold': [],
    	'model': [],
    	'value': [],
    'error': []}
    list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    #list2=[ 'SCA', 'MVCF-Net', 'cMGI-Net', 'DGN' , 'Average', 'netNorm' ]
    list2=['SCA', 'MVCF-Net', 'cMGI-Net','DGN','netNorm' ,'Average']
 
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=['SCA', 'MVCF-Net', 'cMGI-Net','DGN','netNorm' ,'Average' ]
    box_pairs=[
        (("LH-female","DGN"), ("LH-female","netNorm")),
        (("LH-female","DGN"), ("LH-female","MVCF-Net")),
        (("LH-female","DGN"), ("LH-female","SCA")),
        (("LH-female","DGN"), ("LH-female","cMGI-Net")),
        (("LH-female","DGN"), ("LH-female","Average")),
        (("LH-male","DGN"), ("LH-male","netNorm")),
        (("LH-male","DGN"), ("LH-male","MVCF-Net")),
        (("LH-male","DGN"), ("LH-male","SCA")),
        (("LH-male","DGN"), ("LH-male","cMGI-Net")),
        (("LH-male","DGN"), ("LH-male","Average")),
        (("RH-female","DGN"), ("RH-female","netNorm")),
        (("RH-female","DGN"), ("RH-female","MVCF-Net")),
        (("RH-female","DGN"), ("RH-female","SCA")),
        (("RH-female","DGN"), ("RH-female","cMGI-Net")),
        (("RH-female","DGN"), ("RH-female","Average")),
        (("RH-male","DGN"), ("RH-male","netNorm")),
        (("RH-male","DGN"), ("RH-male","MVCF-Net")),
        (("RH-male","DGN"), ("RH-male","SCA")),
        (("RH-male","DGN"), ("RH-male","cMGI-Net")),
        (("RH-male","DGN"), ("RH-male","Average")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    df1 = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df1 = df1.append(new_row1, ignore_index=True)
    err = "error"
    ax = grouped_barplot(df1, x, hue, y, err )  
    ax = sns.barplot(data=df, x=x, y=y, hue=hue)
       
    colors = ['mediumorchid','lightgreen','tomato','magenta','lightseagreen','purple']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    #ax.set(ylim=(0.6,0.65))
    #plt.yticks(np.arange(0.6,0.65, 0.01))
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    legend = ax.legend(frameon=False, prop={'size': 6})
    #ax.set(ylim=(0.6,0.65))
    #plt.yticks(np.arange(0.6,0.65, 0.01))
    # Call the function above. All the magic happens there.
    #add_value_labels(ax)
    #fig.tight_layout()
    legend.remove()
    
        
        
# call the function 
dist_plot_multiview( CM_ground_truth_male_RH, CM_CBT_male_SCA_RH, CM_CBT_male_netNorm_RH, CM_CBT_male_NAGFS_RH, 
CM_CBT_male_MVCFNet_RH, CM_CBT_male_DGN_RH, CM_CBT_male_cMGINet_RH, CM_ground_truth_female_RH, CM_CBT_female_SCA_RH,
CM_CBT_female_netNorm_RH, CM_CBT_female_NAGFS_RH, CM_CBT_female_MVCFNet_RH, CM_CBT_female_DGN_RH, 
CM_CBT_female_cMGINet_RH, CM_ground_truth_male_LH, CM_CBT_male_SCA_LH, CM_CBT_male_netNorm_LH, CM_CBT_male_NAGFS_LH, 
CM_CBT_male_MVCFNet_LH, CM_CBT_male_DGN_LH, CM_CBT_male_cMGINet_LH, CM_ground_truth_female_LH, CM_CBT_female_SCA_LH,
CM_CBT_female_netNorm_LH, CM_CBT_female_NAGFS_LH, CM_CBT_female_MVCFNet_LH, CM_CBT_female_DGN_LH,
CM_CBT_female_cMGINet_LH)

