In [1]:
from pathlib import Path

In [2]:
import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root=Path().cwd().parent / "datasets", name="MUTAG")

print(dataset)
print(len(dataset), dataset.num_features, dataset.num_classes)

MUTAG(188)
188 7 2


In [3]:
from torch_geometric.utils import degree
graph = dataset[0]
print(graph)
print(graph.num_nodes, graph.num_edges)
print(graph.has_isolated_nodes(), graph.has_self_loops(), graph.is_undirected())

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
17 38
False False True


In [4]:
torch.manual_seed(0)
dataset = dataset.shuffle()
dataset_train = dataset[:150]
dataset_test = dataset[150:]

In [5]:
from torch_geometric.loader import DataLoader

loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False)

for step, data in enumerate(loader_train):
    print(f"Step: {step}, Number of graphs in the current batch: {data.num_graphs}")
    print(data, "\n")
    print(data.batch.shape)

Step: 0, Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 1238], x=[570, 7], edge_attr=[1238, 4], y=[32], batch=[570], ptr=[33]) 

torch.Size([570])
Step: 1, Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 1286], x=[580, 7], edge_attr=[1286, 4], y=[32], batch=[580], ptr=[33]) 

torch.Size([580])
Step: 2, Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 1274], x=[574, 7], edge_attr=[1274, 4], y=[32], batch=[574], ptr=[33]) 

torch.Size([574])
Step: 3, Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 1276], x=[575, 7], edge_attr=[1276, 4], y=[32], batch=[575], ptr=[33]) 

torch.Size([575])
Step: 4, Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 848], x=[383, 7], edge_attr=[848, 4], y=[22], batch=[383], ptr=[23]) 

torch.Size([383])


In [6]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv

class GIN(torch.nn.Module):
    def __init__(self, num_features, num_hiddens, num_classes):
        super(GIN, self).__init__()
        torch.manual_seed(0)
        self.conv1 = GINConv(Linear(num_features, num_hiddens))
        self.conv2 = GINConv(Linear(num_hiddens, num_hiddens))
        self.conv3 = GINConv(Linear(num_hiddens, num_hiddens))
        self.lin = Linear(num_hiddens, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        
        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x


In [7]:
graph_batch = iter(loader_train).__next__()
print(graph_batch)
print(graph_batch.x.shape, graph_batch.y)
emb = global_mean_pool(graph_batch.x, graph_batch.batch)
print(emb.shape)

DataBatch(edge_index=[2, 1182], x=[539, 7], edge_attr=[1182, 4], y=[32], batch=[539], ptr=[33])
torch.Size([539, 7]) tensor([0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 1, 1, 0, 1])
torch.Size([32, 7])


graph_batchには32個のグラフが纏まっており、各グラフに対して０か１が割り当てられている。
今回の場合32個のグラフは539個の頂点を持っている。
global_mean_poolは１つのグラフに平均化している。
![](../images/mini-batching-of-graph.png)

In [8]:
model = GIN(dataset.num_features, 64, dataset.num_classes)
print(model)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

def train(loader: DataLoader, model: torch.nn.Module):
    model.train()

    for graph in loader:
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index, graph.batch)
        loss = criterion(out, graph.y)
        loss.backward()
        optimizer.step()

def test(loader: DataLoader, model: torch.nn.Module):
    model.eval()
    correct = 0
    for graph in loader:
        out  = model(graph.x, graph.edge_index, graph.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == graph.y).sum())
    
    return correct / len(loader.dataset)

for epoch in range(1, 201):
    train(loader_train, model)
    if epoch % 10 == 0:
        acc_train = test(loader_train, model)
        acc_test = test(loader_test, model)
        print(f"Epoch: {epoch:03d}, Train Acc: {acc_train:.3f}, Test Acc: {acc_test:.3f}")

GIN(
  (conv1): GINConv(nn=Linear(in_features=7, out_features=64, bias=True))
  (conv2): GINConv(nn=Linear(in_features=64, out_features=64, bias=True))
  (conv3): GINConv(nn=Linear(in_features=64, out_features=64, bias=True))
  (lin): Linear(in_features=64, out_features=2, bias=True)
)
Epoch: 010, Train Acc: 0.667, Test Acc: 0.684
Epoch: 020, Train Acc: 0.740, Test Acc: 0.816
Epoch: 030, Train Acc: 0.733, Test Acc: 0.816
Epoch: 040, Train Acc: 0.767, Test Acc: 0.842
Epoch: 050, Train Acc: 0.760, Test Acc: 0.816
Epoch: 060, Train Acc: 0.767, Test Acc: 0.842
Epoch: 070, Train Acc: 0.787, Test Acc: 0.842
Epoch: 080, Train Acc: 0.780, Test Acc: 0.842
Epoch: 090, Train Acc: 0.787, Test Acc: 0.842
Epoch: 100, Train Acc: 0.807, Test Acc: 0.868
Epoch: 110, Train Acc: 0.807, Test Acc: 0.842
Epoch: 120, Train Acc: 0.787, Test Acc: 0.842
Epoch: 130, Train Acc: 0.860, Test Acc: 0.868
Epoch: 140, Train Acc: 0.820, Test Acc: 0.868
Epoch: 150, Train Acc: 0.867, Test Acc: 0.868
Epoch: 160, Train Acc: 

In [9]:
from torch_geometric.nn import GraphConv


class GNN(torch.nn.Module):
    def __init__(self, num_features, num_hiddens, num_classes):
        super(GNN, self).__init__()
        torch.manual_seed(0)
        self.conv1 = GraphConv(num_features, num_hiddens)
        self.conv2 = GraphConv(num_hiddens, num_hiddens)
        self.conv3 = GraphConv(num_hiddens, num_hiddens)
        self.lin = Linear(num_hiddens, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [10]:
model = GNN(dataset.num_features, 64, dataset.num_classes)
print(model)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 201):
    train(loader_train, model)
    if epoch % 10 == 0:
        acc_train = test(loader_train, model)
        acc_test = test(loader_test, model)
        print(f"Epoch: {epoch:03d}, Train Acc: {acc_train:.3f}, Test Acc: {acc_test:.3f}")

GNN(
  (conv1): GraphConv(7, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)
Epoch: 010, Train Acc: 0.700, Test Acc: 0.789
Epoch: 020, Train Acc: 0.760, Test Acc: 0.842
Epoch: 030, Train Acc: 0.787, Test Acc: 0.868
Epoch: 040, Train Acc: 0.800, Test Acc: 0.842
Epoch: 050, Train Acc: 0.807, Test Acc: 0.921
Epoch: 060, Train Acc: 0.840, Test Acc: 0.921
Epoch: 070, Train Acc: 0.873, Test Acc: 0.921
Epoch: 080, Train Acc: 0.907, Test Acc: 0.921
Epoch: 090, Train Acc: 0.893, Test Acc: 0.895
Epoch: 100, Train Acc: 0.900, Test Acc: 0.895
Epoch: 110, Train Acc: 0.893, Test Acc: 0.947
Epoch: 120, Train Acc: 0.860, Test Acc: 0.947
Epoch: 130, Train Acc: 0.880, Test Acc: 0.895
Epoch: 140, Train Acc: 0.913, Test Acc: 0.921
Epoch: 150, Train Acc: 0.913, Test Acc: 0.921
Epoch: 160, Train Acc: 0.907, Test Acc: 0.921
Epoch: 170, Train Acc: 0.933, Test Acc: 0.921
Epoch: 180, Train Acc: 0.920, Test Acc: 0.921
Epoch: 190, Train Ac