<a href="https://colab.research.google.com/github/batu-el/l65_be301_dc755/blob/main/Notebook4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installs

In [1]:
!pip install dgl torch_geometric torch

# Install required python libraries
import os

# Install PyTorch Geometric and other libraries
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print("Installing PyTorch Geometric")
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-geometric
    print("Installing other libraries")
    !pip install networkx
    !pip install lovely-tensors

Installing PyTorch Geometric
Installing other libraries


In [157]:
import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
import networkx as nx
import numpy as np
import scipy as sp

from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, LayerNorm, Module, ModuleList, Sequential
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
from torch.optim import Adam

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid

import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, dense_to_sparse, to_dense_batch, to_dense_adj

from torch_geometric.nn import GCNConv, GATConv, GATv2Conv

from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import seaborn as sns

# import warnings
# warnings.filterwarnings("ignore", category=RuntimeWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

All imports succeeded.
Python version 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch version 2.1.0+cu121
PyG version 2.5.0


# Overview

In [3]:
# ## Outline ###

# STEP 1. - Datasets

# 1.1 Synthetic Datasets
# 1.1.1 Homophilic Node Classification
# 1.1.2 Heterophilic Node Classification
# 1.1.3 Homophilic Graph Classification
# 1.1.4 Heterophilic Graph Classification

# 1.2 Real Datasets
# 1.2.1 Homophilic Node Classification - Cora
# 1.2.2 Heterophilic Node Classification - Texas
# 1.2.3 Homophilic Graph Classification - QM9
# 1.2.4 Heterophilic Graph Classification - (?)

# STEP 2. Models

# 2.1 Baselines to Compare Model Accuracies
# 2.1.1 GCN
# 2.1.2 Sparse Transformer
# 2.1.3 MPNN
# 2.1.4 Dense Transformer with Attention Mask
# 2.1.5 Dense Transformer with Positional Encodings

# 2.2 Comparison of 2 Models: Dense (w/ PosEnc) & Sparse Transformer
# 2.2.1 1 Head 1 Layer
# 2.2.1 4 Head 1 Layer
# 2.2.1 1 Head 3 Layer
# 2.2.1 4 Head 3 Layer

# STEP 3. Evaluation

# Comparisons:
# A: Adjacency vs Sparse Attention
# B: Adjacency vs Dense Attention
# C: Sparse Attention vs Dense Attention

# 3.1 Combining Multiple Attention Matrices from 2.2
# 3.1.1 If Edge Exists
# 3.1.2 PCA

# 3.2 1D (Vector) Similarity Comparison
# 3.2.1 Node Degree (histogram)
# 3.2.2 Substructures (histogram)

# 3.3 2D (Matrix) Similarity Comparison
# 3.3.1 Adjacency Matrix (Graph Edit Dist & Kernel 1 WL)
# 3.3.2 Shortest Path (Graph Edit Dist & Kernel 1 WL)

# STEP 4. Discussion
# Note: Future research can look at how attention evolves over the course of training


# Synthetic Dataset Generation

In [4]:
import torch
from torch_geometric.data import Data
import numpy as np

def preprocess(data, train_ratio = 0.7, val_ratio = 0.15, test_ratio = 0.15):
    g = dataset[0]
    y = g.ndata['label']
    feat = g.ndata['feat']

    num_nodes = len(y)
    indices = torch.randperm(num_nodes)

    num_train, num_val = int(num_nodes * train_ratio), int(num_nodes * val_ratio)
    num_test = num_nodes - num_train - num_val

    train_mask, val_mask, test_mask = torch.zeros(num_nodes, dtype=torch.bool), torch.zeros(num_nodes, dtype=torch.bool), torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[indices[:num_train]] = True
    val_mask[indices[num_train:num_train+num_val]] = True
    test_mask[indices[num_train+num_val:]] = True

    # Convert NetworkX graph to edge list
    src, dst = g.edges()
    edge_list = list(zip(src.tolist(), dst.tolist()))
    # Create a set for symmetric edges to avoid duplicates
    symmetric_edges = set()

    # Add each edge and its reverse to the set
    for u, v in edge_list:
        symmetric_edges.add((u, v))
        symmetric_edges.add((v, u))
    edge_list = list(symmetric_edges)

    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    node_features = feat #[g.nodes[node]['feat'] for node in G.nodes()]
    # Create a Data object
    if len(np.array(node_features).shape) == 1:
      data = Data(x=torch.tensor(np.array(node_features)).unsqueeze(1), edge_index=torch.tensor(np.array(edge_index)), y=torch.tensor(np.array(y)), train_mask=torch.tensor(np.array(train_mask)), val_mask=torch.tensor(np.array(val_mask)), test_mask=torch.tensor(np.array(test_mask)))
    else:
      data = Data(x=torch.tensor(np.array(node_features, dtype=float)).float(), edge_index=torch.tensor(np.array(edge_index)), y=torch.tensor(np.array(y)), train_mask=torch.tensor(np.array(train_mask)), val_mask=torch.tensor(np.array(val_mask)), test_mask=torch.tensor(np.array(test_mask)))
    return data

