## Evaluating my KG on OGB dataset

In [1]:
import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from kg_model import KG_model

from ogb.linkproppred import Evaluator, PygLinkPropPredDataset

from pykeen.evaluation import RankBasedEvaluator
from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline 

  from .autonotebook import tqdm as notebook_tqdm


### Load and prepare data

In [2]:
dataset = PygLinkPropPredDataset(name='ogbl-ddi', transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t

SparseTensor(row=tensor([   0,    0,    0,  ..., 4266, 4266, 4266]),
             col=tensor([   4,    6,    7,  ..., 3953, 3972, 4014]),
             size=(4267, 4267), nnz=2135822, density=11.73%)

In [3]:
split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
train_edge

{'edge': tensor([[4039, 2424],
         [4039,  225],
         [4039, 3901],
         ...,
         [ 647,  708],
         [ 708,  338],
         [ 835, 3554]])}

In [4]:
print(valid_edge['edge_neg'].shape)
print(valid_edge['edge'].shape)

torch.Size([101882, 2])
torch.Size([133489, 2])


In [5]:
def convert_to_triples_factory(data):
    tf_data = TriplesFactory.from_labeled_triples(
        data[["head", "relation", "tail"]].values,
        create_inverse_triples=True,
        entity_to_id=None,
        relation_to_id=None,
        compact_id=False 
    )

    print(tf_data.mapped_triples)

    return tf_data

In [6]:
# add relation type - interacts with

train = train_edge['edge']
train = torch.tensor([[x[0], 0, x[1]] for x in train])
train_df = pd.DataFrame(train, columns=['head', 'relation', 'tail']).astype(str)

valid = valid_edge['edge']
valid = torch.tensor([[x[0], 0, x[1]] for x in valid])
valid_df = pd.DataFrame(valid, columns=['head', 'relation', 'tail']).astype(str)

valid_neg = valid_edge['edge_neg']
valid_neg = torch.tensor([[x[0], 0, x[1]] for x in valid_neg])

test = test_edge['edge']
test = torch.tensor([[x[0], 0, x[1]] for x in test])
test_df = pd.DataFrame(test, columns=['head', 'relation', 'tail']).astype(str)

test_neg = test_edge['edge_neg']
test_neg = torch.tensor([[x[0], 0, x[1]] for x in test_neg])

train_tf = convert_to_triples_factory(train_df)
valid_tf = convert_to_triples_factory(valid_df)
test_tf = convert_to_triples_factory(test_df)

tensor([[   0,    0,  667],
        [   0,    0, 1182],
        [   0,    0, 1280],
        ...,
        [4266,    0, 4250],
        [4266,    0, 4252],
        [4266,    0, 4260]])
tensor([[   0,    0,  729],
        [   1,    0,  681],
        [   1,    0,  768],
        ...,
        [3812,    0, 3722],
        [3812,    0, 3758],
        [3812,    0, 3802]])
tensor([[   0,    0,    3],
        [   0,    0,  185],
        [   0,    0,  187],
        ...,
        [1611,    0, 1562],
        [1611,    0, 1573],
        [1611,    0, 1601]])


### Train my KG model

In [None]:
model_kg = KG_model('TransE', train_tf, valid_tf, test_tf, 'ogb')
model_kg.set_params(20, 'Adam', RankBasedEvaluator, 'gpu')
print('Training...')
model_kg.train()
print('Training done')

### Compute scores for given triplets

In [8]:
# compute scores for positive and negative triplets 

batch_size = 512

n = train.size(0) // batch_size
pos_train_preds = []
for i in range(n+1):
    start_idx = i*batch_size
    end_idx = min((i+1)*batch_size, train.size(0))
    edge = train[start_idx:end_idx]
    pos_train_preds += [model_kg.trained_model.model.score_hrt(edge).squeeze().cpu().detach()]
pos_train_pred = torch.cat(pos_train_preds, dim=0)

n = valid.size(0) // batch_size
pos_valid_preds = []
for i in range(n+1):
    start_idx = i*batch_size
    end_idx = min((i+1)*batch_size, valid.size(0))
    edge = valid[start_idx:end_idx]
    pos_valid_preds += [model_kg.trained_model.model.score_hrt(edge).squeeze().cpu().detach()]
pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

n = valid_neg.size(0) // batch_size
neg_valid_preds = []
for i in range(n+1):
    start_idx = i*batch_size
    end_idx = min((i+1)*batch_size, valid_neg.size(0))
    edge = valid_neg[start_idx:end_idx]
    neg_valid_preds += [model_kg.trained_model.model.score_hrt(edge).squeeze().cpu().detach()]
neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

n = test.size(0) // batch_size
pos_test_preds = []
for i in range(n+1):
    start_idx = i*batch_size
    end_idx = min((i+1)*batch_size, test.size(0))
    edge = test[start_idx:end_idx]
    pos_test_preds += [model_kg.trained_model.model.score_hrt(edge).squeeze().cpu().detach()]
pos_test_pred = torch.cat(pos_test_preds, dim=0)

n = test_neg.size(0) // batch_size
neg_test_preds = []
for i in range(n+1):
    start_idx = i*batch_size
    end_idx = min((i+1)*batch_size, test_neg.size(0))
    edge = test_neg[start_idx:end_idx]
    neg_test_preds += [model_kg.trained_model.model.score_hrt(edge).squeeze().cpu().detach()]
neg_test_pred = torch.cat(neg_test_preds, dim=0)

### Evaluate my results

In [9]:
# Evaluate the coputed scores - hits@K

evaluator = Evaluator(name = 'ogbl-ddi')

results = {}
for K in [10, 20, 30]:
    evaluator.K = K
    train_hits = evaluator.eval({
        'y_pred_pos': pos_train_pred,
        'y_pred_neg': neg_valid_pred,
    })[f'hits@{K}']
    valid_hits = evaluator.eval({
        'y_pred_pos': pos_valid_pred,
        'y_pred_neg': neg_valid_pred,
    })[f'hits@{K}']
    test_hits = evaluator.eval({
        'y_pred_pos': pos_test_pred,
        'y_pred_neg': neg_test_pred,
    })[f'hits@{K}']
    
    results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)
    
    
for hits, result in results.items():
    print(hits)
#     print(result)
    train_hits, valid_hits, test_hits = result
    print(f'Train: {100 * train_hits:.2f}%')
    print(f'Valid: {100 * valid_hits:.2f}%')
    print(f'Test: {100 * test_hits:.2f}%')


