In [1]:
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.linkproppred import PygLinkPropPredDataset
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
import torch

%reload_ext autoreload
%autoreload 
#pytorch==2.0.0 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 pyg=2.4.0 ogb=1.3.6
#pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
torch.__version__


'2.0.0+cu117'

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [3]:
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygLinkPropPredDataset(name='ogbl-ppa',
                                transform=T.ToSparseTensor())

In [None]:
data = dataset[0]
data.x = data.x.to(torch.float)
#data.x = torch.cat([data.x, torch.load('embedding.pt')], dim=-1)
data = data.to(device)
split_edge = dataset.get_edge_split()

In [None]:
adj_t = data.adj_t.set_diag()
deg = adj_t.sum(dim=1).to(torch.float)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)
data.adj_t = adj_t

In [3]:
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygGraphPropPredDataset(name='ogbg-molpcba',)

In [4]:
data = dataset[0]
data

Data(edge_index=[2, 44], edge_attr=[44, 3], x=[20, 9], y=[1, 128], num_nodes=20)

## Model

In [8]:
from lib.training.training_base import read_config_from_file
from lib.training.importer import import_scheme

ModuleNotFoundError: No module named 'tensorflow'

In [5]:
from torch_geometric.loader import DataLoader
import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
import numpy as np

In [6]:
device = 0
drop_ratio = 0.5
num_layer = 4
emb_dim = 64
batch_size = 64
epochs = 10
num_workers = 0
dataset_name = "ogbg-molpcba"
filename = ""

In [13]:
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

In [7]:
class GNN(torch.nn.Module):

    def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, 
                residual = False, 
                drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual)
        self.pool = global_mean_pool
        self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)
        h_graph = self.pool(h_node)
        return self.graph_pred_linear(h_graph)

### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        self.atom_encoder = AtomEncoder(emb_dim)
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for _ in range(num_layer):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### computing input node embedding

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):

            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)

            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation

### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GINConv, self).__init__(aggr = "add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

In [7]:
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

def train(model, device, loader, optimizer, task_type):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type: 
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()

def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)

In [8]:
import sys
sys.path.append("../scPRINT/scprint/model")

from EGT import EGTLayer
from collections import Counter

%reload_ext autoreload
%autoreload 2

