based on main.py

runs Continuous Query Decomposition

ref: https://github.com/pminervini/KGReaso

changed self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=False) for s in sizes[:2]]) sparse to False in cqd/base.pyning

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import logging
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from models import KGReasoning
# SingledirectionalOneShotIterator provides negative sampling
from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator
# from tensorboardX import SummaryWriter
import time
import pickle
from collections import defaultdict
from tqdm import tqdm
from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple

import collections
import random
from tqdm.notebook import tqdm

import torch.nn.functional as F

from cqd import CQD

# Metadata

In [2]:
query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM']

print(all_tasks)

['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', 'up-DNF', '2u-DM', 'up-DM']


In [3]:
class DummyArgs:
    def __init__(self):
        None
        
args = DummyArgs()
args.cuda = True
args.geo = "cqd" # choices=['vec', 'box', 'beta', 'cqd']
args.gamma = 12.0
args.box_mode = "(none,0.02)" # Query2box
args.beta_mode = "(1600,2)" # BetaE relational projection
args.cqd_type = "discrete" # Continuous Query Decomposition
args.cqd_t_norm = CQD.PROD_NORM
args.cqd_k = 5
args.cqd_sigmoid_scores = False
args.cqd_normalize_scores = False
args.reg_weight = 1e-3
args.use_qa_iterator = True
args.data_path = "data/FB15k-237-betae"
args.batch_size = 64 #1024
args.cpu_num = 1# 10
args.negative_sample_size = 128
args.hidden_dim = 500
args.test_batch_size = 3 #2
args.print_on_screen = True
args.test_log_steps = 1000
args.learning_rate = 0.0001
args.max_steps = 1000 #100000
args.evaluate_union = "DNF" # choices=['DNF', 'DM'] evaluate union querioes, disjunctive normal form (DNF) or de Morgan's laws (DM)

In [4]:
# different query types
tasks = all_tasks[0:1]

if args.geo in ['box']:
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
elif args.geo in ['vec']:
    tmp_str = "g-{}".format(args.gamma)
elif args.geo == 'beta':
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)
elif args.geo == 'cqd':
    tmp_str = "g-cqd"

print(all_tasks)

['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', 'up-DNF', '2u-DM', 'up-DM']


In [5]:
with open('%s/stats.txt' % args.data_path) as f:
    entrel = f.readlines()
    nentity = int(entrel[0].split(' ')[-1])
    nrelation = int(entrel[1].split(' ')[-1])

args.nentity = nentity
args.nrelation = nrelation

logging.info('-------------------------------'*3)
logging.info('Geo: %s' % args.geo)
logging.info('Data Path: %s' % args.data_path)
logging.info('#entity: %d' % nentity)
logging.info('#relation: %d' % nrelation)
# logging.info('#max steps: %d' % args.max_steps)
# logging.info('Evaluate unoins using: %s' % args.evaluate_union)


# Load Data

In [6]:
def load_data(args, tasks):
    '''
    Load queries and remove queries not in tasks
    '''
    logging.info("loading data")
    train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb'))
    train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb'))
    valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb'))
    valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb'))
    valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb'))
    test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb'))
    test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb'))
    test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb'))
    
    # remove tasks not in args.tasks
    for name in all_tasks:
        if 'u' in name:
            name, evaluate_union = name.split('-')
        else:
            evaluate_union = args.evaluate_union
        if name not in tasks or evaluate_union != args.evaluate_union:
            query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])]
            if query_structure in train_queries:
                del train_queries[query_structure]
            if query_structure in valid_queries:
                del valid_queries[query_structure]
            if query_structure in test_queries:
                del test_queries[query_structure]

    return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers


In [7]:
"""
LOAD DATA
based on tasks defined in args
"""
train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks)

In [39]:
## inspect dataset

# len(test_queries[('e', ('r',))]) #, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers
# valid_hard_answers.keys()

1. train_queries: for (e, r) given entityID and relationID, find all other entities that are connected to it.
2. train_answers: given (e, r) this is the groundtruth answers of all entityIDs that are connected

In [8]:
[(type(x), len(x)) for x in [train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers]]

[(collections.defaultdict, 1),
 (collections.defaultdict, 1496890),
 (collections.defaultdict, 1),
 (collections.defaultdict, 95094),
 (collections.defaultdict, 95094),
 (collections.defaultdict, 1),
 (collections.defaultdict, 97804),
 (collections.defaultdict, 97804)]

