In [None]:
import os
import sys
import math
import pprint
import random
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from tqdm import tqdm

import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
from torch import distributed as dist
from torch.utils import data as torch_data
from torch_geometric.data import Data

from nbfnet import tasks, util

In [None]:
vars = {'gpus': [0], 'version': 'v1'}

# load the NBFNet config
cfg = util.load_config("config/inductive/fb15k237.yaml", context=vars)
working_dir = util.create_working_directory(cfg)

logger = util.get_root_logger()
is_inductive = cfg.dataset["class"].startswith("Ind")
dataset = util.build_dataset(cfg)
cfg.model.num_relation = dataset.num_relations
model = util.build_model(cfg)

device = util.get_device(cfg)
model = model.to(device)
train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
train_data = train_data.to(device)
valid_data = valid_data.to(device)
test_data = test_data.to(device)

filtered_data = None

# Redirect this path to the trained NBFNet model.
state = torch.load("path-to-NBFNet-model", map_location=device)
model.load_state_dict(state["model"])

In [None]:
def get_mrr(g, ge):
    mrr = 0.0
    h1 = 0.0
    h3 = 0.0
    h10 = 0.0
    for i in range(g, ge+1):
        mrr += 1.0 / i
        if i <= 1:
            h1 += 1
        if i <= 3:
            h3 += 1
        if i <= 10:
            h10 += 1
    return mrr / (ge-g+1), h1 / (ge-g+1), h3 / (ge-g+1), h10 / (ge-g+1)

In [None]:
test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
test_ture_triple = torch.cat([test_data.edge_index, test_data.edge_type.unsqueeze(0)]).t()
test_ture_triple = torch.cat([test_ture_triple, test_triplets], dim=0)
test_edge_index = test_data.edge_index
test_edge_type = test_data.edge_type
query = test_triplets

tmp = set()
for s,t,r in test_ture_triple.tolist():
    tmp.add((s,r,t))
test_ture_triple = tmp

tmp = set()
for s,t,r in query.tolist():
    tmp.add((s,r,t))
query = tmp

test_N = test_data.num_nodes
offset = int(dataset.num_relations / 2)
bsz = 128

In [None]:
with torch.no_grad():
    results = list()
    r_results = list()
    rank1 = list()
    rank2 = list()

    pred = torch.zeros((test_N,test_N,offset))

    for r in range(offset):
        with tqdm(range(0,test_N,bsz), ncols=80) as _tqdm:
            _tqdm.set_description_str(f'{r}')
            for batch_st in _tqdm:
                batch_ed = min(test_N, batch_st + bsz)
                all_index = torch.arange(test_N)
                h_index = torch.arange(batch_st, batch_ed)
                h_index, t_index = torch.meshgrid(h_index, all_index)
                r_index = r * torch.ones_like(h_index)
                batch = torch.stack([h_index, t_index, r_index], dim=-1)
                pred0 = model(test_data, batch)
                pred[batch_st:batch_ed,:,r] = pred0
            torch.cuda.empty_cache()

In [None]:
cnt = 0
auc_pred = list()
auc_target = list()
with torch.no_grad():
    for s, r, t in query:
        p = pred[s,t,:].clone()
        test_d = len(p)
        test_target = p[r].clone()
        test_mask = torch.ones(test_d,dtype=torch.bool)
        for rr in range(test_d):
            if (s, rr, t) in test_ture_triple:
                test_mask[rr] = 0
            # if (t, rr, s) in test_ture_triple:
            #     test_mask[rr+offset] = 0
        p = p[test_mask]
        index = torch.LongTensor(random.sample(range(len(p)), min(len(p),50)))
        p = p[index]
        g = (test_target < p).sum()
        rank1.append(g)
        ge = (test_target <= p).sum()
        rank2.append(ge)
        r_results.append(get_mrr(g+1, ge+1))

        p = pred[s,:,r]
        test_target = p[t]
        auc_pred.append(test_target.sigmoid().item())
        # acc_threshold.append(model.outlin.bias[r].sigmoid().item())
        auc_target.append(1)
        test_mask = torch.ones(test_N,dtype=torch.bool)
        for tt in range(test_N):
            if (s,r,tt) in test_ture_triple:
                test_mask[tt] = 0
        p = p[test_mask]
        index = torch.LongTensor(random.sample(range(len(p)), min(len(p),50)))
        p = p[index]
        g = (test_target < p).sum()
        rank1.append(g)
        ge = (test_target <= p).sum()
        rank2.append(ge)
        results.append(get_mrr(g+1, ge+1))
        
        p = pred[:,t,r]
        test_target = p[s]
        test_mask = torch.ones(test_N,dtype=torch.bool)
        for ss in range(test_N):
            if (ss,r,t) in test_ture_triple:
                test_mask[ss] = 0
        p = p[test_mask]
        index = torch.LongTensor(random.sample(range(len(p)), min(len(p),50)))
        p = p[index]
        g = (test_target < p).sum()
        rank1.append(g)
        ge = (test_target <= p).sum()
        rank2.append(ge)
        results.append(get_mrr(g+1, ge+1))

        # acc
        test_target = pred[s,t,r]
        ns, nr, nt = s, r, t
        while (ns, nr, nt) in test_ture_triple:
            ns = random.randint(0, test_N-1)
            nr = random.randint(0, test_d-1)
            nt = random.randint(0, test_N-1)
        if nr < offset:
            neg_target = pred[ns,nt,nr]
        else:
            neg_target = pred[nt,ns,nr-offset]
        
        auc_pred.append(neg_target.sigmoid().item())
        # acc_threshold.append(model.outlin.bias[nr].sigmoid().item())
        auc_target.append(0)
        if test_target > neg_target:
            cnt += 1

    # auc
    auc_pred = np.array(auc_pred)
    # acc_threshold = np.array(acc_threshold)
    auc_target = np.array(auc_target)
    auc_score = roc_auc_score(auc_target, auc_pred)


        
    cnt /= len(query)
    mrr, h1, h3, h10 = 0.0, 0.0, 0.0, 0.0
    for _mrr, _h1, _h3, _h10 in results:
        mrr += _mrr
        h1 += _h1
        h3 += _h3
        h10 += _h10
    mrr /= len(results)
    h1 /= len(results)
    h3 /= len(results)
    h10 /= len(results)

    r_mrr, r_h1, r_h3, r_h10 = 0.0, 0.0, 0.0, 0.0
    for _mrr, _h1, _h3, _h10 in r_results:
        r_mrr += _mrr
        r_h1 += _h1
        r_h3 += _h3
        r_h10 += _h10
    r_mrr /= len(r_results)
    r_h1 /= len(r_results)
    r_h3 /= len(r_results)
    r_h10 /= len(r_results)
    # acc_max = max(acc_max, cnt)
    acc_score = accuracy_score(auc_target, auc_pred>0.5)
    # acc_max = max(acc_max, accuracy_score(auc_target, auc_pred>acc_threshold))
    print(f'h3:{h3:.4f}, rh3:{r_h3:.4f}, acc:{acc_score:.4f}, auc:{auc_score:.4f}')