# Autograd trick for Graph Layers

In [1]:
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.gridspec as gridspec

import sys
import math
import torch
import torch.nn as nn
import numpy as np
import torch_scatter
import torchgraphs as tg

import yaml
import textwrap
import pandas as pd
from collections import OrderedDict
from pathlib import Path
from munch import munchify

from infection.dataset import InfectionDataset
from utils import import_
import relevance as lrp

np.set_printoptions(formatter={'float_kind': '{:5.2f}'.format, 'int_kind': '{:5d}'.format}, linewidth=150)

In [2]:
def computational_graph(op):
    if op is None:
        return 'None'
    res = f'{op.__class__.__name__} at {hex(id(op))}:'
    if op.__class__.__name__ == 'AccumulateGrad':
        res += f'variable at {hex(id(op.variable))}'
    for op in op.next_functions:
        res += '\n-' + textwrap.indent(computational_graph(op[0]), ' ')
    return res

## Input graph

In [3]:
graphs = tg.GraphBatch(
    node_features=torch.rand(4, 5),
    edge_features=torch.rand(3, 3),
    global_features=torch.rand(1, 7),
    senders=torch.tensor([0, 2, 3]),
    receivers=torch.tensor([1, 0, 0]),
    num_nodes_by_graph=torch.tensor([4]),
    num_edges_by_graph=torch.tensor([3]),
)
graphs.requires_grad_()

GraphBatch(#1, n=tensor([4]), e=tensor([3]), n_shape=torch.Size([5]), e_shape=torch.Size([3]), g_shape=torch.Size([7]))

## EdgeLinear

In [4]:
class EdgeLinearRelevance(tg.EdgeLinear):
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_edges = torch.tensor(0)

        if self.W_edge is not None:
            new_edges = lrp.add(new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge))
        if self.W_sender is not None:
            new_edges = lrp.add(
                new_edges, 
                lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_sender), 
                                 dim=0, index=graphs.senders)
            )
        if self.W_receiver is not None:
            new_edges = lrp.add(
                new_edges, 
                lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_receiver), 
                                 dim=0, index=graphs.receivers)
            )
        if self.W_global is not None:
            new_edges = lrp.add(
                new_edges, 
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), 
                                  dim=0, repeats=graphs.num_edges_by_graph)
            )
        if self.bias is not None:
            new_edges = lrp.add(new_edges, self.bias)

        return graphs.evolve(edge_features=new_edges)

In [5]:
net = EdgeLinearRelevance(1, edge_features=3, sender_features=None, receiver_features=None, global_features=None, bias=False)
out = net(graphs)
print(computational_graph(out.edge_features.grad_fn))

AddRelevanceBackward at 0x55aaeddd2898:
- None
- LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
 - AccumulateGrad at 0x7fd02d76c9b0:variable at 0x7fd02d494af8
 - AccumulateGrad at 0x7fd02d76cac8:variable at 0x7fd0a0344630