In [9]:
ex_val_q = (6288, (8,))
print("*" * 10)
print(ex_val_q)
print(ex_val_q in valid_queries[('e', ('r', ))])
print("valid_hard_answers", valid_hard_answers[ex_val_q])
print("valid_easy_answers", valid_easy_answers[ex_val_q])

# get example query
ex_q = ex_val_q #(6288, (8,))
print("*" * 10)
print(ex_q)
print(ex_q in train_queries[('e', ('r', ))])
print("train_answers", train_answers[ex_q])

ex_test_q = ex_val_q
print("*" * 10)
print(ex_test_q)
print(ex_test_q in test_queries[('e', ('r', ))])
print("test_hard_answers", test_hard_answers[ex_val_q])
print("test_easy_answers", test_easy_answers[ex_val_q])

**********
(6288, (8,))
False
valid_hard_answers set()
valid_easy_answers set()
**********
(6288, (8,))
True
train_answers {9, 11, 117, 399}
**********
(6288, (8,))
False
test_hard_answers set()
test_easy_answers set()


In [10]:
"""
TRAINING Data Preparation
"""
train_path_queries = defaultdict(set)
for query_structure in train_queries:
#     train_path_queries[query_structure] = set(random.sample(train_queries[query_structure], 1))
    train_path_queries[query_structure] = train_queries[query_structure]
            
train_path_queries = flatten_query(train_path_queries)
print("number of train_queries", len(train_path_queries))
train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
                            # {(e,r): (entityID, relID)} and {(e, r): [entityID]}
                            TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.cpu_num,
                            collate_fn=TrainDataset.collate_fn
                        ))


number of train_queries 149689


In [11]:
"""
VALIDATION Data Preparation
"""
# for each type of structure / key
for query_structure in valid_queries:
    try:
        print(query_name_dict[query_structure[1]]+": "+str(len(valid_queries[query_structure])))
    except:
        print("error", query_structure)

valid_dataloader = DataLoader(
    TestDataset(
        flatten_query(valid_queries), 
        args.nentity, 
        args.nrelation, 
    ), 
    batch_size=args.test_batch_size,
    num_workers=args.cpu_num, 
    collate_fn=TestDataset.collate_fn
)


error ('e', ('r',))


In [12]:
model = CQD(nentity,
    nrelation,
    rank=args.hidden_dim,
    test_batch_size=args.test_batch_size,
    reg_weight=args.reg_weight,
    query_name_dict=query_name_dict,
    method=args.cqd_type,
    t_norm_name=args.cqd_t_norm,
    k=args.cqd_k,
    do_sigmoid=args.cqd_sigmoid_scores,
    do_normalize=args.cqd_normalize_scores,
    use_cuda=args.cuda)

logging.info('Model Parameter Configuration:')
num_params = 0
for name, param in model.named_parameters():
    print('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
    if param.requires_grad:
        num_params += np.prod(param.size())
print('Parameter Number: %d' % num_params)

model = model.cuda()

current_learning_rate = args.learning_rate
optimizer = torch.optim.Adam(
# optimizer = torch.optim.SparseAdam(
# optimizer = torch.optim.RMSprop(
# optimizer = torch.optim.Adagrad(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=current_learning_rate
)
warm_up_steps = args.max_steps // 2
_steps = args.max_steps // 2


Parameter embeddings.0.weight: torch.Size([14505, 1000]), require_grad = True
Parameter embeddings.1.weight: torch.Size([474, 1000]), require_grad = True
Parameter Number: 14979000


# Training

In [13]:
import wandb
import datetime

timenow = datetime.datetime.utcnow().isoformat()

wandb.init(project="KGReasoning", name=timenow, group="cqd")



wandb: Currently logged in as: eekosasih (use `wandb login --relogin` to force relogin)


In [19]:
args.max_steps = 10000

In [20]:
"""
TRAIN LOOP
"""
init_step = 0

training_logs = []
# loop over all batches
# initialising train_path_iterator is expensive. So at certain loop it might take some time to load
for step in tqdm(range(init_step, args.max_steps)):
    # train model for a step over all batches
    log = KGReasoning.train_step(model, optimizer, train_path_iterator, args, step)
    training_logs.append(log)

    # raise Exception("")

    log["epoch"] = step
    wandb.log(log, step=step)

KeyboardInterrupt: 

In [21]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
positive_sample_loss,16.42752
negative_sample_loss,16.42752
loss,16.42752
epoch,6778.0
_runtime,743.0
_timestamp,1631542849.0
_step,6778.0


0,1
positive_sample_loss,████████▇▇▇▇▇▇▆▇▆▇▆▅▆▆▄▅▄▄▅▄▄▄▃▃▂▂▂▃▁▃▁▁
negative_sample_loss,████████▇▇▇▇▇▇▆▇▆▇▆▅▆▆▄▅▄▄▅▄▄▄▃▃▂▂▂▃▁▃▁▁
loss,████████▇▇▇▇▇▇▆▇▆▇▆▅▆▆▄▅▄▄▅▄▄▄▃▃▂▂▂▃▁▃▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▁▁▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████
_timestamp,▁▁▁▁▁▁▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


In [25]:
# save the trained model
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    },
    os.path.join("checkpoint", "model.pt")
)

