In [None]:
import pickle
import numpy as np
import pandas as pd
import networkx as nx
import seaborn as sns
from matplotlib import pyplot as plt
import gseapy as gp
import matplotlib_venn as mvenn

import utils_description as utils

In [None]:
sns.set(style="white", font='Arial', font_scale=1.)

In [None]:
infile_edges    = 'results/edges_scores.csv'
infile_clusters = 'results/results_clustering.csv'
infile_metadata = 'data/ADNI/ADNIMERGE_processed.csv'
infile_network  = 'data/networks/PPI_SNAP_brain_False.edgelist'

In [None]:
edges_scores = pd.read_csv(infile_edges, index_col=0).T
clusters     = pd.read_csv(infile_clusters, index_col=0)['cluster_3'].rename('cluster')
metadata     = pd.read_csv(infile_metadata, index_col=0)
network      = nx.read_edgelist(infile_network)

In [None]:
data = pd.concat([metadata, clusters], axis=1, join='inner')

# Representative graphs of each cluster
Features of each representative graph defining each cluster.

### Obtain representative graphs

In [None]:
utils.obtain_cluster_graphs_edges_scores(edges_scores, clusters, network)

### Compare the similarity between graphs
Against the original network

In [None]:
for k in range(len(data['cluster'].unique())):
    
    net_cluster_file = f'results/cluster{k}.edgelist'
    cluster_network  = nx.read_edgelist(net_cluster_file)
    
    sim   = utils.jaccard_similarity(network, cluster_network)
    sim_w = utils.weighted_jaccard(network, cluster_network)
    
    print(f'Cluster {k} vs. Original: ', sim, sim_w)
    

Between clusters

In [None]:
for i in range(len(data['cluster'].unique())):
    
    net_cluster_file_i = f'results/cluster{i}.edgelist'
    cluster_network_i  = nx.read_edgelist(net_cluster_file_i)
    
    for j in range(i, len(data['cluster'].unique())):
        
        if i != j:
    
            net_cluster_file_j = f'results/cluster{j}.edgelist'
            cluster_network_j  = nx.read_edgelist(net_cluster_file_j)

            sim   = utils.jaccard_similarity(cluster_network_i, cluster_network_j)
            sim_w = utils.weighted_jaccard(cluster_network_i, cluster_network_j)
            
            print(f'Cluster {j} vs. Cluster {i}: ', sim, sim_w)

### Global graphs metrics

In [None]:
# Add unit weights to original network
nx.set_edge_attributes(network, values = 1, name = 'weight')
print('Original network')
utils.graph_metrics(network)
print()

for k in [0, 1, 2]:

    net_cluster_file = f'results/cluster{k}.edgelist'
    cluster_network  = nx.read_edgelist(net_cluster_file)
    print(f'Cluster {k} network')
    utils.graph_metrics(cluster_network)
    print()

# Significantly different edge scores between clusters

In [None]:
edges_scores = edges_scores.loc[data.index]
statistics_edges = utils.get_significant_edges(edges_scores, data)

In [None]:
print(statistics_edges[['c0_mean', 'c1_mean', 'c2_mean', 'F', 'pvalue', 'significant', 'tukey']].to_latex())

### Visualization of the significant edge scores

In [None]:
columns = statistics_edges.index.to_list()
columns.append('cluster')

heatmap_data = pd.concat([edges_scores, clusters], axis=1, join='inner')
heatmap_data = heatmap_data[columns]
heatmap_data = heatmap_data.sort_values(by='cluster')

plt.figure(figsize=(10, 10))

# Edges
data1 = heatmap_data.copy()
data1['cluster'] = float('nan')
ax = sns.heatmap(data1.T, cmap="rocket")

# Clusters
data2 = heatmap_data.copy()
data2[statistics_edges.index] = float('nan')
print(data2['cluster'].value_counts())
my_cmap = sns.hls_palette(n_colors=3)
sns.heatmap(data2.T, cmap=my_cmap, cbar=False)

plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.tight_layout()
# plt.savefig('figures/heatmap_edges.png', dpi=500)
# plt.show()

### Pathway Enrichment Analysis
To show biological differences between clusters. Make a list of nodes (genes) in edges that (i) are significantly different between two clusters, and (ii) have the lowest mean. That is, edges that are significantly "worse" (more affected) between clusters.

In [None]:
c0_set, c1_set, c2_set, original_set = utils.make_gene_list(statistics_edges, network)

In [None]:
source = 'Reactome_2022' # Reactome_2022
pea = utils.enrichment_analysis([source], c0_set, c1_set, c2_set, original_set, statistics_edges.index)
pea['Adjusted P-value'] = pea['Adjusted P-value'].round(4)

In [None]:
# print(pea[['Term', 'Adjusted P-value', 'Genes', 'Edges', 'cluster']].to_latex(index=None))

