In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average Jaccard distance across 5 folds cross-validation 
# between all combinations of the single-view fusion methods (SNF, SM-netFusion,and NAGFS) 

import pickle
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset = 'LH_GSP'

with open('CBT_male_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_LH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_LH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_LH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_LH_female = pickle.load(f)

    
args_dataset = 'RH_GSP'

with open('CBT_male_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_RH_male = pickle.load(f)
with open('CBT_male_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_RH_male = pickle.load(f)

with open('CBT_female_' + args_dataset + '_' + 'SNF', 'rb') as f:
    CBT_SNF_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'NAGFS', 'rb') as f:
    CBT_NAGFS_RH_female = pickle.load(f)
with open('CBT_female_' + args_dataset + '_' + 'SM-netFusion', 'rb') as f:
    CBT_SMnetFusion_RH_female = pickle.load(f)

def heat_map(CBT_SNF,CBT_SMnetFusion,CBT_NAGFS):
    CBT_singleview = []
    CBT_singleview.append(CBT_SNF)
    CBT_singleview.append(CBT_SMnetFusion)
    CBT_singleview.append(CBT_NAGFS)
    
    columns = ['SNF', 'SMnetFusion', 'NAGFS']
    index = ['SNF', 'SMnetFusion', 'NAGFS']
    H_singleview = pd.DataFrame(index=index, columns=columns)
    for i in range(len(CBT_singleview)):
        for j in range(len(CBT_singleview)):
            Jaccard_dist = mean_Jaccard_dist(CBT_singleview[i],CBT_singleview[j])
            H_singleview.iloc[i, j] = Jaccard_dist  
    H_singleview = H_singleview.astype(float) 
    
    return H_singleview

def Jaccard_dist(G1,G2):
    sum_min = 0
    sum_max = 0
    N = G1.shape[0]
    for i in range (N):
        for j in range(N):
            sum_min = sum_min + min(G1[i,j],G2[i,j])
            sum_max = sum_max + max(G1[i,j],G2[i,j])
    Jacc = 1-(sum_min/sum_max)
    return (Jacc)    

def mean_Jaccard_dist (G1,G2):
    summ = 0
    for i in range(len(G1)):
        summ = summ + Jaccard_dist(G1[i], G2[i])   
    mean = summ / 5             
    return(mean)


H_singleview_LH_female = heat_map(CBT_SNF_LH_female,CBT_SMnetFusion_LH_female,CBT_NAGFS_LH_female)
H_singleview_RH_female = heat_map(CBT_SNF_RH_female,CBT_SMnetFusion_RH_female,CBT_NAGFS_RH_female)
H_singleview_LH_male = heat_map(CBT_SNF_LH_male,CBT_SMnetFusion_LH_male,CBT_NAGFS_LH_male)
H_singleview_RH_male = heat_map(CBT_SNF_RH_male,CBT_SMnetFusion_RH_male,CBT_NAGFS_RH_male)

plt.rcParams['figure.dpi'] = 300

fig = plt.figure(figsize=(13, 3))
cmap = sns.cm.rocket_r
ax1 = fig.add_subplot(141)
ax2 = fig.add_subplot(142)
ax3 = fig.add_subplot(143)
ax4 = fig.add_subplot(144)
sns.set(font_scale=1)

sns.heatmap(H_singleview_LH_female, ax=ax1, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_singleview_LH_male, ax=ax2, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_singleview_RH_female, ax=ax3, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_singleview_RH_male, ax=ax4, fmt=".3f", vmin=0, cmap=cmap)

ax1.set_title('LH_female', y=1.1, fontsize=10)
ax2.set_title('LH_male', y=1.1, fontsize=10)
ax3.set_title('RH_female', y=1.1, fontsize=10)
ax4.set_title('RH_male', y=1.1, fontsize=10)
plt.rcParams['figure.dpi'] = 300
#fig.suptitle('Jaccard distance between all pairs singleview fusion methods', y=1.1)
plt.yticks(rotation=0)
plt.tight_layout()