## DEBUG TRAINING STEP

In [13]:
step = 0

model, optimizer, train_iterator, args, step = model, optimizer, train_path_iterator, args, step

# there's an overhead operation to shufflte the train_iterator that makes this expensive to run
positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator)

In [17]:
# step = 0

# KGReasoning.train_step(model, optimizer, train_other_iterator, args, step)

model, optimizer, train_iterator, args, step = model, optimizer, train_path_iterator, args, step

model.train()
optimizer.zero_grad()

# # there's an overhead operation to shufflte the train_iterator that makes this expensive to run
# positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator)

# group queries into batch
batch_queries_dict = collections.defaultdict(list)
batch_idxs_dict = collections.defaultdict(list)
for i, query in enumerate(batch_queries): # group queries with same structure
    batch_queries_dict[query_structures[i]].append(query)
    batch_idxs_dict[query_structures[i]].append(i)

for query_structure in batch_queries_dict:
    if args.cuda:
        batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda()
    else:
        batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])

if args.cuda:
    positive_sample = positive_sample.cuda()
    negative_sample = negative_sample.cuda()
    subsampling_weight = subsampling_weight.cuda()

if isinstance(model, CQD):
    input_batch = batch_queries_dict[('e', ('r',))]
    input_batch = torch.cat((input_batch, positive_sample.unsqueeze(1)), dim=1)
    
    model = model.cpu()
    input_batch = input_batch.cpu()
    
    loss = model.loss(input_batch)

    positive_sample_loss = negative_sample_loss = loss

    
# # score positive and negative samples
# positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)

# # calculate loss
# # positive samples should have label of 1 while negative samples label of 0
# negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
# positive_score = F.logsigmoid(positive_logit).mean(dim=1)
# # aggregate loss with subsampling_weight
# positive_sample_loss = - (subsampling_weight * positive_score).sum()
# negative_sample_loss = - (subsampling_weight * negative_score).sum()
# positive_sample_loss /= subsampling_weight.sum()
# negative_sample_loss /= subsampling_weight.sum()
# loss = (positive_sample_loss + negative_sample_loss)/2

loss.backward()
optimizer.step()

log = {
    'positive_sample_loss': positive_sample_loss.item(),
    'negative_sample_loss': negative_sample_loss.item(),
    'loss': loss.item(),
}
log

{'positive_sample_loss': 19.164501190185547,
 'negative_sample_loss': 19.164501190185547,
 'loss': 19.164501190185547}

In [19]:
""" track loss = model.loss(input_batch) """
triples = input_batch
(scores_o, scores_s), factors = model.score_candidates(triples)
l_fit = model.loss_fn(scores_o, triples[:, 2]) + model.loss_fn(scores_s, triples[:, 0])
l_reg = model.regularizer.forward(factors)

# l_fit.backward()
l_reg.backward()
optimizer.step()

l_fit

tensor(19.1645, grad_fn=<AddBackward0>)

In [53]:
model

CQD(
  (embeddings): ModuleList(
    (0): Embedding(14505, 1000, sparse=True)
    (1): Embedding(474, 1000, sparse=True)
  )
  (loss_fn): CrossEntropyLoss()
)

In [None]:
print("=== input ===")
print(positive_sample.shape, negative_sample.shape, len(batch_queries), len(query_structures), subsampling_weight.shape)
print("=== output ===")
print(positive_logit.shape, negative_logit.shape)

