In [1]:
# ref https://docs.dgl.ai/generated/dgl.nn.pytorch.explain.HeteroGNNExplainer.html

In [2]:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroGNNExplainer
from tqdm import tqdm

In [3]:
import matplotlib.pyplot as plt
import networkx as nx

## Data
- create random heterogeneous graph

In [21]:
input_dim = 5
num_classes = 2

# The keys are in the form of string triplets (src_type, edge_type, dst_type), 
# The values are graph data in the form of (U,V), where (U[i],V[i]) forms the edge with ID i

g = dgl.heterograph({('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)
g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)

# create also the reverse edge
transform = dgl.transforms.AddReverse()
g = transform(g)

In [22]:
g

Graph(num_nodes={'game': 2, 'user': 3},
      num_edges={('game', 'rev_plays', 'user'): 4, ('user', 'plays', 'game'): 4},
      metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays')])

In [23]:
# cannot plot directly with networkx because dgl.to_networkx() only supports homogeneous

# options = {
#     'node_color': 'black',
#     'node_size': 20,
#     'width': 1,
# }
# G = dgl.to_networkx(g)
# plt.figure(figsize=[15,7])
# nx.draw(G, **options)

## Define and train the model for graph explanation

In [24]:
class Model(nn.Module):
    
    def __init__(self, in_dim, num_classes, canonical_etypes):
        super().__init__()
        self.etype_weights = nn.ModuleDict({
            '_'.join(c_etype): nn.Linear(in_dim, num_classes)
            for c_etype in canonical_etypes
        })
        
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            c_etype_func_dict = {}
            for c_etype in graph.canonical_etypes:
                src_type, etype, dst_type = c_etype
                wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
                graph.nodes[src_type].data[f'h_{c_etype}'] = wh
                if eweight is None:
                    c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
                        fn.mean('m', 'h'))
                else:
                    graph.edges[c_etype].data['w'] = eweight[c_etype]
                    c_etype_func_dict[c_etype] = (
                        fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
            graph.multi_update_all(c_etype_func_dict, 'sum')
            hg = 0
            for ntype in graph.ntypes:
                if graph.num_nodes(ntype):
                    hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype)
            return hg

In [25]:
feat = g.ndata['h']

In [26]:
"""
dictionary that associates input node features (values) 
with the respective node types (keys) present in the graph. 

The input features are of shape (Nt,Dt)
- Nt is the number of nodes for node type t
- Dt is the feature size for node type t
"""
feat

{'game': tensor([[-0.2012,  1.8359, -0.5838, -0.2247, -0.2292],
         [ 0.5800, -0.4405,  1.1574,  0.2054, -0.7466]]),
 'user': tensor([[-0.2755, -1.3307,  1.2350, -1.1448,  0.4985],
         [-0.1814,  0.4376, -0.9165, -0.1042,  0.6654],
         [ 0.5546, -0.0310,  1.4335,  0.6867, -1.8691]])}

In [28]:
model = Model(input_dim, num_classes, g.canonical_etypes)
optimizer = th.optim.Adam(model.parameters())
for epoch in tqdm(range(10)):
    logits = model(g, feat)
    loss = F.cross_entropy(logits, th.tensor([1]))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

100%|██████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 142.29it/s]


## Explain for the graph

In [30]:
# num_hops: number of hops for GNN info aggregation
# num_epochs: default 1
# lr: default 0.01
explainer = HeteroGNNExplainer(model, num_hops=1, num_epochs=2, log=True)

In [31]:
feat_mask, edge_mask = explainer.explain_graph(g, feat)

Explain graph: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.60it/s]


In [32]:
""" 
dictionary that associates the learned node feature importance masks (values) 
with the respective node types (keys). The masks are of shape (Dt), 
where Dt is the node feature size for node type t. 

The values are within range (0,1)
. The higher, the more important.
"""
feat_mask

{'game': tensor([0.4935, 0.4548, 0.4878, 0.5498, 0.4797]),
 'user': tensor([0.5090, 0.5293, 0.4333, 0.4891, 0.4740])}

In [33]:
""" 
The dictionary that associates the learned edge importance masks (values) 
with the respective canonical edge types (keys). The masks are of shape (Et), 
where Et is the number of edges for canonical edge type t in the graph. 
The values are within range (0,1)

. The higher, the more important.
"""
edge_mask

{('game', 'rev_plays', 'user'): tensor([0.6172, 0.6838, 0.6021, 0.6938]),
 ('user', 'plays', 'game'): tensor([0.5584, 0.1907, 0.6272, 0.4027])}

## Explain for the nodes

In [36]:
class ModelNode(nn.Module):
    
    """ this model learns about the users """
    
    def __init__(self, in_dim, num_classes, canonical_etypes):
        super().__init__()
        self.etype_weights = nn.ModuleDict({
            '_'.join(c_etype): nn.Linear(in_dim, num_classes)
            for c_etype in canonical_etypes
        })
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            c_etype_func_dict = {}
            for c_etype in graph.canonical_etypes:
                src_type, etype, dst_type = c_etype
                wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
                graph.nodes[src_type].data[f'h_{c_etype}'] = wh
                if eweight is None:
                    c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
                        fn.mean('m', 'h'))
                else:
                    graph.edges[c_etype].data['w'] = eweight[c_etype]
                    c_etype_func_dict[c_etype] = (
                        fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
            graph.multi_update_all(c_etype_func_dict, 'sum')
            return graph.ndata['h']

In [37]:
# define and train the model
model_node = ModelNode(input_dim, num_classes, g.canonical_etypes)
optimizer = th.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model_node(g, feat)['user']
    loss = F.cross_entropy(logits, th.tensor([1, 1, 1]))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [38]:
## Explain node

In [39]:
explainer_node_users = HeteroGNNExplainer(model_node, num_hops=1, num_epochs=2, log=True)

In [40]:
node_type = "user"
node_id = 0
new_center, sg, feat_mask_node, edge_mask_node = explainer_node_users.explain_node(
    node_type, 
    node_id, 
    g, 
    feat
)

Explain node 0 with type user: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 90.66it/s]


In [41]:
sg

Graph(num_nodes={'game': 1, 'user': 1},
      num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1},
      metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays')])

In [42]:
feat_mask_node

{'game': tensor([0.4595, 0.4973, 0.4566, 0.5186, 0.5063]),
 'user': tensor([0.4918, 0.5082, 0.4697, 0.4907, 0.4837])}

In [43]:
edge_mask

{('game', 'rev_plays', 'user'): tensor([0.6172, 0.6838, 0.6021, 0.6938]),
 ('user', 'plays', 'game'): tensor([0.5584, 0.1907, 0.6272, 0.4027])}