In [1]:
# For Node Classification
# Select One Node, then Show the results for that node. 

# def ExplainThatNode(no_of_nodes):
#     return explain_results_for_that_nodes

# def GetGraph(no_of_graphs):
#    return that_graph_info



In [5]:
from models import *
def GetModel(model_no):
    experiment = Path('../models/infection/lr.001_nodes.1_count1_wd0_QYNATD')
    net, net_gp, net_lrp = load_nets(experiment)
    return net, net_gp, net_lrp

def GetGraph(graph_no):
    graph_in = tg.Graph(
        node_features=torch.tensor([
            [-1, -1,  .45,  .23],
            [ 1, -1,  .52, -.12],
            [-1, -1, -.43,  .47],
            [-1, -1,  .78,  .96],
            [ 1, -1, -.82, -.28],
        ]),
        edge_features=torch.tensor([
            [ 1, -.13],
            [ 1,  .54],
            [-1, -.26],
            [-1,  .98],
            [ 1,  .17],
        ]),
        senders=torch.tensor(  [0, 1, 3, 4, 4]),
        receivers=torch.tensor([2, 2, 2, 2, 3]),
    )
    graph_target = tg.Graph(
        node_features=torch.tensor([0, 1, 1, 0, 1]).view(-1, 1),
        global_features=torch.tensor([3])
    )
    name = 'graph-1'
    whole_graph = {
        "graph_in":graph_in,
        "graph_target":graph_target,
        "name":name
    }
    return whole_graph

models = GetModel(1)
whole_graph = GetGraph(1)
graph_in = whole_graph["graph_in"]
graph_target = whole_graph["graph_target"]
name = whole_graph["name"]
net, _, _ = models
graph_out = net(tg.GraphBatch.collate([graph_in]))[0]
print(graph_out.node_features)
print(graph_out.global_features)

tensor([[-1.6330],
        [ 0.2310],
        [ 0.8082],
        [-1.3081],
        [ 0.1493]], grad_fn=<SliceBackward>)
tensor([1.6719], grad_fn=<SelectBackward>)


In [11]:
### SA
EXPLAIN_SA = 11
EXPLAIN_GBP = 12
EXPLAIN_LRP = 13
def SA_ExplainNode(net, graph_in, node_no):
    batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
    graph_out = net(batch)[0]

    N = node_no
    node_relevance = torch.zeros_like(graph_out.node_features)
    node_relevance[N] = 1

    graph_in.zero_grad_()
    graph_out.node_features.backward(node_relevance)

    node_importance = batch.node_features.grad.pow(2).sum(dim=1)
    edge_importance = batch.edge_features.grad.pow(2).sum(dim=1)
    return node_importance, edge_importance

def GBP_ExplainNode(net_gp, graph_in, node_no):
    batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
    batch.node_features.register_hook(lambda grad: grad.clamp(min=0))
    batch.edge_features.register_hook(lambda grad: grad.clamp(min=0))
    graph_out = net_gp(batch)[0]

    N = node_no
    node_relevance = torch.zeros_like(graph_out.node_features)
    node_relevance[N] = 1

    batch.zero_grad_()
    graph_out.node_features.backward(node_relevance)
    node_importance = batch.node_features.grad.pow(2).sum(dim=1)
    edge_importance = batch.edge_features.grad.pow(2).sum(dim=1)
    return node_importance, edge_importance

def LRP_ExplainNode(net_lrp, graph_in, node_no):
    batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
    graph_out = net_lrp(batch)[0]

    N = node_no
    node_relevance = torch.zeros_like(graph_out.node_features)
    node_relevance[N] = graph_out.node_features[N]

    graph_in.zero_grad_()
    graph_out.node_features.backward(node_relevance)
    node_importance = batch.node_features.grad.pow(1).sum(dim=1)
    edge_importance = batch.edge_features.grad.pow(1).sum(dim=1)
    return node_importance, edge_importance

def ExplainNode(ExplainMethod, models, graph_in, node_no):
    net, net_gp, net_lrp = models
    if ExplainMethod == EXPLAIN_SA:
        return SA_ExplainNode(net, graph_in,  node_no)
    elif ExplainMethod == EXPLAIN_GBP:
        return GBP_ExplainNode(net_gp, graph_in,  node_no)
    elif ExplainMethod == EXPLAIN_LRP:
        return LRP_ExplainNode(net_lrp, graph_in,  node_no)
    else:
        print("Not implemented error")
        return None, None
#print(ExplainNode(EXPLAIN_SA, models, graph_in,2))
#print(ExplainNode(EXPLAIN_GBP,models, graph_in,2))
node_importance_list = []
edge_importance_list = []

for i in range(5):
    n1, e1 = ExplainNode(EXPLAIN_LRP,models, graph_in,i)
    node_importance_list.append(n1.view(1,-1))
    edge_importance_list.append(e1.view(1,-1))
node_importance = torch.cat(node_importance_list, dim=0)
edge_importance = torch.cat(edge_importance_list, dim=0)
print(node_importance)
print(edge_importance)