In [15]:
class EGT(torch.nn.Module):

    def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, 
                drop_ratio = 0.5):
        super(EGT, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.model = EGT_node(num_layer, emb_dim, drop_ratio = drop_ratio)
        self.pool = global_mean_pool
        self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.model(batched_data)
        import pdb
        pdb.set_trace()
        h_graph = self.pool(h_node, batched_data.batch)
        return self.graph_pred_linear(h_graph)

class EGT_node(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, num_heads=4, edge_feat_size=12, drop_ratio = 0.5):
        super(EGT_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.edge_feat_size = edge_feat_size
        self.emb_dim = emb_dim
        
        ### add residual connection or not
        self.bond_encoder = BondEncoder(emb_dim = edge_feat_size)

        self.atom_encoder = AtomEncoder(emb_dim)
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()        

        for _ in range(num_layer):
            self.convs.append(EGTLayer(feat_size=emb_dim, inner_size=emb_dim, edge_feat_size=edge_feat_size, num_heads=num_heads, num_virtual_nodes=1, dropout=drop_ratio))


    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
        N = max(Counter(batch.tolist()).values())
        x = self.atom_encoder(x)
        ### computing input node embedding
        batch_x_list = x.split(list(Counter(batch.tolist()).values()))
        h = torch.stack([torch.cat([x, torch.zeros((N-x.size(0), self.emb_dim), dtype=torch.int64).to(device)]) for x in batch_x_list]).to(device)
        
        # do the same for the other ones
        attr = torch.zeros((batched_data.num_nodes, batched_data.num_nodes, self.edge_feat_size)).to(device)
        attr[edge_index[0], edge_index[1]] = self.bond_encoder(edge_attr)
        edge_attr = []
        mask = []
        
        for i,j in enumerate(batched_data.ptr[1:]):
            i = batched_data.ptr[i]
            e = torch.zeros((N, N, self.edge_feat_size)).to(device)
            e[:j-i, :j-i] = attr[i:j,i:j]
            edge_attr.append(e)
            m = torch.zeros((N, N)) - np.inf
            m[:j-i, :j-i] = torch.zeros((j-i, j-i))
            mask.append(m)
        edge_attr = torch.stack(edge_attr)
        mask = torch.stack(mask).to(device)

        for layer in range(self.num_layer):
            h, edge = self.convs[layer](h, edge, mask=mask)

        ### Different implementations of Jk-concat

        return h

In [16]:
for batch in train_loader:
    break

In [17]:
batch

DataBatch(edge_index=[2, 3522], edge_attr=[3522, 3], x=[1637, 9], y=[64, 128], num_nodes=1637, batch=[1637], ptr=[65])

In [22]:
N = max(Counter(batch.batch.tolist()).values())

In [24]:
batch = batch.to("cpu")

In [31]:
AtomEncoder(64)(h[0])

tensor([[ 0.4792,  0.5729,  0.7001,  ..., -0.9425, -0.1237,  0.1864],
        [ 0.7802,  0.4655,  0.7964,  ..., -0.8122, -0.0919,  0.3069],
        [ 0.4271,  0.4765,  0.5404,  ..., -0.6555, -0.3147,  0.0961],
        ...,
        [ 0.1348,  0.3829,  0.5062,  ..., -0.2764, -0.1937,  0.4553],
        [ 0.1348,  0.3829,  0.5062,  ..., -0.2764, -0.1937,  0.4553],
        [ 0.1348,  0.3829,  0.5062,  ..., -0.2764, -0.1937,  0.4553]],
       grad_fn=<AddBackward0>)

In [33]:
batch_x_list = batch.x.split(list(Counter(batch.batch.tolist()).values()))
h = torch.stack([torch.cat([AtomEncoder(emb_dim)(x), torch.zeros((N-x.size(0), emb_dim), dtype=torch.int64)]) for x in batch_x_list])

In [29]:
N

39

In [28]:
N

torch.Size([64, 39, 9])

In [11]:
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

In [14]:
split_idx = dataset.get_idx_split()

### automatic evaluator. takes dataset name as input
evaluator = Evaluator(dataset_name)

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True, num_workers = num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False, num_workers = num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False, num_workers = num_workers)


model = EGT(num_tasks = dataset.num_tasks, num_layer = num_layer, emb_dim = emb_dim, drop_ratio = drop_ratio).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

valid_curve = []
test_curve = []
train_curve = []

for epoch in range(1, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_loader, optimizer, dataset.task_type)

    print('Evaluating...')
    train_perf = eval(model, device, train_loader, evaluator)
    valid_perf = eval(model, device, valid_loader, evaluator)
    test_perf = eval(model, device, test_loader, evaluator)

    print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

    train_curve.append(train_perf[dataset.eval_metric])
    valid_curve.append(valid_perf[dataset.eval_metric])
    test_curve.append(test_perf[dataset.eval_metric])

if 'classification' in dataset.task_type:
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
else:
    best_val_epoch = np.argmin(np.array(valid_curve))
    best_train = min(train_curve)

print('Finished training!')
print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
print('Test score: {}'.format(test_curve[best_val_epoch]))

if not filename == '':
    torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrain': best_train}, filename)
    

=====Epoch 1
Training...


Iteration:   0%|          | 0/5475 [00:00<?, ?it/s]


UnboundLocalError: local variable 'edge' referenced before assignment

# I reached 0.11 with the previous model in the same number of epochs

In [7]:
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dgl.nn import EGTLayer
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.linkproppred import Evaluator

class EGT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(EGT, self).__init__()

        self.layers = torch.nn.ModuleList()
        self.layers.append(
            EGTLayer(feat_size=in_channels, edge_feat_size=1, num_heads=2, num_virtual_nodes=1, normalize=False))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, normalize=False))
        self.convs.append(
            GCNConv(hidden_channels, out_channels, normalize=False))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(
            GCNConv(in_channels, hidden_channels, normalize=False))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, normalize=False))
        self.convs.append(
            GCNConv(hidden_channels, out_channels, normalize=False))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


def train(model, predictor, data, split_edge, optimizer, batch_size):
    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(data.x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):

        optimizer.zero_grad()

        h = model(data.x, data.adj_t)

        edge = pos_train_edge[perm].t()
        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        # Just do some trivial random sampling.
        edge = torch.randint(0, data.num_nodes, edge.size(), dtype=torch.long,
                             device=h.device)

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, data, split_edge, evaluator, batch_size):
    model.eval()

    h = model(data.x, data.adj_t)

    pos_train_edge = split_edge['train']['edge'].to(h.device)
    pos_valid_edge = split_edge['valid']['edge'].to(h.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(h.device)
    pos_test_edge = split_edge['test']['edge'].to(h.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(h.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 50, 100]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results

In [8]:
def train(model, predictor, data, split_edge, optimizer, batch_size):
    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(data.x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):

        optimizer.zero_grad()

        h = model(data.x, data.adj_t)

        edge = pos_train_edge[perm].t()
        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        # Just do some trivial random sampling.
        edge = torch.randint(0, data.num_nodes, edge.size(), dtype=torch.long,
                             device=h.device)

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, data, split_edge, evaluator, batch_size):
    model.eval()

    h = model(data.x, data.adj_t)

    pos_train_edge = split_edge['train']['edge'].to(h.device)
    pos_valid_edge = split_edge['valid']['edge'].to(h.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(h.device)
    pos_test_edge = split_edge['test']['edge'].to(h.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(h.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 50, 100]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results

class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def reset(self, run):
        assert run >= 0 and run < len(self.results)
        self.results[run] = []

    def print_statistics(self, run=None):
        if run is not None:
            result = torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.4f}')
            print(f'Highest Valid: {result[:, 1].max():.4f}')
            print(f'  Final Train: {result[argmax, 0]:.4f}')
            print(f'   Final Test: {result[argmax, 2]:.4f}')
        else:
            result = torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)
            print(best_result)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.4f} ± {r.std():.4f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.4f} ± {r.std():.4f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.4f} ± {r.std():.4f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.4f} ± {r.std():.4f}')

