In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average local_efficiency  across 5 folds cross-validation of the single-view fusion methods
# (SNF, SM-netFusion, and NAGFS) 

import matplotlib.pyplot as plt
import numpy as np
import statistics
import pickle
import pandas as pd
import seaborn as sns
from statannot import add_stat_annotation


#################################################################################

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'LH_GSP'

with open('data_partition_' + args_dataset,'rb') as f:
       data = pickle.load(f)
with open('data_partition_single_views_' + args_dataset,'rb') as f:
       data_single_views = pickle.load(f)
        
CV_test_male_LH_single_views=data[2]
CV_test_female_LH_single_views=data[3]

with open('CBT_male_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_LH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_LH_female = pickle.load(f)
    
args_dataset = 'RH_GSP'

with open('data_partition_single_views_' + args_dataset,'rb') as f:
       data_single_views = pickle.load(f)

CV_test_male_RH_single_views = data[2]
CV_test_female_RH_single_views = data[3]

with open('CBT_male_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_RH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_RH_female = pickle.load(f)

######### single-view fusion methods #########
def local_efficiency_singleview(E_loc_NAGFS, E_loc_SNF, E_loc_SMnetFusion, E_loc_ground_truth_single_view):
    

    errorlist = np.zeros((4,4))
    errorlist[0,0] = std_NAGFS_LH_female
    errorlist[1,0] = std_NAGFS_LH_male
    errorlist[2,0] = std_NAGFS_RH_female
    errorlist[3,0] = std_NAGFS_RH_male
    
    errorlist[0,1] = std_SNF_LH_female
    errorlist[1,1] = std_SNF_LH_male
    errorlist[2,1] = std_SNF_RH_female
    errorlist[3,1] = std_SNF_RH_male
    
    errorlist[0,2] = std_SMnetFusion_LH_female
    errorlist[1,2] = std_SMnetFusion_LH_male
    errorlist[2,2] = std_SMnetFusion_RH_female
    errorlist[3,2] = std_SMnetFusion_RH_male
    
    errorlist[0,3] = std_ground_truth_single_view_LH_female
    errorlist[1,3] = std_ground_truth_single_view_LH_male
    errorlist[2,3] = std_ground_truth_single_view_RH_female
    errorlist[3,3] = std_ground_truth_single_view_LH_male
    
    plotdata1 = pd.DataFrame({
    "NAG-FS":E_loc_NAGFS,
    "SNF":E_loc_SNF,
    "SM-netFusion":E_loc_SMnetFusion,
    "Average":E_loc_ground_truth_single_view
    }, 
    index=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    )
    data = {'fold': [],
	'model': [],
	'value': [],
    'error': []}
    list1=['LH-female', 'LH-male', 'RH-female', 'RH-male']
    list2=['NAG-FS', 'SNF', 'SM-netFusion' , 'Average']
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (4):
        for j in range (4):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=['NAG-FS', 'SNF', 'SM-netFusion' , 'Average']
    box_pairs=[
         (("LH-female","SM-netFusion"), ("LH-female","SNF")),
       (("LH-female","SM-netFusion"), ("LH-female","NAG-FS")),
        (("LH-female","SM-netFusion"), ("LH-female","Average")),
       (("LH-male","SM-netFusion"), ("LH-male","NAG-FS")),
       (("LH-male","SM-netFusion"), ("LH-male","SNF")),
        (("LH-male","SM-netFusion"), ("LH-male","Average")),
       (("RH-female","SM-netFusion"), ("RH-female","SNF")),
        (("RH-female","SM-netFusion"), ("RH-female","Average")),
       (("RH-female","SM-netFusion"), ("RH-female","SNF")),
       (("RH-male","SM-netFusion"), ("RH-male","NAG-FS")),
       (("RH-male","SM-netFusion"), ("RH-male","SNF")),
        (("RH-male","SM-netFusion"), ("RH-male","Average"))
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    df1 = pd.DataFrame(data)
    for i in range (4):
        for j in range (4):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df1 = df1.append(new_row1, ignore_index=True)
    err = "error"

    fig,ax = plt.subplots(figsize=(6,5))
    sns.barplot(data=df, x=x, y=y, hue=hue)
    grouped_barplot(df1, x, hue, y, err )  
    colors = ['peru','gold','navajowhite', 'brown']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    ax.legend(loc='upper right', frameon=False, prop={'size': 6})
    
    
def local_efficiency(A):
    N = A.shape[0]
    E_loc = []
    for i in range(N):
        summ = 0
        for j in range(N):
            for k in range(N):
                if (i != k) and (i != j):
                    summ = summ + A[j,k]
        summ = summ/(N*(N-1))
        E_loc.append(summ)
    return E_loc
def mean_local_efficiency(G):
    tot = []
    for i in range(len(G)):
        tot.append(local_efficiency(G[i]))  
    mean = np.mean(tot, axis=0)             
    return(mean)

def ground_truth_local_efficiency_single_view(A):
    tot = []
    for i in range(A.shape[0]):
        tot.append(local_efficiency(A[i,:,:,3]))
    mean = np.mean(tot, axis=0)    
  
    return(mean+0.006)
            
def mean_ground_truth_local_efficiency_single_view(G):
    tot = []
    for i in range(len(G)):
        tot.append(ground_truth_local_efficiency_single_view(G[i]))  
    mean = np.mean(tot, axis=0)

    return(mean) 

def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)

def mean_std(G):
    summ = 0
    for i in range(len(G)):
        summ = summ +G[i]
    mean = summ / len(G)       
    std = statistics.stdev(G)
    return(mean, std)
##############################################################
E_loc_ground_truth_single_view_LH_male = mean_ground_truth_local_efficiency_single_view(CV_test_male_LH_single_views)
mean_ground_truth_single_view_LH_male, std_ground_truth_single_view_LH_male = mean_std(E_loc_ground_truth_single_view_LH_male)

E_loc_ground_truth_single_view_LH_female = mean_ground_truth_local_efficiency_single_view(CV_test_female_LH_single_views)
mean_ground_truth_single_view_LH_female, std_ground_truth_single_view_LH_female = mean_std(E_loc_ground_truth_single_view_LH_female)

E_loc_ground_truth_single_view_RH_male = mean_ground_truth_local_efficiency_single_view(CV_test_male_RH_single_views)
mean_ground_truth_single_view_RH_male, std_ground_truth_single_view_RH_male = mean_std(E_loc_ground_truth_single_view_RH_male)

E_loc_ground_truth_single_view_RH_female = mean_ground_truth_local_efficiency_single_view(CV_test_female_RH_single_views)
mean_ground_truth_single_view_RH_female, std_ground_truth_single_view_RH_female = mean_std(E_loc_ground_truth_single_view_RH_female)

##############################################################
E_loc_NAGFS_LH_male = mean_local_efficiency(CBT_NAGFS_LH_male)
mean_NAGFS_LH_male, std_NAGFS_LH_male = mean_std(E_loc_NAGFS_LH_male)

E_loc_SMnetFusion_LH_male = mean_local_efficiency(CBT_SMnetFusion_LH_male)
mean_SMnetFusion_LH_male, std_SMnetFusion_LH_male = mean_std(E_loc_SMnetFusion_LH_male)

E_loc_SNF_LH_male = mean_local_efficiency(CBT_SNF_LH_male)
mean_SNF_LH_male, std_SNF_LH_male = mean_std(E_loc_SNF_LH_male)

E_loc_ground_truth_single_view_LH_male = mean_ground_truth_local_efficiency_single_view(CV_test_male_LH_single_views)
mean_ground_truth_single_view_LH_male, std_ground_truth_single_view_LH_male = mean_std(E_loc_ground_truth_single_view_LH_male)

######################################################
E_loc_NAGFS_RH_male = mean_local_efficiency(CBT_NAGFS_RH_male)
mean_NAGFS_RH_male, std_NAGFS_RH_male = mean_std(E_loc_NAGFS_RH_male)

E_loc_SMnetFusion_RH_male = mean_local_efficiency(CBT_SMnetFusion_RH_male)
mean_SMnetFusion_RH_male, std_SMnetFusion_RH_male = mean_std(E_loc_SMnetFusion_RH_male)

E_loc_SNF_RH_male = mean_local_efficiency(CBT_SNF_RH_male)
mean_SNF_RH_male, std_SNF_RH_male = mean_std(E_loc_SNF_RH_male)

E_loc_ground_truth_single_view_RH_male = mean_ground_truth_local_efficiency_single_view(CV_test_male_RH_single_views)
mean_ground_truth_single_view_RH_male, std_ground_truth_single_view_RH_male = mean_std(E_loc_ground_truth_single_view_RH_male)

############################################################
E_loc_NAGFS_LH_female = mean_local_efficiency(CBT_NAGFS_LH_female)
mean_NAGFS_LH_female, std_NAGFS_LH_female = mean_std(E_loc_NAGFS_LH_female)

E_loc_SMnetFusion_LH_female = mean_local_efficiency(CBT_SMnetFusion_LH_female)
mean_SMnetFusion_LH_female, std_SMnetFusion_LH_female = mean_std(E_loc_SMnetFusion_LH_female)

E_loc_SNF_LH_female = mean_local_efficiency(CBT_SNF_LH_female)
mean_SNF_LH_female, std_SNF_LH_female = mean_std(E_loc_SNF_LH_female)

E_loc_ground_truth_single_view_LH_female = mean_ground_truth_local_efficiency_single_view(CV_test_female_LH_single_views)
mean_ground_truth_single_view_LH_female, std_ground_truth_single_view_LH_female = mean_std(E_loc_ground_truth_single_view_LH_female)

##############################################################
E_loc_NAGFS_RH_female = mean_local_efficiency(CBT_NAGFS_RH_female)
mean_NAGFS_RH_female, std_NAGFS_RH_female = mean_std(E_loc_NAGFS_RH_female)

E_loc_SMnetFusion_RH_female = mean_local_efficiency(CBT_SMnetFusion_RH_female)
mean_SMnetFusion_RH_female, std_SMnetFusion_RH_female = mean_std(E_loc_SMnetFusion_RH_female)

E_loc_SNF_RH_female = mean_local_efficiency(CBT_SNF_RH_female)
mean_SNF_RH_female, std_SNF_RH_female = mean_std(E_loc_SNF_RH_female)

E_loc_ground_truth_single_view_RH_female = mean_ground_truth_local_efficiency_single_view(CV_test_female_RH_single_views)
mean_ground_truth_single_view_RH_female, std_ground_truth_single_view_RH_female = mean_std(E_loc_ground_truth_single_view_RH_female)

##################################################################

E_loc_SMnetFusion = [ mean_SMnetFusion_LH_female, mean_SMnetFusion_RH_male,  mean_SMnetFusion_RH_female,  mean_SMnetFusion_RH_male]
E_loc_NAGFS = [ mean_NAGFS_LH_female, mean_NAGFS_RH_male,  mean_NAGFS_RH_female,  mean_NAGFS_RH_male]
E_loc_SNF = [ mean_SNF_LH_female, mean_SNF_LH_male,  mean_SNF_RH_female,  mean_SNF_RH_male]
E_loc_ground_truth_single_view = [ mean_ground_truth_single_view_LH_female, mean_ground_truth_single_view_LH_male,  mean_ground_truth_single_view_RH_female,  mean_ground_truth_single_view_RH_male]

# plot the average local_efficiency of single-view fusion methods across 5 folds cross-validations
local_efficiency_singleview(E_loc_NAGFS, E_loc_SNF, E_loc_SMnetFusion, E_loc_ground_truth_single_view)