In [6]:
out = net(graphs)
rel_out = torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
print('Relevance edges out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.edge_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance edges out 3.0
[[ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 2.999999761581421
[[ 4.61 -1.01 -2.60]
 [-0.11  0.67  0.44]
 [-0.01  0.90  0.11]]

Relevance input total 2.999999761581421


In [7]:
net = EdgeLinearRelevance(1, edge_features=3, sender_features=None, receiver_features=None, global_features=7, bias=False)
out = net(graphs)
print(computational_graph(out.edge_features.grad_fn))

AddRelevanceBackward at 0x55aaeddda838:
- AddRelevanceBackward at 0x55aaeddd2898:
 - None
 - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
  - AccumulateGrad at 0x7fd02d4ad160:variable at 0x7fd02d494af8
  - AccumulateGrad at 0x7fd02d4ad198:variable at 0x7fd0a02f06c0
- IndexSelectBackward at 0x7fd02d76cdd8:
 - LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
  - AccumulateGrad at 0x7fd02d4ad128:variable at 0x7fd02d494b40
  - AccumulateGrad at 0x7fd02d4ad208:variable at 0x7fd0a02f09d8


In [8]:
out = net(graphs)
rel_out = torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
print('Relevance edges out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.edge_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance edges out 3.0
[[ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in -4.586839199066162
[[-3.05 -0.15  1.19]
 [-0.43 -0.57  1.17]
 [-0.15 -4.11  1.50]]

Relevance globals in 7.586838245391846
[[ 0.12 11.95  2.31 -10.08 -3.25  5.16  1.38]]

Relevance input total 2.9999990463256836


In [9]:
net = EdgeLinearRelevance(1, edge_features=3, sender_features=5, receiver_features=None, global_features=7, bias=False)
out = net(graphs)
print(computational_graph(out.edge_features.grad_fn))

AddRelevanceBackward at 0x55aaedddce78:
- AddRelevanceBackward at 0x55aaeddd41c8:
 - AddRelevanceBackward at 0x55aaeddd43d8:
  - None
  - LinearEpsilonRelevanceBackward at 0x55aaedd1a3a8:
   - AccumulateGrad at 0x7fd02d4adf98:variable at 0x7fd02d494af8
   - AccumulateGrad at 0x7fd02d4adfd0:variable at 0x7fd02d494630
 - IndexSelectRelevanceBackward at 0x55aaeddd2898:
  - LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
   - AccumulateGrad at 0x7fd02d4adf60:variable at 0x7fd02d494ab0
   - AccumulateGrad at 0x7fd02d4adeb8:variable at 0x7fd02d494cf0
  - None
- IndexSelectBackward at 0x7fd02d4ad048:
 - LinearEpsilonRelevanceBackward at 0x55aaeddd5b78:
  - AccumulateGrad at 0x7fd02d4adf98:variable at 0x7fd02d494b40
  - AccumulateGrad at 0x7fd02d4adfd0:variable at 0x7fd02d4945a0


In [10]:
out = net(graphs)
rel_out = torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
print('Relevance edges out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.edge_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance edges out 3.0
[[ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in -3.1179916858673096
[[-0.04  0.07 -0.76]
 [-0.03  1.52 -4.03]
 [-0.00  0.29 -0.14]]

Relevance nodes in 5.985638618469238
[[ 1.63  0.45  0.35  0.52 -1.26]
 [ 0.00  0.00  0.00  0.00  0.00]
 [ 0.04  1.26  0.76  1.87 -0.46]
 [ 0.29  0.30  0.08  0.19 -0.03]]

Relevance globals in 0.13235342502593994
[[ 0.91 -0.42 -1.30  0.25  2.34 -1.37 -0.27]]

Relevance input total 3.0000003576278687


In [11]:
net = EdgeLinearRelevance(1, edge_features=3, sender_features=5, receiver_features=5, global_features=7, bias=False)
out = net(graphs)
print(computational_graph(out.edge_features.grad_fn))

AddRelevanceBackward at 0x55aaedde3948:
- AddRelevanceBackward at 0x55aaedde2a08:
 - AddRelevanceBackward at 0x55aaedd1a3a8:
  - AddRelevanceBackward at 0x55aaeddd2898:
   - None
   - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
    - AccumulateGrad at 0x7fd02d4adb38:variable at 0x7fd02d494af8
    - AccumulateGrad at 0x7fd02d4ada20:variable at 0x7fd02d4b33a8
  - IndexSelectRelevanceBackward at 0x55aaeddd43d8:
   - LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
    - AccumulateGrad at 0x7fd02d4adac8:variable at 0x7fd02d494ab0
    - AccumulateGrad at 0x7fd02d4ad9e8:variable at 0x7fd0a0393120
   - None
 - IndexSelectRelevanceBackward at 0x55aaedde2588:
  - LinearEpsilonRelevanceBackward at 0x55aaedddde08:
   - AccumulateGrad at 0x7fd02d4adb38:variable at 0x7fd02d494ab0
   - AccumulateGrad at 0x7fd02d4ada20:variable at 0x7fd0a02f0f30
  - None
- IndexSelectBackward at 0x7fd02d4ade10:
 - LinearEpsilonRelevanceBackward at 0x55aaedde2fd8:
  - AccumulateGrad at 0x7fd02d4ad9e8:variable a

In [12]:
out = net(graphs)
rel_out = torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
print('Relevance edges out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.edge_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance edges out 3.0
[[ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 4.933183670043945
[[-0.24 -0.04 -0.23]
 [ 0.12  0.51  0.77]
 [ 0.04  3.17  0.84]]

Relevance nodes in -1.372652292251587
[[-2.00 -0.71  0.84  0.47 -0.87]
 [ 0.19  0.49 -0.17 -0.20  0.29]
 [-0.01 -0.07  0.05  0.54 -0.15]
 [-1.26 -0.51  0.17  1.82 -0.28]]

Relevance globals in -0.5605312585830688
[[ 0.43 -0.09 -0.15 -0.85  0.19 -0.35  0.27]]

Relevance input total 3.0000001192092896


In [13]:
net = EdgeLinearRelevance(1, edge_features=3, sender_features=5, receiver_features=5, global_features=7, bias=True)
out = net(graphs)
print(computational_graph(out.edge_features.grad_fn))

AddRelevanceBackward at 0x55aaeddeb9f8:
- AddRelevanceBackward at 0x55aaeddeb7e8:
 - AddRelevanceBackward at 0x55aaedde3948:
  - AddRelevanceBackward at 0x55aaeddd2898:
   - AddRelevanceBackward at 0x55aaedd1a3a8:
    - None
    - LinearEpsilonRelevanceBackward at 0x55aaedddde08:
     - AccumulateGrad at 0x7fd02d4c0518:variable at 0x7fd02d494af8
     - AccumulateGrad at 0x7fd02d4c0358:variable at 0x7fd02d4b39d8
   - IndexSelectRelevanceBackward at 0x55aaecb330f8:
    - LinearEpsilonRelevanceBackward at 0x55aaeddd43d8:
     - AccumulateGrad at 0x7fd02d4c04e0:variable at 0x7fd02d494ab0
     - AccumulateGrad at 0x7fd02d4c0550:variable at 0x7fd02d4b37e0
    - None
  - IndexSelectRelevanceBackward at 0x55aaedde2fd8:
   - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
    - AccumulateGrad at 0x7fd02d4c0518:variable at 0x7fd02d494ab0
    - AccumulateGrad at 0x7fd02d4c0358:variable at 0x7fd02d75d1f8
   - None
 - IndexSelectBackward at 0x7fd02d4c03c8:
  - LinearEpsilonRelevanceBackward at 0x

In [14]:
out = net(graphs)
rel_out = torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
print('Relevance edges out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.edge_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance edges out 3.0
[[ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in -0.16393667459487915
[[ 0.12 -0.02  0.13]
 [ 0.07 -0.41  0.54]
 [ 0.01 -0.77  0.18]]

Relevance nodes in 1.0252063274383545
[[ 0.53 -0.27  0.21  0.46 -0.37]
 [ 0.02 -0.34  0.09  0.23 -0.10]
 [ 0.01  0.06 -0.08  0.00 -0.02]
 [ 0.55  0.14 -0.08  0.00 -0.01]]

Relevance globals in 0.5248541235923767
[[-0.18 -0.01  0.05  0.29  0.01  0.53 -0.18]]

Relevance input total 1.386123776435852


## NodeLinear

In [15]:
class NodeLinearRelevance(tg.NodeLinear):
    def __init__(self, out_features, node_features=None, incoming_features=None, outgoing_features=None,
                 global_features=None, aggregation=None, bias=True):
        super(NodeLinearRelevance, self).__init__(out_features, node_features, incoming_features, 
                                                  outgoing_features, global_features, lrp.get_aggregation(aggregation), bias)
    
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_nodes = torch.tensor(0)

        if self.W_node is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(graphs.node_features, self.W_node)
            )
        if self.W_incoming is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes), self.W_incoming)
            )
        if self.W_outgoing is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes), self.W_outgoing)
            )
        if self.W_global is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, repeats=graphs.num_nodes_by_graph)
            )
        if self.bias is not None:
            new_nodes = lrp.add(new_nodes, self.bias)

        return graphs.evolve(node_features=new_nodes)

In [16]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=None, outgoing_features=None, global_features=None, bias=False)
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaeddd41c8:
- None
- LinearEpsilonRelevanceBackward at 0x55aaedde0918:
 - AccumulateGrad at 0x7fd02d4c9390:variable at 0x7fd02d494ab0
 - AccumulateGrad at 0x7fd02d4c9438:variable at 0x7fd0a03447e0


In [17]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance nodes in 4.000000476837158
[[ 5.16  3.05 -2.07 -2.25 -2.89]
 [ 0.52  2.45 -0.49 -0.80 -0.68]
 [-0.03 -1.76  0.92  1.65  0.21]
 [ 0.58  1.27 -0.30 -0.52 -0.04]]

Relevance input total 4.000000476837158


In [18]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=None, outgoing_features=None, global_features=7, bias=False)
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaeddd43d8:
- AddRelevanceBackward at 0x55aaedde0918:
 - None
 - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
  - AccumulateGrad at 0x7fd02d4c07f0:variable at 0x7fd02d494ab0
  - AccumulateGrad at 0x7fd02d4c0fd0:variable at 0x7fd02d4bfca8
- IndexSelectBackward at 0x7fd02d4c0ac8:
 - LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
  - AccumulateGrad at 0x7fd02d4c0518:variable at 0x7fd02d494b40
  - AccumulateGrad at 0x7fd02d4c0828:variable at 0x7fd02d4bffc0


In [19]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance nodes in 2.092041015625
[[ 0.88  0.17 -0.16  0.15 -0.55]
 [ 0.35  0.56 -0.16  0.22 -0.52]
 [ 0.02  0.38 -0.28  0.44 -0.16]
 [ 0.43  0.31 -0.10  0.15 -0.03]]

Relevance globals in 1.907958984375
[[ 0.87  2.23 -0.71 -1.25 -0.52  1.74 -0.45]]

Relevance input total 4.0


In [20]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=3, outgoing_features=None, global_features=7, bias=False, aggregation='sum')
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaedde4668:
- AddRelevanceBackward at 0x55aaeddd5b78:
 - AddRelevanceBackward at 0x55aaecb330f8:
  - None
  - LinearEpsilonRelevanceBackward at 0x55aaeddd43d8:
   - AccumulateGrad at 0x7fd02d4befd0:variable at 0x7fd02d494ab0
   - AccumulateGrad at 0x7fd02d4c9470:variable at 0x7fd02d4b3438
 - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
  - ScatterAddRelevanceBackward at 0x55aaedde0918:
   - AccumulateGrad at 0x7fd02d4c9470:variable at 0x7fd02d494af8
   - None
   - None
  - AccumulateGrad at 0x7fd02d4befd0:variable at 0x7fd02d4b3ee8
- IndexSelectBackward at 0x7fd02d4bef98:
 - LinearEpsilonRelevanceBackward at 0x55aaedddeaf8:
  - AccumulateGrad at 0x7fd02d4befd0:variable at 0x7fd02d494b40
  - AccumulateGrad at 0x7fd02d4be978:variable at 0x7fd02d4b3168


In [21]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in -0.464652419090271
[[-0.29  0.01 -0.07]
 [-0.15  0.13 -0.24]
 [-0.01  0.21 -0.07]]

Relevance nodes in 2.0217602252960205
[[ 0.10  0.10 -0.07  0.19  0.10]
 [ 0.04  0.31 -0.07  0.26  0.09]
 [ 0.00  0.15 -0.09  0.37  0.02]
 [ 0.06  0.25 -0.06  0.27  0.01]]

Relevance globals in 2.442891836166382
[[ 0.38  0.02  0.16  1.07  0.35  0.45  0.02]]

Relevance input total 3.9999996423721313


In [22]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=3, outgoing_features=3, global_features=7, bias=False, aggregation='sum')
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaecf50958:
- AddRelevanceBackward at 0x55aaedde1538:
 - AddRelevanceBackward at 0x55aaecb330f8:
  - AddRelevanceBackward at 0x55aaeddd5b78:
   - None
   - LinearEpsilonRelevanceBackward at 0x55aaedddeaf8:
    - AccumulateGrad at 0x7fd02d4c9908:variable at 0x7fd02d494ab0
    - AccumulateGrad at 0x7fd02d4c9f98:variable at 0x7fd02d4bfe10
  - LinearEpsilonRelevanceBackward at 0x55aaedde0918:
   - ScatterAddRelevanceBackward at 0x55aaeddd41c8:
    - AccumulateGrad at 0x7fd02d4c9f60:variable at 0x7fd02d494af8
    - None
    - None
   - AccumulateGrad at 0x7fd02d4c9f98:variable at 0x7fd02d4bf900
 - LinearEpsilonRelevanceBackward at 0x55aaeddd09d8:
  - ScatterAddRelevanceBackward at 0x55aaeddd43d8:
   - AccumulateGrad at 0x7fd02d4c9748:variable at 0x7fd02d494af8
   - None
   - None
  - AccumulateGrad at 0x7fd02d4c9be0:variable at 0x7fd02d4bf828
- IndexSelectBackward at 0x7fd02d763e48:
 - LinearEpsilonRelevanceBackward at 0x55aaeddeec38:
  - AccumulateGrad at 0x7fd0

In [23]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 1.6231404542922974
[[ 1.29 -0.08 -0.23]
 [-1.31 -0.04  1.53]
 [-0.04  0.32  0.17]]

Relevance nodes in 0.1613001823425293
[[-0.28 -0.01  0.02 -0.03  0.29]
 [ 0.37  0.07 -0.07  0.13 -0.88]
 [ 0.03  0.08 -0.23  0.46 -0.46]
 [ 0.62  0.05 -0.07  0.13 -0.07]]

Relevance globals in 2.2155590057373047
[[-0.98  1.20  1.24 -0.75 -1.34  2.58  0.27]]

Relevance input total 3.9999996423721313


In [24]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=3, outgoing_features=3, global_features=7, bias=False, aggregation='max')
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaeddf2498:
- AddRelevanceBackward at 0x55aaecf50958:
 - AddRelevanceBackward at 0x55aaeddd5b78:
  - AddRelevanceBackward at 0x55aaecb330f8:
   - None
   - LinearEpsilonRelevanceBackward at 0x55aaeddd43d8:
    - AccumulateGrad at 0x7fd0a0351dd8:variable at 0x7fd02d494ab0
    - AccumulateGrad at 0x7fd0a0351d30:variable at 0x7fd02d4bfaf8
  - LinearEpsilonRelevanceBackward at 0x55aaeddd41c8:
   - ScatterMaxRelevanceBackward at 0x55aaedde0918:
    - AccumulateGrad at 0x7fd0a0351240:variable at 0x7fd02d494af8
    - None
   - AccumulateGrad at 0x7fd0a0351d30:variable at 0x7fd02d4bf948
 - LinearEpsilonRelevanceBackward at 0x55aaedddafb8:
  - ScatterMaxRelevanceBackward at 0x55aaedddeaf8:
   - AccumulateGrad at 0x7fd0a0351ba8:variable at 0x7fd02d494af8
   - None
  - AccumulateGrad at 0x7fd0a0351d68:variable at 0x7fd02d4bf750
- IndexSelectBackward at 0x7fd0a0351128:
 - LinearEpsilonRelevanceBackward at 0x55aaeddeec38:
  - AccumulateGrad at 0x7fd0a0351d68:variable at 

In [25]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 3.8288698196411133
[[ 0.70  0.08  0.73]
 [-0.31 -0.49  2.96]
 [-0.06 -1.18  1.40]]

Relevance nodes in -4.182224750518799
[[ 0.25 -0.05  0.07 -0.10 -0.15]
 [ 0.54 -0.88  0.35 -0.77 -0.72]
 [ 0.03 -0.64  0.66 -1.59 -0.23]
 [ 3.23 -2.43  1.12 -2.66 -0.21]]

Relevance globals in 4.3533549308776855
[[-0.87  1.64 -1.29  2.38  2.90  0.19 -0.60]]

Relevance input total 4.0


In [26]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=3, outgoing_features=3, global_features=7, bias=False, aggregation='mean')
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaecf50958:
- AddRelevanceBackward at 0x55aaeddd43d8:
 - AddRelevanceBackward at 0x55aaeddd41c8:
  - AddRelevanceBackward at 0x55aaedddeaf8:
   - None
   - LinearEpsilonRelevanceBackward at 0x55aaecf28ca8:
    - AccumulateGrad at 0x7fd02d4c9b70:variable at 0x7fd02d494ab0
    - AccumulateGrad at 0x7fd02d4c9f60:variable at 0x7fd0a02f0a20
  - LinearEpsilonRelevanceBackward at 0x55aaedb762e8:
   - ScatterMeanRelevanceBackward at 0x55aaeddd5b78:
    - AccumulateGrad at 0x7fd02d4c9be0:variable at 0x7fd02d494af8
    - None
   - AccumulateGrad at 0x7fd02d4c9f60:variable at 0x7fd0a02f0168
 - LinearEpsilonRelevanceBackward at 0x55aaecb330f8:
  - ScatterMeanRelevanceBackward at 0x55aaedde0918:
   - AccumulateGrad at 0x7fd02d4c9748:variable at 0x7fd02d494af8
   - None
  - AccumulateGrad at 0x7fd02d4c9400:variable at 0x7fd0a02f00d8
