# OGB arxiv paper node classification 

In [None]:
!pip uninstall torch torchvision torchaudio --y
!pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install torch_geometric
!pip install ogb
!pip install GPUtil

In [None]:
from torch_geometric.data import Data
import json
import numpy as np
import argparse
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import GPUtil
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]
      
    def pickle(self, key_save):
        f = open(key_save, 'wb')
        pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
        f.close()

    def unpickle(self, key_save):
        with open(key_save, 'rb') as f:
            return pickle.load(f)

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'Highest Test: {result[:, 2].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'  Final Valid: {result[argmax, 1]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                test = r[:, 2].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test2 = r[r[:, 1].argmax(), 2].item()
                best_results.append((train, valid, test, train2, test2))

            best_result = torch.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'Highest Test: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'  Final Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 4]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')


def test(model, data, split_idx, evaluator):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.adj_t)
        y_pred = out.argmax(dim=-1, keepdim=True)
        train_acc = evaluator.eval({
            'y_true': data.y[split_idx['train']],
            'y_pred': y_pred[split_idx['train']],
        })['acc']
        valid_acc = evaluator.eval({
            'y_true': data.y[split_idx['valid']],
            'y_pred': y_pred[split_idx['valid']],
        })['acc']
        test_acc = evaluator.eval({
            'y_true': data.y[split_idx['test']],
            'y_pred': y_pred[split_idx['test']],
        })['acc']
    return train_acc, valid_acc, test_acc

In [None]:
def slide_idx(data, indices):
    full_mat = data.adj_t.to_scipy().tocsr()
    coo = full_mat[indices][:, indices].tocoo() # This is wrong, because it omits "papers" in the future
    values = coo.data
    sub_indices = np.vstack((coo.row, coo.col))
    i = torch.LongTensor(sub_indices)
    v = torch.FloatTensor(values)
    shape = coo.shape
    sub_idx = torch.sparse.FloatTensor(i, v, torch.Size(shape)).coalesce().indices().clone()
    sub_x = data.x[indices].clone()
    sub_y = data.y[indices].clone()
    return Data(x=sub_x, y=sub_y, edge_index=sub_idx).to(device)

# Combine and resplit split_idx
def shuffle_split_idx(split_idx):
    train_idx = split_idx['train']
    val_idx = split_idx['valid']
    test_idx = split_idx['test']
    full_idx = torch.cat([train_idx, val_idx, test_idx])
    train_frac, val_frac, test_frac = len(train_idx)/len(full_idx), len(val_idx)/len(full_idx), len(test_idx)/len(full_idx)
    torch.manual_seed(1103)
    train_idx, val_idx, test_idx = torch.utils.data.random_split(full_idx, [int(train_frac*len(full_idx)), int(val_frac*len(full_idx)), int(test_frac*len(full_idx))])
    train_idx = torch.sort(torch.tensor(train_idx.indices))[0]
    val_idx = torch.sort(torch.tensor(val_idx.indices))[0]
    test_idx = torch.sort(torch.tensor(test_idx.indices))[0]
    split_idx = {'train': train_idx, 'valid': val_idx, 'test': test_idx}
    return split_idx

def mem_report():
    if device.type == 'cuda':
        GPUs = GPUtil.getGPUs()
        for i, gpu in enumerate(GPUs):
            print('GPU {:d} ... Mem Free: {:.0f}MB / {:.0f}MB | Utilization {:3.0f}%'.format(
                i, gpu.memoryFree, gpu.memoryTotal, gpu.memoryUtil*100))
    else:
        print("CPU RAM Free: "
              + humanize.naturalsize(psutil.virtual_memory().available))

In [None]:
def train(model, data, args):
    model.train()
    # No batch
    out = model(data.x, data.edge_index)
    label = F.one_hot(data.y.squeeze(1), num_classes = dataset.num_classes).to(torch.float)
    loss = nn.MSELoss()(out, label)*dataset.num_classes
    loss.backward()
    return loss.item()


