In [None]:
import argparse
import models
import networkx as nx
import pgm_explainer as pe
import pylab as plt
import random
import utils
    
def arg_parse():
    parser = argparse.ArgumentParser(description="Explainer arguments.")

    parser.add_argument(
            "--bmname", dest="bmname", help="Name of the benchmark dataset"
        )
    parser.add_argument("--dataset", dest="dataset", help="Input dataset.")
    parser.add_argument("--ckptdir", dest="ckptdir", help="Model checkpoint directory")
    parser.add_argument(
            "--gpu",
            dest="gpu",
            action="store_const",
            const=True,
            default=False,
            help="whether to use GPU.",
        )
    parser.add_argument(
            "--node-start", dest="node_start", type=int, help="Index of starting node."
        )
    parser.add_argument(
            "--node-end", dest="node_end", type=int, help="Index of ending node."
        )
    parser.add_argument(
            "--num-perturb-samples", dest="num_perturb_samples", type=int, help="Number of perturbed sample using to generate explanations."
        )
    parser.add_argument(
            "--top-node", dest="top_node", type=int, help="Number of nodes in explanation."
        )
    parser.add_argument(
            "--epochs", dest="num_epochs", type=int, help="Number of epochs to train."
        )
    parser.add_argument(
            "--hidden-dim", dest="hidden_dim", type=int, help="Hidden dimension"
        )
    parser.add_argument(
            "--output-dim", dest="output_dim", type=int, help="Output dimension"
        )
    parser.add_argument(
            "--num-gc-layers",
            dest="num_gc_layers",
            type=int,
            help="Number of graph convolution layers before each pooling",
        )
    parser.add_argument(
            "--bn",
            dest="bn",
            action="store_const",
            const=True,
            default=False,
            help="Whether batch normalization is used",
        )
    parser.add_argument("--dropout", dest="dropout", type=float, help="Dropout rate.")
    parser.add_argument(
            "--method", dest="method", type=str, help="Method. Possible values: base, att."
        )
    parser.add_argument(
            "--nobias",
            dest="bias",
            action="store_const",
            const=False,
            default=True,
            help="Whether to add bias. Default to True.",
        )
    
        # Explainer

    parser.set_defaults(
            ckptdir=None,
            dataset="syn1",
            opt="adam",  
            opt_scheduler="none",
            lr=0.1,
            clip=2.0,
            batch_size=20,
            num_epochs=100,
            hidden_dim=20,
            output_dim=20,
            num_gc_layers=3,
            method="base",
            dropout=0.0,
            node_start = None,
            node_end = None,
            num_perturb_samples = 100,
            top_node = None
        )
    
    return parser

In [None]:
prog_args = arg_parse()
args = prog_args.parse_args(['--dataset','syn6','--num-perturb-samples', '1000','--top-node', '7'])

In [None]:
# Load model
A, X = utils.load_XA(args.dataset, datadir="../Generate_XA_Data/XAL")
L = utils.load_labels(args.dataset, datadir="../Generate_XA_Data/XAL")
num_classes = max(L) + 1
input_dim = X.shape[1]
num_nodes = X.shape[0]
ckpt = utils.load_ckpt(args)

print("input dim: ", input_dim, "; num classes: ", num_classes)
    
model = models.GcnEncoderNode(
    input_dim=input_dim,
    hidden_dim=args.hidden_dim,
    embedding_dim=args.output_dim,
    label_dim=num_classes,
    num_layers=args.num_gc_layers,
    bn=args.bn,
    args=args
)
model.load_state_dict(ckpt["model_state"]) 
pred = ckpt["save_data"]["pred"]

In [None]:
random.seed(21)

In [None]:
nodes_to_explain = list(range(300, 700))

In [None]:
explainer = pe.Node_Explainer(model, A, X, pred, 3)

In [None]:
for target in nodes_to_explain:
    subnodes, data, stats = explainer.explain(target, num_samples=1000, top_node=7, pred_threshold=0.2)
    pgm_explanation = explainer.pgm_generate(target, data, stats, subnodes)
    print("PGM nodes: ", pgm_explanation.nodes())
    print("PGM edges: ", pgm_explanation.edges())

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    nx.draw(pgm_explanation, with_labels=True, ax=axes[0])
    axes[0].set_title("PGM explanation")

    G = nx.from_numpy_array(A)
    neighborhood = list(nx.single_source_shortest_path_length(G, target, cutoff=3).keys())
    H = nx.subgraph_view(G, filter_node=lambda x: x in neighborhood)
    I = nx.subgraph_view(G, filter_node=lambda x: str(x) in list(pgm_explanation))

    def color_map_func(graph):
        color_map = []
        for node in graph:
            if node == target:
                color_map.append('red')
            elif str(node) in list(pgm_explanation):
                color_map.append('green')
            else:
                color_map.append('blue')
        return color_map

    nx.draw(G, node_size=30, node_color=color_map_func(G), ax=axes[1])
    axes[1].set_title('Original graph')

    nx.draw(H, node_size=50, node_color=color_map_func(H), ax=axes[2])
    axes[2].set_title('Neighborhood')

    nx.draw(I, node_size=100, node_color=color_map_func(I), ax=axes[3])
    axes[3].set_title('Explained nodes')

    plt.show()