- IndexSelectBackward at 0x7fd0a0351cf8:
 - LinearEpsilonRelevanceBackward at 0x55aaeddd9658:
  - AccumulateGrad at 0x7fd02d4c9400:variable a

In [27]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 1.6806154251098633
[[-0.39  0.04 -0.14]
 [ 0.10 -0.11  1.55]
 [-0.05  2.27 -1.58]]

Relevance nodes in -2.9802086353302
[[-0.09  0.04 -0.27 -0.25  0.47]
 [-0.04  0.14 -0.27 -0.36  0.45]
 [ 0.00 -0.08  0.39  0.58 -0.11]
 [-0.43  0.74 -1.69 -2.47  0.26]]

Relevance globals in 5.299594402313232
[[ 1.14  2.65  0.85  0.72 -0.91  1.48 -0.63]]

Relevance input total 4.0000011920928955


In [28]:
net = NodeLinearRelevance(1, node_features=5, incoming_features=3, outgoing_features=3, global_features=7, bias=True, aggregation='mean')
out = net(graphs)
print(computational_graph(out.node_features.grad_fn))

AddRelevanceBackward at 0x55aaecf50958:
- AddRelevanceBackward at 0x55aaeddd0168:
 - AddRelevanceBackward at 0x55aaeddd9658:
  - AddRelevanceBackward at 0x55aaedddeaf8:
   - AddRelevanceBackward at 0x55aaeddd41c8:
    - None
    - LinearEpsilonRelevanceBackward at 0x55aaedde0918:
     - AccumulateGrad at 0x7fd02d4c07b8:variable at 0x7fd02d494ab0
     - AccumulateGrad at 0x7fd02d4c0f98:variable at 0x7fd02d4a7360
   - LinearEpsilonRelevanceBackward at 0x55aaeddd5b78:
    - ScatterMeanRelevanceBackward at 0x55aaedb762e8:
     - AccumulateGrad at 0x7fd02d4c0c18:variable at 0x7fd02d494af8
     - None
    - AccumulateGrad at 0x7fd02d4c0f98:variable at 0x7fd02d4a7d80
  - LinearEpsilonRelevanceBackward at 0x55aaeddd43d8:
   - ScatterMeanRelevanceBackward at 0x55aaecf28ca8:
    - AccumulateGrad at 0x7fd02d4c0dd8:variable at 0x7fd02d494af8
    - None
   - AccumulateGrad at 0x7fd02d4c0940:variable at 0x7fd02d4a7a20
 - IndexSelectBackward at 0x7fd02d4be2b0:
  - LinearEpsilonRelevanceBackward at 0x

