# Interpretation of the predictive models

In [None]:
import os
import pickle
import pandas as pd

import torch
import torch.nn.functional as F

In [None]:
import numpy as np

import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [None]:
disease_id2name = {'D006262': 'Healthy', 'D012141': 'RTIs', 'D016585': 'BV', 'D003550': 'Cystic Fibrosis', 
                'D029424': 'COPD', 'D047928': 'Premature Birth', 'D001249': 'Asthma', 
                'D045169': 'SARS', 'D016889': 'Endometrial Neoplasms', 'D019449': 'Pouchitis', 
                'D012136': 'RSV', 'D010034': 'OME', 'C562730': 'ADE of Esophagus', 'D014627': 'Vaginitis', 
                'D008175': 'Lung Neoplasms', 'D014777': 'Virus Diseases', 'D020345': 'NEC', 
                'D010300': 'Parkinson Disease', 'D011014': 'Pneumonia', 'D003424': 'Crohn Disease', 
                'D002692': 'Chlamydia', 'D043183': 'IBS', 'D011565': 'Psoriasis', 'D014890': 'GPA', 
                'C566241': 'ASD II', 'D012507': 'Sarcoidosis'}

modified_hidden_nodes_name = {'24': 'C24', 
                            'CellobioseConsumptionimport': 'Cellobiose Consumption',
                            '212': 'C212',
                            'TrimethylamineN-oxideTrimethylamine-N-oxideProductionexport': 'Trimethylamine-N-oxide production', 
                            '152': 'C152', 
                            'ButanolConsumptionimport': 'Butanol consumption'}

run_num = 20

## 1. Heatmap: meta-data -> disease

In [None]:
def plot_att_heatmap(xlabel, ylabel, att, cmap='RdBu', figsize=(8,8), dpi=300, title=None, save_path=None, show=False): 
    fig, ax = plt.subplots(figsize=figsize)
    if cmap == 'RdBu': 
        im = ax.imshow(att, cmap=cmap, vmin=-np.max(np.abs(att)), vmax=np.max(np.abs(att)))
    else: 
        im = ax.imshow(att, cmap=cmap, vmin=0, vmax=np.max(np.abs(att)))
    
    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(len(xlabel)))
    ax.set_yticks(np.arange(len(ylabel)))
    ax.set_xticklabels(xlabel)
    ax.set_yticklabels(ylabel)

    # Rotate the tick labels and set their alignment
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations
    for i in range(len(ylabel)): 
        for j in range(len(xlabel)): 
            text = ax.text(j, i, '{:.2f}'.format(att[i, j]),
                           ha="center", va="center", color="k", fontsize=5)

    ax.set_title(title)
    fig.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=dpi)
        print('Save the figure into {}'.format(save_path))
    if show:
        plt.show()
    plt.close()
    
def stack_sth(x, sth): 
    stacked_sth = []
    sorted_x = {'{}_{}'.format(sth, str(i)): x['{}_{}'.format(sth, str(i))] 
                          for i in range(1, len(x.keys())+1)}
    for k, v in sorted_x.items(): 
        stacked_sth.append(torch.nanmean(v, dim=0))
    
    stacked_sth = torch.stack(stacked_sth, dim=0)
    return stacked_sth

In [None]:
# Load dicts
with open('./output/dict_explanation.pkl', 'rb') as file: 
    loaded_dicts = pickle.load(file)
print('Loaded embedding dictionaries')

allrun_atts = {'age': [], 'gender': [], 'bmi': [], 'bodysite': []}
for exp_index in range(run_num): 
    # Load explanation results
    with open('./output/Explain/explanation_sample_{}.pkl'.format(exp_index), 'rb') as file:
        loaded_explaining_data = pickle.load(file)
    print('Loaded exp_{}'.format(exp_index))
    
    # Plot heatmaps
    for metadata in ['age', 'gender', 'bmi', 'bodysite']: 
        meta_disease = stack_sth(loaded_explaining_data['{}_disease_explanations'.format(metadata)], sth='disease')
        att = np.array(meta_disease)
        
        meta_labels = [k for k in loaded_dicts[metadata].keys() if isinstance(k, str) and k != 'nan']
        disease_labels = [disease_id2name[k] for k in loaded_dicts['disease'].keys() if isinstance(k, str) and k != 'nan']
        
#         if metadata == 'gender': 
#             plot_att_heatmap(['gender'], disease_labels, att, title='Run {}'.format(exp_index), 
#                              save_path='./output/Figures/{}_disease_exp_{}.png'.format(metadata, exp_index))
#         else: 
#             plot_att_heatmap(meta_labels, disease_labels, att, title='Run {}'.format(exp_index), 
#                              save_path='./output/Figures/{}_disease_exp_{}.png'.format(metadata, exp_index))
            
        allrun_atts[metadata].append(att)

In [None]:
# Average on runs
allcls_atts = []
allcls_meta_labels = []
for metadata, att in allrun_atts.items():
    att = np.mean(np.stack(att, axis=0), axis=0)
    meta_labels = [k for k in loaded_dicts[metadata].keys() if isinstance(k, str) and k != 'nan']
    disease_labels = [disease_id2name[k] for k in loaded_dicts['disease'].keys() if isinstance(k, str) and k != 'nan']
        
    if metadata == 'gender': 
        plot_att_heatmap(['gender'], disease_labels, att, 
                        save_path='./output/Figures/{}_disease.pdf'.format(metadata))
    else: 
        plot_att_heatmap(meta_labels, disease_labels, att, 
                        save_path='./output/Figures/{}_disease.pdf'.format(metadata))
    
    allcls_atts.append(np.mean(np.abs(att), axis=1)) # one meta-data -> disease
    allcls_meta_labels.append(metadata)

