In [27]:
import torch
import torch.nn as nn
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import TUDataset
import torch.nn.functional as F

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
print(len(dataset), dataset.num_classes, dataset.num_node_features)
print(dataset[0])
# dataset = PygGraphPropPredDataset(name = "ogbg-molhiv", root = 'dataset/')

188 2 7
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])


In [28]:


train_dataset = dataset[0:150]
test_dataset = dataset[150:]

train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

class GCN(nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, dataset.num_classes)
    
    def forward(self, x, edge_index, batch):
        print("x_size: {0}, edge_index_size: {1}".format(x.size(), edge_index.size()))
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x
class GraphTransformerEncoder(nn.TransformerEncoder):
    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

# class MultiHeadAttention(nn.module):
#     def __init__(self, in_dim, out_dim, num_heads, use_bias):
#         super().__init__()

#         assert out_dim % num_heads == 0
#         self.out_dim = out_dim
#         self.num_heads = numheads
#         if use_bias:
#             self.qkv_proj = nn.Linear(in_dim, 3*out_dim, bias=True)
#         else:
#             self.qkv_proj = nn.Linear(in_dim, 3*out_dim, bias = False)
        
            

        

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)




In [29]:
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

def train():
    model.train()
    for data in train_dataloader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

def test(loader):
    model.eval()
    corr = 0
    for data in test_dataloader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        corr += int((pred == data.y).sum())
    return corr / len(test_dataloader.dataset)

for epoch in range(1, 121):
    train()
    train_acc = test(train_dataloader)
    test_acc = test(test_dataloader)
    print(f'epoch: {epoch:03d}, Train acc: {train_acc:.4f}, Test acc : {test_acc:.4f}')

x_size: torch.Size([1144, 7]), edge_index_size: torch.Size([2, 2518])
x_size: torch.Size([1158, 7]), edge_index_size: torch.Size([2, 2572])
x_size: torch.Size([382, 7]), edge_index_size: torch.Size([2, 840])
x_size: torch.Size([687, 7]), edge_index_size: torch.Size([2, 1512])
x_size: torch.Size([687, 7]), edge_index_size: torch.Size([2, 1512])
epoch: 001, Train acc: 0.6842, Test acc : 0.6842
x_size: torch.Size([1195, 7]), edge_index_size: torch.Size([2, 2636])
x_size: torch.Size([1111, 7]), edge_index_size: torch.Size([2, 2460])
x_size: torch.Size([378, 7]), edge_index_size: torch.Size([2, 834])
x_size: torch.Size([687, 7]), edge_index_size: torch.Size([2, 1512])
x_size: torch.Size([687, 7]), edge_index_size: torch.Size([2, 1512])
epoch: 002, Train acc: 0.6842, Test acc : 0.6842
x_size: torch.Size([1154, 7]), edge_index_size: torch.Size([2, 2554])
x_size: torch.Size([1158, 7]), edge_index_size: torch.Size([2, 2564])
x_size: torch.Size([372, 7]), edge_index_size: torch.Size([2, 812])
x_