In [1]:
import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)

In [3]:
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)


from dgl.dataloading import GraphDataLoader

Node feature dimensionality: 3
Number of graph categories: 2


In [4]:
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import train_test_split

num_examples = len(dataset)
x_train,x_test = train_test_split(torch.arange(num_examples),test_size=0.2)
train_sampler = SubsetRandomSampler(x_train)
test_sampler = SubsetRandomSampler(x_test)

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=4, drop_last=False
)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=4, drop_last=False
)

In [5]:
num_examples

1113

In [6]:
it = iter(test_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=153, num_edges=779,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), tensor([0, 0, 1, 0])]


In [7]:
batched_graph, labels = batch
print(
    "Number of nodes for each graph element in the batch:",
    batched_graph.batch_num_nodes(),
)
print(
    "Number of edges for each graph element in the batch:",
    batched_graph.batch_num_edges(),
)

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)

Number of nodes for each graph element in the batch: tensor([24, 17, 32, 80])
Number of edges for each graph element in the batch: tensor([130,  81, 160, 408])
The original graphs in the minibatch:
[Graph(num_nodes=24, num_edges=130,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=17, num_edges=81,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=32, num_edges=160,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=80, num_edges=408,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})]


In [8]:
from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")

In [9]:
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling
class MLP(nn.Module):
    """Construct two-layer MLP-type aggreator for GIN model"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)


class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        num_layers = 5
        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        for layer in range(num_layers - 1):  # excluding the input layer
            if layer == 0:
                mlp = MLP(input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
            self.ginlayers.append(
                GINConv(mlp, learn_eps=False)
            )  # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        # linear functions for graph sum poolings of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
        self.drop = nn.Dropout(0.5)
        self.pool = (
            SumPooling()
        )  # change to mean readout (AvgPooling) on social network datasets

    def forward(self, g, h):
        # list of hidden representation at each layer (including the input layer)
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            hidden_rep.append(h)
        score_over_layer = 0
        # perform graph sum pooling over all nodes in each layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
        return score_over_layer

In [None]:
# Create the model with given dimensions
model = GIN(dataset.dim_nfeats, 16, dataset.gclasses).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

for epoch in range(100):
    for batched_graph, labels in train_dataloader:
#         print(labels)
        batched_graph, labels = batched_graph.to('cuda'), labels.to('cuda')
        pred = model(batched_graph, batched_graph.ndata["attr"].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    num_correct = 0
    num_tests = 0
    for batched_graph, labels in test_dataloader:
#         print(labels)
        batched_graph, labels = batched_graph.to('cuda'), labels.to('cuda')
        pred = model(batched_graph, batched_graph.ndata["attr"].float())
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)

    print(f"{epoch}Test accuracy:", num_correct / num_tests)

0Test accuracy: 0.5605381165919282
1Test accuracy: 0.6547085201793722
2Test accuracy: 0.6681614349775785
3Test accuracy: 0.6098654708520179
4Test accuracy: 0.547085201793722
5Test accuracy: 0.6591928251121076
6Test accuracy: 0.6053811659192825
7Test accuracy: 0.5695067264573991
8Test accuracy: 0.6502242152466368
9Test accuracy: 0.6995515695067265
10Test accuracy: 0.6905829596412556
11Test accuracy: 0.672645739910314
12Test accuracy: 0.6457399103139013
13Test accuracy: 0.6278026905829597
14Test accuracy: 0.695067264573991
15Test accuracy: 0.6636771300448431
16Test accuracy: 0.6502242152466368
17Test accuracy: 0.6188340807174888
18Test accuracy: 0.600896860986547
19Test accuracy: 0.5739910313901345
20Test accuracy: 0.6412556053811659
21Test accuracy: 0.6233183856502242
22Test accuracy: 0.6322869955156951
23Test accuracy: 0.5112107623318386
24Test accuracy: 0.695067264573991
25Test accuracy: 0.6412556053811659
26Test accuracy: 0.6591928251121076
27Test accuracy: 0.6502242152466368
28Test 