In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import time
import pickle
from tqdm import tqdm
# for some reason, need to go to the sheaf_kg directory in order for torch.load to work
os.chdir('/home/gebhart/projects/sheaf_kg/sheaf_kg')

import sheaf_kg.batch_harmonic_extension as harmonic_extension
from sheaf_kg.sheafE_models import SheafE_Multisection, SheafE_Diag

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline

In [2]:
dataset = 'FB15k-237'
num_test = 500
num_train = 1000
batch_size = 500
use_section = 0
model_name = 'SheafE_Multisection_64embdim_64esdim_1sec_1norm_1000epochs_SoftplusLossloss_20210304-1923'
save_loc = '/home/gebhart/projects/sheaf_kg/data/{}/{}/trained_model.pkl'.format(dataset,model_name)
betae_path = '/home/gebhart/projects/sheaf_kg/data/{}-betae'.format(dataset)
model = torch.load(save_loc).to('cpu')

In [3]:
query_structures = [('e', ('r', 'r')), ('e', ('r', 'r', 'r')), (('e', ('r',)), ('e', ('r',))), (('e', ('r',)), ('e', ('r',)), ('e', ('r',))), (('e', ('r', 'r')), ('e', ('r',))), ((('e', ('r',)), ('e', ('r',))), ('r',))]

query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }

In [4]:
# ds = pykeen.datasets.get_dataset(dataset=dataset)
ds = pykeen.datasets.get_dataset(dataset=dataset, dataset_kwargs=dict(create_inverse_triples=True))
training = ds.training.mapped_triples
relid2label = ds.training.relation_id_to_label 
label2relid = {v:k for k,v in relid2label.items()}

entid2label = ds.training.entity_id_to_label 
label2entid = {v:k for k,v in entid2label.items()}

You're trying to map triples with 30 entities and 0 relations that are not in the training set. These triples will be excluded from the mapping.
In total 28 from 20466 triples were filtered out


In [5]:
with open(os.path.join(betae_path,'test-queries.pkl'), 'rb') as f:
    test_queries = pickle.load(f)

with open(os.path.join(betae_path,'test-easy-answers.pkl'), 'rb') as f:
    test_answers = pickle.load(f)

with open(os.path.join(betae_path,'train-queries.pkl'), 'rb') as f:
    train_queries = pickle.load(f)

with open(os.path.join(betae_path,'train-answers.pkl'), 'rb') as f:
    train_answers = pickle.load(f)
    
with open(os.path.join(betae_path,'id2rel.pkl'), 'rb') as f:
    id2rel = pickle.load(f)
    
with open(os.path.join(betae_path,'id2ent.pkl'), 'rb') as f:
    id2ent = pickle.load(f)

In [6]:
def map_ent(e):
    return label2entid[id2ent[e]]

def map_rel(r):
    relname = id2rel[r]
    return label2relid[relname[1:]]

def orient_rel(r):
    orientation = 1
    relname = id2rel[r]
    if relname[0] == '-':
        orientation = -1
    return orientation

In [7]:
test_queries

