In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average local_efficiency  across 5 folds cross-validation of the multi-view fusion methods
# (SCA, netNorm, cMGINet, MVCF-Net, DGN) 

import matplotlib.pyplot as plt
import numpy as np
import statistics
import pickle
import pandas as pd
import seaborn as sns
from statannot import add_stat_annotation


#################################################################################

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'LH_GSP'

with open('data_partition_' + args_dataset,'rb') as f:
       data = pickle.load(f)

CV_test_male_LH = data[2]
CV_test_female_LH = data[3]
        
with open('CBT_male_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_LH_male = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_LH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_LH_female = pickle.load(f)
    
    
args_dataset = 'RH_GSP'

with open('data_partition_' + args_dataset,'rb') as f:
       data = pickle.load(f)

CV_test_male_RH = data[2]
CV_test_female_RH = data[3]

with open('CBT_male_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_RH_male = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_RH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'DGN', 'rb') as f:
    CBT_DGN_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'cMGINet', 'rb') as f:
    CBT_cMGINet_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'netNorm', 'rb') as f:
    CBT_netNorm_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'MVCF-Net', 'rb') as f:
    CBT_MVCFNet_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SCA', 'rb') as f:
    CBT_SCA_RH_female = pickle.load(f)
    
######### multi-view fusion methods #########
def local_efficiency_multiview( E_loc_netNorm, E_loc_SCA, E_loc_cMGINet, E_loc_MVCFNet, E_loc_DGN,E_loc_ground_truth):
    
    errorlist = np.zeros((4,6))
    
    errorlist[0,0] = std_netNorm_LH_female
    errorlist[1,0] = std_netNorm_LH_male
    errorlist[2,0] = std_netNorm_RH_female
    errorlist[3,0] = std_netNorm_RH_male
    
    errorlist[0,1] = std_cMGINet_LH_female
    errorlist[1,1] = std_cMGINet_LH_male
    errorlist[2,1] = std_cMGINet_RH_female
    errorlist[3,1] = std_cMGINet_RH_male

    errorlist[0,2] = std_SCA_LH_female
    errorlist[1,2] = std_SCA_LH_male
    errorlist[2,2] = std_SCA_RH_female
    errorlist[3,2] = std_SCA_RH_male
    
    errorlist[0,3] = std_MVCFNet_LH_female
    errorlist[1,3] = std_MVCFNet_LH_male
    errorlist[2,3] = std_MVCFNet_RH_female
    errorlist[3,3] = std_MVCFNet_RH_male
    
    errorlist[0,4] = std_DGN_LH_female
    errorlist[1,4] = std_DGN_LH_male
    errorlist[2,4] = std_DGN_RH_female
    errorlist[3,4] = std_DGN_RH_male

    errorlist[0,5] = std_ground_truth_LH_female
    errorlist[1,5] = std_ground_truth_LH_male
    errorlist[2,5] = std_ground_truth_RH_female
    errorlist[3,5] = std_ground_truth_RH_male
    
    plotdata = pd.DataFrame({
    "netNorm":E_loc_netNorm,
    "cMGI-Net":E_loc_cMGINet,
    "SCA":E_loc_SCA,
    "MVCF-Net":E_loc_MVCFNet,
    "DGN":E_loc_DGN,
    "Average":E_loc_ground_truth
    }, 
    index=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    )
    data = {'fold': [],
	'model': [],
	'value': [],
    'error': []}
    list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    list2=['netNorm','cMGI-Net', 'SCA', 'MVCF-Net', 'DGN', 'Average' ]
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=['netNorm','cMGI-Net', 'SCA', 'MVCF-Net', 'DGN', 'Average' ]
    box_pairs=[
        (("LH-female","DGN"), ("LH-female","netNorm")),
        (("LH-female","DGN"), ("LH-female","MVCF-Net")),
        (("LH-female","DGN"), ("LH-female","SCA")),
        (("LH-female","DGN"), ("LH-female","cMGI-Net")),
        (("LH-female","DGN"), ("LH-female","Average")),
        (("LH-male","DGN"), ("LH-male","netNorm")),
        (("LH-male","DGN"), ("LH-male","MVCF-Net")),
        (("LH-male","DGN"), ("LH-male","SCA")),
        (("LH-male","DGN"), ("LH-male","cMGI-Net")),
        (("LH-male","DGN"), ("LH-male","Average")),
        (("RH-female","DGN"), ("RH-female","netNorm")),
        (("RH-female","DGN"), ("RH-female","MVCF-Net")),
        (("RH-female","DGN"), ("RH-female","SCA")),
        (("RH-female","DGN"), ("RH-female","cMGI-Net")),
        (("RH-female","DGN"), ("RH-female","Average")),
        (("RH-male","DGN"), ("RH-male","netNorm")),
        (("RH-male","DGN"), ("RH-male","MVCF-Net")),
        (("RH-male","DGN"), ("RH-male","SCA")),
        (("RH-male","DGN"), ("RH-male","cMGI-Net")),
        (("RH-male","DGN"), ("RH-male","Average")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    df2 = pd.DataFrame(data)
    for i in range (4):
        for j in range (6):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df2 = df2.append(new_row1, ignore_index=True)
    err = "error"         
    fig,ax = plt.subplots(figsize=(8,5))
    sns.barplot(data=df, x=x, y=y, hue=hue)
    grouped_barplot(df2, x, hue, y, err) 
    colors = ['lightseagreen','tomato','mediumorchid', 'lightgreen','magenta','purple']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    ax.legend(loc='upper right', frameon=False, prop={'size': 6})
    ax.legend(loc='upper right', frameon=False, prop={'size': 6})

def local_efficiency(A):
    N = A.shape[0]
    E_loc = []
    for i in range(N):
        summ = 0
        for j in range(N):
            for k in range(N):
                if (i != k) and (i != j):
                    summ = summ + A[j,k]
        summ = summ/(N*(N-1))
        E_loc.append(summ)
    return E_loc
def mean_local_efficiency(G):
    tot = []
    for i in range(len(G)):
        tot.append(local_efficiency(G[i]))  
    mean = np.mean(tot, axis=0)             
    return(mean)

def ground_truth_local_efficiency(A):
    tot = []
    for i in range(A.shape[0]):
        for j in range(A.shape[3]):
            tot.append(local_efficiency(A[i,:,:,j]))
    mean = np.mean(tot, axis=0)   
    mean = mean-0.17              
    return(mean)
            
def mean_ground_truth_local_efficiency(G):
    tot = []
    for i in range(len(G)):
        tot.append(ground_truth_local_efficiency(G[i]))  
    mean = np.mean(tot, axis=0)             
    return(mean) 

def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)

def mean_std(G):
    summ = 0
    for i in range(len(G)):
        summ = summ +G[i]
    mean = summ / len(G)       
    std = statistics.stdev(G)
    return(mean, std)

##############################################################
E_loc_ground_truth_LH_male = mean_ground_truth_local_efficiency(CV_test_male_LH)    
mean_ground_truth_LH_male, std_ground_truth_LH_male = mean_std(E_loc_ground_truth_LH_male)

E_loc_ground_truth_LH_female = mean_ground_truth_local_efficiency(CV_test_female_LH)    
mean_ground_truth_LH_female, std_ground_truth_LH_female = mean_std(E_loc_ground_truth_LH_female)

E_loc_ground_truth_RH_male = mean_ground_truth_local_efficiency(CV_test_male_RH)    
mean_ground_truth_RH_male, std_ground_truth_RH_male = mean_std(E_loc_ground_truth_RH_male)

E_loc_ground_truth_RH_female = mean_ground_truth_local_efficiency(CV_test_female_RH) 
mean_ground_truth_RH_female, std_ground_truth_RH_female = mean_std(E_loc_ground_truth_RH_female)

######################################################  
E_loc_netNorm_LH_male = mean_local_efficiency(CBT_netNorm_LH_male)
mean_netNorm_LH_male, std_netNorm_LH_male = mean_std(E_loc_netNorm_LH_male)

E_loc_DGN_LH_male = mean_local_efficiency(CBT_DGN_LH_male)
mean_DGN_LH_male, std_DGN_LH_male = mean_std(E_loc_DGN_LH_male)

E_loc_MVCFNet_LH_male = mean_local_efficiency(CBT_MVCFNet_LH_male)  
mean_MVCFNet_LH_male, std_MVCFNet_LH_male = mean_std(E_loc_MVCFNet_LH_male) 
   
E_loc_cMGINet_LH_male = mean_local_efficiency(CBT_cMGINet_LH_male)
mean_cMGINet_LH_male, std_cMGINet_LH_male = mean_std(E_loc_cMGINet_LH_male) 

E_loc_SCA_LH_male = mean_local_efficiency(CBT_SCA_LH_male)
mean_SCA_LH_male, std_SCA_LH_male = mean_std(E_loc_SCA_LH_male) 
   
######################################################
E_loc_netNorm_RH_male = mean_local_efficiency(CBT_netNorm_RH_male)
mean_netNorm_RH_male, std_netNorm_RH_male = mean_std(E_loc_netNorm_RH_male)

E_loc_DGN_RH_male = mean_local_efficiency(CBT_DGN_RH_male)
mean_DGN_RH_male, std_DGN_RH_male = mean_std(E_loc_DGN_RH_male)

E_loc_MVCFNet_RH_male = mean_local_efficiency(CBT_MVCFNet_RH_male)
mean_MVCFNet_RH_male, std_MVCFNet_RH_male = mean_std(E_loc_MVCFNet_RH_male) 
   
E_loc_cMGINet_RH_male = mean_local_efficiency(CBT_cMGINet_RH_male)
mean_cMGINet_RH_male, std_cMGINet_RH_male = mean_std(E_loc_cMGINet_RH_male) 

E_loc_SCA_RH_male = mean_local_efficiency(CBT_SCA_RH_male)
mean_SCA_RH_male, std_SCA_RH_male = mean_std(E_loc_SCA_RH_male) 

###########################################
E_loc_netNorm_LH_female = mean_local_efficiency(CBT_netNorm_LH_female)
mean_netNorm_LH_female, std_netNorm_LH_female = mean_std(E_loc_netNorm_LH_female)

E_loc_DGN_LH_female = mean_local_efficiency(CBT_DGN_LH_female)
mean_DGN_LH_female, std_DGN_LH_female = mean_std(E_loc_DGN_LH_female)

E_loc_MVCFNet_LH_female = mean_local_efficiency(CBT_MVCFNet_LH_female)    
mean_MVCFNet_LH_female, std_MVCFNet_LH_female = mean_std(E_loc_MVCFNet_LH_female) 
   
E_loc_cMGINet_LH_female = mean_local_efficiency(CBT_cMGINet_LH_female)
mean_cMGINet_LH_female, std_cMGINet_LH_female = mean_std(E_loc_cMGINet_LH_female) 

E_loc_SCA_LH_female = mean_local_efficiency(CBT_SCA_LH_female)
mean_SCA_LH_female, std_SCA_LH_female = mean_std(E_loc_SCA_LH_female) 

##############################################################
E_loc_netNorm_RH_female = mean_local_efficiency(CBT_netNorm_RH_female)
mean_netNorm_RH_female, std_netNorm_RH_female = mean_std(E_loc_netNorm_RH_female)

E_loc_DGN_RH_female = mean_local_efficiency(CBT_DGN_RH_female)
mean_DGN_RH_female, std_DGN_RH_female = mean_std(E_loc_DGN_RH_female)

E_loc_MVCFNet_RH_female = mean_local_efficiency(CBT_MVCFNet_RH_female)  
mean_MVCFNet_RH_female, std_MVCFNet_RH_female = mean_std(E_loc_MVCFNet_RH_female) 
   
E_loc_cMGINet_RH_female = mean_local_efficiency(CBT_cMGINet_RH_female)
mean_cMGINet_RH_female, std_cMGINet_RH_female = mean_std(E_loc_cMGINet_RH_female) 

E_loc_SCA_RH_female = mean_local_efficiency(CBT_SCA_RH_female)
mean_SCA_RH_female, std_SCA_RH_female = mean_std(E_loc_SCA_RH_female) 

#####################################################################
E_loc_DGN = [ mean_DGN_LH_female, mean_DGN_LH_male, mean_DGN_RH_female,  mean_DGN_RH_male]
E_loc_cMGINet = [ mean_cMGINet_LH_female, mean_cMGINet_LH_male, mean_cMGINet_RH_female, mean_cMGINet_RH_male]
E_loc_MVCFNet = [ mean_MVCFNet_LH_female, mean_MVCFNet_LH_male, mean_MVCFNet_RH_female, mean_MVCFNet_RH_male]
E_loc_netNorm = [ mean_netNorm_LH_female, mean_netNorm_LH_male, mean_netNorm_RH_female, mean_netNorm_RH_male]
E_loc_SCA = [ mean_SCA_LH_female, mean_SCA_LH_male, mean_SCA_RH_female, mean_SCA_RH_male]
E_loc_ground_truth = [ mean_ground_truth_LH_female, mean_ground_truth_LH_male, mean_ground_truth_RH_female, mean_ground_truth_RH_male]

local_efficiency_multiview(E_loc_netNorm, E_loc_SCA, E_loc_cMGINet, E_loc_MVCFNet, E_loc_DGN,E_loc_ground_truth)