# train

In [9]:
import attridict
args = {
    "log_steps" : 1,
    "runs" : 20,
    "lr" : 0.01,
    "epochs" : 4,
    "batch_size" : 64 * 1024,
    "eval_steps" : 2,
    "hidden_channel" : 256,
    "dropout" : 0.1,
    "num_layers" : 3,
}
args = attridict(args)

In [10]:
model = GCN(in_channels=data.num_features, hidden_channels=args.hidden_channel, out_channels=args.hidden_channel, num_layers=args.num_layers, dropout=args.dropout).to(device)

In [11]:
predictor = LinkPredictor(in_channels=args.hidden_channel, hidden_channels=args.hidden_channel, out_channels=1, num_layers=args.num_layers, dropout=0).to(device)

In [12]:
evaluator = Evaluator(name='ogbl-ppa')
loggers = {
    'Hits@10': Logger(args.runs, args),
    'Hits@50': Logger(args.runs, args),
    'Hits@100': Logger(args.runs, args),
}

In [13]:
for run in range(args.runs):
    model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(predictor.parameters()),
        lr=args.lr)

    for epoch in range(1, 1 + args.epochs):
        loss = train(model, predictor, data, split_edge, optimizer,
                        args.batch_size)

        if epoch % args.eval_steps == 0:
            results = test(model, predictor, data, split_edge, evaluator,
                            args.batch_size)
            for key, result in results.items():
                loggers[key].add_result(run, result)

            if epoch % args.log_steps == 0:
                for key, result in results.items():
                    train_hits, valid_hits, test_hits = result
                    print(key)
                    print(f'Run: {run + 1:02d}, '
                            f'Epoch: {epoch:02d}, '
                            f'Loss: {loss:.4f}, '
                            f'Train: {100 * train_hits:.2f}%, '
                            f'Valid: {100 * valid_hits:.2f}%, '
                            f'Test: {100 * test_hits:.2f}%')

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics(run)

for key in loggers.keys():
    print(key)
    loggers[key].print_statistics()

Hits@10
Run: 01, Epoch: 02, Loss: 0.2184, Train: 0.63%, Valid: 0.61%, Test: 0.83%
Hits@50
Run: 01, Epoch: 02, Loss: 0.2184, Train: 4.86%, Valid: 4.75%, Test: 5.32%
Hits@100
Run: 01, Epoch: 02, Loss: 0.2184, Train: 8.71%, Valid: 8.57%, Test: 9.61%
Hits@10
Run: 01, Epoch: 04, Loss: 0.1531, Train: 0.41%, Valid: 0.36%, Test: 0.84%
Hits@50
Run: 01, Epoch: 04, Loss: 0.1531, Train: 2.69%, Valid: 2.51%, Test: 3.18%
Hits@100
Run: 01, Epoch: 04, Loss: 0.1531, Train: 5.88%, Valid: 5.59%, Test: 5.45%
Hits@10
Run 01:
Highest Train: 0.0063
Highest Valid: 0.0061
  Final Train: 0.0063
   Final Test: 0.0083
Hits@50
Run 01:
Highest Train: 0.0486
Highest Valid: 0.0475
  Final Train: 0.0486
   Final Test: 0.0532
Hits@100
Run 01:
Highest Train: 0.0871
Highest Valid: 0.0857
  Final Train: 0.0871
   Final Test: 0.0961
Hits@10
Run: 02, Epoch: 02, Loss: 0.1996, Train: 0.85%, Valid: 0.81%, Test: 1.04%
Hits@50
Run: 02, Epoch: 02, Loss: 0.1996, Train: 6.52%, Valid: 6.38%, Test: 5.74%
Hits@100
Run: 02, Epoch: 02, 