=== input ===
torch.Size([64]) torch.Size([64, 128]) 64 64 torch.Size([64])
=== output ===
torch.Size([64, 1]) torch.Size([64, 128])


# Testing

In [None]:
# load trained model
checkpoint = torch.load(os.path.join("checkpoint", "model.pt"))
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [None]:
def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step):
    '''
    Evaluate queries in dataloader
    '''
    average_metrics = defaultdict(float)
    all_metrics = defaultdict(float)

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics:
        log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure])
        for metric in metrics[query_structure]:
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != 'num_queries':
                average_metrics[metric] += metrics[query_structure][metric]
        num_queries += metrics[query_structure]['num_queries']
        num_query_structures += 1

    for metric in average_metrics:
        average_metrics[metric] /= num_query_structures
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    log_metrics('%s average'%mode, step, average_metrics)

    return all_metrics


In [None]:
def log_metrics(mode, step, metrics):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))


In [None]:
"""`
EVAL
"""
step = -1
valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step)
valid_all_metrics

# after training the model

100%|██████████| 6698/6698 [01:06<00:00, 100.05it/s]


defaultdict(float,
            {'1p_MRR': 0.366261465070138,
             '1p_HITS1': 0.2592751371505161,
             '1p_HITS3': 0.41276127284368985,
             '1p_HITS10': 0.5799358917393381,
             '1p_num_queries': 20094,
             'average_MRR': 0.366261465070138,
             'average_HITS1': 0.2592751371505161,
             'average_HITS3': 0.41276127284368985,
             'average_HITS10': 0.5799358917393381})

## DEBUG TEST STEP

In [16]:
"""def evaluate in main.py"""

metrics = model.test_step(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict)
metrics

100%|██████████| 6698/6698 [01:07<00:00, 99.49it/s] 


defaultdict(<function models.KGReasoning.test_step.<locals>.<lambda>()>,
            {('e', ('r',)): defaultdict(int,
                         {'MRR': 0.009518244747182089,
                          'HITS1': 0.008909909141400957,
                          'HITS3': 0.009125562239005588,
                          'HITS10': 0.009491169651853655,
                          'num_queries': 20094})})

In [17]:
for negative_sample, queries, queries_unflatten, query_structures in tqdm(valid_dataloader, disable=not args.print_on_screen):
    break
    
negative_sample, queries, 

  0%|          | 0/6698 [00:00<?, ?it/s]

(tensor([[    0,     1,     2,  ..., 14502, 14503, 14504],
         [    0,     1,     2,  ..., 14502, 14503, 14504],
         [    0,     1,     2,  ..., 14502, 14503, 14504]]),
 [[11217, 143], [5673, 34], [4190, 62]])

1. why is there no positive_sample in test_dataloader
2. how to embed query vector and perform projection & intersection

In [29]:
negative_sample

tensor([[    0,     1,     2,  ..., 14502, 14503, 14504],
        [    0,     1,     2,  ..., 14502, 14503, 14504],
        [    0,     1,     2,  ..., 14502, 14503, 14504]], device='cuda:0')

In [18]:
"""model.test_step in models.py"""
model, easy_answers, hard_answers, args, test_dataloader, query_name_dict = model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict

model.eval()

step = 0
total_steps = len(test_dataloader)
logs = collections.defaultdict(list)

