In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average of the topological centrality distributions across 35 ROIs of templates (CBTs) 
# generated by SNF, NAGFS and SM-netFusion against the ground truth ndistribution of the GSP data represented with
# single view brain networks (view1)
# The topological centrality metrics are: degree centrality, Betweeness centrality , eigenvector centrality, 
# information centrality, PageRank, random-walk centrality, Katz centrality, Laplacian centrality
    
import pickle
import numpy as np
import statistics
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset1 = 'LH_GSP' # 'RH_GSP'

# test female LH population centrality metrics 
with open('CM_ground_truth_female_view1_' + args_dataset1 ,'rb') as f:
        CM_ground_truth_female_view1_LH = pickle.load(f)
# test male LH population centrality metrics 
with open('CM_ground_truth_male_view1_' + args_dataset1 ,'rb') as f:
        CM_ground_truth_male_view1_LH = pickle.load(f)
        
# cbt female LH centrality metrics       
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'SM-netFusion','rb') as f:
        CM_CBT_female_SMnetFusion_LH = pickle.load(f)
                
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'NAGFS','rb') as f:
        CM_CBT_female_NAGFS_LH = pickle.load(f)
           
with open('CM_CBT_female_' + args_dataset1 +'_'+ 'SNF','rb') as f:
        CM_CBT_female_SNF_LH = pickle.load(f)
# cbt male LH centrality metrics    
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'SNF','rb') as f:
        CM_CBT_male_SNF_LH = pickle.load(f)
               
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'SM-netFusion','rb') as f:
        CM_CBT_male_SMnetFusion_LH = pickle.load(f)
                
with open('CM_CBT_male_' + args_dataset1 +'_'+ 'NAGFS','rb') as f:
        CM_CBT_male_NAGFS_LH = pickle.load(f)

def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)

######### single-view fusion methods #########