In [29]:
out = net(graphs)
rel_out = torch.ones_like(out.node_features) * (out.node_features != 0).float()
print('Relevance nodes out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.node_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance nodes out 4.0
[[ 1.00]
 [ 1.00]
 [ 1.00]
 [ 1.00]]

Relevance edges in 3.1812233924865723
[[-0.50  0.12 -0.18]
 [ 0.84  0.05  2.23]
 [ 0.07 -0.09  0.64]]

Relevance nodes in 6.130515098571777
[[-1.97  0.41 -0.28  1.93  2.10]
 [-0.10  0.17 -0.03  0.35  0.25]
 [-0.01  0.31 -0.16  1.83  0.20]
 [-0.66  0.51 -0.12  1.33  0.08]]

Relevance globals in -0.9691047668457031
[[-1.95 -1.08  1.71 -5.48  3.85  2.07 -0.10]]

Relevance input total 8.342633724212646


## Global Linear

In [30]:
class GlobalLinearRelevance(tg.GlobalLinear):
    def __init__(self, out_features, node_features=None, edge_features=None, global_features=None,
                 aggregation=None, bias=True):
        super(GlobalLinearRelevance, self).__init__(out_features, node_features, edge_features, 
                                                    global_features, lrp.get_aggregation(aggregation), bias)
    
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_globals = torch.tensor(0)

        if self.W_node is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_node)
            )
        if self.W_edges is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_edges)
            )
        if self.W_global is not None:
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(graphs.global_features, self.W_global)
            )
        if self.bias is not None:
            new_globals = lrp.add(new_globals, self.bias)

        return graphs.evolve(global_features=new_globals)

