<a href="https://colab.research.google.com/github/jupiterepoch/SGC/blob/master/cs224w_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Final Project for CS224W: Machine Learning with Graphs
Chenshu Zhu\*, Xinglong Sun\*, Ziang Liu\*
{chenshu, xs15, ziangliu}@stanford.edu

\* equal contributions

# Part 1: Getting started with SGC

TODO: math formulation of SGC and a picture of the model architecture

In this part we replicate the results with the official code released by the SGC authors. Then we go on to explore the limitations of the SGC formulation, and experiment with different techniques to mitigate these limitations.



### Installation
This combines the environment dependencies for both part1 and part2 to ensure a hassle-free code running process through this notebook.

In [1]:
# Install required packages
! pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
! pip install torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
! pip install torch-geometric
! pip install ogb
! pip install networkx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl (9.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1+pt113cu116
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.17%2Bpt113cu116-cp39-cp39-linux_x86_64.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
Installing collect

Please run the following 3 code blocks even if you just want to play with Part 2

In [2]:
# seed the notebook for more stability
import numpy as np
import random
import torch_geometric
import torch
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch_geometric.seed_everything(seed)
set_seed(0)

In [3]:
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

In [4]:
assert torch.cuda.is_available(), print('please attach to a GPU runtime')
device = torch.device('cuda')

### A GPU runtime is required for the notebook

##TODO:

You also need to mount the google drive and have a data folder to run the first part

### Import required packages for Part 1

In [39]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import scipy.sparse as sp
import numpy as np
import pickle as pkl
import networkx as nx
import sys
from time import perf_counter
from copy import deepcopy
import random
from torch_geometric.nn import SGConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.conv import MessagePassing

### Let's start by defining the SGC model 

The SGC model can be viewed as a feature encoder. Given a graph G, SGC conducts a k-round neighborhood aggregation of initial features, thus encoding the graph structure into the aggregated node embeddings. After this embedding is extracted, any decoder network can be trained on it for downstream tasks. The essential part of the SGC thus lies in its encoder, or "precompute" stage as named by the authors.

In [4]:
def sgc_precompute(features, adj, degree):
    """
    The most basic SGC model
    """
    t = perf_counter()
    for i in range(degree):
        features = torch.spmm(adj, features)
    precompute_time = perf_counter()-t
    return features, precompute_time

def sgc_precompute_concat(features, adj, degree):
    """
    Augmented SGC model with residual connections
    """
    t = perf_counter()
    total = [deepcopy(features)]
    for i in range(degree):
        features = torch.spmm(adj, features)
        total.append(deepcopy(features))
    precompute_time = perf_counter()-t
    total = torch.cat(total, dim=-1)
    return total, precompute_time

def aug_normalized_adjacency(adj):
    """
    Normalize the adjacency matrix as defined in SGC paper
    """
    adj = adj + sp.eye(adj.shape[0])
    adj = sp.coo_matrix(adj)
    row_sum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

Then we move on to building the decoders for the SGC model

In [40]:
class SGC(nn.Module):
    """
    A Simple PyTorch Implementation of Logistic Regression.
    Assuming the features have been preprocessed with k-step graph propagation.
    """
    def __init__(self, nfeat, nclass):
        super(SGC, self).__init__()

        self.W = nn.Linear(nfeat, nclass)

    def forward(self, x):
        return self.W(x)

class SGC_Big(nn.Module):
    """
    A Simple PyTorch Implementation of Logistic Regression.
    Assuming the features have been preprocessed with k-step graph propagation.
    """
    def __init__(self, nfeat, nclass, nhid):
        super(SGC_Big, self).__init__()
        self.W1 = nn.Linear(nfeat, nhid)
        self.W2 = nn.Linear(nhid, nclass)
        self.dropout = 0.9

    def forward(self, x, use_relu=True):
        x = self.W1(x)
        if use_relu:
            x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.W2(x)
        return x

class SGC_Big_CS(nn.Module):
    """
    A Simple PyTorch Implementation of Logistic Regression.
    Assuming the features have been preprocessed with k-step graph propagation.
    """
    def __init__(self, nfeat, nclass, nhid):
        super(SGC_Big, self).__init__()
        self.W1 = nn.Linear(nfeat, nhid)
        self.W2 = nn.Linear(nhid, nclass)
        self.dropout = 0.9

    def forward(self, x, use_relu=True):
        x = self.W1(x)
        if use_relu:
            x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.W2(x)
        return x

class LPAconv(MessagePassing):
    """
    LPA convolution operation.

    Args:
        x (Tensor): Input node features of shape [num_nodes, num_node_features].
        edge_index (LongTensor or SparseTensor): Graph edge indices of shape [2, num_edges].
        mask (bool Tensor, optional): Mask for selecting a subset of nodes to operate on.
        edge_weight (float Tensor, optional): Edge weights of shape [num_edges].
        post_step (function, optional): Post-processing function applied to output.
        
    Returns:
        Tensor: Output node features of shape [num_nodes, num_node_features].
    """

    def __init__(self, num_layers: int):
        super(LPAconv, self).__init__(aggr='add')
        self.num_layers = num_layers

    def forward(
            self, y, edge_index, mask = None,
            edge_weight = None,
            post_step = lambda y: y.clamp_(0., 1.)
    ):

        if y.dtype == torch.int64:
            y = F.one_hot(y.view(-1)).to(torch.float)

        out = y
        if mask is not None:
            out = torch.zeros_like(y)
            out[mask] = y[mask]

        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
            edge_index = gcn_norm(edge_index, add_self_loops=False)
        elif isinstance(edge_index, Tensor) and edge_weight is None:
            edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
                                              add_self_loops=False)

        for _ in range(self.num_layers):
            # propagate_type: (y: Tensor, edge_weight: OptTensor)
            out = self.propagate(edge_index, x=out, edge_weight=edge_weight,
                                size=None)
        return out