with torch.no_grad():
    # process per batch
    # loop through every validation dataset
    # negative sample is basically all entities in nentity
    for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen):
        batch_queries_dict = collections.defaultdict(list)
        batch_idxs_dict = collections.defaultdict(list)
        
        # for each test query
        for i, query in enumerate(queries):
            batch_queries_dict[query_structures[i]].append(query)
            batch_idxs_dict[query_structures[i]].append(i)
            
        # convert each positive query to a LongTensor
        for query_structure in batch_queries_dict:
            if args.cuda:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda()
            else:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])
            
        # convert each negative query to a LongTensor
        if args.cuda:
            negative_sample = negative_sample.cuda()
            
        # get prediction output
        # (2, number_of_edges) same dimension with negative_sample
        # this is basically embedding lookup
        """model.forward"""
        # negative_sample here is just a list of all entities, some of them are actually positive
        _, negative_logit, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict)
          
        # ... scoring ...
        queries_unflatten = [queries_unflatten[i] for i in idxs]
        query_structures = [query_structures[i] for i in idxs]
        # sort from maximum value based on logit values
        argsort = torch.argsort(negative_logit, dim=1, descending=True)
        ranking = argsort.clone().to(torch.float)
        if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one
            ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities
        else: # otherwise, create a new torch Tensor for batch_entity_range
            if args.cuda:
                ranking = ranking.scatter_(1, argsort, torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).cuda()) # achieve the ranking of all entities
            else:
                ranking = ranking.scatter_(1, argsort, torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1)) # achieve the ranking of all entities
        
        # loop through ranking
        # score every query in the dataloader batch
        for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)):
            # get groundtruth labels
            hard_answer = hard_answers[query]
            easy_answer = easy_answers[query]
            num_hard = len(hard_answer)
            num_easy = len(easy_answer)
            assert len(hard_answer.intersection(easy_answer)) == 0

            # compare ranking of groundtruth (easy_answer, hard_answer) and predicted results
            cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)]
            cur_ranking, indices = torch.sort(cur_ranking)
            masks = indices >= num_easy
            if args.cuda:
                answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda()
            else:
                answer_list = torch.arange(num_hard + num_easy).to(torch.float)
            cur_ranking = cur_ranking - answer_list + 1 # filtered setting
            cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers

            mrr = torch.mean(1./cur_ranking).item()
            h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item()
            h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item()
            h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item()

            logs[query_structure].append({
                'MRR': mrr,
                'HITS1': h1,
                'HITS3': h3,
                'HITS10': h10,
                'num_hard_answer': num_hard,
            })

metrics = collections.defaultdict(lambda: collections.defaultdict(int))
for query_structure in logs:
    for metric in logs[query_structure][0].keys():
        if metric in ['num_hard_answer']:
            continue
        metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure])
    metrics[query_structure]['num_queries'] = len(logs[query_structure])

metrics

  0%|          | 0/6698 [00:00<?, ?it/s]

defaultdict(<function __main__.<lambda>()>,
            {('e', ('r',)): defaultdict(int,
                         {'MRR': 0.009518244747182089,
                          'HITS1': 0.008909909141400957,
                          'HITS3': 0.009125562239005588,
                          'HITS10': 0.009491169651853655,
                          'num_queries': 20094})})

In [41]:
"""model.forward"""
# model.forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)

# model.geo == "vec"

# done per batch
"""model.forward_vec"""
# model.forward_vec(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
# negative_sample is a list of entities to be scored as potential answers to the queries
# batch_queries_dict stores list of queries (in this case (e, r) to be scored)
positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict = None, negative_sample, None, batch_queries_dict, batch_idxs_dict

all_center_embeddings, all_idxs = [], []
all_union_center_embeddings, all_union_idxs = [], []
for query_structure in batch_queries_dict:
    if 'u' in model.query_name_dict[query_structure]:
        # embedding 
        center_embedding, _ = model.embed_query_vec(model.transform_union_query(batch_queries_dict[query_structure], 
                                                            query_structure), 
                                                        model.transform_union_structure(query_structure), 0)
        all_union_center_embeddings.append(center_embedding)
        all_union_idxs.extend(batch_idxs_dict[query_structure])
    else:
        # get vector embedding for each query (anchor + projection + intersection)
        # this will be compared with the groundtruth answer entity embedding
        center_embedding, _ = model.embed_query_vec(batch_queries_dict[query_structure], query_structure, 0)
        all_center_embeddings.append(center_embedding)
        all_idxs.extend(batch_idxs_dict[query_structure])

if len(all_center_embeddings) > 0:
    all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1)
if len(all_union_center_embeddings) > 0:
    all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1)
    all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1)

if type(subsampling_weight) != type(None):
    subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]

# in test, positive samples are not given
if type(positive_sample) != type(None):
    if len(all_center_embeddings) > 0:
        positive_sample_regular = positive_sample[all_idxs]
        positive_embedding = torch.index_select(model.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1)
        positive_logit = model.cal_logit_vec(positive_embedding, all_center_embeddings)
    else:
        positive_logit = torch.Tensor([]).to(model.entity_embedding.device)

    if len(all_union_center_embeddings) > 0:
        positive_sample_union = positive_sample[all_union_idxs]
        positive_embedding = torch.index_select(model.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1)
        positive_union_logit = model.cal_logit_vec(positive_embedding, all_union_center_embeddings)
        positive_union_logit = torch.max(positive_union_logit, dim=1)[0]
    else:
        positive_union_logit = torch.Tensor([]).to(model.entity_embedding.device)
    positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0)