In [197]:
"""Synthetic graph datasets."""
import math
import networkx as nx
import numpy as np
import os
import pickle
import random

from dgl.data.dgl_dataset import DGLBuiltinDataset
from dgl.data.utils import save_graphs, load_graphs, _get_dgl_url, download
from dgl import backend as F
from dgl.batch import batch
from dgl.convert import graph
from dgl.transforms import reorder_graph

class BAShapeDataset(DGLBuiltinDataset):
    r"""BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a base Barabási–Albert (BA) graph.
    - Construct a set of five-node house-structured network motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of
      label 1, 2, 3 are separately at the middle, bottom, or top of houses.
    - Generate constant feature for all nodes, which is 1.

    Parameters
    ----------
    num_base_nodes : int, optional
        Number of nodes in the base BA graph. Default: 300
    num_base_edges_per_node : int, optional
        Number of edges to attach from a new node to existing nodes in constructing the base BA
        graph. Default: 5
    num_motifs : int, optional
        Number of house-structured network motifs to use. Default: 80
    perturb_ratio : float, optional
        Number of random edges to add in perturbation divided by the number of edges in the
        original graph. Default: 0.01
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import BAShapeDataset
    >>> dataset = BAShapeDataset()
    >>> dataset.num_classes
    4
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """
    def __init__(self,
                 num_base_nodes=300,
                 num_base_edges_per_node=5,
                 num_motifs=80,
                 perturb_ratio=0.01,
                 seed=None,
                 raw_dir=None,
                 force_reload=False,
                 verbose=True,
                 transform=None):
        self.num_base_nodes = num_base_nodes
        self.num_base_edges_per_node = num_base_edges_per_node
        self.num_motifs = num_motifs
        self.perturb_ratio = perturb_ratio
        self.seed = seed
        super(BAShapeDataset, self).__init__(name='BA-SHAPES',
                                             url=None,
                                             raw_dir=raw_dir,
                                             force_reload=force_reload,
                                             verbose=verbose,
                                             transform=transform)

    def process(self):
        g = nx.barabasi_albert_graph(self.num_base_nodes, self.num_base_edges_per_node, self.seed)
        edges = list(g.edges())
        src, dst = map(list, zip(*edges))
        n = self.num_base_nodes


        # Nodes in the base BA graph belong to class 0
        node_labels = [0] * n
        # The motifs will be evenly attached to the nodes in the base graph.
        spacing = math.floor(n / self.num_motifs)

        ##########################
        # for motif_id in range(self.num_motifs):
        #     # Construct a five-node house-structured network motif
        #     motif_edges = [
        #         (n, n + 1),
        #         (n + 1, n + 2),
        #         (n + 2, n + 3),
        #         (n + 3, n),
        #         (n + 4, n),
        #         (n + 4, n + 1)
        #     ]
        #     motif_src, motif_dst = map(list, zip(*motif_edges))
        #     src.extend(motif_src)
        #     dst.extend(motif_dst)

        #     # Nodes at the middle of a house belong to class 1
        #     # Nodes at the bottom of a house belong to class 2
        #     # Nodes at the top of a house belong to class 3
        #     node_labels.extend([1, 1, 2, 2, 3])
        #     # node_labels.extend([1, 1, 1, 1, 1])

        #     # Attach the motif to the base BA graph
        #     src.append(n)
        #     dst.append(int(motif_id * spacing))
        #     n += 5
        # g = graph((src, dst), num_nodes=n)
        ##########################


        ##########################
        # Construct an n-by-n grid
        self.grid_size = 3
        motif_g = nx.grid_graph([self.grid_size, self.grid_size])
        grid_size = nx.number_of_nodes(motif_g)
        motif_g = nx.convert_node_labels_to_integers(motif_g, first_label=0)
        motif_edges = list(motif_g.edges())
        motif_src, motif_dst = map(list, zip(*motif_edges))
        motif_src, motif_dst = np.array(motif_src), np.array(motif_dst)

        print(motif_edges)
        for motif_id in range(self.num_motifs):
            src.extend((motif_src + n).tolist())
            dst.extend((motif_dst + n).tolist())
            # Nodes in grids belong to class 1
            node_labels.extend([1,1,1,1,2,1,1,1,1])
            # Attach the motif to the base tree graph
            src.append(n)
            dst.append(int(motif_id * spacing))

            n += grid_size
        g = graph((src, dst), num_nodes=n)
        ############################

        # Perturb the graph by adding non-self-loop random edges
        num_real_edges = g.num_edges()
        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
        assert self.perturb_ratio <= max_ratio, \
            'perturb_ratio cannot exceed {:.4f}'.format(max_ratio)
        num_random_edges = int(num_real_edges * self.perturb_ratio)

        if self.seed is not None:
            np.random.seed(self.seed)
        for _ in range(num_random_edges):
            while True:
                u = np.random.randint(0, n)
                v = np.random.randint(0, n)
                if (not g.has_edges_between(u, v)) and (u != v):
                    break
            g.add_edges(u, v)

        g.ndata['label'] = F.tensor(node_labels, F.int64)
        g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu())
        self._graph = reorder_graph(
            g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)

    @property
    def graph_path(self):
        return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)


    def __len__(self):
        return 1


    @property
    def num_classes(self):
        return 4