tensor([[-1.2337,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.6303,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0868, -0.8393,  0.0636,  1.3328],
        [ 0.0000,  0.0000,  0.0000, -1.1307,  0.9172],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.5486]])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.2610, -0.0496,  0.8742,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.6954],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


In [None]:
'''
whole_graph = {
    "graph_in":graph_in,
    "graph_target":graph_target,
    "name":name
}
print(whole_graph)
import pickle as pkl
with open("graph-1","wb") as f:
    pkl.dump(whole_graph,f)
with open("graph-1","rb") as f:
    load_whole_graph = pkl.load(f)
    print(load_whole_graph)
'''

In [None]:
def heatmaps(graph_in):
    fig1 = plt.figure(figsize=(6, 4), dpi=100)
    grid = plt.GridSpec(4, 16)

    ax = fig1.add_subplot(grid[:2, :-1])
    ax.bar(torch.arange(graph_in.num_nodes), graph_in.node_features.grad.clamp(min=0).sum(dim=1), color='C3')
    ax.bar(torch.arange(graph_in.num_nodes), graph_in.node_features.grad.clamp(max=0).sum(dim=1), color='C0')
    ax.set_xticks([])
    ax.axhline(0, color='k', linestyle='-', linewidth=.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', which='both', bottom=False, labelbottom=False)

    ax = fig1.add_subplot(grid[2:, :-1], sharex=ax)
    vmax = max(graph_in.node_features.grad.abs().max(), 1e-16)
    vmin = -vmax
    im=ax.imshow(graph_in.node_features.grad.t(), cmap='bwr', vmin=vmin, vmax=vmax, aspect='auto')
    ax.set_xticks(torch.arange(graph_in.num_nodes))
    ax.set_xlim(-.5, graph_in.num_nodes -.5)
    ax.set_yticks(torch.arange(graph_in.node_features.shape[1]))
    ax.set_ylim(graph_in.node_features.shape[1] -.5, -.5)
    ax.set_yticklabels(['Sick', 'Immune', 'Noise', 'Noise'])
    ax.set_xlabel('Node')
    for (j,i),label in np.ndenumerate(graph_in.node_features.t().detach().numpy()):
        ax.text(i,j, label_txt(label),ha='center',va='center')
        
    ax = fig1.add_subplot(grid[2:, -1])
    fig1.colorbar(mappable=im, cax=ax)
    
    fig2 = plt.figure(figsize=(6, 4), dpi=100)
    grid = plt.GridSpec(4, 16)
    
    ax = fig2.add_subplot(grid[:2, :-1])
    ax.bar(torch.arange(graph_in.num_edges), graph_in.edge_features.grad.clamp(min=0).sum(dim=1), color='C3')
    ax.bar(torch.arange(graph_in.num_edges), graph_in.edge_features.grad.clamp(max=0).sum(dim=1), color='C0')
    ax.set_xticks([])
    ax.set_xlim(-.5, graph_in.num_edges-.5)
    ax.axhline(0, color='k', linestyle='-', linewidth=.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', which='both', bottom=False, labelbottom=False)
    
    ax = fig2.add_subplot(grid[2:, :-1], sharex=ax)
    edge_vmax = max(graph_in.edge_features.grad.abs().max(), 1e-16)
    edge_vmin = -edge_vmax
    im=ax.imshow(graph_in.edge_features.grad.t(), cmap='bwr', vmin=edge_vmin, vmax=edge_vmax, aspect='auto')
    ax.set_xticks(torch.arange(graph_in.num_edges))
    ax.set_xlim(-.5, graph_in.num_edges -.5)
    ax.set_xticklabels([f'${s} \\to {r}$' for s, r in zip(graph_in.senders, graph_in.receivers)], rotation=0)
    ax.set_yticks(torch.arange(graph_in.edge_features.shape[1]))
    ax.set_ylim(graph_in.edge_features.shape[1] -.5, -.5)
    ax.set_yticklabels(['Virtual', 'Noise'])
    ax.set_xlabel('Edge');
    for (j,i),label in np.ndenumerate(graph_in.edge_features.t().detach().numpy()):
        ax.text(i,j, label_txt(label),ha='center',va='center')

    ax = fig2.add_subplot(grid[2:, -1])
    fig2.colorbar(mappable=im, cax=ax)

    return fig1, fig2

In [None]:
def plot_node_edge(g_nx, graph_in, layout, squared=False):
    node_color = batch.node_features.grad.pow(2 if squared else 1).sum(dim=1)
    edge_color = batch.edge_features.grad.pow(2 if squared else 1).sum(dim=1)
    vmax = max(node_color.abs().max(), edge_color.abs().max(), 1e-16)
    vmin = -vmax
    
    infected = (graph_in.node_features[:,0] == 1).nonzero().view(-1).tolist()
    immune = (graph_in.node_features[:,1] == 1).nonzero().view(-1).tolist()
    others = [i for i in range(graph_in.num_nodes) if i not in infected and i not in immune]

    virtual = [list(g_nx.edges)[i] for i in (graph_in.edge_features[:,0] == 1).nonzero().view(-1).tolist()]
    nonvirtual = [e for e in g_nx.edges if e not in virtual]

    fig = plt.figure(figsize=(3.2, 1.6))
    grid = plt.GridSpec(1, 2, figure=fig, wspace=0, hspace=0)

    ax = fig.add_subplot(grid[0])
    node_ids = {i: str(i) for i in range(graph_in.num_nodes)}
    nx.draw_networkx_nodes(g_nx, nodelist=infected, node_size=node_size, node_color='C3', pos=layout, ax=ax)
    nx.draw_networkx_nodes(g_nx, nodelist=immune,   node_size=node_size, node_color='C2', pos=layout, ax=ax)
    nx.draw_networkx_nodes(g_nx, nodelist=others,   node_size=node_size, node_color='C0', pos=layout, ax=ax)
    nx.draw_networkx_edges(g_nx, edgelist=virtual,    node_size=node_size, edge_color='c',  pos=layout, ax=ax, width=2., alpha=.7, arrowsize=15)
    nx.draw_networkx_edges(g_nx, edgelist=nonvirtual, node_size=node_size, edge_color='k',  pos=layout, ax=ax, arrowsize=15)
    nx.draw_networkx_labels(g_nx, pos=layout, labels=node_ids, ax=ax, font_size=12, font_weight='bold', font_color='white')
    ax.axis('image')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(-.04, +.04)
    ax.set_ylim(-.04, +.04)
    ax.axis('off')

    ax = fig.add_subplot(grid[1])
    nx.draw_networkx_nodes(g_nx, node_size=node_size, node_color=node_color, cmap='bwr',vmin=-vmax, vmax=vmax, pos=layout, ax=ax, linewidths=1, edgecolors='k')
    nx.draw_networkx_edges(g_nx, node_size=node_size, edge_color=edge_color, edge_cmap=cm.get_cmap('bwr'), edge_vmin=-vmax, edge_vmax=vmax, pos=layout, ax=ax, arrowsize=20)
    nx.draw_networkx_labels(g_nx, pos=layout, labels=node_ids, ax=ax, font_size=12, font_weight='bold', font_color='k')
    ax.axis('image')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(-.04, +.04)
    ax.set_ylim(-.04, +.04)
    ax.axis('off')
    
    return fig

In [None]:
### SA
batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
graph_out = net(batch)[0]

N = 4
node_relevance = torch.zeros_like(graph_out.node_features)
node_relevance[N] = 1

graph_in.zero_grad_()
graph_out.node_features.backward(node_relevance)

fig = plot_node_edge(graph_in.to_networkx(), batch, layout, squared=True)
#fig.savefig(f'{name}-sa.png', dpi=300, pad_inches=0)
display(fig)

heat_n, heat_e = heatmaps(batch)
#heat_n.savefig(f'{name}-sa-nodes.png', dpi=300, pad_inches=0)
#heat_e.savefig(f'{name}-sa-edges.png', dpi=300, pad_inches=0)
display(heat_n)
display(heat_e)

In [None]:
batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
batch.node_features.register_hook(lambda grad: grad.clamp(min=0))
batch.edge_features.register_hook(lambda grad: grad.clamp(min=0))
graph_out = net_gp(batch)[0]

N = 2
node_relevance = torch.zeros_like(graph_out.node_features)
node_relevance[N] = 1

batch.zero_grad_()
graph_out.node_features.backward(node_relevance)

fig = plot_node_edge(graph_in.to_networkx(), batch, layout, squared=True)
#fig.savefig(f'{name}-gbp.png', dpi=300, pad_inches=0)
display(fig)

heat_n, heat_e = heatmaps(batch)

heat_e.axes[0].yaxis.get_major_formatter().set_powerlimits((-3,4))
heat_e.axes[2].yaxis.get_major_formatter().set_powerlimits((-3,4))

#heat_n.savefig(f'{name}-gbp-nodes.png', dpi=300, pad_inches=0)
#heat_e.savefig(f'{name}-gbp-edges.png', dpi=300)
display(heat_n)
display(heat_e)

In [None]:
batch = tg.GraphBatch.collate([graph_in]).requires_grad_()
graph_out = net_lrp(batch)[0]

N = 2
node_relevance = torch.zeros_like(graph_out.node_features)
node_relevance[N] = graph_out.node_features[N]

graph_in.zero_grad_()
graph_out.node_features.backward(node_relevance)

fig = plot_node_edge(graph_in.to_networkx(), batch, layout, squared=False)
#fig.savefig(f'{name}-lrp.png', dpi=300, pad_inches=0)
display(fig)

heat_n, heat_e = heatmaps(batch)
#heat_n.savefig(f'{name}-lrp-nodes.png', dpi=300, pad_inches=0)
#heat_e.savefig(f'{name}-lrp-edges.png', dpi=300, pad_inches=0)
display(heat_n)
display(heat_e)