In [1]:
import torch
import torch_geometric
from custom_dataloder import MYDataLoader
from torch_geometric.datasets import ZINC, TUDataset

In [2]:
dataset = TUDataset('./datasets', name="ZINC_full")

In [3]:
graph = dataset[0]

In [4]:
from models import GINConv

edge_feature_size = 3
node_feature_size = 28
hid_size = 8

torch.manual_seed(2022)
model = GINConv(edge_feature_size, node_feature_size, hid_size)

In [5]:
# synthetic mask
mask = torch.zeros(33, dtype=torch.float)

# select nodes 0 - 3
mask[:4] = 1
mask = torch.nn.Parameter(mask)

In [6]:
# select edges (0, 1), (1, 2), (2, 3)
edge_mask = mask[graph.edge_index[0]] * mask[graph.edge_index[1]]

In [7]:
edge_mask

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<MulBackward0>)

In [8]:
graph.edge_index

tensor([[ 0,  1,  2,  3,  4,  5,  6,  6,  7,  8,  8,  9,  9, 11, 11, 12, 13, 13,
         14, 14, 16, 16, 17, 17, 18, 19, 20, 20, 22, 23, 25, 25, 28, 28, 30, 31],
        [ 1,  2,  3,  4,  5,  6,  7, 32,  8,  9, 31, 10, 11, 12, 28, 13, 14, 27,
         15, 16, 17, 25, 18, 23, 19, 20, 21, 22, 23, 24, 26, 27, 29, 30, 31, 32]])

# case 1

Mask the nodes but not slice them

In [11]:
x = graph.x * mask[:, None]

In [12]:
edge_attr = graph.edge_attr * edge_mask[:, None]

In [13]:
logits = model(x, graph.edge_index, edge_attr)

In [14]:
# message propagated to node 4, since there exists edge (3, 4)
logits

tensor([[-0.0614, -0.0468,  0.0176, -0.0022,  0.0337,  0.0578, -0.0309,  0.0841],
        [-0.1483, -0.1292,  0.0436,  0.0166,  0.0393,  0.1117, -0.0818,  0.1585],
        [-0.1483, -0.1292,  0.0436,  0.0166,  0.0393,  0.1117, -0.0818,  0.1585],
        [-0.1483, -0.1292,  0.0436,  0.0166,  0.0393,  0.1117, -0.0818,  0.1585],
        [-0.0614, -0.0468,  0.0176, -0.0022,  0.0337,  0.0578, -0.0309,  0.0841],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000

In [15]:
# synthetic loss
loss = logits.sum()
loss.backward()

In [16]:
# there are non-zero grads on mask[0:5], since node 4 is affected
mask.grad

tensor([ 0.0086, -0.0869, -0.0869,  0.0086,  0.0520,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000])

# case 2

Does not allow message passing if edge_weight is 0

In [11]:
x = graph.x * mask[:, None]
edge_attr = graph.edge_attr * edge_mask[:, None]

In [12]:
# here the edge_mask is detached from the computation graph, otherwise
# https://github.com/Spazierganger/diffsubgraph/blob/master/models.py#L86 will introduce extra gradient
logits = model(x, graph.edge_index, edge_attr, edge_mask.detach())

In [13]:
# synthetic loss
loss2 = logits.sum()
loss2.backward()

In [14]:
# only grads on mask[0:4]
mask.grad

tensor([ 0.0086, -0.0869, -0.0869, -0.0434,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000])

# case 3

mask the `x` and `edge_attr` then slice them

In [11]:
x = (graph.x * mask[:, None])[:4]
edge_attr = (graph.edge_attr * edge_mask[:, None])[:3, :]
edge_index = graph.edge_index[:, :3]

In [12]:
logits = model(x, edge_index, edge_attr)

In [13]:
# synthetic loss
loss2 = logits.sum()
loss2.backward()

In [15]:
# only grads on mask[0:4], and the same as case 2
mask.grad

tensor([ 0.0086, -0.0869, -0.0869, -0.0434,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000])

# case 4

First propagate then mask

In [9]:
logits = model(graph.x, graph.edge_index, graph.edge_attr)

In [10]:
# synthetic loss
loss2 = (logits * mask[:, None]).sum()
loss2.backward()

In [11]:
# all non-zero grads
mask.grad

tensor([ 0.0520,  0.0105,  0.0105,  0.0105,  0.0105,  0.0105,  0.0105,  0.1692,
         0.1120,  0.1233,  0.1692,  0.0105,  0.0644,  0.0105,  0.0644,  0.0105,
         0.1233,  0.1233,  0.0644,  0.0105,  0.0644,  0.0105,  0.0105,  0.0237,
         0.0105,  0.1233,  0.0105,  0.0237,  0.0105,  0.0500,  0.1233,  0.2027,
        -0.1352])