class SGC_LPA(nn.Module):
    """
    A class representing the SGC-LPA model.

    Args:
        in_feature (int): Number of input features.
        hidden (int): Number of hidden features.
        out_feature (int): Number of output features.
        dropout (float): Dropout rate to use.
        num_edges (int): Number of edges in the graph.
        lpaiters (int): Number of LPA iterations to perform.
    """

    def __init__(self, in_feature, hidden, out_feature, dropout, num_edges, lpaiters):
        super(SGC_LPA, self).__init__()
        self.edge_weight = nn.Parameter(torch.ones(num_edges))
        self.conv1 = SGConv(in_feature, out_feature, K=2, cached=True)
        self.lpa = LPAconv(lpaiters)
        self.dropout_rate = dropout

    def forward(self, x, adj, y, mask):
        x, edge_index, y = data.x, data.edge_index, data.y
        x = self.conv1(x, adj, edge_weight=self.edge_weight)
        y_hat = self.lpa(y, edge_index, mask, self.edge_weight)

        return x, y_hat


In [41]:
def get_model(model_opt, nfeat, nclass, nhid=0, dropout=0, num_edges=None):
    """
    A helper function that chooses between different augmentations of SGC
    """

    if model_opt == "SGC":
        model = SGC(nfeat=nfeat, nclass=nclass)
    elif model_opt == "SGC-Concat":
        model = SGC_Big(nfeat=nfeat, nclass=nclass, nhid=nhid)
    elif model_opt == "SGC-LPA":
        model = SGC_LPA(in_feature=nfeat, out_feature=nclass, hidden=nhid, dropout=0.9, num_edges=num_edges, lpaiters=2)
    else:
        raise NotImplementedError('model:{} is not implemented!'.format(model_opt))

    return model.cuda()

### Get the data loaders

Many functions are defined for loading and pre-processing data

In [32]:
def fetch_normalization(type):
    """

    """
    switcher = {
        'AugNormAdj': aug_normalized_adjacency,  # A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
    }
    func = switcher.get(type, lambda: "Invalid normalization technique.")
    return func

def row_normalize(mx):
    """
    Row-normalize sparse matrix
    """
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def parse_index_file(filename):
    """
    Parse index file.
    """
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def generate_DAD(adj):
    """
    Computes the DAD matrix for a sparse adjacency matrix. D is the degree matrix.
    """
    adj = adj
    adj = sp.coo_matrix(adj)
    row_sum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    DAD = d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)
    DA = d_mat_inv_sqrt.dot(adj)
    return DAD.tocoo(), DA.tocoo()

