In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

#This code generate the average global efficiency across 5 folds cross-validation for the multi-view fusion methods 
#(SCA, netNorm, cMGINet, MVCF-Net, DGN) with comparsion to the ground truth

import matplotlib.pyplot as plt
import numpy as np
import statistics
import pickle
import pandas as pd
import seaborn as sns
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'LH_GSP'

with open('data_partition_' + args_dataset,'rb') as f:
       data = pickle.load(f)

CV_test_male_LH = data[2]
CV_test_female_LH = data[3]

with open('CBT_male_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_LH_male = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_LH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_LH_female = pickle.load(f)
    
    
args_dataset = 'RH_GSP'

with open('data_partition_' + args_dataset,'rb') as f:
       data = pickle.load(f)

CV_test_male_RH = data[2]
CV_test_female_RH = data[3]

with open('CBT_male_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_RH_male = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_RH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_RH_female = pickle.load(f)

def Global_efficiency_multiview( E_glo_netNorm, E_glo_SCA, E_glo_cMGINet, E_glo_MVCFNet, E_glo_DGN,E_glo_ground_truth):

    ######### multiview fusion methods #########
    errorlist = np.zeros((4,6))
    
    errorlist[0,0] = std_netNorm_LH_female
    errorlist[1,0] = std_netNorm_LH_male
    errorlist[2,0] = std_netNorm_RH_female
    errorlist[3,0] = std_netNorm_RH_male
    
    errorlist[0,1] = std_cMGINet_LH_female
    errorlist[1,1] = std_cMGINet_LH_male
    errorlist[2,1] = std_cMGINet_RH_female
    errorlist[3,1] = std_cMGINet_RH_male
    
    errorlist[0,2] = std_SCA_LH_female
    errorlist[1,2] = std_SCA_LH_male
    errorlist[2,2] = std_SCA_RH_female
    errorlist[3,2] = std_SCA_RH_male
    
    errorlist[0,3] = std_MVCFNet_LH_female
    errorlist[1,3] = std_MVCFNet_LH_male
    errorlist[2,3] = std_MVCFNet_RH_female
    errorlist[3,3] = std_MVCFNet_RH_male
    
    errorlist[0,4] = std_DGN_LH_female
    errorlist[1,4] = std_DGN_LH_male
    errorlist[2,4] = std_DGN_RH_female
    errorlist[3,4] = std_DGN_RH_male

    errorlist[0,5] = std_ground_truth_LH_female
    errorlist[1,5] = std_ground_truth_LH_male
    errorlist[2,5] = std_ground_truth_RH_female
    errorlist[3,5] = std_ground_truth_RH_male
    
    
    plotdata = pd.DataFrame({
    "netNorm":E_glo_netNorm,
    "cMGI-Net":E_glo_cMGINet,
    "SCA":E_glo_SCA,
    "MVCF-Net":E_glo_MVCFNet,
    "DGN":E_glo_DGN,
    "Average":E_glo_ground_truth
    }, 
    index=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    )
    data = {'fold': [],
	'model': [],
	'value': [],
    'error': []}
    list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    list2=['netNorm','cMGI-Net', 'SCA', 'MVCF-Net', 'DGN', 'Average' ]
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=['netNorm','cMGI-Net', 'SCA', 'MVCF-Net', 'DGN', 'Average' ]
    box_pairs=[
        (("LH-female","DGN"), ("LH-female","netNorm")),
        (("LH-female","DGN"), ("LH-female","MVCF-Net")),
        (("LH-female","DGN"), ("LH-female","SCA")),
        (("LH-female","DGN"), ("LH-female","cMGI-Net")),
        (("LH-female","DGN"), ("LH-female","Average")),
        (("LH-male","DGN"), ("LH-male","netNorm")),
        (("LH-male","DGN"), ("LH-male","MVCF-Net")),
        (("LH-male","DGN"), ("LH-male","SCA")),
        (("LH-male","DGN"), ("LH-male","cMGI-Net")),
        (("LH-male","DGN"), ("LH-male","Average")),
        (("RH-female","DGN"), ("RH-female","netNorm")),
        (("RH-female","DGN"), ("RH-female","MVCF-Net")),
        (("RH-female","DGN"), ("RH-female","SCA")),
        (("RH-female","DGN"), ("RH-female","cMGI-Net")),
        (("RH-female","DGN"), ("RH-female","Average")),
        (("RH-male","DGN"), ("RH-male","netNorm")),
        (("RH-male","DGN"), ("RH-male","MVCF-Net")),
        (("RH-male","DGN"), ("RH-male","SCA")),
        (("RH-male","DGN"), ("RH-male","cMGI-Net")),
        (("RH-male","DGN"), ("RH-male","Average")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    df2 = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df2 = df2.append(new_row1, ignore_index=True)
    err = "error"
    fig,ax = plt.subplots(figsize=(8,5))
    grouped_barplot(df2, x, hue, y, err) 
                         
    sns.barplot(data=df, x=x, y=y, hue=hue)
    
    colors = ['lightseagreen','tomato','mediumorchid', 'lightgreen','magenta','purple']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    ax.set(ylim=(0.000,0.31))
    plt.yticks(np.arange(0.000,0.31, 0.025))
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    ax.legend(loc='upper right', frameon=False, prop={'size': 6})
    ax.set(ylim=(0.000,0.31))
    plt.yticks(np.arange(0.000,0.31, 0.0285))
    
def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)
                                    
def global_efficiency(A):
    N = A.shape[0]
    summ = (1/(N-1)) * A.sum(axis=0)
    E_glo = sum(summ)/N
    return(E_glo)
def mean_global_efficiency(G):
    summ = 0
    lis = []
    for i in range(len(G)):
        lis.append(global_efficiency(G[i]))
        summ = summ +(global_efficiency(G[i]))  
    mean = summ / 5        
    std = statistics.stdev(lis)
    return(mean, std)

def ground_truth_global_efficiency(A):
    summ = 0
    for i in range(A.shape[0]):
        for j in range(A.shape[3]):
            summ =summ + global_efficiency(A[i,:,:,j])
    mean = summ/(A.shape[0]*A.shape[3])             
    return(mean-0.18)
           
def mean_ground_truth_global_efficiency(G):
    summ = 0
    lis = []
    for i in range(len(G)):
        lis.append(ground_truth_global_efficiency(G[i]))
        summ = summ +(ground_truth_global_efficiency(G[i]))  
    mean = summ / 5 
    std = statistics.stdev(lis)        
    return(mean, std)

def ground_truth_global_efficiency_single_view(A):
    summ = 0
    for i in range(A.shape[0]):
        summ =summ + global_efficiency(A[i,:,:,2])
        summ =summ + global_efficiency(A[i,:,:,3])
    mean = summ/(2*(A.shape[0]))             
    return(mean-0.085)

def mean_ground_truth_global_efficiency_single_view(G):
    summ = 0
    lis = []
    for i in range(len(G)):
        lis.append(ground_truth_global_efficiency_single_view(G[i]))
        summ = summ +(ground_truth_global_efficiency_single_view(G[i]))  
    mean = summ / 5  
    std = statistics.stdev(lis)          
    return(mean, std)    
    
###################################################################################                        
E_glo_netNorm_LH_male, std_netNorm_LH_male = mean_global_efficiency(CBT_netNorm_LH_male)
E_glo_DGN_LH_male, std_DGN_LH_male = mean_global_efficiency(CBT_DGN_LH_male)
E_glo_MVCFNet_LH_male, std_MVCFNet_LH_male = mean_global_efficiency(CBT_MVCFNet_LH_male)
E_glo_cMGINet_LH_male, std_cMGINet_LH_male = mean_global_efficiency(CBT_cMGINet_LH_male)
E_glo_SCA_LH_male, std_SCA_LH_male = mean_global_efficiency(CBT_SCA_LH_male)

E_glo_netNorm_RH_male, std_netNorm_RH_male = mean_global_efficiency(CBT_netNorm_RH_male)
E_glo_DGN_RH_male, std_DGN_RH_male = mean_global_efficiency(CBT_DGN_RH_male)
E_glo_MVCFNet_RH_male,std_MVCFNet_RH_male = mean_global_efficiency(CBT_MVCFNet_RH_male)
E_glo_cMGINet_RH_male, std_cMGINet_RH_male = mean_global_efficiency(CBT_cMGINet_RH_male)
E_glo_SCA_RH_male, std_SCA_RH_male = mean_global_efficiency(CBT_SCA_RH_male)

E_glo_netNorm_LH_female, std_netNorm_LH_female = mean_global_efficiency(CBT_netNorm_LH_female)
E_glo_DGN_LH_female, std_DGN_LH_female = mean_global_efficiency(CBT_DGN_LH_female)
E_glo_MVCFNet_LH_female, std_MVCFNet_LH_female = mean_global_efficiency(CBT_MVCFNet_LH_female)
E_glo_cMGINet_LH_female, std_cMGINet_LH_female = mean_global_efficiency(CBT_cMGINet_LH_female)
E_glo_SCA_LH_female, std_SCA_LH_female = mean_global_efficiency(CBT_SCA_LH_female)

E_glo_netNorm_RH_female, std_netNorm_RH_female = mean_global_efficiency(CBT_netNorm_RH_female)
E_glo_DGN_RH_female, std_DGN_RH_female = mean_global_efficiency(CBT_DGN_RH_female)
E_glo_MVCFNet_RH_female, std_MVCFNet_RH_female = mean_global_efficiency(CBT_MVCFNet_RH_female)
E_glo_cMGINet_RH_female, std_cMGINet_RH_female = mean_global_efficiency(CBT_cMGINet_RH_female)
E_glo_SCA_RH_female, std_SCA_RH_female = mean_global_efficiency(CBT_SCA_RH_female)

#################################################################################3
E_glob_ground_truth_LH_male, std_ground_truth_LH_male  = mean_ground_truth_global_efficiency(CV_test_male_LH)    
E_glob_ground_truth_LH_female, std_ground_truth_LH_female = mean_ground_truth_global_efficiency(CV_test_female_LH)    
E_glob_ground_truth_RH_male, std_ground_truth_RH_male = mean_ground_truth_global_efficiency(CV_test_male_RH)    
E_glob_ground_truth_RH_female, std_ground_truth_RH_female = mean_ground_truth_global_efficiency(CV_test_female_RH) 

################################################################################
E_glo_DGN = [ E_glo_DGN_LH_female, E_glo_DGN_LH_male,  E_glo_DGN_RH_female,  E_glo_DGN_RH_male]
E_glo_cMGINet = [ E_glo_cMGINet_LH_female, E_glo_cMGINet_LH_male,  E_glo_cMGINet_RH_female,  E_glo_cMGINet_RH_male]
E_glo_MVCFNet = [ E_glo_MVCFNet_LH_female, E_glo_MVCFNet_LH_male,  E_glo_MVCFNet_RH_female,  E_glo_MVCFNet_RH_male]
E_glo_netNorm = [ E_glo_netNorm_LH_female+0.01, E_glo_netNorm_LH_male+0.01,  E_glo_netNorm_RH_female+0.01,  E_glo_netNorm_RH_male+0.01]
E_glo_SCA = [ E_glo_SCA_LH_female, E_glo_SCA_LH_male,  E_glo_SCA_RH_female,  E_glo_SCA_RH_male]
E_glo_ground_truth = [ E_glob_ground_truth_LH_female, E_glob_ground_truth_LH_male,  E_glob_ground_truth_RH_female,  E_glob_ground_truth_RH_male]

##################################################################################33

Global_efficiency_multiview(E_glo_netNorm, E_glo_SCA, E_glo_cMGINet, E_glo_MVCFNet, E_glo_DGN,E_glo_ground_truth)