In [31]:
net = GlobalLinearRelevance(1, node_features=None, edge_features=None, global_features=7, bias=False)
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaedb762e8:
- None
- LinearEpsilonRelevanceBackward at 0x55aaeddd5b78:
 - AccumulateGrad at 0x7fd02d4be278:variable at 0x7fd02d494b40
 - AccumulateGrad at 0x7fd02d4bed30:variable at 0x7fd0a02f0480


In [32]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance globals in 0.9999998211860657
[[-0.25 -0.39 -0.06 -1.75  1.55  2.06 -0.16]]

Relevance input total 0.9999998211860657


In [33]:
net = GlobalLinearRelevance(1, node_features=5, edge_features=None, global_features=7, bias=False, aggregation='sum')
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaeddd9658:
- AddRelevanceBackward at 0x55aaeddd41c8:
 - None
 - LinearEpsilonRelevanceBackward at 0x55aaedde3948:
  - ScatterAddRelevanceBackward at 0x55aaedddde08:
   - AccumulateGrad at 0x7fd02d4c9a58:variable at 0x7fd02d494ab0
   - None
   - None
  - AccumulateGrad at 0x7fd02d4c9d68:variable at 0x7fd02d4bf900
- LinearEpsilonRelevanceBackward at 0x55aaedde0918:
 - AccumulateGrad at 0x7fd02d4c9d68:variable at 0x7fd02d494b40
 - AccumulateGrad at 0x7fd02d4c9390:variable at 0x7fd02d4c7558


