In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the mean Frobenius distance between templates (CBTs) generated by single-view fusion methods
# (SNF, NAGFS and SM-netFusion) learned from the training set and single-view networks of samples in the testing set
# using 5 folds cross-validation

import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
import seaborn as sns
from statannot import add_stat_annotation

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'RH_GSP'

with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'SNF' ,'rb') as f:
        Frob_dist_SNF_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'SM-netFusion' ,'rb') as f:
        Frob_dist_SM_netFusion_female = pickle.load(f)
with open('Frobenius_Dist_female_' + args_dataset +'_'+ 'NAGFS' ,'rb') as f:
        Frob_dist_NAGFS_female = pickle.load(f) 
        
with open('std_female_single_' + args_dataset, 'rb') as f:
        std_female_single = pickle.load(f)


def grouped_barplot(df, cat,subcat, val , err):
    u = df[cat].unique()
    x = np.arange(len(u))
    subx = df[subcat].unique()
    offsets = (np.arange(len(subx))-np.arange(len(subx)).mean())/(len(subx)+1.)
    width= np.diff(offsets).mean()

    for i,gr in enumerate(subx):
        dfg = df[df[subcat] == gr]
        plt.bar(x+offsets[i], dfg[val].values,width = width, yerr=dfg[err].values)

######### single-view fusion methods #########    
def Frob_plot_singleview(Frob_dist_SNF, Frob_dist_SM_netFusion, Frob_dist_NAGFS):
    
    
    mean_NAGFS = np.mean(Frob_dist_NAGFS)
    mean_SNF = np.mean(Frob_dist_SNF)
    mean_SM_netFusion = np.mean(Frob_dist_SM_netFusion)
    
    Frob_dist_NAGFS.append(mean_NAGFS)   
    Frob_dist_SNF.append(mean_SNF)
    Frob_dist_SM_netFusion.append(mean_SM_netFusion) 
    plotdata1 = pd.DataFrame({
    "NAG-FS":Frob_dist_NAGFS,
    "SNF":Frob_dist_SNF,
    "SM-netFusion":Frob_dist_SM_netFusion,
    }, 
    index=['Fold 1', 'Fold 2', 'Fold 3', 'Fold 4', 'Fold 5', 'mean']
    )
    data = {'fold': [],
	'model': [],
	'value': [],
    'error': []}
    list1=['Fold 1', 'Fold 2', 'Fold 3', 'Fold 4', 'Fold 5', 'mean']
    list2=['NAG-FS', 'SNF', 'SM-netFusion' ]
    #create dataframe
    df = pd.DataFrame(data)
    for i in range (6):
        for j in range (3):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            new_row2 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df = df.append(new_row1, ignore_index=True)
            df = df.append(new_row2, ignore_index=True)
    x = "fold"
    y = "value"
    hue = "model"
    hue_order=['NAG-FS', 'SNF', 'SM-netFusion']
    box_pairs=[
        (("Fold 1","SM-netFusion"), ("Fold 1","SNF")),
        (("Fold 1","SM-netFusion"), ("Fold 1","NAG-FS")),
        (("Fold 2","SM-netFusion"), ("Fold 2","SNF")),
        (("Fold 2","SM-netFusion"), ("Fold 2","NAG-FS")),
        (("Fold 3","SM-netFusion"), ("Fold 3","SNF")),
        (("Fold 3","SM-netFusion"), ("Fold 3","NAG-FS")),
        (("Fold 4","SM-netFusion"), ("Fold 4","SNF")),
        (("Fold 4","SM-netFusion"), ("Fold 4","NAG-FS")),
        (("Fold 5","SM-netFusion"), ("Fold 5","SNF")),
        (("Fold 5","SM-netFusion"), ("Fold 5","NAG-FS")),
        (("mean","SM-netFusion"), ("mean","SNF")),
        (("mean","SM-netFusion"), ("mean","NAG-FS")),
        ]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    df1 = pd.DataFrame(data)
    for i in range (6):
        for j in range (3):
            new_row1 = {'fold': list1[i], 'model':list2[j], 'value':plotdata1.iat[i,j], 'error':errorlist[i,j]}
            #append row to the dataframe
            df1 = df1.append(new_row1, ignore_index=True)
    err = "error"
    ax = grouped_barplot(df1, x, hue, y, err )  
    ax = sns.barplot(data=df, x=x, y=y, hue=hue)
   
    colors = ['peru','gold','navajowhite']
    sns.set_palette(sns.color_palette(colors))
    sns.set_style("darkgrid")
    ax.set(xlabel=None)  # remove the x axis label
    ax.set(ylabel=None)  # remove the y axis label
    ax.set(ylim=(41.9,43.2))
    plt.yticks(np.arange(41.9,43.2, 0.1))
    add_stat_annotation(ax, data=df, x=x, y=y, hue=hue,  box_pairs=box_pairs,
                    test='t-test_paired', loc='inside', verbose=2)
    ax.legend(frameon=False, prop={'size': 6})
    ax.set(ylim=(41.9,43.2))
    plt.yticks(np.arange(41.9,43.2, 0.1))
           
errorlist = std_female_single 

# call the function
Frob_plot_singleview(Frob_dist_SNF_female, Frob_dist_SM_netFusion_female, Frob_dist_NAGFS_female)