In [2]:
from complex2.models.ComplEx2 import ComplEx2
from complex2.data.TriplesDataset import TriplesDataset 

import pandas as pd 
import numpy as np 
import torch 

from ogb.linkproppred import LinkPropPredDataset
from ogb.linkproppred import Evaluator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = LinkPropPredDataset(name = 'ogbl-biokg', root='./data/ogbl-biokg/')
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

  self.graph = torch.load(pre_processed_file_path, 'rb')
  train = torch.load(osp.join(path, 'train.pt'))
  valid = torch.load(osp.join(path, 'valid.pt'))
  test = torch.load(osp.join(path, 'test.pt'))


In [4]:
train_dataset = TriplesDataset(train_triples)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=4) 

In [5]:
# heterogenous pyg data object 
data = dataset[0]

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ComplEx2(data = data, 
                 hidden_channels=64,
                 scale_grad_by_freq=False, 
                 dtype=torch.float32,
                 dropout=0).to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)

In [7]:
for epoch in range(1): 

    model.train()
    tot_loss = 0.0
    for i, (pos_head, pos_tail, pos_relation) in enumerate(train_loader):
        optim.zero_grad() 
        
        nll = -model.forward(head        = pos_head.to(device), 
                            relation    = pos_relation.to(device),
                            tail        = pos_tail.to(device)).squeeze(-1)
        
        loss = nll.mean()
        loss.backward() 
        optim.step()
        
        tot_loss += loss.item()
        print(f'[batch: {i+1}/{len(train_loader)}] Loss: {loss.item():.2f}', end='\r')

    print(f"Epoch {epoch+1}, Loss: {tot_loss / len(train_loader):.2f}")


[batch: 8/4652] Loss: 18.50

KeyboardInterrupt: 

In [8]:
evaluator = Evaluator(name='ogbl-biokg')
print(evaluator.expected_input_format) 
print(evaluator.expected_output_format)

==== Expected input format of Evaluator for ogbl-biokg
{'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg}
- y_pred_pos: numpy ndarray or torch tensor of shape (num_edges, ). Torch tensor on GPU is recommended for efficiency.
- y_pred_neg: numpy ndarray or torch tensor of shape (num_edges, num_nodes_neg). Torch tensor on GPU is recommended for efficiency.
y_pred_pos is the predicted scores for positive edges.
y_pred_neg is the predicted scores for negative edges. It needs to be a 2d matrix.
y_pred_pos[i] is ranked among y_pred_neg[i].
Note: As the evaluation metric is ranking-based, the predicted scores need to be different for different edges.
==== Expected output format of Evaluator for ogbl-biokg
{'hits@1_list': hits@1_list, 'hits@3_list': hits@3_list, 
'hits@10_list': hits@10_list, 'mrr_list': mrr_list}
- mrr_list (list of float): list of scores for calculating MRR 
- hits@1_list (list of float): list of scores for calculating Hits@1 
- hits@3_list (list of float): list of scores 

In [None]:
test_dataset = TriplesDataset(test_triples)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=4) 

