## Heterogeneous Graph Learning
In this notebook, we will explore how to train a GNN model with heterogeneous graph data. In real world, most of the graph data is available in the form of different node and edge type.

In [1]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

[K     |████████████████████████████████| 7.9 MB 11.1 MB/s 
[K     |████████████████████████████████| 3.5 MB 10.0 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [2]:
import torch_geometric
from torch_geometric.data import HeteroData
import os
import os.path as osp
from typing import Dict, List, Union


import torch
import torch.nn.functional as F
from torch import nn

import torch_geometric.transforms as T
from torch_geometric.datasets import IMDB
from torch_geometric.nn import HANConv


In [3]:
path = osp.join(osp.dirname(osp.realpath('__file__')), '../../data/IMDB')

metapaths = [[('movie', 'actor'), ('actor', 'movie')],
             [('movie', 'director'), ('director', 'movie')]]
transform = T.AddMetaPaths(metapaths= metapaths, drop_orig_edges= True, drop_unconnected_nodes= True)
dataset = IMDB(path, transform = transform)
data = dataset[0]
print(data)


Downloading https://www.dropbox.com/s/g0btk9ctr1es39x/IMDB_processed.zip?dl=1
Extracting /data/IMDB/raw/IMDB_processed.zip
Processing...
Done!


HeteroData(
  metapath_dict={
    (movie, metapath_0, movie)=[2],
    (movie, metapath_1, movie)=[2]
  },
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1m(movie, metapath_0, movie)[0m={ edge_index=[2, 85358] },
  [1m(movie, metapath_1, movie)[0m={ edge_index=[2, 17446] }
)


In [4]:
# define HAN conv
class HAN(nn.Module):
  def __init__(self, in_channels: Union[int, Dict[str, int]],
               out_channels: int, hidden_channels = 128, heads=8):
    super().__init__()

    self.han_conv = HANConv(in_channels, hidden_channels, heads = heads, 
                            dropout = 0.6, metadata = data.metadata())
    self.lin = nn.Linear(hidden_channels, out_channels)

  def forward(self, x_dict, edge_index_dict):
    x = self.han_conv(x_dict, edge_index_dict)
    x = self.lin(x['movie'])
    return x
model = HAN(in_channels= -1, out_channels= 3)
print(model)

HAN(
  (han_conv): HANConv(128, heads=8)
  (lin): Linear(in_features=128, out_features=3, bias=True)
)


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HAN(in_channels= -1, out_channels =3)
data, model = data.to(device), model.to(device)
with torch.no_grad():
  out = model(data.x_dict, data.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.005, weight_decay= 0.001)
criterion = nn.CrossEntropyLoss()


In [6]:
def train() -> float:
  model.train()
  optimizer.zero_grad()
  out = model(data.x_dict, data.edge_index_dict)
  mask = data['movie'].train_mask
  loss = criterion(out[mask], data['movie'].y[mask])
  loss.backward()
  optimizer.step()
  return float(loss)

In [7]:
def test() -> List[float]:
  model.eval()
  pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)


  accs = []
  for split in ['train_mask', 'val_mask', 'test_mask']:
    mask = data['movie'][split]
    acc = ((pred[mask] == data['movie'].y[mask]).sum())/mask.sum()
    accs.append(float(acc))
  return accs

In [9]:
best_val_acc = 0
start_patience = patience = 100
for epoch in range(1, 200):

    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 10 == 0:
      print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

    if best_val_acc <= val_acc:
        patience = start_patience
        best_val_acc = val_acc
    else:
        patience -= 1

    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {start_patience} epochs')
        break

Epoch: 010, Loss: 0.0402, Train: 0.9900, Val: 0.5175, Test: 0.4879
Epoch: 020, Loss: 0.0387, Train: 0.9900, Val: 0.5100, Test: 0.4911
Epoch: 030, Loss: 0.0413, Train: 0.9900, Val: 0.5100, Test: 0.4908
Epoch: 040, Loss: 0.0383, Train: 0.9900, Val: 0.5150, Test: 0.4917
Epoch: 050, Loss: 0.0357, Train: 0.9900, Val: 0.5125, Test: 0.4905
Epoch: 060, Loss: 0.0376, Train: 0.9900, Val: 0.5150, Test: 0.4862
Epoch: 070, Loss: 0.0339, Train: 0.9900, Val: 0.5125, Test: 0.4899
Epoch: 080, Loss: 0.0331, Train: 0.9900, Val: 0.5200, Test: 0.4945
Epoch: 090, Loss: 0.0342, Train: 0.9900, Val: 0.5100, Test: 0.4948
Epoch: 100, Loss: 0.0343, Train: 0.9900, Val: 0.5200, Test: 0.4991
Epoch: 110, Loss: 0.0345, Train: 0.9900, Val: 0.5050, Test: 0.4899
Epoch: 120, Loss: 0.0337, Train: 0.9900, Val: 0.5075, Test: 0.4885
Epoch: 130, Loss: 0.0339, Train: 0.9900, Val: 0.5200, Test: 0.4928
Epoch: 140, Loss: 0.0340, Train: 0.9900, Val: 0.5150, Test: 0.4905
Epoch: 150, Loss: 0.0320, Train: 0.9900, Val: 0.5225, Test: 0.