class BACommunityDataset(DGLBuiltinDataset):
    r"""BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a base Barabási–Albert (BA) graph.
    - Construct a set of five-node house-structured network motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of
      label 1, 2, 3 are separately at the middle, bottom, or top of houses.
    - Generate normally distributed features of length 10
    - Repeat the above steps to generate another graph. Its nodes are assigned to class
      4, 5, 6, 7. Its node features are generated with a distinct normal distribution.
    - Join the two graphs by randomly adding edges between them.

    Parameters
    ----------
    num_base_nodes : int, optional
        Number of nodes in each base BA graph. Default: 300
    num_base_edges_per_node : int, optional
        Number of edges to attach from a new node to existing nodes in constructing a base BA
        graph. Default: 4
    num_motifs : int, optional
        Number of house-structured network motifs to use in constructing each graph. Default: 80
    perturb_ratio : float, optional
        Number of random edges to add to a graph in perturbation divided by the number of original
        edges in it. Default: 0.01
    num_inter_edges : int, optional
        Number of random edges to add between the two graphs. Default: 350
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import BACommunityDataset
    >>> dataset = BACommunityDataset()
    >>> dataset.num_classes
    8
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """
    def __init__(self,
                 num_base_nodes=300,
                 num_base_edges_per_node=4,
                 num_motifs=80,
                 perturb_ratio=0.01,
                 num_inter_edges=350,
                 seed=None,
                 raw_dir=None,
                 force_reload=False,
                 verbose=True,
                 transform=None):
        self.num_base_nodes = num_base_nodes
        self.num_base_edges_per_node = num_base_edges_per_node
        self.num_motifs = num_motifs
        self.perturb_ratio = perturb_ratio
        self.num_inter_edges = num_inter_edges
        self.seed = seed
        super(BACommunityDataset, self).__init__(name='BA-COMMUNITY',
                                                 url=None,
                                                 raw_dir=raw_dir,
                                                 force_reload=force_reload,
                                                 verbose=verbose,
                                                 transform=transform)

    def process(self):
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)

        # Construct two BA-SHAPES graphs
        g1 = BAShapeDataset(self.num_base_nodes,
                            self.num_base_edges_per_node,
                            self.num_motifs,
                            self.perturb_ratio,
                            force_reload=True,
                            verbose=False)[0]
        g2 = BAShapeDataset(self.num_base_nodes,
                            self.num_base_edges_per_node,
                            self.num_motifs,
                            self.perturb_ratio,
                            force_reload=True,
                            verbose=False)[0]

        # Join them and randomly add edges between them
        g = batch([g1, g2])
        num_nodes = g.num_nodes() // 2
        src = np.random.randint(0, num_nodes, (self.num_inter_edges,))
        dst = np.random.randint(num_nodes, 2 * num_nodes, (self.num_inter_edges,))
        src = F.astype(F.zerocopy_from_numpy(src), g.idtype)
        dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype)
        g.add_edges(src, dst)
        # print(g1.ndata['label'] )
        # print(g2.ndata['label'] + 2)
        # print(torch.tensor(np.concatenate([g1.ndata['label'], g2.ndata['label'] + 2])))
        g.ndata['label'] = F.cat([g1.ndata['label'], g2.ndata['label'] + 3], dim=0)

        # feature generation
        random_mu = [0.0] * 8
        random_sigma = [1.0] * 8

        mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
        feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes)

        mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
        feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes)

        feat = np.concatenate([feat1, feat2])
        g.ndata['feat'] = F.zerocopy_from_numpy(feat)
        self._graph = reorder_graph(
            g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)

    @property
    def graph_path(self):
        return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)
    def __len__(self):
        return 1


    @property
    def num_classes(self):
        return 8