else:
    positive_logit = None

# in test, negative samples are basically all possible list of entities
if type(negative_sample) != type(None):
    if len(all_center_embeddings) > 0:
        negative_sample_regular = negative_sample[all_idxs]
        batch_size, negative_size = negative_sample_regular.shape
        negative_embedding = torch.index_select(model.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1)
        negative_logit = model.cal_logit_vec(negative_embedding, all_center_embeddings)
    else:
        negative_logit = torch.Tensor([]).to(model.entity_embedding.device)

    if len(all_union_center_embeddings) > 0:
        negative_sample_union = negative_sample[all_union_idxs]
        batch_size, negative_size = negative_sample_union.shape
        negative_embedding = torch.index_select(model.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1)
        negative_union_logit = model.cal_logit_vec(negative_embedding, all_union_center_embeddings)
        negative_union_logit = torch.max(negative_union_logit, dim=1)[0]
    else:
        negative_union_logit = torch.Tensor([]).to(model.entity_embedding.device)
    negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0)
else:
    negative_logit = None

positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs

(None,
 tensor([[0.8224, 0.7421, 0.8782,  ..., 0.5703, 0.7527, 0.5376],
         [0.5680, 0.4127, 0.2810,  ..., 0.4131, 0.6748, 0.3881],
         [0.6553, 1.1643, 1.2482,  ..., 0.8182, 1.0112, 1.0331]],
        device='cuda:0', grad_fn=<CatBackward>),
 None,
 [0, 1, 2])

In [79]:
"""model.embed_query_vec in models.py"""
# get the query embedding
# input is (batch_size, query_type_length=2), query_structure=(e, r)
# model.embed_query_vec(batch_queries_dict[query_structure], query_structure, 0)

queries, query_structure, idx = batch_queries_dict[query_structure], query_structure, 0

all_relation_flag = True
for ele in query_structure[-1]:
    if ele not in ['r', 'n']:
        all_relation_flag = False
        break
if all_relation_flag:
    # if anchor
    if query_structure[0] == "e":
        # get entity embedding
        embedding = torch.index_select(model.entity_embedding, dim=0, index=queries[:, idx])
        idx += 1
    else:
        # recursion
        embedding, idx = model.embed_query_vec(queries, query_structure[0], idx)
    for i in range(len(query_structure[-1])):
        if query_structure[-1][i] == "n":
            assert False, "vec cannot handle queries with negation"
        else:
            # get relation embedding
            r_embedding = torch.index_select(model.relation_embedding, dim=0, index=queries[:, idx])
            # add entity and relation embedding
            embedding += r_embedding
        idx += 1
else:
    embedding_list = []
    for i in range(len(query_structure)):
        embedding, idx = model.embed_query_vec(queries, query_structure[i], idx)
        embedding_list.append(embedding)
    embedding = model.center_net(torch.stack(embedding_list))
    
embedding, idx

(tensor([[ 0.0005,  0.0071, -0.0040,  ..., -0.0015,  0.0015, -0.0132],
         [-0.0148,  0.0471,  0.0523,  ...,  0.0172, -0.0143,  0.0200],
         [ 0.0341, -0.0213, -0.0102,  ..., -0.0171,  0.0311, -0.0060]],
        device='cuda:0', grad_fn=<AddBackward0>),
 2)

In [65]:
""" model.cal_logit_vec """
# score / calculate distance between candidate embedding and query embedding
# negative 
print(negative_embedding.shape, all_center_embeddings.shape)
# model.cal_logit_vec(negative_embedding, all_center_embeddings)

#     def cal_logit_vec(self, entity_embedding, query_embedding):
#         distance = entity_embedding - query_embedding
#         logit = self.gamma - torch.norm(distance, p=1, dim=-1)
#         return logit

entity_embedding, query_embedding = negative_embedding, all_center_embeddings
distance = entity_embedding - query_embedding
logit = model.gamma - torch.norm(distance, p=1, dim=-1)

distance.shape, logit.shape

torch.Size([3, 14505, 500]) torch.Size([3, 1, 500])


(torch.Size([3, 14505, 500]), torch.Size([3, 14505]))

In [69]:
logit.argmax(axis=1)

tensor([1718, 4586, 8803], device='cuda:0')