Hits@10
Train: 0.01%
Valid: 0.01%
Test: 0.01%
Hits@20
Train: 0.02%
Valid: 0.02%
Test: 0.02%
Hits@30
Train: 0.03%
Valid: 0.03%
Test: 0.03%


In [31]:
model_kg.predict_head('2424', '0')

Leuprolide - decrease_adverse_effects:
      head_id head_label     score  in_training
1776     1776       2597 -4.225868         True
1787     1787       2606 -4.259010         True
1549     1549       2392 -4.277555         True
3626     3626       4261 -4.303725         True
4148     4148        892 -4.309386         True
1917     1917       2723 -4.312338        False
1929     1929       2734 -4.312355         True
1610     1610       2447 -4.334918         True
3955     3955        718 -4.335896        False
1236     1236        211 -4.366917         True


In [32]:
print(model_kg.trained_model.get_metric('hits@1'))
print(model_kg.trained_model.get_metric('hits@5'))
print(model_kg.trained_model.get_metric('hits@10'))

0.0005993003168800426
0.0036557319329682597
0.007348920135741521


### Optimizing parameters

In [9]:
from pykeen.hpo import hpo_pipeline_from_config

config = {
    'optuna': dict(
        n_trials=5,
    ),
    'pipeline': dict(
#         dataset='Nations',
        training = train_tf,
        testing = test_tf,
        validation = valid_tf,
        model='TransE',
        model_kwargs_ranges=dict(
               embedding_dim=dict(type=int, low=50, high=220, q=20),
        ),
        optimizer='Adam',
        optimizer_kwargs=dict(lr=0.01),
        loss='marginranking',
        loss_kwargs=dict(margin=1),
        training_loop='slcwa',
        training_kwargs=dict(num_epochs=20, batch_size=128),
        negative_sampler='basic',
        negative_sampler_kwargs=dict(num_negs_per_pos=1),
        evaluator_kwargs=dict(filtered=True),
        evaluation_kwargs=dict(batch_size=128),
        stopper='early',
        stopper_kwargs=dict(frequency=5, patience=2, relative_delta=0.002),
    )
}

In [10]:
hpo_pipeline_result = hpo_pipeline_from_config(config)