def train_SVI(model, data, args):
    model.train()
    model.layers_Xtilde = []
    model.layers_grad = []
    model.on_training = True # IF false, then SVI is NOT used
    out = model(data.x, data.edge_index)
    label = F.one_hot(data.y.squeeze(1), num_classes = dataset.num_classes).to(torch.float)
    loss = nn.MSELoss()(out, label)*dataset.num_classes
    model.turn_on_off_grad(on = False) 
    loss.backward(retain_graph = True)  # To get grad of L w.r.t. X_{l+1} for all layers at once
    model.turn_on_off_grad(on = True)
    for Xlplus1, Xlplus1grad in zip(model.layers_Xtilde, model.layers_grad):
        Xlplus1grad = Xlplus1grad.grad.detach().to(device)
        loss_tilde = (Xlplus1*Xlplus1grad).sum()
        loss_tilde.backward(retain_graph=True)
        model.turn_off_grad_during_SVI()
    model.on_training = False
    model.turn_on_off_grad(on = True) 
    return loss.item()

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, FC=False):
        super(GNN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.FC = FC
        if self.FC:
            self.convs.append(nn.Linear(in_channels, hidden_channels))
        else:
            self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            if self.FC:
                self.convs.append(nn.Linear(hidden_channels, hidden_channels))
            else:
                self.convs.append(
                    GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        if self.FC:
            self.convs.append(nn.Linear(hidden_channels, out_channels))
        else:
            self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
        

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            if self.FC:
                x = conv(x)
            else:
                x = conv(x, edge_index)
            if i < len(self.convs)-1:
                x = self.bns[i](x)
                x = F.relu(x)
            else:
                x = x.softmax(dim=1)
        return x

class GNN_SVI(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, FC=False):
        super(GNN_SVI, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.FC = FC
        if self.FC:
            self.convs.append(nn.Linear(in_channels, hidden_channels))
        else:
            self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            if self.FC:
                self.convs.append(nn.Linear(hidden_channels, hidden_channels))
            else:
                self.convs.append(
                    GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        if self.FC:
            self.convs.append(nn.Linear(hidden_channels, out_channels))
        else:
            self.convs.append(
                GCNConv(hidden_channels, out_channels, cached=True))
        #### New lines for SVI ####
        self.layers_Xtilde = []
        self.layers_grad = []
        self.on_training = True
        #### End #####

    #### New lines for SVI ####
    # Avoid gradient accumulation
    def turn_on_off_grad(self, on = True):
        for param in self.parameters():
            param.requires_grad = on
            
    def turn_off_grad_during_SVI(self):
        for param in self.parameters():
            if param.requires_grad:
                if param.grad is not None and param.grad.sum() != 0:
                    # Turn off since its SVI update direction is already computed
                    param.requires_grad = False
    #### End #####


    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            if self.FC:
                x = conv(x)
            else:
                x = conv(x, edge_index)
            if self.on_training and i == len(self.convs)-1:
                # No BN, last layer
                self.layers_Xtilde.append(x)
            if i < len(self.convs)-1:
                x = self.bns[i](x)
                if self.on_training:
                    self.layers_Xtilde.append(x)
                x = F.relu(x)
            else:
                x = x.softmax(dim=1)
            if self.on_training:
                x.retain_grad()  # To get the gradient with respect to output
                self.layers_grad.append(x)
        return x

In [None]:
lr, optim_name = 0.001, 'SGD'
SVI_ls = [True, False]  # If False, use ordinary SGD or Adam
SVI_pause = [False, False]  # If True, only SVI to warm start
FC = False  # If use fully-connected layers. Default False. NOTE, it can be removed when introducing this as an example publically
hidden_channels_ls = [256]
num_runs = 3
num_epochs = 1000
epoch_stop_SVI = int(num_epochs/10) # When to stop SVI as warm start

result_ls = []
if __name__ == "__main__":
    for use_SVI, pause_SVI in zip(SVI_ls, SVI_pause):
        for hidden_channels in hidden_channels_ls:
            result_dict = {'SVI-SGD': [],
                            'SVI_warmstart-SGD': [],
                            'SGD': [],
                            'SVI-Adam': [],
                            'SVI_warmstart-Adam': [],
                            'Adam': []}
            parser = argparse.ArgumentParser(
                description='OGBN-Arxiv (GNN)')
            parser.add_argument('--log_steps', type=int, default=1)
            parser.add_argument('--num_layers', type=int, default=4)
            parser.add_argument('--lr', type=float, default=lr)
            parser.add_argument('--momentum', type=float, default=0.95)
            parser.add_argument('--epochs', type=int, default=num_epochs)
            parser.add_argument('--batch', type=int, default=1)
            parser.add_argument('--runs', type=int, default=num_runs)
            parser.add_argument('--SVI', type=bool, default=use_SVI)
            parser.add_argument(
                '--optimizer', type=str, default=optim_name)
            args = parser.parse_args(args=[])
            args.FC = FC # If use fully-connected nets instead of GCN layers
            # args.hidden_channels = 128 if args.num_layers >= 3 else 1000
            args.hidden_channels = hidden_channels
            print(args)
            dataset = PygNodePropPredDataset(name='ogbn-arxiv',
                                            transform=T.ToSparseTensor())
            data = dataset[0]
            split_idx = dataset.get_idx_split()
            split_idx = shuffle_split_idx(split_idx)

            data = dataset[0]
            data = data.to(device)
            data.adj_t = data.adj_t.to_symmetric()
            data_train=slide_idx(data, split_idx['train'])
            logger = Logger(args.runs, args)
            results_over_runs = {}
            for run in range(args.runs):
                accu_at_run = []
                args.SVI = use_SVI
                torch.manual_seed(1103 + run)
                if args.SVI:
                    model = GNN_SVI(data.num_features, args.hidden_channels,
                                    dataset.num_classes, args.num_layers, args.FC).to(device)
                else:
                    model = GNN(data.num_features, args.hidden_channels,
                                dataset.num_classes, args.num_layers, args.FC).to(device)
                evaluator = Evaluator(name='ogbn-arxiv')
                if args.optimizer == 'SGD':
                    optimizer = torch.optim.SGD(
                        model.parameters(), lr=args.lr, momentum=args.momentum, nesterov=True)
                else:
                    optimizer = torch.optim.Adam(
                        model.parameters(), lr=args.lr)
                for epoch in range(1, 1 + args.epochs):
                    if device.type == 'cuda':
                        # Useful to avoid GPU allocation excess
                        torch.cuda.empty_cache()
                    # print(f"LR is {optimizer.param_groups[0]['lr']}")
                    if epoch == epoch_stop_SVI + 1 and pause_SVI:
                        # Reinitialize optimizer to avoid gradient issue
                        args.SVI = False
                        sdict = model.state_dict()
                        print(
                            '############ Pause SVI from now on ############')
                        model = GNN(data.num_features, args.hidden_channels,
                            dataset.num_classes, args.num_layers,
                            args.dropout, args.FC).to(device)
                        model.load_state_dict(sdict)
                        model = model.to(device)
                        if args.optimizer == 'SGD':
                            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, 
                                                        momentum=args.momentum, 
                                                        nesterov=True)
                        else:
                            optimizer = torch.optim.Adam(
                                model.parameters(), lr=args.lr)
                    optimizer.zero_grad()
                    loss = train_SVI(model, data_train, args) if args.SVI else train(model, data_train, args)
                    optimizer.step()
                    # print('Testing')
                    if epoch % args.log_steps == 0:
                        print('Testing')
                        if args.SVI:
                            print(f'SVI-{args.optimizer} training at epoch {epoch}')
                        else:
                            print(f'{args.optimizer} training at epoch {epoch}')
                        result = test(model, data, split_idx, evaluator)                       
                        mem_report()    
                        logger.add_result(run, result)
                        train_acc, valid_acc, test_acc = result
                        accu_at_run += [[train_acc, valid_acc, test_acc]]
                        print(f'Run: {run + 1:02d}, '
                                f'Epoch: {epoch:02d}, '
                                f'Loss: {loss:.4f}, '
                                f'Train: {100 * train_acc:.2f}%, '
                                f'Valid: {100 * valid_acc:.2f}% '
                                f'Test: {100 * test_acc:.2f}%')
                # Running np.array(accu_at_run) would make it into Epoch-by-3 matrices, but doing so causes .json saving error so I just use the list version
                results_over_runs[f'lr={args.lr}@Run{run+1}'] = accu_at_run
                logger.print_statistics(run)
                # Save results
                if use_SVI:
                    SVI_prefix = 'SVI_warmstart-' if pause_SVI else 'SVI-'
                else:
                    SVI_prefix = ''
                key = f'{SVI_prefix}{optim_name}'
                fc_use = '-FC' if args.FC else ''
                key_save = f'{SVI_prefix}{optim_name}-{args.num_layers}layers-{args.hidden_channels}nodes-{args.lr}LR{fc_use}_correct_split_1'
                # logger.pickle(key_save) # Save it to file, but need not now because only one run.
                result_dict[key].append(results_over_runs)
                with open(f"{key_save}_loss_together_SVI_only_shuffle_{num_epochs}.json", "w") as outfile:
                    json.dump(result_dict, outfile)
            logger.print_statistics()
            result_ls.append(logger)

In [None]:
for logger in result_ls:
    logger.print_statistics()