In [None]:
c0_terms = set(pea.loc[pea['cluster'] == 'Cluster 0']['Term'].values.tolist())
c1_terms = set(pea.loc[pea['cluster'] == 'Cluster 1']['Term'].values.tolist())
c2_terms = set(pea.loc[pea['cluster'] == 'Cluster 2']['Term'].values.tolist())

plt.figure(figsize=(15, 15))
venn = mvenn.venn3([c0_terms, c1_terms, c2_terms], ('Cluster 1', 'Cluster 2', 'Cluster 3'))

venn.get_label_by_id('100').set_text('\n'.join(c0_terms-c1_terms-c2_terms))
venn.get_label_by_id('010').set_text('\n'.join(c1_terms-c2_terms-c0_terms))
venn.get_label_by_id('001').set_text('\n'.join(c2_terms-c1_terms-c0_terms))

venn.get_label_by_id('101').set_text('\n'.join(c0_terms&c2_terms-c1_terms))
venn.get_label_by_id('111').set_text('\n'.join(c0_terms&c2_terms&c1_terms))

plt.tight_layout()
# plt.savefig('figures/venn3_reactome.png', dpi=500)
plt.show()

### Variants analysis

In [None]:
variants = pd.read_csv('results/processed_variants_ADNI_WGS.csv', index_col=0)

nodes_file = open(f'results/nodes_significant.txt', 'r')
genes = nodes_file.read().split('\n')
nodes_file.close()

variants_sel = variants.loc[variants['SYMBOL'].isin(genes)]
variants_sel = variants_sel.sort_values(by='SYMBOL', ascending=True)
variants_sel.drop(columns='SYMBOL', inplace=True)
variants_sel = variants_sel.T

In [None]:
import sklearn.metrics as metrics
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold
from imblearn.ensemble import BalancedRandomForestClassifier
import numpy as np
# import shap

In [None]:
data_variants = variants.drop(columns='SYMBOL')
data_variants = data_variants.drop_duplicates(keep=False)
# dataset = pd.concat([data_variants.T, clusters], axis=1, join='inner')
dataset = pd.concat([variants_sel, clusters], axis=1, join='inner')

y = dataset['cluster']
x = dataset.drop(columns=['cluster'])

for task in ['C1_vs_All', 'C2_vs_All', 'C3_vs_All']:
    
    if task == 'C1_vs_All':
        y_tmp = y.replace({0:'Positive', 1:'Negative', 2:'Negative'})
        title = '(a) Cluster 1 vs. All\n'
        lim = 0.15
        
    elif task == 'C2_vs_All':
        y_tmp = y.replace({0:'Negative', 1:'Positive', 2:'Negative'})
        title = '(b) Cluster 2 vs. All\n'
        lim = 0.10
        
    elif task == 'C3_vs_All':
        y_tmp = y.replace({0:'Negative', 1:'Negative', 2:'Positive'})
        title = '(c) Cluster 3 vs. All\n'
        lim = 0.26

    x_train, x_test, y_train, y_test = train_test_split(x, y_tmp, stratify=y_tmp, test_size=0.2, random_state=42)

    model = BalancedRandomForestClassifier(random_state=42)
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    
    print(task)
    print()

    cm = pd.DataFrame(metrics.confusion_matrix(y_test, y_pred))
    plt.figure(figsize=(3.5, 2.5))
    ax = sns.heatmap(cm, annot=True, cmap='crest', cbar=False, fmt='g')
    ax.tick_params(tick2On=False, labelsize=False)
    # plt.suptitle(title)
    plt.title('Confusion matrix')
    plt.xlabel('True')
    plt.ylabel('Predicted')
    plt.tight_layout()
    plt.savefig(f'figures/conf_matrix_{task}.png', dpi=300)

    data_symbols = variants.groupby(variants.index)['SYMBOL'].apply(lambda x: ',\n'.join(x.astype(str)))

    importances = model.feature_importances_
    forest_importances = pd.DataFrame(importances, index=x.columns)
    forest_importances = pd.concat([forest_importances, data_symbols], axis=1, join='inner')
    forest_importances.reset_index(inplace=True)
    forest_importances['variant'] = forest_importances['index'] + ' (' + forest_importances['SYMBOL'] + ')'

    plt.figure(figsize=(4, 6))
    sns.barplot(forest_importances.sort_values(by=0, ascending=False).head(10), y='variant', x=0, palette='viridis')
    plt.title('Feature importances')
    plt.ylabel('')
    plt.xlim(0, lim)
    plt.tight_layout()
    # plt.savefig(f'figures/importances_{task}.png', dpi=300)

    for v in forest_importances.sort_values(by=0, ascending=False).head(10)['index']:
        print(v)
    print()