In [191]:
dataset = BAShapeDataset(num_base_nodes=160,
                             num_base_edges_per_node=1,
                             num_motifs=int(160/5),
                             perturb_ratio=0.00,
                             seed=None,
                             raw_dir=None,
                             force_reload=True,
                             verbose=True,
                             transform=None)
data = preprocess(dataset)
data.x = torch.tensor(np.concatenate([data.x]*8, axis=1))
# np.mean(data.y.numpy())
import pandas as pd
pd.DataFrame(data.y.cpu().numpy()).value_counts(), data


[(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 6), (3, 4), (4, 7), (4, 5), (5, 8), (6, 7), (7, 8)]
Done saving data into cached files.


(1    256
 0    160
 2     32
 dtype: int64,
 Data(x=[448, 8], edge_index=[2, 1150], y=[448], train_mask=[448], val_mask=[448], test_mask=[448]))

In [198]:
dataset = BACommunityDataset(num_base_nodes=160,
                             num_base_edges_per_node=4,
                             num_motifs=80,
                             perturb_ratio=0.00,
                             num_inter_edges=1000,
                             seed=None,
                             raw_dir=None,
                             force_reload=True,
                             verbose=True,
                             transform=None)
data = preprocess(dataset)

[(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 6), (3, 4), (4, 7), (4, 5), (5, 8), (6, 7), (7, 8)]
[(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 6), (3, 4), (4, 7), (4, 5), (5, 8), (6, 7), (7, 8)]
Done saving data into cached files.


# GNNModel

In [201]:
# PyG example code: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py
import torch.nn.functional as F
class GNNModel(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                GCNConv(hidden_dim, hidden_dim)
                # GATConv(hidden_dim, hidden_dim // num_heads, num_heads)
                # GATv2Conv(hidden_dim, hidden_dim // num_heads, num_heads)

            )
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.lin_in(x)

        for layer in self.layers:
            # conv -> activation ->  dropout -> residual
            x_in = x
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

        x = self.lin_out(x)

        return x.log_softmax(dim=-1)

# Train GCN

In [202]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GNNModel(num_heads=1, num_layers=4).to(device)

data = data.to(device)


optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,  weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.5)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


import torch

@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)
    class_correct = torch.zeros(data.y.max() + 1)
    class_total = torch.zeros(data.y.max() + 1)

    class_accs = {}
    for mask_name, mask in data('train_mask', 'val_mask', 'test_mask'):
        mask_pred = pred[mask]
        mask_true = data.y[mask]

        for i in range(data.y.max() + 1):
            class_total[i] += (mask_true == i).sum().item()
            class_correct[i] += ((mask_pred == i) & (mask_true == i)).sum().item()
        class_accs[mask_name] = list(class_correct / class_total)
    return class_accs

best_val_acc = [0] * (data.y.max() + 1)
test_acc = [0] * (data.y.max() + 1)
times = []