defaultdict(set,
            {('e', ('r',)): {(3594, (119,)),
              (7330, (38,)),
              (1212, (38,)),
              (12120, (172,)),
              (3494, (3,)),
              (1543, (206,)),
              (7714, (41,)),
              (4992, (12,)),
              (9807, (77,)),
              (4994, (188,)),
              (2027, (311,)),
              (6843, (164,)),
              (6569, (96,)),
              (9242, (52,)),
              (6547, (117,)),
              (6992, (50,)),
              (9225, (143,)),
              (12396, (62,)),
              (2515, (59,)),
              (5806, (286,)),
              (3028, (27,)),
              (8854, (370,)),
              (889, (185,)),
              (13632, (415,)),
              (4720, (38,)),
              (406, (39,)),
              (7344, (219,)),
              (1680, (12,)),
              (3566, (65,)),
              (6620, (212,)),
              (2248, (39,)),
              (9195, (89,)),
              (3793, (38,)

In [8]:
def L_p(queries, model):
    '''query of form ('e', ('r', 'r', ... , 'r')).
    here we assume 2 or more relations are present so 2p or greater
    '''
    all_ents = [map_ent(query[0]) for query in queries]
    all_rels = [[map_rel(r) for r in query[1]] for query in queries]
    all_invs = [[orient_rel(r) for r in query[1]] for query in queries]
    n_path_ents = len(all_rels[0])
    num_queries = len(queries)
    
    edge_indices = np.concatenate([np.arange(0,n_path_ents)[:,np.newaxis].T, np.arange(1,n_path_ents+1)[:,np.newaxis].T], axis=0)
    edge_indices = torch.LongTensor(np.repeat(edge_indices[np.newaxis, :, :], num_queries, axis=0))
    
    rel_idx_tensor = torch.LongTensor(all_rels)
    
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    
    restrictions = torch.empty((num_queries, rel_idx_tensor.shape[1], 2, left_restrictions.shape[-2], left_restrictions.shape[-1]))
    for ainvix in range(len(all_invs)):
        invs = all_invs[ainvix]
        for invix in range(len(invs)):
            if invs[invix] == -1:
                restrictions[ainvix,invix,0,:,:] = right_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = left_restrictions[ainvix,invix,:,:]
            else:
                restrictions[ainvix,invix,0,:,:] = left_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = right_restrictions[ainvix,invix,:,:]
    
    ent_idx_tensor = torch.LongTensor(all_ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor).view(-1, model.embedding_dim, model.num_sections)
    
    B = torch.LongTensor(np.repeat(np.array([0,n_path_ents],np.int)[np.newaxis,:], num_queries, axis=0))
    U = torch.LongTensor(np.repeat(np.array(range(1,n_path_ents),np.int)[np.newaxis,:], num_queries, axis=0))
    source_vertices = np.zeros((num_queries,1), dtype=np.int)
    target_vertices = np.full((num_queries,1), 1, dtype=np.int)
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings


def L_i(queries, model):
    '''query of form (('e', ('r',)), ('e', ('r',)), ... , ('e', ('r',)))'''
    num_intersects = len(queries[0])
    all_ents = [[map_ent(pair[0]) for pair in query] for query in queries]
    all_rels = [[map_rel(pair[1][0]) for pair in query] for query in queries]
    all_invs = [[orient_rel(pair[1][0]) for pair in query] for query in queries]
    n_ents = len(all_ents[0])
    num_queries = len(queries)
    
    edge_indices = np.concatenate([np.full(n_ents,n_ents)[:,np.newaxis].T, np.arange(0,n_ents)[:,np.newaxis].T], axis=0)
    edge_indices = torch.LongTensor(np.repeat(edge_indices[np.newaxis, :, :], num_queries, axis=0))
    
    rel_idx_tensor = torch.LongTensor(all_rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    
    restrictions = torch.empty((num_queries, rel_idx_tensor.shape[1], 2, left_restrictions.shape[-2], left_restrictions.shape[-1]))
    for ainvix in range(len(all_invs)):
        invs = all_invs[ainvix]
        for invix in range(len(invs)):
            # why is this not -1 like with all the others?! can't figure this out!
            if invs[invix] == -1:
                restrictions[ainvix,invix,0,:,:] = right_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = left_restrictions[ainvix,invix,:,:]
            else:
                restrictions[ainvix,invix,0,:,:] = left_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = right_restrictions[ainvix,invix,:,:]
    
    ent_idx_tensor = torch.LongTensor(all_ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor.flatten()).view(-1, model.embedding_dim, model.num_sections)
    
    L = harmonic_extension.Laplacian(edge_indices, restrictions)
    source_vertices = np.repeat(np.arange(n_ents)[np.newaxis,:], num_queries, axis=0)
    target_vertices = np.full((num_queries, 1),n_ents, dtype=np.int)
    return torch.transpose(L,1,2), source_vertices, target_vertices, source_embeddings

def L_ip(queries, model):
    '''query of form ((('e', ('r',)), ('e', ('r',))), ('r',))'''
    all_ents = [[map_ent(t[0]) for t in query[0]] for query in queries] 
    all_rels = [[map_rel(query[0][0][1][0]), map_rel(query[0][1][1][0]), map_rel(query[1][0])] for query in queries]
    all_invs = [[orient_rel(query[0][0][1][0]), orient_rel(query[0][1][1][0]), orient_rel(query[1][0])] for query in queries]
    n_ents = len(all_ents[0])
    num_queries = len(queries)
    
    edge_indices = torch.LongTensor(np.repeat(np.array([[0,2],[1,2],[2,3]],np.int).T[np.newaxis,:,:], num_queries, axis=0))
    
    rel_idx_tensor = torch.LongTensor(all_rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    
    restrictions = torch.empty((num_queries, rel_idx_tensor.shape[1], 2, left_restrictions.shape[-2], left_restrictions.shape[-1]))
    for ainvix in range(len(all_invs)):
        invs = all_invs[ainvix]
        for invix in range(len(invs)):
            if invs[invix] == -1:
                restrictions[ainvix,invix,0,:,:] = right_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = left_restrictions[ainvix,invix,:,:]
            else:
                restrictions[ainvix,invix,0,:,:] = left_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = right_restrictions[ainvix,invix,:,:]
    
    ent_idx_tensor = torch.LongTensor(all_ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor.flatten()).view(-1, model.embedding_dim, model.num_sections)
    
    B = torch.LongTensor(np.repeat(np.array([0,2,3],dtype=np.int)[np.newaxis,:], num_queries, axis=0))
    U = torch.LongTensor(np.full((num_queries,1), 1, dtype=np.int))
    source_vertices = np.repeat(np.array([0,1], dtype=np.int)[np.newaxis,:], num_queries, axis=0)
    target_vertices = np.full((num_queries,1), 2, dtype=np.int)
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings

def L_pi(queries, model):
    '''query of form (('e', ('r', 'r')), ('e', ('r',)))'''
    all_ents = [[map_ent(t[0]) for t in query] for query in queries]
    all_rels = [[map_rel(query[0][1][0]), map_rel(query[0][1][1]), map_rel(query[1][1][0])] for query in queries]
    all_invs = [[orient_rel(query[0][1][0]), orient_rel(query[0][1][1]), orient_rel(query[1][1][0])] for query in queries]
    n_ents = len(all_ents[0])
    num_queries = len(queries)
    
    edge_indices = torch.LongTensor(np.repeat(np.array([[0,2],[2,3],[1,3]],np.int).T[np.newaxis,:,:], num_queries, axis=0))
    
    rel_idx_tensor = torch.LongTensor(all_rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor.flatten()).view(-1,rel_idx_tensor.shape[1], model.edge_stalk_dim, model.embedding_dim)
    
    restrictions = torch.empty((num_queries, rel_idx_tensor.shape[1], 2, left_restrictions.shape[-2], left_restrictions.shape[-1]))
    for ainvix in range(len(all_invs)):
        invs = all_invs[ainvix]
        for invix in range(len(invs)):
            if invs[invix] == -1:
                restrictions[ainvix,invix,0,:,:] = right_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = left_restrictions[ainvix,invix,:,:]
            else:
                restrictions[ainvix,invix,0,:,:] = left_restrictions[ainvix,invix,:,:]
                restrictions[ainvix,invix,1,:,:] = right_restrictions[ainvix,invix,:,:]
    
    ent_idx_tensor = torch.LongTensor(all_ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor.flatten()).view(-1, model.embedding_dim, model.num_sections)
    
    B = torch.LongTensor(np.repeat(np.array([0,1,3], dtype=np.int)[np.newaxis, :], num_queries, axis=0))
    U = torch.LongTensor(np.full((num_queries, 1), 2, dtype=np.int))
    source_vertices = np.repeat(np.array([0,1], dtype=np.int).T[np.newaxis,:], num_queries, axis=0)
    target_vertices = np.full((num_queries,1), 2, dtype=np.int)
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings


query_name_fn_dict = {'2p':L_p, '3p':L_p, '2i':L_i, '3i':L_i, 'ip':L_ip, 'pi':L_pi}

def softplusloss(logits, labels):
    loss_fn = torch.nn.Softplus(beta=1, threshold=20)
    assert 0. <= labels.min() and labels.max() <= 1.
    # scale labels from [0, 1] to [-1, 1]
    labels = 2 * labels - 1
    loss = loss_fn((-1) * labels * logits)
    loss = torch.mean(loss)
    return loss

In [9]:
def train_step(model, optimizer, query_type, batch_queries, batch_targets, batch_answers):
    model.train()
    optimizer.zero_grad()

    fn = query_name_fn_dict[query_name]
    
    L, source_vertices, target_vertices, source_embeddings = fn(qs, model)
    target_embeddings = torch.index_select(model.ent_embeddings, 0, batch_targets).view(-1, model.embedding_dim, model.num_sections)
    Q = harmonic_extension.compute_costs(L,source_vertices,target_vertices,torch.mean(source_embeddings, -1).view(batch_size, -1),target_embeddings,source_embeddings.shape[1])
    loss = softplusloss(Q,batch_answers)

    loss.backward()
    optimizer.step()
    
    return loss.item()

In [10]:
%%time
allhits1 = []
allhits3 = []
allhits5 = []
allhits10 = []
allmrr = []
query_names = []
# target_embeddings = model.ent_embeddings.view(-1, model.embedding_dim, model.num_sections)[:,:,use_section].T
target_embeddings = torch.mean(model.ent_embeddings.view(-1, model.embedding_dim, model.num_sections), 2).T
for query_structure in query_structures:
    print('Running query : {}'.format(query_structure))
    query_name = query_name_dict[query_structure]
    query_names.append(query_name)
    fn = query_name_fn_dict[query_name]
    hits1 = 0.
    hits3 = 0.
    hits5 = 0.
    hits10 = 0.
    mrr = 0.
    cnt = 0
    # the len() > 0 part is to determine whether we have an "easy" query
    queries = [q for q in test_queries[query_structure] if len(test_answers[q]) > 0] 
    for qix in tqdm(range(0, num_test, batch_size)):
        qs = queries[qix:qix+batch_size]
        # we have a non-trivial "easy" query
        if len(qs) > 0:
            all_answers = [[map_ent(a) for a in test_answers[query]] for query in qs]
            L, source_vertices, target_vertices, source_embeddings = fn(qs, model)
            Q = harmonic_extension.compute_costs(L,source_vertices,target_vertices,torch.mean(source_embeddings, -1).view(batch_size, -1),target_embeddings,source_embeddings.shape[1])
            for i in range(len(qs)):
                Qi = Q[i].squeeze()
                answers = all_answers[i]
                sortd,_ = torch.sort(Qi)
                idxleft = torch.searchsorted(sortd, Qi[answers], right=False) + 1
                idxright = torch.searchsorted(sortd, Qi[answers], right=True) + 1
                nl = idxleft.shape[0]
                nr = idxright.shape[0]
                # idxright = idxleft # throw this for optimistic ranking
                hits1 += ((torch.sum(idxleft <= 1)/nl + torch.sum(idxright <= 1)/nr) / 2.)
                hits3 += ((torch.sum(idxleft <= 3)/nl + torch.sum(idxright <= 3)/nr) / 2.)
                hits5 += ((torch.sum(idxleft <= 5)/nl + torch.sum(idxright <= 5)/nr) / 2.)
                hits10 += ((torch.sum(idxleft <= 10)/nl + torch.sum(idxright <= 10)/nr) / 2.)
                mrr += ((torch.sum(1./idxleft)/nl + torch.sum(1./idxright)/nr) / 2.)
                cnt += 1
    if cnt > 0:
        allhits1.append(hits1/cnt)
        allhits3.append(hits3/cnt)
        allhits5.append(hits5/cnt)
        allhits10.append(hits10/cnt)
        allmrr.append(mrr/cnt)
    else:
        default = 0.
        allhits1.append(default)
        allhits3.append(default)
        allhits5.append(default)
        allhits10.append(default)
        allmrr.append(default)

  0%|          | 0/1 [00:00<?, ?it/s]

Running query : ('e', ('r', 'r'))


100%|██████████| 1/1 [00:01<00:00,  1.53s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Running query : ('e', ('r', 'r', 'r'))


100%|██████████| 1/1 [00:01<00:00,  1.82s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Running query : (('e', ('r',)), ('e', ('r',)))


100%|██████████| 1/1 [00:01<00:00,  1.42s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Running query : (('e', ('r',)), ('e', ('r',)), ('e', ('r',)))


100%|██████████| 1/1 [00:01<00:00,  1.49s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Running query : (('e', ('r', 'r')), ('e', ('r',)))


100%|██████████| 1/1 [00:01<00:00,  1.63s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Running query : ((('e', ('r',)), ('e', ('r',))), ('r',))


100%|██████████| 1/1 [00:01<00:00,  1.66s/it]

CPU times: user 29.2 s, sys: 4.97 s, total: 34.1 s
Wall time: 9.58 s





In [11]:
cols = ['hits@1', 'hits@3', 'hits@5', 'hits@10', 'mrr']
df_before = pd.DataFrame(np.array([allhits1, allhits3, allhits5, allhits10, allmrr]).T, columns=cols, index=query_names) 
df_before * 100

Unnamed: 0,hits@1,hits@3,hits@5,hits@10,mrr
2p,0.024201,0.129982,0.274562,1.051856,0.448053
3p,0.00661,0.030103,0.07629,0.200592,0.213284
2i,0.0,0.016667,0.039863,0.063492,0.082921
3i,0.0,0.00517,0.006346,0.061399,0.060924
pi,0.026984,0.241568,0.530733,1.226756,0.665777
ip,0.069039,0.293503,0.459427,0.929549,0.699226


In [None]:
query_names = []
for query_structure in query_structures:
    print('Running query : {}'.format(query_structure))
    query_name = query_name_dict[query_structure]
    query_names.append(query_name)

    # the len() > 0 part is to determine whether we have an "easy" query
    queries = [q for q in train_queries[query_structure] if len(train_answers[q]) > 0]
    
    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    print([param.shape for param in params])
    optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=1e-4
        )
    
    for qix in tqdm(range(0, num_train, batch_size)):
        qs = queries[qix:qix+batch_size]
        # we have a non-trivial "easy" query
        if len(qs) > 0:
            true_targets = torch.LongTensor([np.random.choice([map_ent(a) for a in train_answers[query]]) for query in qs])
            true_answers = torch.ones(true_targets.shape)
            neg_targets = torch.randint(model.num_entities, true_targets.shape)
            neg_answers = torch.zeros(neg_targets.shape)
            
            true_loss = train_step(model, optimizer, query_name, qs, true_targets, true_answers)
            neg_loss = train_step(model, optimizer, query_name, qs, neg_targets, neg_answers)
                        

  0%|          | 0/2 [00:00<?, ?it/s]

Running query : ('e', ('r', 'r'))
[torch.Size([14505, 64, 1]), torch.Size([237, 64, 64]), torch.Size([237, 64, 64])]


100%|██████████| 2/2 [00:56<00:00, 28.37s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

Running query : ('e', ('r', 'r', 'r'))
[torch.Size([14505, 64, 1]), torch.Size([237, 64, 64]), torch.Size([237, 64, 64])]


 50%|█████     | 1/2 [01:04<01:04, 64.34s/it]

In [None]:
%%time
allhits1 = []
allhits3 = []
allhits5 = []
allhits10 = []
allmrr = []
query_names = []
model.eval()
# target_embeddings = model.ent_embeddings.view(-1, model.embedding_dim, model.num_sections)[:,:,use_section].T
target_embeddings = torch.mean(model.ent_embeddings.view(-1, model.embedding_dim, model.num_sections), 2).T
for query_structure in query_structures:
    print('Running query : {}'.format(query_structure))
    query_name = query_name_dict[query_structure]
    query_names.append(query_name)
    fn = query_name_fn_dict[query_name]
    hits1 = 0.
    hits3 = 0.
    hits5 = 0.
    hits10 = 0.
    mrr = 0.
    cnt = 0
    # the len() > 0 part is to determine whether we have an "easy" query
    queries = [q for q in test_queries[query_structure] if len(test_answers[q]) > 0] 
    for qix in tqdm(range(0, num_test, batch_size)):
        qs = queries[qix:qix+batch_size]
        # we have a non-trivial "easy" query
        if len(qs) > 0:
            all_answers = [[map_ent(a) for a in test_answers[query]] for query in qs]
            L, source_vertices, target_vertices, source_embeddings = fn(qs, model)
            Q = harmonic_extension.compute_costs(L,source_vertices,target_vertices,torch.mean(source_embeddings, -1).view(batch_size, -1),target_embeddings,source_embeddings.shape[1])
            for i in range(len(qs)):
                Qi = Q[i].squeeze()
                answers = all_answers[i]
                sortd,_ = torch.sort(Qi)
                idxleft = torch.searchsorted(sortd, Qi[answers], right=False) + 1
                idxright = torch.searchsorted(sortd, Qi[answers], right=True) + 1
                nl = idxleft.shape[0]
                nr = idxright.shape[0]
                # idxright = idxleft # throw this for optimistic ranking
                hits1 += ((torch.sum(idxleft <= 1)/nl + torch.sum(idxright <= 1)/nr) / 2.)
                hits3 += ((torch.sum(idxleft <= 3)/nl + torch.sum(idxright <= 3)/nr) / 2.)
                hits5 += ((torch.sum(idxleft <= 5)/nl + torch.sum(idxright <= 5)/nr) / 2.)
                hits10 += ((torch.sum(idxleft <= 10)/nl + torch.sum(idxright <= 10)/nr) / 2.)
                mrr += ((torch.sum(1./idxleft)/nl + torch.sum(1./idxright)/nr) / 2.)
                cnt += 1
    if cnt > 0:
        allhits1.append(hits1/cnt)
        allhits3.append(hits3/cnt)
        allhits5.append(hits5/cnt)
        allhits10.append(hits10/cnt)
        allmrr.append(mrr/cnt)
    else:
        default = 0.
        allhits1.append(default)
        allhits3.append(default)
        allhits5.append(default)
        allhits10.append(default)
        allmrr.append(default)

In [None]:
cols = ['hits@1', 'hits@3', 'hits@5', 'hits@10', 'mrr']
df_after = pd.DataFrame(np.array([allhits1, allhits3, allhits5, allhits10, allmrr]).T, columns=cols, index=query_names) 
df_after * 100

In [None]:
(df_after - df_before) * 100