## Training a GNN for Graph Classification 

> `Graph Classification`은 Node Classification과 Link Prediction과는 달리 Graph 를 분류하는 태스크입니다. 이때는 다양한 종류의 그래프가 입력으로 사용되며 그래프가 어떤 클래스에 속하는지 예측하는 형태로 진행됩니다. 주로 단백질 구조 데이터를 사용합니다.

In [1]:
import dgl 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import dgl.data

dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)

print('Node feature dimensionality:', dataset.dim_nfeats)
print('Number of graph categories:', dataset.gclasses)

Downloading C:\Users\EonKim\.dgl\GINDataset.zip from https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip...
Extracting file to C:\Users\EonKim\.dgl\GINDataset
Node feature dimensionality: 3
Number of graph categories: 2


## Defining Data Loader 

In [4]:
from dgl.dataloading import GraphDataLoader 
from torch.utils.data.sampler import SubsetRandomSampler

In [5]:
num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(dataset, sampler = train_sampler, batch_size = 16, drop_last = False)
test_dataloader = GraphDataLoader(dataset, sampler = test_sampler, batch_size=16, drop_last = False) # PyTorch와 동일하게 DataLoader를 만들 수 있습니다.

In [6]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=671, num_edges=3003,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0])]


In [7]:
batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)

Number of nodes for each graph element in the batch: tensor([ 21,  56,  21, 152,  19,  10,  16,  88,  59,   8,  60,  16,  29,   7,
         11,  98])
Number of edges for each graph element in the batch: tensor([107, 262,  93, 616,  83,  46,  66, 372, 259,  40, 316,  72, 135,  33,
         49, 454])
The original graphs in the minibatch:
[Graph(num_nodes=21, num_edges=107,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=56, num_edges=262,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=21, num_edges=93,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=152, num_edges=616,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtyp

In [8]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

## Training Loop

In [9]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)



Test accuracy: 0.273542600896861
