<a href="https://colab.research.google.com/github/gbdl/ECDSep/blob/main/ECDSep_graphs_arxiv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code to reproduce the experiments performed on the dataset `ogbn-arxiv`

In [None]:
import torch

def format_pytorch_version(version):
    return version.split('+')[0]

def format_cuda_version(version):
    return 'cu' + version.replace('.', '')

TORCH_version = torch.__version__
TORCH = '2.0.0'#format_pytorch_version(TORCH_version)
CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html;
!pip install ogb;

In [None]:
import argparse
import sys
sys.path.append("..")

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

from inflation import ECDSep

import random
import numpy as np

## Dataset, useful functions and model, all from ogbn-arxiv. We also set some parameters as in the OGB paper.

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


def train(model, data, train_idx, optimizer):
    model.train()

    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[train_idx]
    loss = F.nll_loss(out, data.y.squeeze(1)[train_idx])
    loss.backward()
    def closure():
        return loss
    optimizer.step(closure)
    return loss.item()


@torch.no_grad()
def test(model, data, split_idx, evaluator):
    model.eval()

    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': data.y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc

In [None]:
device = 0
log_steps = 1
use_sage = True
num_layers = 3
hidden_channels = 256
dropout = 0.5
epochs = 500

In [None]:
device = f'cuda:{device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

dataset = PygNodePropPredDataset(name='ogbn-arxiv',
                                 transform=T.ToSparseTensor())

data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data = data.to(device)

split_idx = dataset.get_idx_split()
train_idx = split_idx['train'].to(device)

evaluator = Evaluator(name='ogbn-arxiv')

In [None]:
if use_sage:
  model = SAGE(data.num_features, hidden_channels,
                dataset.num_classes, num_layers,
                dropout).to(device)
else:
  model = GCN(data.num_features, hidden_channels,
              dataset.num_classes, num_layers,
              dropout).to(device)
print('Number of parameters:', sum(p.numel() for p in model.parameters()))

## Experiments
`opt` is the optimizer chosen for the experiments (among "ECDSep", "sgd", "adam", "adamw"). Remember to change the name of the optimizer and the hyperparameters in the first few lines of the next cell.

In [None]:
runs = 10

opt = "ECDSep"
lr = 2.8
eta = 4.5
nu = 1e-5
wd = 0.
momentum = 0.95

best_losses, best_accuracies = [], []
for run in range(runs):
  seed = np.random.randint(100000000)
  torch.manual_seed(seed)
  random.seed(seed)
  torch.cuda.manual_seed(seed)


  model.reset_parameters()

  if opt == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)

  elif opt == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

  elif opt == "adamw":
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

  elif opt == "ECDSep":
    consEn = True
    deltaEn = 0.
    s = 1
    F0 = 0.
    optimizer = ECDSep(model.parameters(), lr=lr, eta=eta, nu=nu, s=s, deltaEn=deltaEn, consEn=consEn, F0=F0, weight_decay=wd)

  best_loss = 1e+10
  best_accuracy = 0.
  
  for epoch in range(1, 1 + epochs):
      loss = train(model, data, train_idx, optimizer)
      result = test(model, data, split_idx, evaluator)
      if loss < best_loss:
          best_loss = loss

      if epoch % log_steps == 0:
          train_acc, valid_acc, test_acc = result
          print(f'Run: {run + 1:02d}, '
                f'Epoch: {epoch:02d}, '
                f'Loss: {loss:.4f}, '
                f'Train: {100 * train_acc:.2f}%, '
                f'Valid: {100 * valid_acc:.2f}% '
                f'Test: {100 * test_acc:.2f}%')
          if test_acc > best_accuracy:
              best_accuracy = test_acc
  best_losses.append(best_loss)
  best_accuracies.append(best_accuracy)

In [None]:
print("Average best accuracy for "+opt+" over "+str(runs)+" runs is ", np.mean(best_accuracies))
print("Average minimum loss for "+opt+" over "+str(runs)+" runs is ", np.mean(best_losses))