In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the mean Frobenius distance between templates (CBTs) generated by multi-view fusion methods 
#(SCA, netNorm, cMGINet, MVCF-Net and DGN) learned from the training set and multi-view networks of samples in the 
# testing set using 5 folds cross-validation

import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
import seaborn as sns
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'RH_GSP'

with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'SCA' ,'rb') as f:
        Frob_dist_SCA_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'cMGINet' ,'rb') as f:
        Frob_dist_cMGINet_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'MVCF-Net' ,'rb') as f:
        Frob_dist_MVCFNet_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'netNorm' ,'rb') as f:
        Frob_dist_netNorm_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'DGN' ,'rb') as f:
        Frob_dist_DGN_female = pickle.load(f)     
        
with open('std_female_multi_' + args_dataset, 'rb') as f:
        std_female_multi = pickle.load(f)

def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)
    
def Frob_plot_multiview( Frob_dist_SCA, Frob_dist_cMGINet,Frob_dist_MVCFNet,Frob_dist_netNorm,Frob_dist_DGN):
######### multiview fusion methods #########
    mean_SCA = np.mean(Frob_dist_SCA)
    mean_cMGINet = np.mean(Frob_dist_cMGINet)
    mean_MVCFNet = np.mean(Frob_dist_MVCFNet)
    mean_netNorm = np.mean(Frob_dist_netNorm)
    mean_DGN = np.mean(Frob_dist_DGN)
    
    Frob_dist_SCA.append(mean_SCA)   
    Frob_dist_cMGINet .append(mean_cMGINet)
    Frob_dist_MVCFNet.append(mean_MVCFNet) 
    Frob_dist_netNorm.append(mean_netNorm)
    Frob_dist_DGN.append(mean_DGN)
    
    plotdata = pd.DataFrame({
    "netNorm":Frob_dist_netNorm,
    "SCA":Frob_dist_SCA,
    "MVCF-Net":Frob_dist_MVCFNet,
    "cMGI-Net":Frob_dist_cMGINet,
    "DGN":Frob_dist_DGN,
    }, 
    index=['Fold 1', 'Fold 2', 'Fold 3', 'Fold 4', 'Fold 5', 'mean']
    )
    data = {'fold': [],
	'model': [],
	'value': [],
    'error': []}
    list1=['Fold 1', 'Fold 2', 'Fold 3', 'Fold 4', 'Fold 5', 'mean']
    list2=['netNorm', 'SCA', 'MVCF-Net', 'cMGI-Net', 'DGN' ]
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (6):
        for j in range (5):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"

    hue_order=['netNorm', 'SCA', 'MVC-FNet', 'cMGI-Net', 'DGN']
    box_pairs=[
        (("Fold 1","DGN"), ("Fold 1","netNorm")),
        (("Fold 1","DGN"), ("Fold 1","MVCF-Net")),
        (("Fold 1","DGN"), ("Fold 1","SCA")),
        (("Fold 1","DGN"), ("Fold 1","cMGI-Net")),
        (("Fold 2","DGN"), ("Fold 2","netNorm")),
        (("Fold 2","DGN"), ("Fold 2","MVCF-Net")),
        (("Fold 2","DGN"), ("Fold 2","SCA")),
        (("Fold 2","DGN"), ("Fold 2","cMGI-Net")),
        (("Fold 3","DGN"), ("Fold 3","netNorm")),
        (("Fold 3","DGN"), ("Fold 3","MVCF-Net")),
        (("Fold 3","DGN"), ("Fold 3","SCA")),
        (("Fold 3","DGN"), ("Fold 3","cMGI-Net")),
        (("Fold 4","DGN"), ("Fold 4","netNorm")),
        (("Fold 4","DGN"), ("Fold 4","MVCF-Net")),
        (("Fold 4","DGN"), ("Fold 4","SCA")),
        (("Fold 4","DGN"), ("Fold 4","cMGI-Net")),
        (("Fold 5","DGN"), ("Fold 5","netNorm")),
        (("Fold 5","DGN"), ("Fold 5","MVCF-Net")),
        (("Fold 5","DGN"), ("Fold 5","SCA")),
        (("Fold 5","DGN"), ("Fold 5","cMGI-Net")),
        (("mean","DGN"), ("mean","netNorm")),
        (("mean","DGN"), ("mean","MVCF-Net")),
        (("mean","DGN"), ("mean","SCA")),
        (("mean","DGN"), ("mean","cMGI-Net")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    df2 = pd.DataFrame(data)
    for i in range (6):
        for j in range (5):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df2 = df2.append(new_row1, ignore_index=True)
    err = "error"
    ax = grouped_barplot(df2, x, hue, y, err) 
                         
    ax = sns.barplot(data=df, x=x, y=y, hue=hue)
    colors = ['lightseagreen','mediumorchid','lightgreen','tomato','magenta']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    plt.ylim(12.8, 14.8)
    plt.yticks(np.arange(12.8, 14.8, 0.25))

    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue, box_pairs=box_pairs,
                        test='t-test_paired', loc='inside', verbose=2)
    ax.legend(loc='upper right', frameon=False, prop={'size': 6})        
        
        
errorlist = std_female_multi
# call the function
Frob_plot_multiview(Frob_dist_SCA_female, Frob_dist_cMGINet_female,Frob_dist_MVCFNet_female,Frob_dist_netNorm_female,Frob_dist_DGN_female)