def preprocess_citation(adj, features, normalization="FirstOrderGCN"):
    """
    Preprocess dataset
    """
    adj_normalizer = fetch_normalization(normalization)
    adj = adj_normalizer(adj)
    features = row_normalize(features)
    return adj, features

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """
    Convert a scipy sparse matrix to a torch sparse tensor.
    """
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def load_citation(path, dataset_str="cora", normalization="AugNormAdj"):
    """
    Load Citation Networks Datasets.
    """
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("{}/ind.{}.{}".format(path, dataset_str.lower(), names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("{}/ind.{}.test.index".format(path, dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    DAD, DA = generate_DAD(adj)
    # Calculate number of edges
    num_edges = np.sum(adj.data) // 2
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    adj, features = preprocess_citation(adj, features, normalization)

    # porting to pytorch
    features = torch.FloatTensor(np.array(features.todense())).float()
    labels = torch.LongTensor(labels)
    labels = torch.max(labels, dim=1)[1]
    adj = sparse_mx_to_torch_sparse_tensor(adj).float()
    DAD = sparse_mx_to_torch_sparse_tensor(DAD).float()
    DA = sparse_mx_to_torch_sparse_tensor(DA).float()
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    # moving things to GPU for later training
    features = features.cuda()
    adj = adj.cuda()
    DAD = DAD.cuda()
    DA = DA.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

    return adj, features, labels, idx_train, idx_val, idx_test, num_edges, DAD, DA

### Now define the driver function for training and testing the SGC model

In [53]:
# a trainer function that trains a GNN model
def train_regression(model,
                     train_features, train_labels,
                     val_features, val_labels,
                     epochs, weight_decay,
                     lr, dropout, need_adj_fwd=False, adj=None, val_adj=None):

    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    t = perf_counter()
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        if not need_adj_fwd:
            output = model(train_features)
        else:
            output = model(train_features, adj=adj)
        loss_train = F.cross_entropy(output, train_labels)
        loss_train.backward()
        optimizer.step()
    train_time = perf_counter()-t

    with torch.no_grad():
        model.eval()
        if not need_adj_fwd:
            output = model(val_features)
        else:
            output = model(val_features, adj=val_adj)
        acc_val = accuracy(output, val_labels)

        # print(f'Epoch {epoch+1},  val accuracy: {acc_val}  loss: {loss_train.item()}')

    return model, acc_val, train_time


In [None]:
# a trainer function that applies correct and smooth after trains a GNN model
def train_regression(model,
                     train_features, train_labels,
                     val_features, val_labels,
                     epochs=args.epochs, weight_decay=args.weight_decay,
                     lr=args.lr, dropout=args.dropout, need_adj_fwd=False, adj=None, val_adj=None, all_features=None, post_cs=None, train_idx=None, DAD=None, DA=None, val_idx=None, test_idx=None, test_labels=None):

    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    t = perf_counter()
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        if not need_adj_fwd:
            output = model(train_features)
        else:
            output = model(train_features, adj=adj)
        loss_train = F.cross_entropy(output, train_labels)
        loss_train.backward()
        optimizer.step()
    train_time = perf_counter()-t
    model.eval()
    if all_features is not None and post_cs is not None:
        with torch.no_grad():
            output = model(all_features)
        y_soft = output.softmax(dim=-1)
        DAD = DAD.to_dense()
        indices = torch.nonzero(DAD).t()
        DAD = SparseTensor(row=indices[0], col=indices[1], value=DAD[indices[0], indices[1]], sparse_sizes=DAD.size())

        DA = DA.to_dense()
        indices = torch.nonzero(DA).t()
        DA = SparseTensor(row=indices[0], col=indices[1], value=DA[indices[0], indices[1]], sparse_sizes=DA.size())
        y_soft = post_cs.correct(y_soft=y_soft, y_true=train_labels.unsqueeze(-1), mask=train_idx, edge_index=DAD)
        y_soft = post_cs.smooth(y_soft=y_soft, y_true=train_labels.unsqueeze(-1), mask=train_idx, edge_index=DAD)
        acc_val = accuracy(y_soft[val_idx], val_labels)
        acc_test = accuracy(y_soft[test_idx], test_labels)

    return model, acc_val, train_time, acc_test

In [None]:
# now define the metrics and tester
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def test_regression(model, test_features, test_labels, need_adj_fwd=False, adj=None):
    model.eval()
    if not need_adj_fwd:
        return accuracy(model(test_features), test_labels)
    else:
        return accuracy(model(test_features, adj), test_labels)

In [50]:
args = objectview({
    'epochs': 250,
    'lr': 0.001,
    'weight_decay': 5e-6,
    'hidden': 128,
    'dropout': 0.5,
    'dataset': 'cora',
    'normalization': 'AugNormAdj',
    'degree': 5,
})

In [54]:
# you need to connect to a drive with data

path = 'drive/MyDrive/data'

def main():
    adj, features, labels, idx_train, idx_val, idx_test, num_edges, DAD, DA = load_citation(path, args.dataset, args.normalization)

    features, precompute_time = sgc_precompute_concat(features, adj, args.degree)
    
    nfeat = features.size(1) # * (args.degree+1)
    nclass = labels.max().item()+1

    model = get_model('SGC-Concat', nfeat=nfeat, nclass=nclass, nhid=args.hidden, dropout=args.dropout)

    model, acc_val, train_time = train_regression(
                     model, features[idx_train], labels[idx_train], features[idx_val], labels[idx_val],
                     args.epochs, args.weight_decay, args.lr, args.dropout)
    
    acc_test = test_regression(model, features[idx_test], labels[idx_test])

    print("Validation Accuracy: {:.4f}  Test Accuracy: {:.4f}".format(acc_val, acc_test))
    print("Pre-compute time: {:.4f}s, train time: {:.4f}s, total: {:.4f}s".format(precompute_time, train_time, precompute_time+train_time))

main()

  objects.append(pkl.load(f, encoding='latin1'))


Validation Accuracy: 0.7960  Test Accuracy: 0.8310
Pre-compute time: 0.0132s, train time: 0.3452s, total: 0.3584s


# Part 2: Applying SGC to real-world problems

In this part, we practice what we learned from the thorough experimentation in Part 1, and verify whether our augmentation works at scale (robust enough to generalize to real-world datasets). Specifically, we re-implement SGC under the OGB framework, and test our methods on OGB node classification datasets. Empirical results demonstrate the effectiveness of our model augmentation.

### Importing required packages for Part 2

Many packages are repeating from Part 1, but we list them here in case you just want to play with Part 2



In [5]:
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_sparse import SparseTensor, matmul
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import CorrectAndSmooth
from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator as NodeEvaluator
from torch.utils.data import DataLoader

### Model Construction

We want to be able to add functionalities to SGC easily, also we would like to have outputs with the right shapes for different datasets. Thus we define a SGC wrapper class. The wrapper class ```SGC``` calls a ```SGConv``` layer which does the SGC propagations, and uses a linear layer to project to the dimension we want. This way, tweaking the ```SGConv``` layer is very easy. The base SGC method is achieved through simple sum aggregation and adjacency matrix normalization provided by ```torch_geometrics```.

In [6]:
class SGC(torch.nn.Module):
    """
    The SGC wrapper class: computes the SGC encoding of network features, 
    then projects to the number of output classes
    """

    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5, K=3, aug=None):
        super(SGC, self).__init__()
        if not aug:
            self.conv = SGConv(in_channels, hidden_channels, K=K)
        elif aug == 'pyramid':
            self.conv = SGConvRes(in_channels, hidden_channels, K=K)
        elif aug == 'deepres':
            self.conv = DRSGConv(in_channels, hidden_channels, K=K)
        elif aug == 'jk':
            self.conv = SGConvJK(in_channels, hidden_channels, K=K)
        self.lin = Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def reset_parameters(self):
        self.conv.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, adj_t):
        x = self.conv(x, adj_t, edge_weight=adj_t)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)

The following blocks define the base SGC convolutional layer and several augmentation schemes.

In [7]:
class SGConv(MessagePassing):
    """
    The simple graph convolutional operator from the `"Simplifying Graph
    Convolutional Networks" <https://arxiv.org/abs/1902.07153>`_ paper
    """

    def __init__(self, in_channels: int, out_channels: int, K: int = 1,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.add_self_loops = add_self_loops
        self.lin = Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if isinstance(edge_index, Tensor):
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        for k in range(self.K):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
        return self.lin(x)

    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K={self.K})')

In [8]:
class SGConvRes(MessagePassing):
    """
    SGC with residual connections from all propagation depths
    """

    def __init__(self, in_channels: int, out_channels: int, K: int = 1,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.add_self_loops = add_self_loops
        self.lin = Linear((K+1) * in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if isinstance(edge_index, Tensor):
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        total = [deepcopy(x.detach())]
        for k in range(self.K):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
            total.append(deepcopy(x.detach()))
        total = torch.cat(total, dim=-1)
        return self.lin(total)

    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K={self.K})')

In [25]:
class DRSGConv(MessagePassing):
    """
    SGC with RNN-gated residual connections from the first propagation layer
    """

    def __init__(self, in_channels: int, out_channels: int, K: int = 1,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.add_self_loops = add_self_loops

        self.rnn = nn.GRU(1, 1, num_layers=1, batch_first=True)
        self.lin = Linear(in_channels, 1, bias=bias)
        self.lin_out = Linear(in_channels, out_channels, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.lin_out.reset_parameters()
        self.rnn.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if isinstance(edge_index, Tensor):
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
        x0 = deepcopy(x.detach())
        h0 = torch.nn.Parameter(torch.rand(1, len(x0), 1) * (1.0 / self.out_channels) ** 0.5).to(x0.device)
        for k in range(self.K-1):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
            z = self.lin(F.normalize(x) * F.normalize(x0))[:, None]
            alpha, h0 = self.rnn(z, h0)
            alpha = torch.abs(alpha).squeeze(1)
            x = (1 - alpha) * x + alpha * x0
        return self.lin_out(x)

In [35]:
class SGConvJK(MessagePassing):
    """
    SGC with Jumping Knowledge
    """

    def __init__(self, in_channels: int, out_channels: int, K: int = 1,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.add_self_loops = add_self_loops

        self.jk = JumpingKnowledge(mode='max', channels=in_channels, num_layers=K+1)
        self.lin = Linear(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.jk.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if isinstance(edge_index, Tensor):
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(  # yapf: disable
                edge_index, edge_weight, x.size(self.node_dim), False,
                self.add_self_loops, dtype=x.dtype)
        total = [deepcopy(x.detach())]
        for k in range(self.K):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
            total.append(deepcopy(x.detach()))
        x = self.jk(total)
        return self.lin(x)

    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K={self.K})')

### Training and testing driver codes

In [11]:
def train_node(model, data, split_node, optimizer, batch_size):
    '''
    Trains a given GNN model on the provided data for node classification
    '''
    model.train()

    total_loss = total_examples = 0
    for perm in DataLoader(split_node['train'], batch_size, shuffle=True):
        
        optimizer.zero_grad()

        # compute SGC propagations
        h = model(data.x, data.adj_t)

        # logits are log probabilities for each class
        logits = h[perm]
        labels = data.y[perm].squeeze(-1)

        # nll_loss for multi-class classification
        loss = F.nll_loss(logits, labels, reduction='sum')
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(split_node['train'])


In [12]:
def get_metrics(y_pred, y_true, split_node, evaluator):
    '''
    Calculates the node prediciton accuracy given a OGB evaluator
    '''

    train_acc = evaluator.eval({
        'y_true': y_true[split_node['train']],
        'y_pred': y_pred[split_node['train']],
    })[f'acc']

    valid_acc = evaluator.eval({
        'y_true': y_true[split_node['valid']],
        'y_pred': y_pred[split_node['valid']],
    })[f'acc']

    test_acc = evaluator.eval({
        'y_true': y_true[split_node['test']],
        'y_pred': y_pred[split_node['test']],
    })[f'acc']

    result_dict = {
        'train': train_acc,
        'valid': valid_acc,
        'test': test_acc
    }

    return result_dict

@torch.no_grad()
def test_node(model, data, split_node, evaluator, correct_smooth=False):
    '''
    Evaluates the model using the ogb evaluator, if correct and smooth is set
    to true, then performs correct&smooth on the predictor results
    '''

    model.eval()

    h = model(data.x, data.adj_t)
    y_soft = h.softmax(-1)

    y_pred = y_soft.argmax(-1, keepdims=True)
    result = get_metrics(y_pred, data.y, split_node, evaluator)
    result_smoothed = None

    if correct_smooth:

        deg = data.adj_t.sum(dim=1).to(torch.float)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        DAD = deg_inv_sqrt.view(-1, 1) * data.adj_t * deg_inv_sqrt.view(1, -1)
        DA = deg_inv_sqrt.view(-1, 1) * deg_inv_sqrt.view(-1, 1) * data.adj_t

        post = CorrectAndSmooth(num_correction_layers=50, correction_alpha=0.9,
                                num_smoothing_layers=50, smoothing_alpha=0.8,
                                autoscale=True, scale=20.)
        
        y_true = data.y[split_node['train']]
        y_soft = post.correct(y_soft, y_true, split_node['train'], DAD)
        y_soft = post.smooth(y_soft, y_true, split_node['train'], DAD) # DAD performs better than DA
        y_pred = y_soft.argmax(-1, keepdims=True)

        result_smoothed = get_metrics(y_pred, data.y, split_node, evaluator)

    return (result, result_smoothed)

## Experiment:
First train a simple SGC model on the ogbn-arxiv dataset. Then we compare the effectiveness of different augmentations. Each time the results with and without using correct & smooth are displayed alongside for comparison.

### Get Dataloaders from OGB

In [13]:
# load a real-world dataset for analysis
dataset = PygNodePropPredDataset('ogbn-arxiv', transform=T.ToSparseTensor())
# get the graph from dataset
data = dataset[0]
data.x = data.x.to(torch.float)
data.adj_t = data.adj_t.to_symmetric()
data = data.to(device)
split_node = dataset.get_idx_split()

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:02<00:00, 33.84it/s]


Extracting dataset/arxiv.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 2150.93it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 73.98it/s]

Saving...



Done!


In [19]:
# The final driver code that puts everything together.
def run(args):
    model = SGC(data.num_features, args.hidden_channels, dataset.num_classes, args.dropout, args.K, args.aug).to(device)
    sum_params = 0.
    for p in model.parameters():
        sum_params += p.numel()
    print(f'Params: {sum_params}')

    model.reset_parameters()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    evaluator = NodeEvaluator(name='ogbn-arxiv')

    best_model = None
    best_val = 0.

    for epoch in range(1, 1 + args.epochs):
        loss = train_node(model, data, split_node, optimizer, args.batch_size)
        result, result_smooth = test_node(model, data, split_node, evaluator, correct_smooth=True)

        if result['valid'] > best_val:
            best_val = result['valid']
            best_model = deepcopy(model)

        print(f'Epoch: {epoch}, ', end=' ')
        for key, val in result.items():
            print(f'{key}: {val:.7f}, ', end=' ')
        if result_smooth:
            print('Correct and smooth... ', end = ' ')
            for key, val in result_smooth.items():
                print(f'{key}: {val:.7f}, ', end=' ')
        print(f'loss: {loss:.7f}, ')

    result, result_smooth = test_node(best_model, data, split_node, evaluator, correct_smooth=True)
    best_acc, best_acc_smooth = result['test'], result_smooth['test']
    print(f"Best test accuracy is {best_acc:.7f} and {best_acc_smooth:.7f} after correct & smooth")

    return best_model

In [20]:
# define the hyper-parameters for a base run
args = objectview({
    'hidden_channels': 64,
    'dropout': 0.5,
    'K': 5,
    'lr': 5e-3,
    'epochs': 1000,
    'batch_size': 1024 * 64,
    'aug': None
})

run(args)

Params: 10792.0
Epoch: 1,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9536183,  valid: 0.6845196,  test: 0.6592597,  loss: 3.6772354, 
Epoch: 2,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9535743,  valid: 0.6859626,  test: 0.6613789,  loss: 3.5028514, 
Epoch: 3,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9536403,  valid: 0.6881439,  test: 0.6628809,  loss: 3.3409822, 
Epoch: 4,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9537722,  valid: 0.6895869,  test: 0.6648972,  loss: 3.2190938, 
Epoch: 5,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9536403,  valid: 0.6961643,  test: 0.6736621,  loss: 3.1639033, 
Epoch: 6,  train: 0.2601247,  valid: 0.2681298,  test: 0.2500257,  Correct and smooth...  train: 0.9534753,  valid: 0.7041176,  test: 0.6867066,  loss: 3.1402600, 


KeyboardInterrupt: ignored

In [21]:
# pyramid SGC
args = objectview({
    'hidden_channels': 64,
    'dropout': 0.5,
    'K': 5,
    'lr': 5e-3,
    'epochs': 1000,
    'batch_size': 1024 * 64,
    'aug': 'pyramid'
})

run(args)

Params: 51752.0
Epoch: 1,  train: 0.1790611,  valid: 0.0762777,  test: 0.0586178,  Correct and smooth...  train: 0.9546739,  valid: 0.6909292,  test: 0.6665226,  loss: 3.5972607, 
Epoch: 2,  train: 0.2774986,  valid: 0.3003121,  test: 0.2692221,  Correct and smooth...  train: 0.9540031,  valid: 0.7058291,  test: 0.6942987,  loss: 3.2314862, 
Epoch: 3,  train: 0.2997988,  valid: 0.2918554,  test: 0.2626381,  Correct and smooth...  train: 0.9524527,  valid: 0.6939159,  test: 0.6725099,  loss: 3.0859137, 
Epoch: 4,  train: 0.2750135,  valid: 0.2647740,  test: 0.2364463,  Correct and smooth...  train: 0.9527056,  valid: 0.6978087,  test: 0.6774479,  loss: 2.9574539, 
Epoch: 5,  train: 0.2913757,  valid: 0.3032652,  test: 0.2742423,  Correct and smooth...  train: 0.9537172,  valid: 0.7071043,  test: 0.6925498,  loss: 2.8225725, 
Epoch: 6,  train: 0.3204605,  valid: 0.3265546,  test: 0.3017098,  Correct and smooth...  train: 0.9538492,  valid: 0.7089500,  test: 0.6963768,  loss: 2.6955177, 


KeyboardInterrupt: ignored

In [28]:
# deep residual SGC
args = objectview({
    'hidden_channels': 32,
    'dropout': 0.5,
    'K': 7,
    'lr': 5e-3,
    'epochs': 1000,
    'batch_size': 1024 * 64,
    'aug': 'deepres'
})

run(args)

Params: 5589.0
Epoch: 1,  train: 0.1099394,  valid: 0.2297393,  test: 0.2155628,  Correct and smooth...  train: 0.9589844,  valid: 0.6768348,  test: 0.6712137,  loss: 17579930173101.8378906, 
Epoch: 2,  train: 0.1099504,  valid: 0.2297393,  test: 0.2155628,  Correct and smooth...  train: 0.9589624,  valid: 0.6768012,  test: 0.6710697,  loss: 9327950344702.2597656, 
Epoch: 3,  train: 0.0769180,  valid: 0.1495017,  test: 0.2208711,  Correct and smooth...  train: 0.9572470,  valid: 0.6562972,  test: 0.6113409,  loss: 6559585797450.8085938, 
Epoch: 4,  train: 0.1099394,  valid: 0.2297393,  test: 0.2155628,  Correct and smooth...  train: 0.9589844,  valid: 0.6768348,  test: 0.6712137,  loss: 5737388383487.3437500, 
Epoch: 5,  train: 0.1099394,  valid: 0.2297393,  test: 0.2155628,  Correct and smooth...  train: 0.9589844,  valid: 0.6768348,  test: 0.6712137,  loss: 4378812693451.1791992, 
Epoch: 6,  train: 0.1099724,  valid: 0.2297393,  test: 0.2155628,  Correct and smooth...  train: 0.95898

KeyboardInterrupt: ignored

In [37]:
# SGC with jumping knowledge
args = objectview({
    'hidden_channels': 64,
    'dropout': 0.5,
    'K': 5,
    'lr': 5e-3,
    'epochs': 1000,
    'batch_size': 1024 * 64,
    'aug': 'jk'
})

run(args)

Params: 10856.0
Epoch: 1,  train: 0.0770610,  valid: 0.1496023,  test: 0.2209740,  Correct and smooth...  train: 0.9534643,  valid: 0.6809960,  test: 0.6535399,  loss: 3.6878890, 
Epoch: 2,  train: 0.1836245,  valid: 0.1917178,  test: 0.2499434,  Correct and smooth...  train: 0.9529145,  valid: 0.6778080,  test: 0.6462358,  loss: 3.4743524, 
Epoch: 3,  train: 0.1912119,  valid: 0.1082922,  test: 0.1080386,  Correct and smooth...  train: 0.9526506,  valid: 0.6780764,  test: 0.6459684,  loss: 3.2807693, 
Epoch: 4,  train: 0.1794350,  valid: 0.0772174,  test: 0.0592350,  Correct and smooth...  train: 0.9533434,  valid: 0.6887815,  test: 0.6621196,  loss: 3.1471150, 
Epoch: 5,  train: 0.1881110,  valid: 0.0999027,  test: 0.0796041,  Correct and smooth...  train: 0.9538492,  valid: 0.6985805,  test: 0.6776125,  loss: 3.1022303, 
Epoch: 6,  train: 0.2434436,  valid: 0.2263163,  test: 0.2011193,  Correct and smooth...  train: 0.9537832,  valid: 0.7033122,  test: 0.6849577,  loss: 3.0825809, 


KeyboardInterrupt: ignored