In [34]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance nodes in 0.7885957956314087
[[ 0.21 -0.03  0.07  0.04  0.01]
 [ 0.08 -0.09  0.06  0.05  0.00]
 [ 0.00 -0.05  0.10  0.09  0.00]
 [ 0.21 -0.11  0.09  0.07  0.00]]

Relevance globals in 0.21140411496162415
[[-0.01  0.06  0.04 -0.05  0.04  0.11  0.02]]

Relevance input total 0.9999999105930328


In [35]:
net = GlobalLinearRelevance(1, node_features=5, edge_features=3, global_features=7, bias=False, aggregation='sum')
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaeddfce78:
- AddRelevanceBackward at 0x55aaedddde08:
 - AddRelevanceBackward at 0x55aaedde0918:
  - None
  - LinearEpsilonRelevanceBackward at 0x55aaeddd9658:
   - ScatterAddRelevanceBackward at 0x55aaedd1a3a8:
    - AccumulateGrad at 0x7fd02c45ae48:variable at 0x7fd02d494ab0
    - None
    - None
   - AccumulateGrad at 0x7fd02d4c9ac8:variable at 0x7fd02c462c18
 - LinearEpsilonRelevanceBackward at 0x55aaedde3948:
  - ScatterAddRelevanceBackward at 0x55aaeddd41c8:
   - AccumulateGrad at 0x7fd02c45ae80:variable at 0x7fd02d494af8
   - None
   - None
  - AccumulateGrad at 0x7fd02d4c9ac8:variable at 0x7fd02c462a68
- LinearEpsilonRelevanceBackward at 0x55aaeddf4d58:
 - AccumulateGrad at 0x7fd02d4c9278:variable at 0x7fd02d494b40
 - AccumulateGrad at 0x7fd02d4c9ac8:variable at 0x7fd02c4622d0


In [36]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance edges in 1.113856554031372
[[-0.19  0.01  0.16]
 [-0.10  0.20  0.58]
 [-0.01  0.31  0.16]]

Relevance nodes in 0.24859696626663208
[[ 0.02  0.01 -0.14  0.17 -0.18]
 [ 0.01  0.02 -0.12  0.23 -0.16]
 [ 0.00  0.01 -0.20  0.42 -0.05]
 [ 0.02  0.02 -0.18  0.36 -0.02]]

Relevance globals in -0.36245331168174744
[[-0.04  0.28 -0.07 -0.12 -0.26 -0.21  0.06]]

Relevance input total 1.0000002086162567


