In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import torch
from sig.utils.graph_utils import plot_important_subgraph

In [2]:
# NOTE: Modify these specs to optimize the visualization

figsize = (12, 12)                                       # figure size

flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c"]
node_cmap = sns.color_palette("muted").as_hex() + flatui # node color map
node_feat_cmap = 'Reds'                                  # matplotlib cmap for node feature heatmap

edge_importance_cmap = plt.cm.YlOrBr                     # for edge heatmap
important_edge_color = '#ff0000'                         # for important edge for extracted subgraph
unimportant_edge_color = '#d3d3d3'                       # for unimportant edge for extracted subgraph

draw_kwargs = {}                                         # additional kwargs for nx.draw_networkx
draw_kwargs['width'] = 5                                 # edge width
draw_kwargs['arrows'] = False                            # show edge arrows
draw_kwargs['with_labels'] = False                       # show node labels
draw_kwargs['node_size'] = 450                           # node size

In [3]:
dataset_example_lookup = {                               # examples to plot
    'mutagenicity': [98, 409],
    'reddit_binary': [106, 157]
}
dataset_plot_mode_lookup = {
    'mutagenicity': ['graph', 'node_feat'],
    'reddit_binary': ['graph']
}

In [4]:
model_size_lookup = {
    'mutagenicity': 16,
    'reddit_binary': 32
}
model_lookup = {
    'gnn': 'gcn',
    'gat': 'gat',
    'pred_grad': 'gcn',
    'mag_pred_grad': 'gcn',
    'pred_sig_grad': 'sigcn'
}
output_subdir_lookup = {
    'gnn': 'nonsig',
    'gat': 'nonsig',
    'pred_grad': 'nonsig',
    'mag_pred_grad': 'nonsig',
    'pred_sig_grad': 'sig_small_reg'
}
explainer_plot_mode_lookup = {
    'gnn': ['graph', 'node_feat'],
    'gat': ['graph'],
    'pred_grad': ['graph'],
    'mag_pred_grad': ['node_feat'],
    'pred_sig_grad': ['graph', 'node_feat']
}
allowed_explainers = ['gat', 'pred_grad', 'mag_pred_grad', 'pred_sig_grad', 'gnn']
allowed_datasets = ['mutagenicity', 'reddit_binary']

In [5]:
def plot_fig(
    explainer, 
    plot_mode,
    dataset,
    index
):
    assert explainer in allowed_explainers, \
        'explainer has to {}'.format(allowed_explainers)
    assert dataset in allowed_datasets, \
        'dataset has to be {}'.format(allowed_datasets)
    assert plot_mode in explainer_plot_mode_lookup[explainer], \
        'plot_mode={} is not available for {}'.format(plot_mode, explainer)
    
    source_dir = 'output/real_graph/hidden_{}/{}/{}/{}/model_output_0/'.format(
        model_size_lookup[dataset],
        output_subdir_lookup[explainer], 
        dataset,
        model_lookup[explainer]
    )
    output_dir = 'figs/{}/'.format(dataset)
    if plot_mode == 'graph':
        minsize = 15
        source_dir = source_dir + '{}_explainer_{}_files_minsize_{}/'.format(explainer, plot_mode, minsize)
        output_prefix = 'index_{}_{}_{}_minsize_{}'.format(index, explainer, plot_mode, minsize)
    else: 
        source_dir = source_dir + '{}_explainer_{}_files/'.format(explainer, plot_mode)
        output_prefix = 'index_{}_{}_{}'.format(index, explainer, plot_mode)
    source_prefix = source_dir + 'index_{}'.format(index)
    output_prefix = output_dir + output_prefix
    
    G = nx.read_gpickle(source_prefix + '.gpkl')
    graph_info = torch.load(source_prefix + '_info.pt')
    node_color = [node_cmap[i] for i in graph_info['node_type']]
    pos = graph_info['pos']
    
    if plot_mode == 'graph':
        edge_score = graph_info['edge_score']
        edge_index = graph_info['edge_index']
        important_edge_mask = graph_info['important_edge_mask']

        # plot edge importance heatmap
        fig = plt.figure(figsize=figsize)
        nx.draw_networkx(
            G,
            pos=pos,
            node_color=node_color,
            edge_color=edge_score.detach().cpu().numpy(),
            edge_cmap=edge_importance_cmap,
            **draw_kwargs
        )
        plt.tight_layout()
        plt.savefig(output_prefix + '_heatmap.png', format='PNG')
        plt.close()

        # plot important subgraph
        fig = plt.figure(figsize=figsize)
        plot_important_subgraph(
            edge_index,
            important_edge_mask,
            node_color,
            G=G,
            pos=pos,
            important_edge_color=important_edge_color,
            unimportant_edge_color=unimportant_edge_color,
            **draw_kwargs
        )
        plt.tight_layout()
        plt.savefig(output_prefix + '_subgraph.png', format='PNG')
        plt.close()
    else:
        node_feat_score = graph_info['node_feat_score']

        fig, ax = plt.subplots()
        ax.imshow(
            node_feat_score.detach().cpu().numpy()[np.newaxis, :], 
            cmap=node_feat_cmap, 
            aspect='auto'
        )
        ax.set_xticks(
            np.arange(-.5, node_feat_score.shape[0], 1), 
            minor=True
        )
        ax.grid(
            which='minor', 
            color='black', 
            linestyle='-', 
            linewidth=2
        )
        ax.set_xticks([])
        ax.set_yticks([])
        plt.tight_layout()
        plt.savefig(output_prefix + '_heatmap.png', format='PNG')
        plt.close()

In [6]:
for explainer in allowed_explainers:
    for dataset in allowed_datasets:
        plot_modes = set(explainer_plot_mode_lookup[explainer]).intersection(
            set(dataset_plot_mode_lookup[dataset])
        )
        for plot_mode in plot_modes:
            for index in dataset_example_lookup[dataset]:
                plot_fig(explainer, plot_mode, dataset, index)