In [None]:
def eval_pos_neg(model, test_loader, n_neg=10, device=None):
    """
    Parameters
    ----------
    model : torch.nn.Module
        Your link‑prediction model.  Must expose `rel2type` and
        `data['num_nodes_dict']` exactly as in the original snippet.
    test_loader : DataLoader
        Yields (head, tail, relation) integer tensors.
    n_neg : int, default 100
        Number of negative tails sampled *per positive triple*.
    device : torch.device, optional
        Defaults to the model’s first parameter device.
    
    Returns
    -------
    y_pred_pos : torch.Tensor   # shape (N_pos,)
    y_pred_neg : torch.Tensor   # shape (N_pos, n_neg)
    """
    if device is None:
        device = next(model.parameters()).device

    model.eval()
    y_pred_pos, y_pred_neg = [], []

    with torch.no_grad():
        for batch_idx, (h_pos, t_pos, r_pos) in enumerate(test_loader):
            print(f'[batch: {batch_idx+1}/{len(test_loader)}] Processing batch...', end='\r')

            # ---------------- move to GPU ----------------------
            h_pos = h_pos.to(device, non_blocking=True)
            t_pos = t_pos.to(device, non_blocking=True)
            r_pos = r_pos.to(device, non_blocking=True)
            B = h_pos.size(0)

            # ---------------- positive scores ------------------
            y_pred_pos.append(model.forward(head=h_pos,
                                            relation=r_pos,
                                            tail=t_pos).cpu())

            # -------------- build *all* negatives --------------
            h_neg = h_pos.repeat_interleave(n_neg)   # (B·n_neg,)
            r_neg = r_pos.repeat_interleave(n_neg)   # (B·n_neg,)

            #   forward calls):
            t_neg_chunks = []
            for j in range(B):
                tail_type = model.rel2type[r_pos[j].item()][2] 
                n_tail_ents = model.data['num_nodes_dict'][tail_type]
                t_neg_chunks.append(
                    torch.randint(0, n_tail_ents, (n_neg,), device=device)
                )
            t_neg = torch.cat(t_neg_chunks, dim=0)               # (B·n_neg,)

            # ------------- one batched forward -----------------
            y_pred_neg.append(
                model.forward(head=h_neg, relation=r_neg, tail=t_neg)
                     .view(B, n_neg).cpu()                             # (B, n_neg)
            )

    y_pred_pos = torch.cat(y_pred_pos, dim=0).cpu()              # (N_pos,)
    y_pred_neg = torch.cat(y_pred_neg, dim=0).cpu()              # (N_pos, n_neg)

    return y_pred_pos, y_pred_neg

In [64]:
y_pred_pos, y_pred_neg = eval_pos_neg(
    model=model,
    test_loader=test_loader,
    n_neg=10,
    device=device
)

[batch: 160/160] Processing batch...

In [65]:
res_dict = evaluator.eval({'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg})

for k,v in res_dict.items():
    print(f'mean {k}: {np.mean(v.numpy()):.4f}')

mean hits@1_list: 0.4737
mean hits@3_list: 0.6363
mean hits@10_list: 0.9540
mean mrr_list: 0.5998


In [70]:
# eval by edge type 
edge_res_dict = {}
test_rels = test_triples['relation']
for rel in np.unique(test_rels):
    rel_mask = test_rels == rel
    h,r,t = model.rel2type[rel.item()]
    edge_res_dict[(h,r,t)] = {
        'mrr': res_dict['mrr_list'][rel_mask].mean(),
        'hits@1': res_dict['hits@1_list'][rel_mask].mean(),
        'hits@3': res_dict['hits@3_list'][rel_mask].mean(),
        'hits@10': res_dict['hits@10_list'][rel_mask].mean(),
        'n_test_obs': rel_mask.sum().item()
    }

In [71]:
for k in edge_res_dict:
    print(f"Edge type {k}:")
    for metric, value in edge_res_dict[k].items():
        print(f"\t{metric}: {value:.4f}")
    print()

Edge type ('disease', 'disease-protein', 'protein'):
	mrr: 0.5920
	hits@1: 0.4602
	hits@3: 0.6355
	hits@10: 0.9511
	n_test_obs: 3742.0000

Edge type ('drug', 'drug-disease', 'disease'):
	mrr: 0.5995
	hits@1: 0.4902
	hits@3: 0.6039
	hits@10: 0.9569
	n_test_obs: 255.0000

Edge type ('drug', 'drug-drug_acquired_metabolic_disease', 'drug'):
	mrr: 0.6010
	hits@1: 0.4728
	hits@3: 0.6413
	hits@10: 0.9625
	n_test_obs: 1762.0000

Edge type ('drug', 'drug-drug_bacterial_infectious_disease', 'drug'):
	mrr: 0.5895
	hits@1: 0.4690
	hits@3: 0.6143
	hits@10: 0.9516
	n_test_obs: 516.0000

Edge type ('drug', 'drug-drug_benign_neoplasm', 'drug'):
	mrr: 0.5998
	hits@1: 0.4781
	hits@3: 0.6311
	hits@10: 0.9419
	n_test_obs: 843.0000

Edge type ('drug', 'drug-drug_cancer', 'drug'):
	mrr: 0.6032
	hits@1: 0.4792
	hits@3: 0.6365
	hits@10: 0.9562
	n_test_obs: 1348.0000

Edge type ('drug', 'drug-drug_cardiovascular_system_disease', 'drug'):
	mrr: 0.6024
	hits@1: 0.4748
	hits@3: 0.6421
	hits@10: 0.9507
	n_test_obs