In [1]:
import os

from pathlib import Path
from typing import Callable, Optional

import torch
import torch.nn.functional as F

from torch_geometric.data.in_memory_dataset import InMemoryDataset
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HeteroConv, Linear, SAGEConv, BatchNorm

# TODO: regularisation like https://stackoverflow.com/questions/42704283/l1-l2-regularization-in-pytorch
# TODO: follow this example https://github.com/pyg-team/pytorch_geometric/issues/3958

while not Path("data") in Path(".").iterdir():
    os.chdir("..")


class MyDataset(InMemoryDataset):
    """
    https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html?highlight=inmemorydataset#creating-in-memory-datasets
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
    ):
        self.name = type(self).__name__
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = self.collate([torch.load(self.processed_paths[0])])

    @property
    def processed_dir(self):
        assert self.root, "Please specify a root directory"
        return str(Path(self.root) / "processed")

    @property
    def processed_file_names(self):
        assert self.processed_dir is not None, "Please specify `processed_dir`"
        return "data.pt"


# path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')
path = Path("data/pyg/MyDataset")
path.mkdir(parents=True, exist_ok=True)
dataset = MyDataset(str(path))
data = dataset[0]
print(data)

HeteroData(
  add_self_loops=False,
  [1mcompany[0m={
    x=[96530, 38],
    y=[96530],
    train_mask=[96530],
    val_mask=[96530],
    test_mask=[96530]
  },
  [1mperson[0m={
    x=[32609, 24],
    y=[32609],
    train_mask=[32609],
    val_mask=[32609],
    test_mask=[32609]
  },
  [1m(company, owns, company)[0m={
    edge_index=[2, 54607],
    edge_attr=[54607, 1],
    train_mask=[54607],
    val_mask=[54607],
    test_mask=[54607]
  },
  [1m(person, owns, company)[0m={
    edge_index=[2, 80219],
    edge_attr=[80219, 1],
    train_mask=[80219],
    val_mask=[80219],
    test_mask=[80219]
  }
)


In [2]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            if _ == num_layers - 1:
                # Overwrite number of channels in last layer
                hidden_channels = out_channels
            conv = HeteroConv(
                {
                    edge_type: SAGEConv((-1, -1), out_channels)
                    for edge_type in metadata[1]
                }
            )
            self.convs.append(conv)

        self.batchnorm_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.batchnorm_dict[node_type] = BatchNorm(hidden_channels)

    def forward(self, x_dict, edge_index_dict, p_dropout=0.0):

        # Dropout
        x_dict = {
            key: F.dropout(x, p=p_dropout, training=self.training)
            for key, x in x_dict.items()
        }
        for i in range(len(self.convs) - 1):
            x_dict = self.convs[i](x_dict, edge_index_dict)
            # Batch normalisation
            x_dict = {key: self.batchnorm_dict[key](x) for key, x in x_dict.items()}
            # Activation function
            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
            # Dropout
            x_dict = {
                key: F.dropout(x, p=p_dropout, training=self.training)
                for key, x in x_dict.items()
            }
        return self.convs[-1](x_dict, edge_index_dict)

In [3]:
model = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4, num_layers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data["author"].train_mask
    loss = F.cross_entropy(out[mask], data["author"].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ["train_mask", "val_mask", "test_mask"]:
        mask = data["author"][split]
        acc = (pred[mask] == data["author"].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs


for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(
        f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, "
        f"Val: {val_acc:.4f}, Test: {test_acc:.4f}"
    )



IndexError: index out of range in self