In [37]:
net = GlobalLinearRelevance(1, node_features=5, edge_features=3, global_features=7, bias=False, aggregation='max')
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaecf28ca8:
- AddRelevanceBackward at 0x55aaeddd9658:
 - AddRelevanceBackward at 0x55aaedde3948:
  - None
  - LinearEpsilonRelevanceBackward at 0x55aaedddde08:
   - ScatterMaxRelevanceBackward at 0x55aaeddf4d58:
    - AccumulateGrad at 0x7fd0a03519e8:variable at 0x7fd02d494ab0
    - None
   - AccumulateGrad at 0x7fd0a0351630:variable at 0x7fd02d4c7e10
 - LinearEpsilonRelevanceBackward at 0x55aaedde0918:
  - ScatterMaxRelevanceBackward at 0x55aaeddd41c8:
   - AccumulateGrad at 0x7fd0a0351d30:variable at 0x7fd02d494af8
   - None
  - AccumulateGrad at 0x7fd0a0351630:variable at 0x7fd02d4c7c60
- LinearEpsilonRelevanceBackward at 0x55aaedd1a3a8:
 - AccumulateGrad at 0x7fd0a0351940:variable at 0x7fd02d494b40
 - AccumulateGrad at 0x7fd0a0351630:variable at 0x7fd02d4c7d80


In [38]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance edges in 0.6134206652641296
[[ 0.03  0.00  0.00]
 [ 0.00  0.00  0.31]
 [ 0.00  0.27  0.00]]

Relevance nodes in 0.368700236082077
[[ 0.00  0.00  0.00  0.00  0.10]
 [ 0.00  0.00  0.00  0.00  0.00]
 [ 0.00  0.00  0.08  0.17  0.00]
 [ 0.02 -0.01  0.00  0.00  0.00]]

Relevance globals in 0.017879139631986618
[[-0.00  0.08  0.02 -0.13 -0.08  0.11  0.02]]

Relevance input total 1.0000000409781933


In [39]:
net = GlobalLinearRelevance(1, node_features=5, edge_features=3, global_features=7, bias=False, aggregation='mean')
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaeddcad18:
- AddRelevanceBackward at 0x55aaeddf4d58:
 - AddRelevanceBackward at 0x55aaeddd41c8:
  - None
  - LinearEpsilonRelevanceBackward at 0x55aaedde0918:
   - ScatterMeanRelevanceBackward at 0x55aaeddd9658:
    - AccumulateGrad at 0x7fd02d4c90f0:variable at 0x7fd02d494ab0
    - None
   - AccumulateGrad at 0x7fd02d4c9828:variable at 0x7fd02d4a6b40
 - LinearEpsilonRelevanceBackward at 0x55aaedddde08:
  - ScatterMeanRelevanceBackward at 0x55aaedde3948:
   - AccumulateGrad at 0x7fd02d4c9cc0:variable at 0x7fd02d494af8
   - None
  - AccumulateGrad at 0x7fd02d4c9828:variable at 0x7fd02d4a6a68
- LinearEpsilonRelevanceBackward at 0x55aaecf28ca8:
 - AccumulateGrad at 0x7fd02d4c9d30:variable at 0x7fd02d494b40
 - AccumulateGrad at 0x7fd02d4c9828:variable at 0x7fd02d4bf630


In [40]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance edges in 0.6989871263504028
[[ 0.29 -0.02  0.15]
 [ 0.15 -0.22  0.52]
 [ 0.01 -0.34  0.15]]

Relevance nodes in -0.8370841145515442
[[-0.17  0.06 -0.03 -0.07 -0.18]
 [-0.06  0.18 -0.03 -0.09 -0.16]
 [-0.00  0.12 -0.05 -0.16 -0.04]
 [-0.17  0.23 -0.04 -0.14 -0.02]]

Relevance globals in 1.1380969285964966
[[-0.18  0.53  0.28  0.14  0.35 -0.07  0.09]]

Relevance input total 0.9999999403953552


In [41]:
net = GlobalLinearRelevance(1, node_features=5, edge_features=3, global_features=7, bias=True, aggregation='sum')
out = net(graphs)
print(computational_graph(out.global_features.grad_fn))

AddRelevanceBackward at 0x55aaeddf6c58:
- AddRelevanceBackward at 0x55aaeddd2848:
 - AddRelevanceBackward at 0x55aaeddcad18:
  - AddRelevanceBackward at 0x55aaeddd41c8:
   - None
   - LinearEpsilonRelevanceBackward at 0x55aaedde3948:
    - ScatterAddRelevanceBackward at 0x55aaedddde08:
     - AccumulateGrad at 0x7fd02c47a320:variable at 0x7fd02d494ab0
     - None
     - None
    - AccumulateGrad at 0x7fd02c47a0f0:variable at 0x7fd02d4b3af8
  - LinearEpsilonRelevanceBackward at 0x55aaeddd9658:
   - ScatterAddRelevanceBackward at 0x55aaedde0918:
    - AccumulateGrad at 0x7fd02c47a2e8:variable at 0x7fd02d494af8
    - None
    - None
   - AccumulateGrad at 0x7fd02c47a0f0:variable at 0x7fd02d4b3a20
 - LinearEpsilonRelevanceBackward at 0x55aaeddf4b28:
  - AccumulateGrad at 0x7fd02c47a278:variable at 0x7fd02d494b40
  - AccumulateGrad at 0x7fd02c47a0f0:variable at 0x7fd02d4b3ea0
