In [2]:
import torch
import torch.nn as nn
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, TopKPooling, global_mean_pool

dataset = TUDataset(root='.', name='PROTEINS')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.manual_seed(42)
dataset = dataset.shuffle()
train_size = int(len(dataset) * 0.7)
val_size = int(len(dataset) * 0.15)
test_size = len(dataset) - train_size - val_size

train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size+val_size]
test_dataset = dataset[train_size+val_size:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
class DynamicGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=3, layer_type='GCN', pool_ratio=0.8):
        super().__init__()
        conv_dict = {'GCN': GCNConv, 'GAT': GATConv, 'SAGE': SAGEConv}
        conv_layer = conv_dict[layer_type]
        
        self.convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        
        self.convs.append(conv_layer(in_channels, hidden_channels))
        self.pools.append(TopKPooling(hidden_channels, ratio=pool_ratio))
        
        for _ in range(num_layers - 2):
            self.convs.append(conv_layer(hidden_channels, hidden_channels))
            self.pools.append(TopKPooling(hidden_channels, ratio=pool_ratio))
            
        if num_layers > 1:
            self.convs.append(conv_layer(hidden_channels, hidden_channels))
        else:
            self.convs[0] = conv_layer(in_channels, hidden_channels)
        self.classifier = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):      
        for i, (conv, pool) in enumerate(zip(self.convs[:-1], self.pools)):
            x = conv(x, edge_index)
            x = x.relu()
            x, edge_index, _, batch, _, _ = pool(x, edge_index, None, batch)

        x = self.convs[-1](x, edge_index)
        x = x.relu()        
        x = global_mean_pool(x, batch)
        x = self.classifier(x)
        
        return x


In [5]:
num_features = dataset.num_features
num_classes = dataset.num_classes
model = DynamicGNN(
    in_channels=num_features, 
    hidden_channels=16, 
    out_channels=num_classes,
    num_layers=3, 
    layer_type='SAGE'
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

In [6]:
for epoch in range(1, 201):
    model.train()
    total_loss = 0
    
    for data in train_loader:
        optimizer.zero_grad()
        pred = model(data.x, data.edge_index, data.batch)
        loss = criterion(pred, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    train_loss = total_loss / len(train_dataset)
    
    if epoch % 10 == 0:
        model.eval()
        correct = 0
        val_loss = 0
        
        with torch.no_grad():
            for data in val_loader:
                pred = model(data.x, data.edge_index, data.batch)
                val_loss += criterion(pred, data.y).item() * data.num_graphs
                correct += pred.argmax(dim=1).eq(data.y).sum().item()
        
        val_loss = val_loss / len(val_dataset)
        val_acc = correct / len(val_dataset)
        
        print(f"Epoch: {epoch}; Train loss: {train_loss:.4f}; Val loss: {val_loss:.4f}; Val acc: {val_acc:.4f}")

Epoch: 10; Train loss: 0.6337; Val loss: 0.6701; Val acc: 0.6566
Epoch: 20; Train loss: 0.6303; Val loss: 0.6110; Val acc: 0.6566
Epoch: 30; Train loss: 0.6053; Val loss: 0.6280; Val acc: 0.6747
Epoch: 40; Train loss: 0.5891; Val loss: 0.6085; Val acc: 0.7108
Epoch: 50; Train loss: 0.5970; Val loss: 0.6133; Val acc: 0.6867
Epoch: 60; Train loss: 0.5767; Val loss: 0.6160; Val acc: 0.7048
Epoch: 70; Train loss: 0.5783; Val loss: 0.6137; Val acc: 0.6988
Epoch: 80; Train loss: 0.5929; Val loss: 0.6479; Val acc: 0.6145
Epoch: 90; Train loss: 0.5769; Val loss: 0.6085; Val acc: 0.6867
Epoch: 100; Train loss: 0.5712; Val loss: 0.6271; Val acc: 0.6386
Epoch: 110; Train loss: 0.5709; Val loss: 0.6397; Val acc: 0.6687
Epoch: 120; Train loss: 0.5750; Val loss: 0.6218; Val acc: 0.7229
Epoch: 130; Train loss: 0.5699; Val loss: 0.6281; Val acc: 0.7048
Epoch: 140; Train loss: 0.5722; Val loss: 0.6233; Val acc: 0.6988
Epoch: 150; Train loss: 0.5682; Val loss: 0.6306; Val acc: 0.6627
Epoch: 160; Train l

In [7]:
model.eval()
test_correct = 0
test_loss = 0

with torch.no_grad():
    for data in test_loader:
        pred = model(data.x, data.edge_index, data.batch)
        test_loss += criterion(pred, data.y).item() * data.num_graphs
        test_correct += pred.argmax(dim=1).eq(data.y).sum().item()

test_loss = test_loss / len(test_dataset)
test_acc = test_correct / len(test_dataset)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

Test Loss: 0.5821
Test Accuracy: 0.7321