num_epochs = 40000
for epoch in range(1, num_epochs + 1):
    start = time.time()
    loss = train()
    if (epoch % 200 == 0 or epoch == num_epochs):
        print("Epoch: ", epoch, " class accuracies: ", test())

    # train_accs, val_accs, tmp_test_acc = test()
    # Update the best validation and test accuracy
    # for i, (val_acc, test_acc) in enumerate(zip(val_accs, test_accs)):
    #     if val_acc > best_val_acc[i]:
    #         best_val_acc[i] = val_acc
    #         test_acc[i] = test_acc

    # print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}')
    # for i, (train_acc, val_acc, tmp_test_acc, best_test_acc) in enumerate(zip(train_accs, val_accs, test_accs, test_acc)):
    #     print(f'Class {i}: Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, Best Test: {best_test_acc:.4f}')

    times.append(time.time() - start)
    scheduler.step()

# Print the median time per epoch
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Epoch:  200  class accuracies:  {'train_mask': [tensor 1.000, tensor 0.993, tensor 0.119, tensor 0.991, tensor 0.986, tensor 0.121], 'val_mask': [tensor 0.993, tensor 0.984, tensor 0.110, tensor 0.978, tensor 0.985, tensor 0.118], 'test_mask': [tensor 0.988, tensor 0.980, tensor 0.100, tensor 0.950, tensor 0.986, tensor 0.100]}
Epoch:  400  class accuracies:  {'train_mask': [tensor 1.000, tensor 0.996, tensor 0.525, tensor 1.000, tensor 1.000, tensor 0.517], 'val_mask': [tensor 0.985, tensor 0.982, tensor 0.452, tensor 0.985, tensor 0.994, tensor 0.471], 'test_mask': [tensor 0.981, tensor 0.978, tensor 0.425, tensor 0.975, tensor 0.994, tensor 0.400]}
Epoch:  600  class accuracies:  {'train_mask': [tensor 1.000, tensor 0.998, tensor 0.864, tensor 1.000, tensor 0.998, tensor 0.948], 'val_mask': [tensor 0.993, tensor 0.982, tensor 0.740, tensor 0.985, tensor 0.991, tensor 0.868], 'test_mask': [tensor 0.994, tensor 0.975, tensor 0.688, tensor 0.975, tensor 0.989, tensor 0.738]}
Epoch:  80

KeyboardInterrupt: 

# Analysis

In [211]:
# PyG example code: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py

class GNNModel(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                # GCNConv(hidden_dim, hidden_dim)
                GATConv(hidden_dim, hidden_dim // num_heads, num_heads)
            )
        self.dropout = dropout

    def forward(self, x, edge_index):

        x = self.lin_in(x)

        for layer in self.layers:
            # conv -> activation ->  dropout -> residual
            x_in = x
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

        x = self.lin_out(x)

        return x.log_softmax(dim=-1)


class SparseGraphTransformerModel(Module):
    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )
        self.dropout = dropout

    def forward(self, x, dense_adj):

        x = self.lin_in(x)

        # TransformerEncoder
        # x = self.encoder(x, mask = ~dense_adj.bool())

        self.attn_weights_list = []

        for layer in self.layers:
            # # TransformerEncoderLayer
            # # boolean mask enforces graph structure
            # x = layer(x, src_mask = ~dense_adj.bool())

            # MHSA layer
            # boolean mask enforces graph structure
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                attn_mask = ~dense_adj.bool(),
                average_attn_weights = False
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)

        return x.log_softmax(dim=-1)

class DenseGraphTransformerModel(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            pos_enc_dim: int = 16,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )


        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout

    def forward(self, x, pos_enc, dense_sp_matrix):

        # x = self.lin_in(x) + self.lin_pos_enc(pos_enc)
        x = self.lin_in(x)  # no node positional encoding

        # attention bias
        # [i, j] -> inverse of shortest path distance b/w node i and j
        # diagonals -> self connection, set to 0
        # disconnected nodes -> -1
        attn_bias = self.attn_bias_scale * torch.nan_to_num(
            (1 / (torch.nan_to_num(dense_sp_matrix, nan=-1, posinf=-1, neginf=-1))),
            nan=0, posinf=0, neginf=0
        )
        #attn_bias = torch.ones_like(attn_bias)

        # TransformerEncoder
        # x = self.encoder(x, mask = attn_bias)

        self.attn_weights_list = []

        for layer in self.layers:
            # # TransformerEncoderLayer
            # # float mask adds learnable additive attention bias
            # x = layer(x, src_mask = attn_bias)

            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                attn_mask = attn_bias,
                average_attn_weights = False
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)

        return x.log_softmax(dim=-1)