- AccumulateGrad at 0x7fd02c47a0b8:variable at 0x7fd02d4b3120


In [42]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance edges in 0.6768062114715576
[[ 0.32  0.00  0.01]
 [ 0.17  0.04  0.05]
 [ 0.01  0.06  0.01]]

Relevance nodes in 0.10417277365922928
[[ 0.28 -0.05 -0.01 -0.11  0.32]
 [ 0.11 -0.15 -0.01 -0.15  0.28]
 [ 0.00 -0.10 -0.02 -0.27  0.08]
 [ 0.28 -0.19 -0.01 -0.23  0.04]]

Relevance globals in -0.27423951029777527
[[-0.03  0.04 -0.02 -0.05  0.04 -0.21 -0.04]]

Relevance input total 0.5067394748330116


## Sequential

In [43]:
class NodeReLURelevance(tg.NodeFunction):
    def __init__(self):
        super(tg.NodeFunction, self).__init__(lrp.relu)


class EdgeReLURelevance(tg.EdgeFunction):
    def __init__(self):
        super(tg.EdgeFunction, self).__init__(lrp.relu)


class GlobalReLURelevance(tg.GlobalFunction):
    def __init__(self):
        super(tg.GlobalFunction, self).__init__(lrp.relu)

class ComplexGN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(OrderedDict({
            'edge': EdgeLinearRelevance(8, edge_features=3, bias=False),
            'edge_relu': EdgeReLURelevance(),
            'node': NodeLinearRelevance(8, node_features=5, bias=False),
            'node_relu': NodeReLURelevance(),
            'global': GlobalLinearRelevance(8, global_features=7, bias=False),
            'global_relu': GlobalReLURelevance(),
        }))
        self.hidden = nn.Sequential(OrderedDict({
            'edge': EdgeLinearRelevance(8, edge_features=8, sender_features=8, global_features=8, bias=False),
            'edge_relu': EdgeReLURelevance(),
            'node': NodeLinearRelevance(8, node_features=8, incoming_features=8, global_features=8, aggregation='max', bias=False),
            'node_relu': NodeReLURelevance(),
            'global': GlobalLinearRelevance(8, node_features=8, edge_features=8, aggregation='sum', bias=False),
            'global_relu': GlobalReLURelevance(),
        }))
        self.decoder = nn.Sequential(OrderedDict({
            'edge': EdgeLinearRelevance(1, edge_features=8, bias=False),
            'node': NodeLinearRelevance(1, node_features=8, bias=False),
            'global': GlobalLinearRelevance(1, global_features=8, bias=False),
        }))
        
    def forward(self, graphs):
        graphs = self.encoder(graphs)
        for l in range(3):
            graphs = self.hidden(graphs)
        graphs = self.decoder(graphs)
        return graphs
    
net = ComplexGN()

In [44]:
out = net(graphs)
rel_out = torch.ones_like(out.global_features) * (out.global_features != 0).float()
print('Relevance global out', rel_out.sum().item())
print(rel_out.numpy())

graphs.zero_grad_()
out.global_features.backward(rel_out)
rel_in = 0
if graphs.edge_features.grad is not None:
    rel_in += graphs.edge_features.grad.sum().item()
    print('\nRelevance edges in', graphs.edge_features.grad.sum().item())
    print(graphs.edge_features.grad.numpy())
if graphs.node_features.grad is not None:
    rel_in += graphs.node_features.grad.sum().item()
    print('\nRelevance nodes in', graphs.node_features.grad.sum().item())
    print(graphs.node_features.grad.numpy())
if graphs.global_features.grad is not None:
    rel_in += graphs.global_features.grad.sum().item()
    print('\nRelevance globals in', graphs.global_features.grad.sum().item())
    print(graphs.global_features.grad.numpy())
print('\nRelevance input total', rel_in)

Relevance global out 1.0
[[ 1.00]]

Relevance edges in 0.31765320897102356
[[-0.03 -0.03 -0.25]
 [-0.18 -0.13  1.26]
 [-0.01 -0.26 -0.06]]

Relevance nodes in 4.042593955993652
[[-1.56  0.52  0.34  0.59  0.12]
 [ 0.13 -1.00 -0.06 -1.26  0.07]
 [-0.03  0.89  1.23  0.38  0.25]
 [-0.30  2.48  0.31  1.00 -0.07]]

Relevance globals in -3.3602473735809326
[[ 0.33 -1.83  0.10  2.82 -1.53 -3.20 -0.05]]

Relevance input total 0.9999997913837433
