In [1]:
# This code works in Python 3.10.6
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.datasets import IMDB
from torch_geometric.nn import GCNConv
import time
from torch_geometric.logging import log
import os
from collections import Counter

In [2]:
dataset = IMDB(root='./imdb_data')
hetero_data = dataset[0]

In [3]:
hetero_data

HeteroData(
  movie={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278],
  },
  director={ x=[2081, 3066] },
  actor={ x=[5257, 3066] },
  (movie, to, director)={ edge_index=[2, 4278] },
  (movie, to, actor)={ edge_index=[2, 12828] },
  (director, to, movie)={ edge_index=[2, 4278] },
  (actor, to, movie)={ edge_index=[2, 12828] }
)

In [4]:
# This code works in torch-geometric==2.6.0
data = hetero_data.to_homogeneous(add_edge_type=False)

In [5]:
data

Data(edge_index=[2, 34212], x=[11616, 3066], y=[11616], train_mask=[11616], val_mask=[11616], test_mask=[11616], node_type=[11616])

In [6]:
data.node_type

tensor([0, 0, 0,  ..., 2, 2, 2])

In [7]:
data.x = F.one_hot(data.node_type, num_classes=len(torch.unique(data.node_type))).float()

In [8]:
data.x

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        ...,
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])

In [9]:
torch.unique(data.y)

tensor([-1,  0,  1,  2])

In [10]:
Counter(data.y.tolist())

Counter({-1: 7338, 1: 1584, 2: 1559, 0: 1135})

In [11]:
data.has_isolated_nodes()

False

In [12]:
data.has_self_loops()

False

In [13]:
data.is_directed()

False

In [14]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels,
                             )
        self.conv2 = GCNConv(hidden_channels, hidden_channels,
                             )
        self.conv3 = GCNConv(hidden_channels, out_channels,
                             )

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        return x

device = 'cpu'
model = GCN(
    in_channels=data.x.shape[1],
    hidden_channels=32,
    out_channels=3,
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0),
    dict(params=model.conv3.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    train_idx = data.y != -1 
    loss = F.cross_entropy(out[train_idx], data.y[train_idx])
   
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = test_acc = 0
start_patience = patience = 100
times = []
for epoch in range(1, 2000 + 1):
    start = time.time()
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        test_acc = tmp_test_acc
    if epoch%10==0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
    times.append(time.time() - start)

    if (val_acc>best_val_acc):
        print('saving....')
        patience = start_patience
        best_val_acc = val_acc
        print('best acc is', best_val_acc)

        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), '../checkpoint/imdb_gcn.pth')
    else:
        patience -= 1
        
    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {start_patience} epochs')
        break   
       
print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')

saving....
best acc is 0.395
Epoch: 010, Loss: 1.0896, Train: 0.3975, Val: 0.3950, Test: 0.3643
saving....
best acc is 0.41
saving....
best acc is 0.425
saving....
best acc is 0.4275
Epoch: 020, Loss: 1.0869, Train: 0.3925, Val: 0.3800, Test: 0.3847
Epoch: 030, Loss: 1.0854, Train: 0.3650, Val: 0.3975, Test: 0.3847
Epoch: 040, Loss: 1.0855, Train: 0.3650, Val: 0.3925, Test: 0.3847
Epoch: 050, Loss: 1.0843, Train: 0.4075, Val: 0.4150, Test: 0.3847
Epoch: 060, Loss: 1.0827, Train: 0.3975, Val: 0.4125, Test: 0.3847
Epoch: 070, Loss: 1.0822, Train: 0.3800, Val: 0.4150, Test: 0.3847
Epoch: 080, Loss: 1.0825, Train: 0.4125, Val: 0.4175, Test: 0.3847
Epoch: 090, Loss: 1.0832, Train: 0.4100, Val: 0.4125, Test: 0.3847
Epoch: 100, Loss: 1.0823, Train: 0.4050, Val: 0.4075, Test: 0.3847
Epoch: 110, Loss: 1.0839, Train: 0.3775, Val: 0.4150, Test: 0.3847
Stopping training as validation accuracy did not improve for 100 epochs
Median time per epoch: 0.0254s