In [None]:
# Sum attributes of different classes together
allcls_atts = np.stack(allcls_atts, axis=1)
disease_labels = [disease_id2name[k] for k in loaded_dicts['disease'].keys() if isinstance(k, str) and k != 'nan']

plot_att_heatmap(disease_labels, allcls_meta_labels, np.transpose(allcls_atts), cmap='OrRd', 
                    save_path='./output/Figures/metadata_disease.pdf')

## 2. Violinplot: hidden layers -> meta-data

In [None]:
# Load hidden layer meanings
edge_list_path = './Dataset/genus/EdgeList.csv'

metadatas = ['BMI', 'gender', 'age', 'bodysite', 'phenotype']
edge_df = pd.read_csv(edge_list_path)

parent_nodes = list(set(edge_df['parent'].tolist()))
parent_nodes = [node for node in parent_nodes if node not in metadatas] # remove metadata from parent nodes
parent_nodes.sort()
print('parent_nodes num:', len(parent_nodes))

hidden_nodes_id2name = {i: k for i, k in enumerate(parent_nodes)}

In [None]:
# Load groups of hidden layer
community_nodes = pd.read_csv('./Dataset/genus/communityNodes.csv')
metabolite_nodes = pd.read_csv('./Dataset/genus/metaboliteNodes.csv')
taxonomy_nodes = pd.read_csv('./Dataset/genus/taxonomyNodes.csv')

dict1 = {str(k): 'community' for k in community_nodes['nodes'].tolist()}
dict2 = {k: 'metabolite' for k in metabolite_nodes['nodes'].tolist()}
dict3 = {k: 'taxon (genus)' for k in taxonomy_nodes['nodes'].tolist()}

hidden_nodes_name2group = {**dict1, **dict2, **dict3}

In [None]:
# Load dicts
with open('./output/dict_explanation.pkl', 'rb') as file: 
    loaded_dicts = pickle.load(file)
print('Loaded embedding dictionaries')

allrun_atts = {'age': [], 'gender': [], 'bmi': [], 'bodysite': []}
for exp_index in range(run_num): 
    # Load explanation results
    with open('./output/Explain/explanation_sample_{}.pkl'.format(exp_index), 'rb') as file:
        loaded_explaining_data = pickle.load(file)
    print('Loaded exp_{}'.format(exp_index))
    
    for metadata in ['age', 'gender', 'bmi', 'bodysite']: 
        hidden_meta = stack_sth(loaded_explaining_data['hidden_{}_explanations'.format(metadata)], sth=metadata)
        att = np.array(hidden_meta)
        att = np.abs(att) # now we do not consider positive or nagetive impact
        
        allrun_atts[metadata].append(att)

In [None]:
# Average on runs and classes
allcls_atts = {}
for metadata, att in allrun_atts.items():
    att = np.mean(np.stack(att, axis=0), axis=1)
    allcls_atts[metadata] = att # hidden nodes -> one meta-data
    print(metadata, att.shape)

### top-15 among all groups

In [None]:
# settings
dpi = 300
figsize = (6,6)
group_colors = {'community': 'tab:green', 'metabolite': 'tab:orange', 'taxon (genus)': 'tab:blue'}
k_num = 15

# create patches for legend
patches = []
for k, v in group_colors.items():
    patches.append(mpatches.Patch(color=v, label=k, alpha=0.6))

    
for metadata in ['age', 'gender', 'bmi', 'bodysite']: 
    save_path = './output/Figures/hidden_{}.pdf'.format(metadata)

    avg_att = np.mean(allcls_atts[metadata], axis=0)
    
    # got the top-20 hidden nodes according to avg_att
    top_k_indices = np.argsort(avg_att)[::-1][:k_num] 
    top_k_names = [hidden_nodes_id2name[int(i)] for i in top_k_indices]
    print(top_k_names)
    top_k_groups = [hidden_nodes_name2group[name] for name in top_k_names]
    top_k_colors = [group_colors[g] for g in top_k_groups]
    top_k_att = allcls_atts[metadata][:, top_k_indices] 

    fig, ax = plt.subplots(figsize=figsize)
    plots = plt.violinplot([top_k_att[:, i] for i in range(top_k_att.shape[1])],
                      showmeans=True, showmedians=True, vert=False)
    
    # Set the color of the violin patches
    for pc, color in zip(plots['bodies'], top_k_colors): 
        pc.set_facecolor(color)

    # Set the color of the lines
    plots['cmedians'].set_colors(top_k_colors)
    plots['cbars'].set_colors(top_k_colors)
    plots['cmaxes'].set_colors(top_k_colors)
    plots['cmins'].set_colors(top_k_colors)
    plots['cmeans'].set_colors('k')
    
    plt.yticks([i for i in range(1, k_num+1)], 
               [modified_hidden_nodes_name[n] if n in modified_hidden_nodes_name.keys() else n for n in top_k_names], 
               rotation=0)
    plt.xlabel('importance')
    # plt.title('Top-k contributing hidden nodes for {}'.format(metadata))
    plt.legend(handles=patches, loc='upper right')
    
    plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
    # plt.show()
    plt.close()
    print('Save the figure into {}'.format(save_path))