In [None]:
"""
@author: nadachaari.nc@gmail.com
"""

# This code generate the average Kull-Back Leibler of Centrality metrics (KL_CM) across 5 folds cross-validation 
# between all combinations of the multi-view fusion methods (SCA, netNorm, cMGINet, MVCF-Net and DGN) 

import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# load data for 4 populations: female LH, male LH, female RH, male RH)
# LH means left hemisphere
# RH means right hemisphere
# GSP is Brain Genomics Superstrcut Project dataset which consist of healthy female and male populations

args_dataset1 = 'LH_GSP'# LH_GSP
args_dataset2 = 'male'  # 'female'

with open('KL_' + args_dataset2 + '_'+ args_dataset1 + '_' + 'cMGINet', 'rb') as f:
    KL_cMGINet = pickle.load(f)

with open('KL_' + args_dataset2 + '_'+ args_dataset1 + '_' + 'SCA', 'rb') as f:
    KL_SCA = pickle.load(f)

with open('KL_' + args_dataset2 + '_' + args_dataset1 + '_' + 'DGN', 'rb') as f:
    KL_DGN = pickle.load(f)

with open('KL_' + args_dataset2 + '_'+  args_dataset1 + '_' + 'MVCF-Net', 'rb') as f:
    KL_MVCFNet = pickle.load(f)

with open('KL_' + args_dataset2 + '_' + args_dataset1 + '_' + 'netNorm', 'rb') as f:
    KL_netNorm = pickle.load(f)

i = 7 # i in [0,1,3,4,5,6,7,9,10] which represent the centrality metric that will be plotted 
    
    # i==0 -- centrality metric is betweeness centrality
    # i==1 -- centrality metric is degree centrality
    # i==3 -- centrality metric is eigenvector centrality
    # i==4 -- centrality metric is PageRank
    # i==6 -- centrality metric is information centrality
    # i==7 -- centrality metric is random-walk centrality
    # i==9 -- centrality metric is Katz centrality
    # i==10 -- centrality metric is Laplacian centrality

# compute the average Kull-Back Leibler of the centrality metric (KL_CM) having the index i 
#(exmaple: i=7 --centrality metric is random-walk centrality) across 5 folds cross-validation between all 
#combinations of the multi-view fusion methods (SCA, netNorm, cMGINet, MVCF-Net and DGN )  
KL_CM_multiview = []
KL_CM_multiview.append(KL_SCA[i])
KL_CM_multiview.append(KL_netNorm[i])
KL_CM_multiview.append(KL_MVCFNet[i])
KL_CM_multiview.append(KL_cMGINet[i])
KL_CM_multiview.append(KL_DGN[i])

columns = ['SCA', 'netNorm', 'MVCFNet', 'cMGINet', 'DGN']
index = ['SCA', 'netNorm', 'MVCFNet', 'cMGINet', 'DGN']
H_CM_multiview = pd.DataFrame(index=index, columns=columns)
for i in range(len(KL_CM_multiview)):
    for j in range(len(KL_CM_multiview)):
        KL_CM = KL(np.array(KL_CM_multiview[i]), np.array(KL_CM_multiview[j]))
        H_CM_multiview.iloc[i, j] = KL_CM
H_CM_multiview = H_CM_multiview.astype(float)

# To regenerate the plots below, you need to calculate 'H_CM_multiview' for every centrality metric 
# (where: H_bc_multiview represents the heatmap for betweeness centrality)
#         H_dc_multiview represents the heatmap for degree centrality
#         H_ec_multiview represents the heatmap for Eigenvector centrality
#         H_pr_multiview represents the heatmap for PageRank

# plot grouped heatmaps for betweenness centrality,degree centrality, Eigenvector centrality and PageRank
fig = plt.figure(figsize=(13, 3))
plt.rcParams['figure.dpi'] = 300
cmap = sns.cm.rocket_r
ax1 = fig.add_subplot(141)
ax2 = fig.add_subplot(142)
ax3 = fig.add_subplot(143)
ax4 = fig.add_subplot(144)

sns.set(font_scale=1)
sns.heatmap(H_bc_multiview, ax=ax1, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_dc_multiview, ax=ax2, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_ec_multiview, ax=ax3, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_pr_multiview, ax=ax4, fmt=".3f", vmin=0, cmap=cmap)

ax1.set_title('betweenness centrality', y=1.1, fontsize=10)
ax2.set_title('degree centrality', y=1.1, fontsize=10)
ax3.set_title('Eigenvector centrality', y=1.1, fontsize=10)
ax4.set_title('PageRank', y=1.1, fontsize=10)
plt.tight_layout()

# plot grouped heatmaps for Katz centrality,Information centrality, Random-walk centrality and Laplacian centrality
# To regenerate the below plots, you need to calculate H_CM_multiview for every centrality metric 
# (where: H_bc_multiview represents the heatmap for Katz centrality)
#         H_dc_multiview represents the heatmap for Information centrality
#         H_ec_multiview represents the heatmap for Random-walk centrality
#         H_pr_multiview represents the heatmap for Laplacian centrality

fig = plt.figure(figsize=(13, 3))
plt.rcParams['figure.dpi'] = 300
cmap = sns.cm.rocket_r
ax1 = fig.add_subplot(141)
ax2 = fig.add_subplot(142)
ax3 = fig.add_subplot(143)
ax4 = fig.add_subplot(144)

sns.set(font_scale=1)
sns.heatmap(H_kc_multiview, ax=ax1, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_ic_multiview, ax=ax2, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_rwbc_multiview, ax=ax3, fmt=".3f", vmin=0, cmap=cmap)
sns.heatmap(H_sc_multiview, ax=ax4, fmt=".3f", vmin=0, cmap=cmap)

ax1.set_title('Katz centrality', y=1.1, fontsize=10)
ax2.set_title('Information centrality', y=1.1, fontsize=10)
ax3.set_title('Random-walk centrality', y=1.1, fontsize=10)
ax4.set_title('Laplacian centrality', y=1.1, fontsize=10)
plt.tight_layout()

# Kull-Back Leibler divergence function
    
def KL(P, Q):
    # Epsilon is used here to avoid conditional code for
    # checking that neither P nor Q is equal to 0. """
    epsilon = 0.000001
    P = P+epsilon
    Q = Q+epsilon
    divergence = (np.sum(P*np.log(P/Q)) + np.sum(Q*np.log(Q/P)))/2

    return divergence