[32m[I 2023-01-30 08:57:20,367][0m A new study created in memory with name: no-name-c73c846a-6d4b-4b64-a2f9-afaca0b6b610[0m
No random seed is specified. Setting to 1486788349.
INFO:pykeen.triples.triples_factory:Creating inverse triples.
Training epochs on cuda:0:   0%|                                               | 0/20 [00:00<?, ?epoch/s]INFO:pykeen.triples.triples_factory:Creating inverse triples.

Training batches on cuda:0:   0%|                                           | 0/16687 [00:00<?, ?batch/s][A
Training batches on cuda:0:   0%|                                 | 14/16687 [00:00<02:00, 138.10batch/s][A
Training batches on cuda:0:   0%|                                 | 55/16687 [00:00<00:56, 296.44batch/s][A
Training batches on cuda:0:   1%|▏                               | 101/16687 [00:00<00:45, 367.95batch/s][A
Training batches on cuda:0:   1%|▎                               | 155/16687 [00:00<00:38, 433.99batch/s][A
Training batches on cuda:0:   1%|▍            

Training batches on cuda:0:  24%|███████▍                       | 4007/16687 [00:07<00:21, 580.37batch/s][A
Training batches on cuda:0:  24%|███████▌                       | 4067/16687 [00:07<00:21, 583.98batch/s][A
Training batches on cuda:0:  25%|███████▋                       | 4127/16687 [00:07<00:21, 586.76batch/s][A
Training batches on cuda:0:  25%|███████▊                       | 4187/16687 [00:07<00:21, 588.24batch/s][A
Training batches on cuda:0:  25%|███████▉                       | 4247/16687 [00:07<00:21, 589.53batch/s][A
Training batches on cuda:0:  26%|████████                       | 4307/16687 [00:07<00:20, 590.63batch/s][A
Training batches on cuda:0:  26%|████████                       | 4367/16687 [00:07<00:20, 591.18batch/s][A
Training batches on cuda:0:  27%|████████▏                      | 4427/16687 [00:07<00:20, 591.35batch/s][A
Training batches on cuda:0:  27%|████████▎                      | 4487/16687 [00:07<00:20, 592.11batch/s][A
Training batches on

Training batches on cuda:0:  50%|███████████████▌               | 8406/16687 [00:14<00:14, 590.05batch/s][A
Training batches on cuda:0:  51%|███████████████▋               | 8466/16687 [00:14<00:13, 590.07batch/s][A
Training batches on cuda:0:  51%|███████████████▊               | 8526/16687 [00:14<00:13, 590.26batch/s][A
Training batches on cuda:0:  51%|███████████████▉               | 8586/16687 [00:14<00:13, 590.19batch/s][A
Training batches on cuda:0:  52%|████████████████               | 8646/16687 [00:15<00:13, 590.20batch/s][A
Training batches on cuda:0:  52%|████████████████▏              | 8706/16687 [00:15<00:13, 590.10batch/s][A
Training batches on cuda:0:  53%|████████████████▎              | 8766/16687 [00:15<00:13, 590.07batch/s][A
Training batches on cuda:0:  53%|████████████████▍              | 8826/16687 [00:15<00:13, 590.10batch/s][A
Training batches on cuda:0:  53%|████████████████▌              | 8886/16687 [00:15<00:13, 590.01batch/s][A
Training batches on

Training batches on cuda:0:  77%|███████████████████████▏      | 12899/16687 [00:22<00:06, 541.16batch/s][A
Training batches on cuda:0:  78%|███████████████████████▎      | 12954/16687 [00:22<00:07, 527.48batch/s][A
Training batches on cuda:0:  78%|███████████████████████▍      | 13007/16687 [00:22<00:07, 518.27batch/s][A
Training batches on cuda:0:  78%|███████████████████████▍      | 13059/16687 [00:22<00:07, 509.19batch/s][A
Training batches on cuda:0:  79%|███████████████████████▌      | 13110/16687 [00:22<00:07, 509.36batch/s][A
Training batches on cuda:0:  79%|███████████████████████▋      | 13161/16687 [00:22<00:07, 503.69batch/s][A
Training batches on cuda:0:  79%|███████████████████████▊      | 13215/16687 [00:22<00:06, 513.92batch/s][A
Training batches on cuda:0:  80%|███████████████████████▊      | 13274/16687 [00:23<00:06, 536.18batch/s][A
Training batches on cuda:0:  80%|███████████████████████▉      | 13332/16687 [00:23<00:06, 546.95batch/s][A
Training batches on

Training batches on cuda:0:   3%|▉                               | 486/16687 [00:00<00:27, 580.37batch/s][A
Training batches on cuda:0:   3%|█                               | 545/16687 [00:01<00:27, 583.08batch/s][A
Training batches on cuda:0:   4%|█▏                              | 604/16687 [00:01<00:27, 584.77batch/s][A
Training batches on cuda:0:   4%|█▎                              | 663/16687 [00:01<00:27, 585.82batch/s][A
Training batches on cuda:0:   4%|█▍                              | 722/16687 [00:01<00:27, 586.74batch/s][A
Training batches on cuda:0:   5%|█▍                              | 781/16687 [00:01<00:27, 587.24batch/s][A
Training batches on cuda:0:   5%|█▌                              | 840/16687 [00:01<00:26, 587.97batch/s][A
Training batches on cuda:0:   5%|█▋                              | 899/16687 [00:01<00:26, 588.22batch/s][A
Training batches on cuda:0:   6%|█▊                              | 958/16687 [00:01<00:26, 588.32batch/s][A
Training batches on

Training batches on cuda:0:  30%|█████████▏                     | 4965/16687 [00:08<00:19, 591.41batch/s][A
Training batches on cuda:0:  30%|█████████▎                     | 5025/16687 [00:08<00:19, 591.38batch/s][A
Training batches on cuda:0:  30%|█████████▍                     | 5085/16687 [00:08<00:19, 591.21batch/s][A
Training batches on cuda:0:  31%|█████████▌                     | 5145/16687 [00:08<00:20, 575.26batch/s][A
Training batches on cuda:0:  31%|█████████▋                     | 5204/16687 [00:08<00:19, 578.91batch/s][A
Training batches on cuda:0:  32%|█████████▊                     | 5263/16687 [00:09<00:19, 581.86batch/s][A
Training batches on cuda:0:  32%|█████████▉                     | 5322/16687 [00:09<00:19, 583.85batch/s][A
Training batches on cuda:0:  32%|█████████▉                     | 5381/16687 [00:09<00:19, 585.24batch/s][A
Training batches on cuda:0:  33%|██████████                     | 5440/16687 [00:09<00:19, 586.36batch/s][A
Training batches on

Training batches on cuda:0:  57%|█████████████████▌             | 9446/16687 [00:16<00:12, 596.91batch/s][A
Training batches on cuda:0:  57%|█████████████████▋             | 9506/16687 [00:16<00:12, 595.68batch/s][A
Training batches on cuda:0:  57%|█████████████████▊             | 9566/16687 [00:16<00:11, 594.36batch/s][A
Training batches on cuda:0:  58%|█████████████████▉             | 9626/16687 [00:16<00:11, 593.40batch/s][A
Training batches on cuda:0:  58%|█████████████████▉             | 9686/16687 [00:16<00:11, 593.00batch/s][A
Training batches on cuda:0:  58%|██████████████████             | 9746/16687 [00:16<00:11, 592.76batch/s][A
Training batches on cuda:0:  59%|██████████████████▏            | 9806/16687 [00:16<00:11, 592.41batch/s][A
Training batches on cuda:0:  59%|██████████████████▎            | 9866/16687 [00:16<00:11, 586.23batch/s][A
Training batches on cuda:0:  59%|██████████████████▍            | 9926/16687 [00:16<00:11, 590.25batch/s][A
Training batches on

Training batches on cuda:0:  83%|█████████████████████████     | 13907/16687 [00:23<00:04, 590.40batch/s][A
Training batches on cuda:0:  84%|█████████████████████████     | 13967/16687 [00:23<00:04, 589.82batch/s][A
Training batches on cuda:0:  84%|█████████████████████████▏    | 14026/16687 [00:23<00:04, 586.34batch/s][A
Training batches on cuda:0:  84%|█████████████████████████▎    | 14085/16687 [00:24<00:04, 586.16batch/s][A
Training batches on cuda:0:  85%|█████████████████████████▍    | 14144/16687 [00:24<00:04, 587.30batch/s][A
Training batches on cuda:0:  85%|█████████████████████████▌    | 14206/16687 [00:24<00:04, 594.10batch/s][A
Training batches on cuda:0:  85%|█████████████████████████▋    | 14266/16687 [00:24<00:04, 593.76batch/s][A
Training batches on cuda:0:  86%|█████████████████████████▊    | 14326/16687 [00:24<00:04, 588.38batch/s][A
Training batches on cuda:0:  86%|█████████████████████████▊    | 14386/16687 [00:24<00:03, 589.83batch/s][A
Training batches on

Training batches on cuda:0:   9%|██▉                            | 1578/16687 [00:02<00:25, 599.32batch/s][A
Training batches on cuda:0:  10%|███                            | 1638/16687 [00:02<00:25, 597.10batch/s][A
Training batches on cuda:0:  10%|███▏                           | 1698/16687 [00:02<00:25, 595.77batch/s][A
Training batches on cuda:0:  11%|███▎                           | 1758/16687 [00:03<00:25, 594.61batch/s][A
Training batches on cuda:0:  11%|███▍                           | 1818/16687 [00:03<00:25, 593.69batch/s][A
Training batches on cuda:0:  11%|███▍                           | 1878/16687 [00:03<00:24, 593.03batch/s][A
Training batches on cuda:0:  12%|███▌                           | 1938/16687 [00:03<00:24, 594.70batch/s][A
Training batches on cuda:0:  12%|███▋                           | 2000/16687 [00:03<00:24, 599.98batch/s][A
Training batches on cuda:0:  12%|███▊                           | 2062/16687 [00:03<00:24, 603.17batch/s][A
Training batches on

Training batches on cuda:0:  36%|███████████▎                   | 6060/16687 [00:10<00:19, 557.01batch/s][A
Training batches on cuda:0:  37%|███████████▎                   | 6118/16687 [00:10<00:18, 561.96batch/s][A
Training batches on cuda:0:  37%|███████████▍                   | 6175/16687 [00:10<00:18, 559.38batch/s][A
Training batches on cuda:0:  37%|███████████▌                   | 6232/16687 [00:10<00:18, 560.41batch/s][A
Training batches on cuda:0:  38%|███████████▋                   | 6290/16687 [00:10<00:18, 565.54batch/s][A
Training batches on cuda:0:  38%|███████████▊                   | 6348/16687 [00:10<00:18, 567.80batch/s][A
Training batches on cuda:0:  38%|███████████▉                   | 6406/16687 [00:10<00:17, 571.35batch/s][A
Training batches on cuda:0:  39%|████████████                   | 6465/16687 [00:11<00:17, 574.10batch/s][A
Training batches on cuda:0:  39%|████████████                   | 6524/16687 [00:11<00:17, 576.36batch/s][A
Training batches on

Training batches on cuda:0:  63%|██████████████████▉           | 10500/16687 [00:17<00:10, 580.94batch/s][A
Training batches on cuda:0:  63%|██████████████████▉           | 10561/16687 [00:18<00:10, 587.07batch/s][A
Training batches on cuda:0:  64%|███████████████████           | 10621/16687 [00:18<00:10, 588.91batch/s][A
Training batches on cuda:0:  64%|███████████████████▏          | 10681/16687 [00:18<00:10, 591.61batch/s][A
Training batches on cuda:0:  64%|███████████████████▎          | 10741/16687 [00:18<00:10, 575.34batch/s][A
Training batches on cuda:0:  65%|███████████████████▍          | 10799/16687 [00:18<00:10, 568.19batch/s][A
Training batches on cuda:0:  65%|███████████████████▌          | 10860/16687 [00:18<00:10, 578.55batch/s][A
Training batches on cuda:0:  65%|███████████████████▋          | 10918/16687 [00:18<00:09, 577.60batch/s][A
Training batches on cuda:0:  66%|███████████████████▋          | 10976/16687 [00:18<00:09, 574.57batch/s][A
Training batches on

Training batches on cuda:0:  90%|██████████████████████████▊   | 14936/16687 [00:25<00:02, 588.32batch/s][A
Training batches on cuda:0:  90%|██████████████████████████▉   | 14995/16687 [00:25<00:02, 587.96batch/s][A
Training batches on cuda:0:  90%|███████████████████████████   | 15054/16687 [00:25<00:02, 574.91batch/s][A
Training batches on cuda:0:  91%|███████████████████████████▏  | 15112/16687 [00:25<00:02, 563.19batch/s][A
Training batches on cuda:0:  91%|███████████████████████████▎  | 15171/16687 [00:26<00:02, 568.33batch/s][A
Training batches on cuda:0:  91%|███████████████████████████▍  | 15233/16687 [00:26<00:02, 581.38batch/s][A
Training batches on cuda:0:  92%|███████████████████████████▍  | 15294/16687 [00:26<00:02, 588.74batch/s][A
Training batches on cuda:0:  92%|███████████████████████████▌  | 15353/16687 [00:26<00:02, 587.08batch/s][A
Training batches on cuda:0:  92%|███████████████████████████▋  | 15412/16687 [00:26<00:02, 585.90batch/s][A
Training batches on

Training batches on cuda:0:  15%|████▋                          | 2548/16687 [00:04<00:24, 582.58batch/s][A
Training batches on cuda:0:  16%|████▊                          | 2607/16687 [00:04<00:24, 582.82batch/s][A
Training batches on cuda:0:  16%|████▉                          | 2666/16687 [00:04<00:24, 582.70batch/s][A
Training batches on cuda:0:  16%|█████                          | 2725/16687 [00:04<00:24, 575.01batch/s][A
Training batches on cuda:0:  17%|█████▏                         | 2784/16687 [00:04<00:24, 578.19batch/s][A
Training batches on cuda:0:  17%|█████▎                         | 2842/16687 [00:04<00:24, 575.64batch/s][A
Training batches on cuda:0:  17%|█████▍                         | 2904/16687 [00:05<00:23, 586.73batch/s][A
Training batches on cuda:0:  18%|█████▌                         | 2966/16687 [00:05<00:23, 593.59batch/s][A
Training batches on cuda:0:  18%|█████▌                         | 3026/16687 [00:05<00:23, 592.54batch/s][A
Training batches on

Training batches on cuda:0:  42%|████████████▉                  | 6959/16687 [00:12<00:16, 588.27batch/s][A
Training batches on cuda:0:  42%|█████████████                  | 7019/16687 [00:12<00:16, 589.20batch/s][A
Training batches on cuda:0:  42%|█████████████▏                 | 7078/16687 [00:12<00:16, 580.85batch/s][A
Training batches on cuda:0:  43%|█████████████▎                 | 7137/16687 [00:12<00:16, 583.52batch/s][A
Training batches on cuda:0:  43%|█████████████▎                 | 7197/16687 [00:12<00:16, 585.99batch/s][A
Training batches on cuda:0:  43%|█████████████▍                 | 7256/16687 [00:12<00:16, 585.69batch/s][A
Training batches on cuda:0:  44%|█████████████▌                 | 7315/16687 [00:12<00:16, 574.06batch/s][A
Training batches on cuda:0:  44%|█████████████▋                 | 7373/16687 [00:12<00:16, 568.31batch/s][A
Training batches on cuda:0:  45%|█████████████▊                 | 7432/16687 [00:12<00:16, 574.42batch/s][A
Training batches on

Training batches on cuda:0:  68%|████████████████████▌         | 11408/16687 [00:19<00:09, 568.06batch/s][A
Training batches on cuda:0:  69%|████████████████████▌         | 11467/16687 [00:19<00:09, 572.45batch/s][A
Training batches on cuda:0:  69%|████████████████████▋         | 11526/16687 [00:19<00:08, 575.51batch/s][A
Training batches on cuda:0:  69%|████████████████████▊         | 11585/16687 [00:20<00:08, 577.62batch/s][A
Training batches on cuda:0:  70%|████████████████████▉         | 11644/16687 [00:20<00:08, 579.00batch/s][A
Training batches on cuda:0:  70%|█████████████████████         | 11703/16687 [00:20<00:08, 581.62batch/s][A
Training batches on cuda:0:  70%|█████████████████████▏        | 11762/16687 [00:20<00:08, 582.71batch/s][A
Training batches on cuda:0:  71%|█████████████████████▎        | 11821/16687 [00:20<00:08, 583.17batch/s][A
Training batches on cuda:0:  71%|█████████████████████▎        | 11880/16687 [00:20<00:08, 578.85batch/s][A
Training batches on

Training batches on cuda:0:  95%|████████████████████████████▍ | 15844/16687 [00:27<00:01, 585.24batch/s][A
Training batches on cuda:0:  95%|████████████████████████████▌ | 15903/16687 [00:27<00:01, 585.74batch/s][A
Training batches on cuda:0:  96%|████████████████████████████▋ | 15962/16687 [00:27<00:01, 586.29batch/s][A
Training batches on cuda:0:  96%|████████████████████████████▊ | 16021/16687 [00:27<00:01, 585.90batch/s][A
Training batches on cuda:0:  96%|████████████████████████████▉ | 16080/16687 [00:27<00:01, 580.87batch/s][A
Training batches on cuda:0:  97%|█████████████████████████████ | 16139/16687 [00:27<00:00, 581.00batch/s][A
Training batches on cuda:0:  97%|█████████████████████████████ | 16198/16687 [00:27<00:00, 581.05batch/s][A
Training batches on cuda:0:  97%|█████████████████████████████▏| 16257/16687 [00:28<00:00, 581.09batch/s][A
Training batches on cuda:0:  98%|█████████████████████████████▎| 16316/16687 [00:28<00:00, 581.28batch/s][A
Training batches on

Training batches on cuda:0:  20%|██████▎                        | 3410/16687 [00:06<00:23, 573.25batch/s][A
Training batches on cuda:0:  21%|██████▍                        | 3469/16687 [00:06<00:22, 575.88batch/s][A
Training batches on cuda:0:  21%|██████▌                        | 3528/16687 [00:06<00:22, 577.78batch/s][A
Training batches on cuda:0:  21%|██████▋                        | 3587/16687 [00:06<00:22, 579.24batch/s][A
Training batches on cuda:0:  22%|██████▊                        | 3646/16687 [00:06<00:22, 580.13batch/s][A
Training batches on cuda:0:  22%|██████▉                        | 3705/16687 [00:06<00:22, 580.58batch/s][A
Training batches on cuda:0:  23%|██████▉                        | 3764/16687 [00:06<00:22, 581.01batch/s][A
Training batches on cuda:0:  23%|███████                        | 3823/16687 [00:06<00:22, 581.31batch/s][A
Training batches on cuda:0:  23%|███████▏                       | 3882/16687 [00:06<00:22, 581.49batch/s][A
Training batches on

Training batches on cuda:0:  47%|██████████████▌                | 7824/16687 [00:13<00:14, 595.94batch/s][A
Training batches on cuda:0:  47%|██████████████▋                | 7885/16687 [00:13<00:14, 599.49batch/s][A
Training batches on cuda:0:  48%|██████████████▊                | 7945/16687 [00:13<00:14, 597.41batch/s][A
Training batches on cuda:0:  48%|██████████████▊                | 8005/16687 [00:13<00:14, 578.81batch/s][A
Training batches on cuda:0:  48%|██████████████▉                | 8064/16687 [00:14<00:14, 578.39batch/s][A
Training batches on cuda:0:  49%|███████████████                | 8122/16687 [00:14<00:14, 577.58batch/s][A
Training batches on cuda:0:  49%|███████████████▏               | 8180/16687 [00:14<00:15, 564.14batch/s][A
Training batches on cuda:0:  49%|███████████████▎               | 8238/16687 [00:14<00:14, 566.96batch/s][A
Training batches on cuda:0:  50%|███████████████▍               | 8296/16687 [00:14<00:14, 570.38batch/s][A
Training batches on

Training batches on cuda:0:  73%|█████████████████████▉        | 12236/16687 [00:21<00:07, 580.21batch/s][A
Training batches on cuda:0:  74%|██████████████████████        | 12295/16687 [00:21<00:07, 582.50batch/s][A
Training batches on cuda:0:  74%|██████████████████████▏       | 12354/16687 [00:21<00:07, 583.42batch/s][A
Training batches on cuda:0:  74%|██████████████████████▎       | 12413/16687 [00:21<00:07, 584.73batch/s][A
Training batches on cuda:0:  75%|██████████████████████▍       | 12472/16687 [00:21<00:07, 577.06batch/s][A
Training batches on cuda:0:  75%|██████████████████████▌       | 12531/16687 [00:21<00:07, 579.04batch/s][A
Training batches on cuda:0:  75%|██████████████████████▋       | 12590/16687 [00:21<00:07, 579.37batch/s][A
Training batches on cuda:0:  76%|██████████████████████▋       | 12649/16687 [00:22<00:06, 581.45batch/s][A
Training batches on cuda:0:  76%|██████████████████████▊       | 12708/16687 [00:22<00:06, 575.39batch/s][A
Training batches on

Training batches on cuda:0: 100%|█████████████████████████████▉| 16673/16687 [00:28<00:00, 581.90batch/s][A
Training epochs on cuda:0:  20%|██▊           | 4/20 [02:24<07:42, 28.89s/epoch, loss=0.7, prev_loss=0.7][AINFO:pykeen.evaluation.evaluator:Evaluation took 44.90s seconds
INFO:pykeen.stoppers.early_stopping:New best result at epoch 5: 0.00323996733813273. Saved model weights to /work/.data/pykeen/checkpoints/best-model-weights-9869f37a-379e-45f6-bc48-57d0b6cd18af.pt
INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 5.
Training epochs on cuda:0:  25%|███▌          | 5/20 [03:09<11:18, 45.20s/epoch, loss=0.7, prev_loss=0.7]
Training batches on cuda:0:   0%|                                           | 0/16687 [00:00<?, ?batch/s][A
Training batches on cuda:0:   0%|                                 | 13/16687 [00:00<02:10, 128.00batch/s][A
Training batches on cuda:0:   0%|▏                                | 69/16687 [00:00<00:43, 380.22batch/s][A
Tr

Training batches on cuda:0:  24%|███████▎                       | 3956/16687 [00:07<00:22, 571.28batch/s][A
Training batches on cuda:0:  24%|███████▍                       | 4014/16687 [00:07<00:22, 570.20batch/s][A
Training batches on cuda:0:  24%|███████▌                       | 4073/16687 [00:07<00:21, 574.31batch/s][A
Training batches on cuda:0:  25%|███████▋                       | 4131/16687 [00:07<00:21, 575.35batch/s][A
Training batches on cuda:0:  25%|███████▊                       | 4191/16687 [00:07<00:21, 579.89batch/s][A
Training batches on cuda:0:  25%|███████▉                       | 4252/16687 [00:07<00:21, 587.04batch/s][A
Training batches on cuda:0:  26%|████████                       | 4311/16687 [00:07<00:21, 587.53batch/s][A
Training batches on cuda:0:  26%|████████                       | 4370/16687 [00:07<00:20, 586.72batch/s][A
Training batches on cuda:0:  27%|████████▏                      | 4429/16687 [00:07<00:20, 585.86batch/s][A
Training batches on

Training batches on cuda:0:  50%|███████████████▌               | 8368/16687 [00:14<00:14, 583.27batch/s][A
Training batches on cuda:0:  51%|███████████████▋               | 8427/16687 [00:14<00:14, 585.14batch/s][A
Training batches on cuda:0:  51%|███████████████▊               | 8486/16687 [00:14<00:13, 586.24batch/s][A
Training batches on cuda:0:  51%|███████████████▊               | 8545/16687 [00:14<00:13, 587.30batch/s][A
Training batches on cuda:0:  52%|███████████████▉               | 8604/16687 [00:15<00:13, 587.55batch/s][A
Training batches on cuda:0:  52%|████████████████               | 8663/16687 [00:15<00:14, 566.24batch/s][A
Training batches on cuda:0:  52%|████████████████▏              | 8721/16687 [00:15<00:14, 568.00batch/s][A
Training batches on cuda:0:  53%|████████████████▎              | 8778/16687 [00:15<00:13, 567.01batch/s][A
Training batches on cuda:0:  53%|████████████████▍              | 8838/16687 [00:15<00:13, 575.57batch/s][A
Training batches on

Training batches on cuda:0:  77%|██████████████████████▉       | 12774/16687 [00:22<00:06, 572.68batch/s][A
Training batches on cuda:0:  77%|███████████████████████       | 12832/16687 [00:22<00:06, 565.46batch/s][A
Training batches on cuda:0:  77%|███████████████████████▏      | 12891/16687 [00:22<00:06, 572.17batch/s][A
Training batches on cuda:0:  78%|███████████████████████▎      | 12950/16687 [00:22<00:06, 577.20batch/s][A
Training batches on cuda:0:  78%|███████████████████████▍      | 13009/16687 [00:22<00:06, 580.76batch/s][A
Training batches on cuda:0:  78%|███████████████████████▍      | 13068/16687 [00:22<00:06, 583.18batch/s][A
Training batches on cuda:0:  79%|███████████████████████▌      | 13127/16687 [00:22<00:06, 584.96batch/s][A
Training batches on cuda:0:  79%|███████████████████████▋      | 13186/16687 [00:23<00:05, 586.26batch/s][A
Training batches on cuda:0:  79%|███████████████████████▊      | 13245/16687 [00:23<00:05, 587.18batch/s][A
Training batches on

Training batches on cuda:0:   2%|▋                               | 363/16687 [00:00<00:29, 561.03batch/s][A
Training batches on cuda:0:   3%|▊                               | 421/16687 [00:00<00:28, 566.99batch/s][A
Training batches on cuda:0:   3%|▉                               | 479/16687 [00:00<00:28, 570.75batch/s][A
Training batches on cuda:0:   3%|█                               | 537/16687 [00:01<00:28, 573.53batch/s][A
Training batches on cuda:0:   4%|█▏                              | 595/16687 [00:01<00:27, 574.82batch/s][A
Training batches on cuda:0:   4%|█▎                              | 653/16687 [00:01<00:27, 575.73batch/s][A
Training batches on cuda:0:   4%|█▎                              | 711/16687 [00:01<00:27, 574.98batch/s][A
Training batches on cuda:0:   5%|█▍                              | 769/16687 [00:01<00:27, 575.74batch/s][A
Training batches on cuda:0:   5%|█▌                              | 827/16687 [00:01<00:27, 576.81batch/s][A
Training batches on

Training batches on cuda:0:  28%|████████▊                      | 4730/16687 [00:08<00:20, 581.12batch/s][A
Training batches on cuda:0:  29%|████████▉                      | 4789/16687 [00:08<00:20, 581.44batch/s][A
Training batches on cuda:0:  29%|█████████                      | 4848/16687 [00:08<00:20, 581.17batch/s][A
Training batches on cuda:0:  29%|█████████                      | 4907/16687 [00:08<00:20, 566.01batch/s][A
Training batches on cuda:0:  30%|█████████▏                     | 4964/16687 [00:08<00:21, 550.27batch/s][A
Training batches on cuda:0:  30%|█████████▎                     | 5023/16687 [00:08<00:20, 559.32batch/s][A
Training batches on cuda:0:  30%|█████████▍                     | 5081/16687 [00:08<00:20, 563.21batch/s][A
Training batches on cuda:0:  31%|█████████▌                     | 5138/16687 [00:09<00:20, 558.18batch/s][A
Training batches on cuda:0:  31%|█████████▋                     | 5196/16687 [00:09<00:20, 561.81batch/s][A
Training batches on

Training batches on cuda:0:  55%|████████████████▉              | 9104/16687 [00:16<00:13, 550.28batch/s][A
Training batches on cuda:0:  55%|█████████████████              | 9161/16687 [00:16<00:13, 554.61batch/s][A
Training batches on cuda:0:  55%|█████████████████▏             | 9222/16687 [00:16<00:13, 569.15batch/s][A
Training batches on cuda:0:  56%|█████████████████▏             | 9282/16687 [00:16<00:12, 577.17batch/s][A
Training batches on cuda:0:  56%|█████████████████▎             | 9340/16687 [00:16<00:12, 574.87batch/s][A
Training batches on cuda:0:  56%|█████████████████▍             | 9398/16687 [00:16<00:12, 562.38batch/s][A
Training batches on cuda:0:  57%|█████████████████▌             | 9456/16687 [00:16<00:12, 566.27batch/s][A
Training batches on cuda:0:  57%|█████████████████▋             | 9515/16687 [00:16<00:12, 571.31batch/s][A
Training batches on cuda:0:  57%|█████████████████▊             | 9574/16687 [00:16<00:12, 575.63batch/s][A
Training batches on

Training batches on cuda:0:  81%|████████████████████████▎     | 13499/16687 [00:23<00:05, 568.41batch/s][A
Training batches on cuda:0:  81%|████████████████████████▎     | 13557/16687 [00:23<00:05, 571.23batch/s][A
Training batches on cuda:0:  82%|████████████████████████▍     | 13615/16687 [00:23<00:05, 573.07batch/s][A
Training batches on cuda:0:  82%|████████████████████████▌     | 13673/16687 [00:23<00:05, 570.59batch/s][A
Training batches on cuda:0:  82%|████████████████████████▋     | 13732/16687 [00:24<00:05, 574.95batch/s][A
Training batches on cuda:0:  83%|████████████████████████▊     | 13790/16687 [00:24<00:05, 561.96batch/s][A
Training batches on cuda:0:  83%|████████████████████████▉     | 13851/16687 [00:24<00:04, 574.26batch/s][A
Training batches on cuda:0:  83%|█████████████████████████     | 13909/16687 [00:24<00:04, 573.28batch/s][A
Training batches on cuda:0:  84%|█████████████████████████     | 13968/16687 [00:24<00:04, 576.46batch/s][A
Training batches on

Training batches on cuda:0:   6%|█▉                             | 1059/16687 [00:01<00:27, 576.69batch/s][A
Training batches on cuda:0:   7%|██                             | 1117/16687 [00:02<00:27, 576.45batch/s][A
Training batches on cuda:0:   7%|██▏                            | 1176/16687 [00:02<00:26, 578.54batch/s][A
Training batches on cuda:0:   7%|██▎                            | 1234/16687 [00:02<00:26, 574.30batch/s][A
Training batches on cuda:0:   8%|██▍                            | 1294/16687 [00:02<00:26, 581.18batch/s][A
Training batches on cuda:0:   8%|██▌                            | 1353/16687 [00:02<00:26, 581.73batch/s][A
Training batches on cuda:0:   8%|██▌                            | 1412/16687 [00:02<00:27, 561.44batch/s][A
Training batches on cuda:0:   9%|██▋                            | 1469/16687 [00:02<00:27, 551.60batch/s][A
Training batches on cuda:0:   9%|██▊                            | 1525/16687 [00:02<00:27, 545.92batch/s][A
Training batches on

Training batches on cuda:0:  33%|██████████                     | 5441/16687 [00:09<00:19, 572.33batch/s][A
Training batches on cuda:0:  33%|██████████▏                    | 5499/16687 [00:09<00:19, 574.04batch/s][A
Training batches on cuda:0:  33%|██████████▎                    | 5557/16687 [00:09<00:19, 562.08batch/s][A
Training batches on cuda:0:  34%|██████████▍                    | 5615/16687 [00:09<00:19, 565.21batch/s][A
Training batches on cuda:0:  34%|██████████▌                    | 5674/16687 [00:10<00:19, 569.63batch/s][A
Training batches on cuda:0:  34%|██████████▋                    | 5732/16687 [00:10<00:19, 572.66batch/s][A
Training batches on cuda:0:  35%|██████████▊                    | 5790/16687 [00:10<00:18, 574.65batch/s][A
Training batches on cuda:0:  35%|██████████▊                    | 5848/16687 [00:10<00:18, 576.18batch/s][A
Training batches on cuda:0:  35%|██████████▉                    | 5907/16687 [00:10<00:18, 580.11batch/s][A
Training batches on

Training batches on cuda:0:  59%|██████████████████▎            | 9886/16687 [00:17<00:11, 578.23batch/s][A
Training batches on cuda:0:  60%|██████████████████▍            | 9944/16687 [00:17<00:11, 578.64batch/s][A
Training batches on cuda:0:  60%|█████████████████▉            | 10002/16687 [00:17<00:11, 578.60batch/s][A
Training batches on cuda:0:  60%|██████████████████            | 10060/16687 [00:17<00:11, 578.78batch/s][A
Training batches on cuda:0:  61%|██████████████████▏           | 10119/16687 [00:17<00:11, 579.18batch/s][A
Training batches on cuda:0:  61%|██████████████████▎           | 10178/16687 [00:17<00:11, 579.71batch/s][A
Training batches on cuda:0:  61%|██████████████████▍           | 10236/16687 [00:17<00:11, 571.91batch/s][A
Training batches on cuda:0:  62%|██████████████████▌           | 10294/16687 [00:17<00:11, 561.99batch/s][A
Training batches on cuda:0:  62%|██████████████████▌           | 10353/16687 [00:18<00:11, 567.92batch/s][A
Training batches on

Training batches on cuda:0:  85%|█████████████████████████▋    | 14260/16687 [00:24<00:04, 592.42batch/s][A
Training batches on cuda:0:  86%|█████████████████████████▋    | 14320/16687 [00:25<00:04, 588.38batch/s][A
Training batches on cuda:0:  86%|█████████████████████████▊    | 14379/16687 [00:25<00:03, 585.95batch/s][A
Training batches on cuda:0:  87%|█████████████████████████▉    | 14438/16687 [00:25<00:03, 584.12batch/s][A
Training batches on cuda:0:  87%|██████████████████████████    | 14497/16687 [00:25<00:03, 582.98batch/s][A
Training batches on cuda:0:  87%|██████████████████████████▏   | 14556/16687 [00:25<00:03, 582.34batch/s][A
Training batches on cuda:0:  88%|██████████████████████████▎   | 14615/16687 [00:25<00:03, 581.86batch/s][A
Training batches on cuda:0:  88%|██████████████████████████▍   | 14674/16687 [00:25<00:03, 579.46batch/s][A
Training batches on cuda:0:  88%|██████████████████████████▍   | 14732/16687 [00:25<00:03, 578.35batch/s][A
Training batches on

Training batches on cuda:0:  11%|███▍                           | 1824/16687 [00:03<00:25, 580.58batch/s][A
Training batches on cuda:0:  11%|███▍                           | 1883/16687 [00:03<00:25, 580.50batch/s][A
Training batches on cuda:0:  12%|███▌                           | 1942/16687 [00:03<00:25, 580.29batch/s][A
Training batches on cuda:0:  12%|███▋                           | 2001/16687 [00:03<00:25, 580.82batch/s][A
Training batches on cuda:0:  12%|███▊                           | 2060/16687 [00:03<00:25, 580.92batch/s][A
Training batches on cuda:0:  13%|███▉                           | 2119/16687 [00:03<00:25, 580.15batch/s][A
Training batches on cuda:0:  13%|████                           | 2178/16687 [00:03<00:25, 580.31batch/s][A
Training batches on cuda:0:  13%|████▏                          | 2237/16687 [00:03<00:24, 580.65batch/s][A
Training batches on cuda:0:  14%|████▎                          | 2296/16687 [00:04<00:24, 580.94batch/s][A
Training batches on

Training batches on cuda:0:  37%|███████████▌                   | 6232/16687 [00:10<00:17, 584.72batch/s][A
Training batches on cuda:0:  38%|███████████▋                   | 6291/16687 [00:11<00:17, 583.64batch/s][A
Training batches on cuda:0:  38%|███████████▊                   | 6351/16687 [00:11<00:17, 587.24batch/s][A
Training batches on cuda:0:  38%|███████████▉                   | 6410/16687 [00:11<00:17, 583.63batch/s][A
Training batches on cuda:0:  39%|████████████                   | 6469/16687 [00:11<00:17, 579.89batch/s][A
Training batches on cuda:0:  39%|████████████▏                  | 6528/16687 [00:11<00:17, 582.47batch/s][A
Training batches on cuda:0:  39%|████████████▏                  | 6587/16687 [00:11<00:17, 573.48batch/s][A
Training batches on cuda:0:  40%|████████████▎                  | 6645/16687 [00:11<00:17, 563.72batch/s][A
Training batches on cuda:0:  40%|████████████▍                  | 6702/16687 [00:11<00:17, 561.45batch/s][A
Training batches on

Training batches on cuda:0:  64%|███████████████████           | 10617/16687 [00:18<00:10, 576.75batch/s][A
Training batches on cuda:0:  64%|███████████████████▏          | 10675/16687 [00:18<00:10, 573.56batch/s][A
Training batches on cuda:0:  64%|███████████████████▎          | 10733/16687 [00:18<00:10, 557.19batch/s][A
Training batches on cuda:0:  65%|███████████████████▍          | 10792/16687 [00:18<00:10, 565.24batch/s][A
Training batches on cuda:0:  65%|███████████████████▌          | 10849/16687 [00:19<00:10, 555.68batch/s][A
Training batches on cuda:0:  65%|███████████████████▌          | 10907/16687 [00:19<00:10, 560.21batch/s][A
Training batches on cuda:0:  66%|███████████████████▋          | 10964/16687 [00:19<00:10, 557.30batch/s][A
Training batches on cuda:0:  66%|███████████████████▊          | 11023/16687 [00:19<00:10, 565.99batch/s][A
Training batches on cuda:0:  66%|███████████████████▉          | 11081/16687 [00:19<00:09, 570.06batch/s][A
Training batches on

KeyboardInterrupt: 

In [None]:
hpo_pipeline_result.save_to_directory('hpo_results')

### Example from OGB

In [22]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

In [23]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

In [24]:
def train(model, predictor, x, adj_t, split_edge, optimizer, batch_size):

    row, col, _ = adj_t.coo()
    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t)
#         print('h:', h)

        edge = pos_train_edge[perm].t()
#         print('Train: edge:', edge)
#         print()
#         print('h[edge[0]]:', h[edge[0]])
#         print()
#         print('h[edge[1]]:', h[edge[1]])
#         print()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()
#         print('pos out:', pos_out)

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples

In [29]:
@torch.no_grad()
def test(model, predictor, x, adj_t, split_edge, evaluator, batch_size):
#     print('test')
    
    model.eval()
    predictor.eval()

    h = model(x, adj_t)

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)
    
#     print('pos_train_pred:', pos_train_pred)
#     print('neg_train_pred:', neg_valid_pred)
#     print()

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results

In [30]:
hidden_channels = 256
num_layers = 2
dropout = 0.5
runs = 4
lr = 0.005
batch_size = 64 * 1024
epochs = 5
log_steps = 1
eval_steps = 1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

dataset = PygLinkPropPredDataset(name='ogbl-ddi', transform=T.ToSparseTensor())
data = dataset[0]
adj_t = data.adj_t.to(device)

split_edge = dataset.get_edge_split()

# We randomly pick some training samples that we want to evaluate on:
torch.manual_seed(12345)
idx = torch.randperm(split_edge['train']['edge'].size(0))
idx = idx[:split_edge['valid']['edge'].size(0)]
split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}


model = GCN(hidden_channels, hidden_channels,
                hidden_channels, num_layers,
                dropout).to(device)

emb = torch.nn.Embedding(data.adj_t.size(0),
                         hidden_channels).to(device)

print('Embedding:', emb)
print()
predictor = LinkPredictor(hidden_channels, hidden_channels, 1,
                          num_layers, dropout).to(device)

evaluator = Evaluator(name='ogbl-ddi')
# loggers = {
#     'Hits@10': Logger(args.runs, args),
#     'Hits@20': Logger(args.runs, args),
#     'Hits@30': Logger(args.runs, args),
# }

for run in range(runs):
    torch.nn.init.xavier_uniform_(emb.weight)
#     print('Weights:', emb.weight)
#     print()
    model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(emb.parameters()) +
        list(predictor.parameters()), lr=lr)

    for epoch in range(1, 1 + epochs):
        loss = train(model, predictor, emb.weight, adj_t, split_edge,
                     optimizer, batch_size)

        if epoch % eval_steps == 0:
#             print('Eval')
            results = test(model, predictor, emb.weight, adj_t, split_edge,
                           evaluator, batch_size)
#             for key, result in results.items():
#                 loggers[key].add_result(run, result)

            if epoch % log_steps == 0:
                for key, result in results.items():
                    train_hits, valid_hits, test_hits = result
                    print(key)
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Train: {100 * train_hits:.2f}%, '
                          f'Valid: {100 * valid_hits:.2f}%, '
                          f'Test: {100 * test_hits:.2f}%')
                print('---')

#     for key in loggers.keys():
#         print(key)
#         loggers[key].print_statistics(run)

# for key in loggers.keys():
#     print(key)
#     loggers[key].print_statistics()

Embedding: Embedding(4267, 256)

Hits@10
Run: 01, Epoch: 01, Loss: 1.2921, Train: 0.03%, Valid: 0.02%, Test: 0.01%
Hits@20
Run: 01, Epoch: 01, Loss: 1.2921, Train: 3.96%, Valid: 3.64%, Test: 2.58%
Hits@30
Run: 01, Epoch: 01, Loss: 1.2921, Train: 4.55%, Valid: 4.17%, Test: 3.92%
---
Hits@10
Run: 01, Epoch: 02, Loss: 0.9923, Train: 3.10%, Valid: 2.87%, Test: 5.68%
Hits@20
Run: 01, Epoch: 02, Loss: 0.9923, Train: 4.74%, Valid: 4.37%, Test: 6.92%
Hits@30
Run: 01, Epoch: 02, Loss: 0.9923, Train: 5.45%, Valid: 5.01%, Test: 7.74%
---
Hits@10
Run: 01, Epoch: 03, Loss: 0.8506, Train: 0.26%, Valid: 0.24%, Test: 0.03%
Hits@20
Run: 01, Epoch: 03, Loss: 0.8506, Train: 0.56%, Valid: 0.51%, Test: 0.08%
Hits@30
Run: 01, Epoch: 03, Loss: 0.8506, Train: 0.75%, Valid: 0.70%, Test: 0.15%
---
Hits@10
Run: 01, Epoch: 04, Loss: 0.7528, Train: 2.34%, Valid: 2.15%, Test: 3.71%
Hits@20
Run: 01, Epoch: 04, Loss: 0.7528, Train: 3.19%, Valid: 2.99%, Test: 4.94%
Hits@30
Run: 01, Epoch: 04, Loss: 0.7528, Train: 4.13