class DenseGraphTransformerModel_V2(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            pos_enc_dim: int = 16,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )


        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout

    def forward(self, x, pos_enc, dense_sp_matrix):

        x = self.lin_in(x) + self.lin_pos_enc(pos_enc)
        # x = self.lin_in(x)  # no node positional encoding

        # attention bias
        # [i, j] -> inverse of shortest path distance b/w node i and j
        # diagonals -> self connection, set to 0
        # disconnected nodes -> -1
        # attn_bias = self.attn_bias_scale * torch.nan_to_num(
        #     (1 / (torch.nan_to_num(dense_sp_matrix, nan=-1, posinf=-1, neginf=-1))),
        #     nan=0, posinf=0, neginf=0
        # )
        #attn_bias = torch.ones_like(attn_bias)

        # TransformerEncoder
        # x = self.encoder(x, mask = attn_bias)

        self.attn_weights_list = []

        for layer in self.layers:
            # # TransformerEncoderLayer
            # # float mask adds learnable additive attention bias
            # x = layer(x, src_mask = attn_bias)

            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                # attn_mask = attn_bias,
                average_attn_weights = False
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)

        return x.log_softmax(dim=-1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DenseGraphTransformerModel_V2(num_heads=1, num_layers=3).to(device)

data = T.AddLaplacianEigenvectorPE(k = 16, attr_name = 'pos_enc')(data)
# data = T.AddRandomWalkPE(walk_length = 16, attr_name = 'pos_enc')(data)
data.dense_adj = to_dense_adj(data.edge_index, max_num_nodes = data.x.shape[0])[0]
# data.dense_sp_matrix = dense_shortest_path_matrix.float()  # pre-computed in previous cell
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001,  weight_decay=1e-4)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.pos_enc, 0)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred, accs = model(data.x, data.pos_enc, 0).argmax(dim=-1), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


best_val_acc = test_acc = 0
times = []
for epoch in range(1, 10000):
    start = time.time()
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, '
          f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, '
          f'Final Test: {test_acc:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

# Notes
# - Dense Transformer needs to be trained for a bit longer to reach low loss value
# - Node positional encodings are not particularly useful
# - Edge distance encodings are very useful
# - Since Cora is highly homophilic, it is important to bias the attention towards nearby nodes

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 0609, Loss: 0.3887 Train: 0.8701, Val: 0.7689, Test: 0.7727, Final Test: 0.8182
Epoch: 0610, Loss: 0.3932 Train: 0.8628, Val: 0.7614, Test: 0.7689, Final Test: 0.8182
Epoch: 0611, Loss: 0.3950 Train: 0.8628, Val: 0.7614, Test: 0.7879, Final Test: 0.8182
Epoch: 0612, Loss: 0.3763 Train: 0.8636, Val: 0.7765, Test: 0.7765, Final Test: 0.8182
Epoch: 0613, Loss: 0.3635 Train: 0.8580, Val: 0.7689, Test: 0.7727, Final Test: 0.8182
Epoch: 0614, Loss: 0.3637 Train: 0.8539, Val: 0.7727, Test: 0.7917, Final Test: 0.8182
Epoch: 0615, Loss: 0.3745 Train: 0.8547, Val: 0.7803, Test: 0.7955, Final Test: 0.8182
Epoch: 0616, Loss: 0.3624 Train: 0.8636, Val: 0.7841, Test: 0.7841, Final Test: 0.8182
Epoch: 0617, Loss: 0.3857 Train: 0.8644, Val: 0.7803, Test: 0.7879, Final Test: 0.8182
Epoch: 0618, Loss: 0.3924 Train: 0.8677, Val: 0.7727, Test: 0.7765, Final Test: 0.8182
Epoch: 0619, Loss: 0.3742 Train: 0.8604, Val: 0.7538, Test: 0.776