def dist_plot_singleview( CM_ground_truth_male_view1_RH, CM_CBT_male_SNF_RH, CM_CBT_male_SMnetFusion_RH, CM_CBT_male_NAGFS_RH, CM_ground_truth_female_view1_RH, CM_CBT_female_SNF_RH, CM_CBT_female_SMnetFusion_RH, CM_CBT_female_NAGFS_RH, CM_ground_truth_male_view1_LH, CM_CBT_male_SNF_LH, CM_CBT_male_SMnetFusion_LH, CM_CBT_male_NAGFS_LH, CM_ground_truth_female_view1_LH, CM_CBT_female_SNF_LH, CM_CBT_female_SMnetFusion_LH, CM_CBT_female_NAGFS_LH):
    mean_NAGFS_male_RH = []
    mean_SNF_male_RH = []
    mean_SM_netFusion_male_RH = []
    mean_ground_truth_view1_male_RH = []
    
    mean_NAGFS_female_RH = []
    mean_SNF_female_RH = []
    mean_SM_netFusion_female_RH = []
    mean_ground_truth_view1_female_RH = []
    
    mean_NAGFS_male_LH = []
    mean_SNF_male_LH = []
    mean_SM_netFusion_male_LH = []
    mean_ground_truth_view1_male_LH = []
    
    mean_NAGFS_female_LH = []
    mean_SNF_female_LH = []
    mean_SM_netFusion_female_LH = []
    mean_ground_truth_view1_female_LH = []
    
    for k in range (len(CM_CBT_male_NAGFS_RH)):
        
        mean_NAGFS_male_RH.append(np.mean(CM_CBT_male_NAGFS_RH[k]))
        mean_SNF_male_RH.append(np.mean(CM_CBT_male_SNF_RH[k]))
        mean_SM_netFusion_male_RH.append(np.mean(CM_CBT_male_SMnetFusion_RH[k]))
        mean_ground_truth_view1_male_RH.append(np.mean(CM_ground_truth_male_view1_RH[k]))

        mean_NAGFS_female_RH.append(np.mean(CM_CBT_female_NAGFS_RH[k]))
        mean_SNF_female_RH.append(np.mean(CM_CBT_female_SNF_RH[k]))
        mean_SM_netFusion_female_RH.append(np.mean(CM_CBT_female_SMnetFusion_RH[k]))
        mean_ground_truth_view1_female_RH.append(np.mean(CM_ground_truth_female_view1_RH[k]))
    
        mean_NAGFS_male_LH.append(np.mean(CM_CBT_male_NAGFS_LH[k]))
        mean_SNF_male_LH.append(np.mean(CM_CBT_male_SNF_LH[k]))
        mean_SM_netFusion_male_LH.append(np.mean(CM_CBT_male_SMnetFusion_LH[k]))
        mean_ground_truth_view1_male_LH.append(np.mean(CM_ground_truth_male_view1_LH[k]))

        mean_NAGFS_female_LH.append(np.mean(CM_CBT_female_NAGFS_LH[k]))
        mean_SNF_female_LH.append(np.mean(CM_CBT_female_SNF_LH[k]))
        mean_SM_netFusion_female_LH.append(np.mean(CM_CBT_female_SMnetFusion_LH[k]))
        mean_ground_truth_view1_female_LH.append(np.mean(CM_ground_truth_female_view1_LH[k]))
        

    i = 3 # i in [0,1,3,4,5,6,7,9,10] which represent the centrality metric that will be plotted 
    # i==0 -- centrality metric is betweeness centrality
    # i==1 -- centrality metric is degree centrality
    # i==3 -- centrality metric is eigenvector centrality
    # i==4 -- centrality metric is PageRank
    # i==6 -- centrality metric is information centrality
    # i==7 -- centrality metric is random-walk centrality
    # i==9 -- centrality metric is Katz centrality
    # i==10 -- centrality metric is Laplacian centrality
    
    CM_CBT_mean_NAGFS= [mean_NAGFS_female_LH[i], mean_NAGFS_male_LH[i], mean_NAGFS_female_RH[i], mean_NAGFS_male_RH[i]]
    CM_CBT_mean_SM_netFusion= [mean_SM_netFusion_female_LH[i], mean_SM_netFusion_male_LH[i], mean_SM_netFusion_female_RH[i], mean_SM_netFusion_male_RH[i]]
    CM_CBT_mean_SNF= [mean_SNF_female_LH[i], mean_SNF_male_LH[i], mean_SNF_female_RH[i], mean_SNF_male_RH[i]]
    CM_CBT_ground_truth_view1= [mean_ground_truth_view1_female_LH[i],mean_ground_truth_view1_male_LH[i], mean_ground_truth_view1_female_RH[i], mean_ground_truth_view1_male_RH[i]]
    
    errorlist = np.zeros((4,4))
    


    errorlist[0,1] = statistics.stdev(CM_CBT_female_NAGFS_LH[i])
    errorlist[1,1] = statistics.stdev(CM_CBT_male_NAGFS_LH[i])
    errorlist[2,1] = statistics.stdev(CM_CBT_female_NAGFS_RH[i])
    errorlist[3,1] = statistics.stdev(CM_CBT_male_NAGFS_RH[i])
    
    errorlist[0,2] = statistics.stdev(CM_CBT_female_SNF_LH[i])
    errorlist[1,2] = statistics.stdev(CM_CBT_male_SNF_LH[i])
    errorlist[2,2] = statistics.stdev(CM_CBT_female_SNF_RH[i])
    errorlist[3,2] = statistics.stdev(CM_CBT_male_SNF_RH[i])
    
    errorlist[0,3] = statistics.stdev(CM_CBT_female_SMnetFusion_LH[i])
    errorlist[1,3] = statistics.stdev(CM_CBT_male_SMnetFusion_LH[i])
    errorlist[2,3] = statistics.stdev(CM_CBT_female_SMnetFusion_RH[i])
    errorlist[3,3] = statistics.stdev(CM_CBT_male_SMnetFusion_RH[i])
    
    errorlist[0,0] = statistics.stdev(CM_ground_truth_female_view1_LH[i]/2)
    errorlist[1,0] = statistics.stdev(CM_ground_truth_male_view1_LH[i]/2)
    errorlist[2,0] = statistics.stdev(CM_ground_truth_female_view1_RH[i]/2)
    errorlist[3,0] = statistics.stdev(CM_ground_truth_male_view1_RH[i]/2)      
    
    plotdata1 = pd.DataFrame({
    "Average":CM_CBT_ground_truth_view1,
    "NAG-FS":CM_CBT_mean_NAGFS,
    "SNF":CM_CBT_mean_SNF,
    "SM-netFusion":CM_CBT_mean_SM_netFusion
    
    }, 
    index=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    )
    data = {'fold': [],
    	'model': [],
    	'value': [],
    'error': []}
    list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    list2=[ 'Average','NAG-FS', 'SNF', 'SM-netFusion' ]
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (4):
        for j in range (4):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=[ 'Average','NAG-FS', 'SNF', 'SM-netFusion' ]
    box_pairs=[
       (("LH-female","SM-netFusion"), ("LH-female","SNF")),
       (("LH-female","SM-netFusion"), ("LH-female","NAG-FS")),
        (("LH-female","SM-netFusion"), ("LH-female","Average")),
       (("LH-male","SM-netFusion"), ("LH-male","NAG-FS")),
       (("LH-male","SM-netFusion"), ("LH-male","SNF")),
        (("LH-male","SM-netFusion"), ("LH-male","Average")),
       (("RH-female","SM-netFusion"), ("RH-female","SNF")),
        (("RH-female","SM-netFusion"), ("RH-female","Average")),
       (("RH-female","SM-netFusion"), ("RH-female","SNF")),
       (("RH-male","SM-netFusion"), ("RH-male","NAG-FS")),
       (("RH-male","SM-netFusion"), ("RH-male","SNF")),
        (("RH-male","SM-netFusion"), ("RH-male","Average")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    df1 = pd.DataFrame(data)
    for i in range (4):
        for j in range (4):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df1 = df1.append(new_row1, ignore_index=True)
    err = "error"
    ax = grouped_barplot(df1, x, hue, y, err )  
    ax = sns.barplot(data=df, x=x, y=y, hue=hue)
       
    colors = ['brown','peru','gold','navajowhite']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    ax.set(ylim=(0.12,0.2))
    plt.yticks(np.arange(0.12,0.2, 0.01))
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    legend = ax.legend(frameon=False, prop={'size': 6})
    ax.set(ylim=(0.12,0.2))
    plt.yticks(np.arange(0.12,0.2, 0.01))
    # Call the function above. All the magic happens there.
    #add_value_labels(ax)
    #fig.tight_layout()
    legend.remove()
    
    
# call the function
dist_plot_singleview(CM_ground_truth_male_view1_RH, CM_CBT_male_SNF_RH, CM_CBT_male_SMnetFusion_RH, 
CM_CBT_male_NAGFS_RH, CM_ground_truth_female_view1_RH, CM_CBT_female_SNF_RH, CM_CBT_female_SMnetFusion_RH, 
CM_CBT_female_NAGFS_RH, CM_ground_truth_male_view1_LH, CM_CBT_male_SNF_LH, CM_CBT_male_SMnetFusion_LH, 
CM_CBT_male_NAGFS_LH, CM_ground_truth_female_view1_LH, CM_CBT_female_SNF_LH, CM_CBT_female_SMnetFusion_LH, 
CM_CBT_female_NAGFS_LH)
