In [1]:
# This code works in Python 3.10.6
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.datasets.dblp import DBLP
from torch_geometric.nn import GCNConv
import time
from torch_geometric.logging import log
import os
from collections import Counter

In [2]:
dataset = DBLP(root='./dblp_data', transform=T.Constant(node_types='conference'))
hetero_data = dataset[0]

In [3]:
hetero_data

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1],
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

In [4]:
# This code works in torch-geometric==2.6.0
data = hetero_data.to_homogeneous(add_edge_type=False)

In [5]:
data

Data(edge_index=[2, 239566], x=[26128, 4231], y=[26128], train_mask=[26128], val_mask=[26128], test_mask=[26128], node_type=[26128])

In [6]:
data.node_type

tensor([0, 0, 0,  ..., 3, 3, 3])

In [7]:
data.x = F.one_hot(data.node_type, num_classes=len(torch.unique(data.node_type))).float()

In [8]:
data.x

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        ...,
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.]])

In [9]:
torch.unique(data.y)

tensor([-1,  0,  1,  2,  3])

In [10]:
Counter(data.y.tolist())

Counter({-1: 22071, 0: 1197, 2: 1109, 3: 1006, 1: 745})

In [11]:
data.has_isolated_nodes()

False

In [12]:
data.has_self_loops()

False

In [13]:
data.is_directed()

False

In [14]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels,
                             )
        self.conv2 = GCNConv(hidden_channels, hidden_channels,
                             )
        self.conv3 = GCNConv(hidden_channels, out_channels,
                             )

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        return x


device = 'cpu'
model = GCN(
    in_channels=data.x.shape[1],
    hidden_channels=32,
    out_channels=4,
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0),
    dict(params=model.conv3.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    train_idx = data.y != -1 
    loss = F.cross_entropy(out[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return float(loss)

    
@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = test_acc = 0
start_patience = patience = 100
times = []
for epoch in range(1, 2000 + 1):
    start = time.time()
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        #best_val_acc = val_acc
        test_acc = tmp_test_acc
    if epoch%10==0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    times.append(time.time() - start)

    if (val_acc>best_val_acc):
        print('saving....')
        patience = start_patience
        best_val_acc = val_acc
        print('best acc is', best_val_acc)

        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), '../checkpoint/dblp_gcn.pth')
    else:
        patience -= 1
        
    if patience <= 0:
        print('Stopping training as validation accuracy did not improve'
              f'for {start_patience} epochs')
        break   
        
print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')

saving....
best acc is 0.2625
saving....
best acc is 0.2725
saving....
best acc is 0.2775
Epoch: 010, Loss: 1.3616, Train: 0.3175, Val: 0.2825, Test: 0.3224
saving....
best acc is 0.2825
saving....
best acc is 0.3375
saving....
best acc is 0.345
saving....
best acc is 0.3525
saving....
best acc is 0.355
saving....
best acc is 0.36
Epoch: 020, Loss: 1.3461, Train: 0.3700, Val: 0.3400, Test: 0.3773
saving....
best acc is 0.365
Epoch: 030, Loss: 1.3391, Train: 0.3650, Val: 0.3625, Test: 0.3755
Epoch: 040, Loss: 1.3354, Train: 0.3750, Val: 0.3475, Test: 0.3755
Epoch: 050, Loss: 1.3358, Train: 0.3825, Val: 0.3575, Test: 0.3755
Epoch: 060, Loss: 1.3266, Train: 0.3800, Val: 0.3600, Test: 0.3755
saving....
best acc is 0.37
Epoch: 070, Loss: 1.3282, Train: 0.3800, Val: 0.3350, Test: 0.3792
Epoch: 080, Loss: 1.3264, Train: 0.3775, Val: 0.3500, Test: 0.3792
Epoch: 090, Loss: 1.3142, Train: 0.3825, Val: 0.3675, Test: 0.3792
Epoch: 100, Loss: 1.3188, Train: 0.3925, Val: 0.3675, Test: 0.3792
Epoch: 