based on main.py

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import logging
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from models import KGReasoning
from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator
# from tensorboardX import SummaryWriter
import time
import pickle
from collections import defaultdict
from tqdm import tqdm
from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple

import collections
import random
from tqdm.notebook import tqdm

import torch.nn.functional as F

# Metadata

In [2]:
query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM']

print(all_tasks)

['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', 'up-DNF', '2u-DM', 'up-DM']


In [3]:
class DummyArgs:
    def __init__(self):
        None
        
args = DummyArgs()
args.cuda = True
args.geo = "beta" # choices=['vec', 'box', 'beta']
args.gamma = 12.0
args.box_mode = "(none,0.02)" # Query2box
args.beta_mode = "(1600,2)" # BetaE relational projection
args.data_path = "data/FB15k-237-betae"
args.batch_size = 64 #1024
args.cpu_num = 1# 10
args.negative_sample_size = 128
args.hidden_dim = 500
args.test_batch_size = 3 #2
args.print_on_screen = True
args.test_log_steps = 1000
args.learning_rate = 0.0001
args.max_steps = 100000
args.evaluate_union = "DNF" # choices=['DNF', 'DM'] evaluate union querioes, disjunctive normal form (DNF) or de Morgan's laws (DM)

In [4]:
# different query types
tasks = all_tasks[10:11]

if args.geo in ['box']:
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
elif args.geo in ['vec']:
    tmp_str = "g-{}".format(args.gamma)
elif args.geo == 'beta':
    tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)

print(tasks, all_tasks)

['pin'] ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', 'up-DNF', '2u-DM', 'up-DM']


In [5]:
with open('%s/stats.txt' % args.data_path) as f:
    entrel = f.readlines()
    nentity = int(entrel[0].split(' ')[-1])
    nrelation = int(entrel[1].split(' ')[-1])

args.nentity = nentity
args.nrelation = nrelation

logging.info('-------------------------------'*3)
logging.info('Geo: %s' % args.geo)
logging.info('Data Path: %s' % args.data_path)
logging.info('#entity: %d' % nentity)
logging.info('#relation: %d' % nrelation)
# logging.info('#max steps: %d' % args.max_steps)
# logging.info('Evaluate unoins using: %s' % args.evaluate_union)


# Load Data

In [6]:
def load_data(args, tasks):
    '''
    Load queries and remove queries not in tasks
    '''
    logging.info("loading data")
    train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb'))
    train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb'))
    valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb'))
    valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb'))
    valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb'))
    test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb'))
    test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb'))
    test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb'))
    
    # remove tasks not in args.tasks
    for name in all_tasks:
        if 'u' in name:
            name, evaluate_union = name.split('-')
        else:
            evaluate_union = args.evaluate_union
        if name not in tasks or evaluate_union != args.evaluate_union:
            query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])]
            if query_structure in train_queries:
                del train_queries[query_structure]
            if query_structure in valid_queries:
                del valid_queries[query_structure]
            if query_structure in test_queries:
                del test_queries[query_structure]

    return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers


In [7]:
"""
LOAD DATA
based on tasks defined in args
"""
train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks)

1. train_queries: for (e, r) given entityID and relationID, find all other entities that are connected to it.
2. train_answers: given (e, r) this is the groundtruth answers of all entityIDs that are connected

In [8]:
[(type(x), len(x)) for x in [train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers]]

[(collections.defaultdict, 1),
 (collections.defaultdict, 1496890),
 (collections.defaultdict, 1),
 (collections.defaultdict, 95094),
 (collections.defaultdict, 95094),
 (collections.defaultdict, 1),
 (collections.defaultdict, 97804),
 (collections.defaultdict, 97804)]

In [9]:
queryType = (('e', ('r', 'r')), ('e', ('r', 'n')))

ex_val_q = random.sample(valid_queries[queryType], 1)[0]
print("*" * 10)
print(ex_val_q)
print(ex_val_q in valid_queries[queryType])
print("valid_hard_answers", valid_hard_answers[ex_val_q])
print("valid_easy_answers", valid_easy_answers[ex_val_q])

# get example query
ex_q = ex_val_q #(6288, (8,))
print("*" * 10)
print(ex_q)
print(ex_q in train_queries[queryType])
print("train_answers", train_answers[ex_q])

ex_test_q = ex_val_q
print("*" * 10)
print(ex_test_q)
print(ex_test_q in test_queries[queryType])
print("test_hard_answers", test_hard_answers[ex_val_q])
print("test_easy_answers", test_easy_answers[ex_val_q])

**********
((6099, (1, 97)), (32, (97, -2)))
True
valid_hard_answers {2137, 4834, 547, 605, 2213, 11847, 12722, 10098, 1332, 2418, 12246, 4377, 8892, 10269, 350}
valid_easy_answers {6532, 9989, 2951, 7053, 9872, 10385, 10899, 1943, 12951, 6936, 4122, 9628, 12573, 6429, 7330, 7459, 7204, 7848, 11947, 10412, 6829, 7981, 9904, 10672, 10290, 13618, 14000, 4668, 11198, 1982, 1727, 3903, 9282, 3906, 13774, 8407, 6873, 346, 6112, 3172, 8805, 8939, 5355, 11374, 9457, 5876, 5493, 12791, 1656, 10105, 5630}
**********
((6099, (1, 97)), (32, (97, -2)))
False
train_answers {6532, 9989, 2951, 7053, 9872, 10385, 10899, 12951, 1943, 6936, 4122, 9628, 6429, 12573, 7330, 7459, 7204, 7848, 11947, 10412, 6829, 7981, 9904, 10672, 10290, 13618, 14000, 4668, 11198, 1982, 1727, 3903, 9282, 3906, 13774, 8407, 6873, 346, 6112, 3172, 8805, 8939, 5355, 11374, 9457, 5876, 5493, 12791, 1656, 10105, 5630}
**********
((6099, (1, 97)), (32, (97, -2)))
False
test_hard_answers set()
test_easy_answers set()


In [10]:
"""
TRAINING Data Preparation
"""
train_path_queries = defaultdict(set)
for query_structure in train_queries:
#     train_path_queries[query_structure] = set(random.sample(train_queries[query_structure], 1))
    train_path_queries[query_structure] = train_queries[query_structure]
            
train_path_queries = flatten_query(train_path_queries)
print("number of train_queries", len(train_path_queries))
train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
                            # {(e,r): (entityID, relID)} and {(e, r): [entityID]}
                            TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.cpu_num,
                            collate_fn=TrainDataset.collate_fn
                        ))


number of train_queries 14968


In [11]:
"""
VALIDATION Data Preparation
"""
# for each type of structure / key
for query_structure in valid_queries:
    try:
        print(query_name_dict[query_structure[1]]+": "+str(len(valid_queries[query_structure])))
    except:
        print("error", query_structure)

valid_dataloader = DataLoader(
    TestDataset(
        flatten_query(valid_queries), 
        args.nentity, 
        args.nrelation, 
    ), 
    batch_size=args.test_batch_size,
    num_workers=args.cpu_num, 
    collate_fn=TestDataset.collate_fn
)


error (('e', ('r', 'r')), ('e', ('r', 'n')))


In [12]:
"""
MODEL DEFINITION
"""
model = KGReasoning(
    nentity=nentity,
    nrelation=nrelation,
    hidden_dim=args.hidden_dim,
    gamma=args.gamma,
    geo=args.geo,
    use_cuda = args.cuda,
    box_mode=eval_tuple(args.box_mode),
    beta_mode = eval_tuple(args.beta_mode),
    test_batch_size=args.test_batch_size,
    query_name_dict = query_name_dict
)

logging.info('Model Parameter Configuration:')
num_params = 0
for name, param in model.named_parameters():
    print('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
    if param.requires_grad:
        num_params += np.prod(param.size())
print('Parameter Number: %d' % num_params)

model = model.cuda()

current_learning_rate = args.learning_rate
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=current_learning_rate
)
warm_up_steps = args.max_steps // 2


Parameter gamma: torch.Size([1]), require_grad = False
Parameter embedding_range: torch.Size([1]), require_grad = False
Parameter entity_embedding: torch.Size([14505, 1000]), require_grad = True
Parameter relation_embedding: torch.Size([474, 500]), require_grad = True
Parameter center_net.layer1.weight: torch.Size([1000, 1000]), require_grad = True
Parameter center_net.layer1.bias: torch.Size([1000]), require_grad = True
Parameter center_net.layer2.weight: torch.Size([500, 1000]), require_grad = True
Parameter center_net.layer2.bias: torch.Size([500]), require_grad = True
Parameter projection_net.layer1.weight: torch.Size([1600, 1500]), require_grad = True
Parameter projection_net.layer1.bias: torch.Size([1600]), require_grad = True
Parameter projection_net.layer0.weight: torch.Size([1000, 1600]), require_grad = True
Parameter projection_net.layer0.bias: torch.Size([1000]), require_grad = True
Parameter projection_net.layer2.weight: torch.Size([1600, 1600]), require_grad = True
Paramet

# Training

In [13]:
"""
TRAIN LOOP
"""
init_step = 0

training_logs = []
# loop over all batches
# initialising train_path_iterator is expensive. So at certain loop it might take some time to load
for step in tqdm(range(init_step, args.max_steps)):
    # train model for a step over all batches
    log = model.train_step(model, optimizer, train_path_iterator, args, step)
    training_logs.append(log)

  0%|          | 0/100000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [28]:
# save the trained model
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    },
    os.path.join("checkpoint", "model-beta-er.pt")
)

## DEBUG TRAINING STEP

In [50]:
step = 0

# model.train_step(model, optimizer, train_other_iterator, args, step)

model, optimizer, train_iterator, args, step = model, optimizer, train_path_iterator, args, step

model.train()
optimizer.zero_grad()

# there's an overhead operation to shufflte the train_iterator that makes this expensive to run
positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator)

# group queries into batch
batch_queries_dict = collections.defaultdict(list)
batch_idxs_dict = collections.defaultdict(list)
for i, query in enumerate(batch_queries): # group queries with same structure
    batch_queries_dict[query_structures[i]].append(query)
    batch_idxs_dict[query_structures[i]].append(i)

for query_structure in batch_queries_dict:
    if args.cuda:
        batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda()
    else:
        batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])

if args.cuda:
    positive_sample = positive_sample.cuda()
    negative_sample = negative_sample.cuda()
    subsampling_weight = subsampling_weight.cuda()

# score positive and negative samples
positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)

# calculate loss
# positive samples should have label of 1 while negative samples label of 0
negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
positive_score = F.logsigmoid(positive_logit).mean(dim=1)
# aggregate loss with subsampling_weight
positive_sample_loss = - (subsampling_weight * positive_score).sum()
negative_sample_loss = - (subsampling_weight * negative_score).sum()
positive_sample_loss /= subsampling_weight.sum()
negative_sample_loss /= subsampling_weight.sum()
loss = (positive_sample_loss + negative_sample_loss)/2

loss.backward()
optimizer.step()

log = {
    'positive_sample_loss': positive_sample_loss.item(),
    'negative_sample_loss': negative_sample_loss.item(),
    'loss': loss.item(),
}
log

{'positive_sample_loss': 0.39944127202033997,
 'negative_sample_loss': 1.0847606658935547,
 'loss': 0.7421009540557861}

In [43]:
print("=== input ===")
print(positive_sample.shape, negative_sample.shape, len(batch_queries), len(query_structures), subsampling_weight.shape)
print("=== output ===")
print(positive_logit.shape, negative_logit.shape)

=== input ===
torch.Size([64]) torch.Size([64, 128]) 64 64 torch.Size([64])
=== output ===
torch.Size([64, 1]) torch.Size([64, 128])


# Testing

In [19]:
# load trained model
checkpoint = torch.load(os.path.join("checkpoint", "model-beta-er.pt"))
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [13]:
def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step):
    '''
    Evaluate queries in dataloader
    '''
    average_metrics = defaultdict(float)
    all_metrics = defaultdict(float)

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics:
        log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure])
        for metric in metrics[query_structure]:
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != 'num_queries':
                average_metrics[metric] += metrics[query_structure][metric]
        num_queries += metrics[query_structure]['num_queries']
        num_query_structures += 1

    for metric in average_metrics:
        average_metrics[metric] /= num_query_structures
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    log_metrics('%s average'%mode, step, average_metrics)

    return all_metrics


In [14]:
def log_metrics(mode, step, metrics):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))


In [17]:
# before training the model


In [15]:
"""`
EVAL
"""
step = -1
valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step)
valid_all_metrics

# after training the model

100%|██████████| 1667/1667 [02:33<00:00, 10.85it/s]


defaultdict(float,
            {'pin_MRR': 0.0007912687872187234,
             'pin_HITS1': 2.38745866343379e-05,
             'pin_HITS3': 0.00017614683248102665,
             'pin_HITS10': 0.001281430758163333,
             'pin_num_queries': 5000,
             'average_MRR': 0.0007912687872187234,
             'average_HITS1': 2.38745866343379e-05,
             'average_HITS3': 0.00017614683248102665,
             'average_HITS10': 0.001281430758163333})

## DEBUG TEST STEP

In [16]:
"""def evaluate in main.py"""

metrics = model.test_step(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict)
metrics

100%|██████████| 1667/1667 [02:34<00:00, 10.76it/s]


defaultdict(<function models.KGReasoning.test_step.<locals>.<lambda>()>,
            {(('e', ('r', 'r')), ('e', ('r', 'n'))): defaultdict(int,
                         {'MRR': 0.0007912687872187234,
                          'HITS1': 2.38745866343379e-05,
                          'HITS3': 0.00017614683248102665,
                          'HITS10': 0.001281430758163333,
                          'num_queries': 5000})})

In [17]:
for negative_sample, queries, queries_unflatten, query_structures in tqdm(valid_dataloader, disable=not args.print_on_screen):
    break
    
negative_sample, queries, 

  0%|          | 0/1667 [00:00<?, ?it/s]

(tensor([[    0,     1,     2,  ..., 14502, 14503, 14504],
         [    0,     1,     2,  ..., 14502, 14503, 14504],
         [    0,     1,     2,  ..., 14502, 14503, 14504]]),
 [[2250, 456, 63, 160, 29, -2],
  [11233, 216, 69, 32, 57, -2],
  [90, 333, 202, 3161, 202, -2]])

1. why is there no positive_sample in test_dataloader
2. how to embed query vector and perform projection & intersection

In [18]:
"""model.test_step in models.py"""
model, easy_answers, hard_answers, args, test_dataloader, query_name_dict = model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict

model.eval()

step = 0
total_steps = len(test_dataloader)
logs = collections.defaultdict(list)

with torch.no_grad():
    # process per batch
    # loop through every validation dataset
    # negative sample is basically all entities in nentity
    for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen):
        batch_queries_dict = collections.defaultdict(list)
        batch_idxs_dict = collections.defaultdict(list)
        
        # for each test query
        for i, query in enumerate(queries):
            batch_queries_dict[query_structures[i]].append(query)
            batch_idxs_dict[query_structures[i]].append(i)
            
        # convert each positive query to a LongTensor
        for query_structure in batch_queries_dict:
            if args.cuda:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda()
            else:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])
            
        # convert each negative query to a LongTensor
        if args.cuda:
            negative_sample = negative_sample.cuda()
            
        # get prediction output
        # (2, number_of_edges) same dimension with negative_sample
        # this is basically embedding lookup
        """model.forward"""
        # negative_sample here is just a list of all entities, some of them are actually positive
        _, negative_logit, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict)
          
        # ... scoring ...
        queries_unflatten = [queries_unflatten[i] for i in idxs]
        query_structures = [query_structures[i] for i in idxs]
        # sort from maximum value based on logit values
        argsort = torch.argsort(negative_logit, dim=1, descending=True)
        ranking = argsort.clone().to(torch.float)
        if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one
            ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities
        else: # otherwise, create a new torch Tensor for batch_entity_range
            if args.cuda:
                ranking = ranking.scatter_(1, argsort, torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).cuda()) # achieve the ranking of all entities
            else:
                ranking = ranking.scatter_(1, argsort, torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1)) # achieve the ranking of all entities
        
        # loop through ranking
        # score every query in the dataloader batch
        for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)):
            # get groundtruth labels
            hard_answer = hard_answers[query]
            easy_answer = easy_answers[query]
            num_hard = len(hard_answer)
            num_easy = len(easy_answer)
            assert len(hard_answer.intersection(easy_answer)) == 0

            # compare ranking of groundtruth (easy_answer, hard_answer) and predicted results
            cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)]
            cur_ranking, indices = torch.sort(cur_ranking)
            masks = indices >= num_easy
            if args.cuda:
                answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda()
            else:
                answer_list = torch.arange(num_hard + num_easy).to(torch.float)
            cur_ranking = cur_ranking - answer_list + 1 # filtered setting
            cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers

            mrr = torch.mean(1./cur_ranking).item()
            h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item()
            h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item()
            h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item()

            logs[query_structure].append({
                'MRR': mrr,
                'HITS1': h1,
                'HITS3': h3,
                'HITS10': h10,
                'num_hard_answer': num_hard,
            })

metrics = collections.defaultdict(lambda: collections.defaultdict(int))
for query_structure in logs:
    for metric in logs[query_structure][0].keys():
        if metric in ['num_hard_answer']:
            continue
        metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure])
    metrics[query_structure]['num_queries'] = len(logs[query_structure])

metrics

  0%|          | 0/1667 [00:00<?, ?it/s]

defaultdict(<function __main__.<lambda>()>,
            {(('e', ('r', 'r')), ('e', ('r', 'n'))): defaultdict(int,
                         {'MRR': 0.0007912687872187234,
                          'HITS1': 2.38745866343379e-05,
                          'HITS3': 0.00017614683248102665,
                          'HITS10': 0.001281430758163333,
                          'num_queries': 5000})})

In [116]:
"""model.forward"""
# model.forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)

# model.geo == "beta"

# done per batch
"""model.forward_beta"""
# compared to vec, just add union of offset embeddings

# model.forward_beta(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
# negative_sample is a list of entities to be scored as potential answers to the queries
# batch_queries_dict stores list of queries (in this case (e, r) to be scored)
positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict = None, negative_sample, None, batch_queries_dict, batch_idxs_dict

all_idxs, all_alpha_embeddings, all_beta_embeddings = [], [], []
all_union_idxs, all_union_alpha_embeddings, all_union_beta_embeddings = [], [], []
for query_structure in batch_queries_dict:
    if 'u' in model.query_name_dict[query_structure] and 'DNF' in model.query_name_dict[query_structure]:
        # embedding 
        alpha_embedding, beta_embedding, _ = model.embed_query_beta(model.transform_union_query(batch_queries_dict[query_structure], 
                                                            query_structure), 
                                                        model.transform_union_structure(query_structure), 0)
        all_union_idxs.extend(batch_idxs_dict[query_structure])
        all_union_alpha_embeddings.append(alpha_embedding)
        all_union_beta_embeddings.append(beta_embedding)
    else:
        alpha_embedding, beta_embedding, _ = model.embed_query_beta(batch_queries_dict[query_structure], query_structure, 0)
        all_idxs.extend(batch_idxs_dict[query_structure])
        all_alpha_embeddings.append(alpha_embedding)
        all_beta_embeddings.append(beta_embedding)

if len(all_alpha_embeddings) > 0:
    all_alpha_embeddings = torch.cat(all_alpha_embeddings, dim=0).unsqueeze(1)
    all_beta_embeddings = torch.cat(all_beta_embeddings, dim=0).unsqueeze(1)
    all_dists = torch.distributions.beta.Beta(all_alpha_embeddings, all_beta_embeddings)
if len(all_union_alpha_embeddings) > 0:
    all_union_alpha_embeddings = torch.cat(all_union_alpha_embeddings, dim=0).unsqueeze(1)
    all_union_beta_embeddings = torch.cat(all_union_beta_embeddings, dim=0).unsqueeze(1)
    all_union_alpha_embeddings = all_union_alpha_embeddings.view(all_union_alpha_embeddings.shape[0]//2, 2, 1, -1)
    all_union_beta_embeddings = all_union_beta_embeddings.view(all_union_beta_embeddings.shape[0]//2, 2, 1, -1)
    all_union_dists = torch.distributions.beta.Beta(all_union_alpha_embeddings, all_union_beta_embeddings)

if type(subsampling_weight) != type(None):
    subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]

# in test, positive samples are not given
if type(positive_sample) != type(None):
    if len(all_alpha_embeddings) > 0:
        positive_sample_regular = positive_sample[all_idxs]
        positive_embedding = torch.index_select(model.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1)
        positive_logit = model.cal_logit_beta(positive_embedding, all_dists)
    else:
        positive_logit = torch.Tensor([]).to(model.entity_embedding.device)

    if len(all_union_alpha_embeddings) > 0:
        positive_sample_union = positive_sample[all_union_idxs]
        positive_embedding = model.entity_regularizer(
            torch.index_select(model.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1)
        )
        positive_union_logit = model.cal_logit_beta(positive_embedding, all_union_dists)
        # take max over all
        positive_union_logit = torch.max(positive_union_logit, dim=1)[0]
    else:
        positive_union_logit = torch.Tensor([]).to(model.entity_embedding.device)
        
    # combine normal embeddings and union
    positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0)
else:
    positive_logit = None

# in test, negative samples are basically all possible list of entities
if type(negative_sample) != type(None):
    if len(all_alpha_embeddings) > 0:
        negative_sample_regular = negative_sample[all_idxs]
        batch_size, negative_size = negative_sample_regular.shape
        negative_embedding = model.entity_regularizer(
            torch.index_select(model.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1)
        )
        negative_logit = model.cal_logit_beta(negative_embedding, all_dists)
    else:
        negative_logit = torch.Tensor([]).to(model.entity_embedding.device)

    if len(all_union_alpha_embeddings) > 0:
        negative_sample_union = negative_sample[all_union_idxs]
        batch_size, negative_size = negative_sample_union.shape
        negative_embedding = model.entity_regularizer(
            torch.index_select(model.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1)
        )
        negative_union_logit = model.cal_logit_beta(negative_embedding, all_union_dists)
        negative_union_logit = torch.max(negative_union_logit, dim=1)[0]
    else:
        negative_union_logit = torch.Tensor([]).to(model.entity_embedding.device)
    negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0)
else:
    negative_logit = None

positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs

RuntimeError: CUDA error: device-side assert triggered

In [48]:
"""model.embed_query_beta in models.py"""
# get the query embedding
# input is (batch_size, query_type_length=2), query_structure=(e, r)
# model.embed_query_beta(batch_queries_dict[query_structure], query_structure, 0)

queries, query_structure, idx = batch_queries_dict[query_structure], query_structure, 0

all_relation_flag = True
for ele in query_structure[-1]:
    print(ele)
    # check that this is not a simple (..., r)
    if ele not in ['r', 'n']:
        # if there are double relations then all_relation_flag is True
        all_relation_flag = False
        break
        
# check that the second tuple is also relation or negation only (r, n)
if all_relation_flag:
    # first part
    # if first part is anchor
    if query_structure[0] == "e":
        # get entity embedding
        embedding = model.entity_regularizer(
            torch.index_select(model.entity_embedding, dim=0, index=queries[:, idx])
        )
        idx += 1
    # if first part is a tuple, call the function again i.e. recursion
    else:
        # recursion
        alpha_embedding, beta_embedding, idx = model.embed_query_beta(queries, query_structure[0], idx)

    # second part
    for i in range(len(query_structure[-1])):
        # if any of the second part is negation, inverse the embedding
        if query_structure[-1][i] == "n":
            assert (queries[:, idx] == -2).all()
            embedding = 1./embedding
        # else get the relation embedding
        else:
            # get relation embedding
            r_embedding = torch.index_select(model.relation_embedding, dim=0, index=queries[:, idx])
            # add entity and relation embedding
            embedding = model.projection_net(embedding, r_embedding)
        idx += 1
    alpha_embedding, beta_embedding = torch.chunk(embedding, 2, dim=-1)
else:
    alpha_embedding_list = []
    beta_embedding_list = []
    # for each query_structure
    # find 
    for i in range(len(query_structure)):
        alpha_embedding, beta_embedding, idx = model.embed_query_beta(queries, query_structure[i], idx)
        alpha_embedding_list.append(alpha_embedding)
        beta_embedding_list.append(beta_embedding)
    alpha_embedding, beta_embedding = model.center_net(torch.stack(alpha_embedding_list), torch.stack(beta_embedding_list))
    
alpha_embedding, beta_embedding, idx

e


(tensor([[ 5.8093,  1.0532,  1.2320,  2.1677,  1.0971,  1.4936,  0.7001,  0.9773,
           0.9894,  1.1520,  0.9649,  1.5259,  1.1835,  1.1495,  1.2265,  0.9741,
           0.8203,  2.3615,  0.3846,  1.0021,  1.1155,  2.1380,  1.1384,  0.8335,
           0.8611,  1.4863,  0.9719,  1.1760,  0.9690,  0.9143,  1.5212,  1.0361,
           1.5558,  1.2096,  0.8359,  1.0983,  0.7586,  1.0377,  1.4682,  0.4126,
           1.4653,  0.6982,  0.3125,  0.9255,  1.1584,  0.8919,  0.8637,  1.0260,
           2.0751,  0.7999,  0.9885,  1.5400,  0.5941,  0.6524,  1.3403,  1.3706,
           0.8378,  0.9711,  2.1258,  0.9225,  1.3712,  1.6246,  1.3157,  1.4441,
           0.7829,  1.2202,  0.9627,  1.0495,  1.2817,  0.9193,  1.0477,  0.7604,
           1.2056,  1.0763,  0.9099,  1.1671,  0.9200,  0.6366,  0.9358,  0.9655,
           0.9517,  2.0458,  1.1489,  0.8780,  1.0216,  1.0172,  0.6254,  1.0263,
           0.9308,  1.1204,  1.1252,  0.8459,  1.8095,  0.8713,  0.9986,  0.9017,
           1.500

In [113]:
# define model.projection_net in model.embed_query_beta
# embedding = model.projection_net(embedding, r_embedding)

# model.projection_net
print(model.projection_net.entity_dim, model.projection_net.relation_dim, model.projection_net.hidden_dim, model.projection_net.num_layers)

class BetaProjection(nn.Module):
    def __init__(self, entity_dim, relation_dim, hidden_dim, projection_regularizer, num_layers):
        super(BetaProjection, self).__init__()
        self.entity_dim = entity_dim
        self.relation_dim = relation_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim)
        self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim)
        
        for nl in range(2, num_layers + 1):
            setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim))
        for nl in range(num_layers + 1):
            nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight)
        
        self.projection_regularizer = projection_regularizer
        
    def forward(self, e_embedding, r_embedding):
        print("e_embedding.shape, r_embedding.shape | ", e_embedding.shape, r_embedding.shape)
        x = torch.cat([e_embedding, r_embedding], dim=-1)
        print("x.shape | ", x.shape)
        for nl in range(1, self.num_layers+1):
            x = F.relu(getattr(self, "layer{}".format(nl))(x))
        x = self.layer0(x)
        x = self.projection_regularizer(x)
        
        return x
    
proj_model = BetaProjection(entity_dim=1000, relation_dim=500, hidden_dim=1600, 
                            projection_regularizer=model.projection_regularizer, num_layers=2).to("cpu")

queries, query_structure, idx = batch_queries_dict[query_structure], query_structure, 0
embedding = model.entity_regularizer(
    torch.index_select(model.entity_embedding, dim=0, index=queries[:, idx])
).to("cpu")
r_embedding = torch.index_select(model.relation_embedding, dim=0, index=queries[:, idx]).to("cpu")


proj_model(embedding, r_embedding)

1000 500 1600 2


RuntimeError: CUDA error: device-side assert triggered

In [86]:
## define model.center_net in model.embed_query_beta
import torch.nn as nn

class BetaIntersection(nn.Module):
    def __init__(self, dim):
        super(BetaIntersection, self).__init__()
        self.dim = dim
        self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim)
        self.layer2 = nn.Linear(2 * self.dim, self.dim)
        
        nn.init.xavier_uniform_(self.layer1.weight)
        nn.init.xavier_uniform_(self.layer2.weight)
    
    def forward(self, alpha_embeddings, beta_embeddings):
        print("alpha_embeddings.shape, beta_embeddings.shape |", alpha_embeddings.shape, beta_embeddings.shape)
        
        # input: [(batch_size, no_query_couple, feat_dim), (batch_size, no_query_couple, feat_dim)]
        # output: (batch_size, no_query_couple, 2*feat_dim)
        all_embeddings = torch.cat([alpha_embeddings, beta_embeddings], dim=-1)
        print("all_embeddings.shape | ", all_embeddings.shape)
        # (batch_size, no_query_couple, 2*feat_dim)
        layer1_act = F.relu(self.layer1(all_embeddings))
        print("layer1_act.shape | ", layer1_act.shape)
        # (batch_size, no_query_couple, feat_dim)
        attention = F.softmax(self.layer2(layer1_act), dim=0)
        print("attention.shape | ", attention.shape)

        # (batch_size, feat_dim)
        alpha_embedding = torch.sum(attention * alpha_embeddings, dim=0)
        print("alpha_embedding.shape | ", alpha_embedding.shape)
        # (batch_size, feat_dim)
        beta_embedding = torch.sum(attention * beta_embeddings, dim=0)
        print("beta_embedding.shape | ", beta_embedding.shape)
        
        return alpha_embedding, beta_embedding
    
idx = 0
alpha_embedding_list = []
beta_embedding_list = []
# for each query_structure
# find 
for i in range(len(query_structure)):
    alpha_embedding, beta_embedding, idx = model.embed_query_beta(queries, query_structure[i], idx)
    alpha_embedding_list.append(alpha_embedding)
    beta_embedding_list.append(beta_embedding)

alpha_embedding, beta_embedding = model.center_net(torch.stack(alpha_embedding_list), torch.stack(beta_embedding_list))

emb_model = BetaIntersection(dim = 500).to("cuda")
_, _ = emb_model(torch.stack(alpha_embedding_list), torch.stack(beta_embedding_list))

alpha_embeddings.shape, beta_embeddings.shape | torch.Size([2, 2, 500]) torch.Size([2, 2, 500])
all_embeddings.shape |  torch.Size([2, 2, 1000])
layer1_act.shape |  torch.Size([2, 2, 1000])
attention.shape |  torch.Size([2, 2, 500])
alpha_embedding.shape |  torch.Size([2, 500])
beta_embedding.shape |  torch.Size([2, 500])


In [None]:
""" model.cal_logit_beta """
model.cal_logit_beta(negative_embedding, all_dists)

# score / calculate distance between candidate embedding and query embedding
# negative 
print(negative_embedding.shape, all_dists)

entity_embedding, query_dist = negative_embedding, all_dists

alpha_embedding, beta_embedding = torch.chunk(entity_embedding, 2, dim=-1)
entity_dist = torch.distributions.beta.Beta(alpha_embedding, beta_embedding)
logit = model.gamma - torch.norm(torch.distributions.kl.kl_divergence(entity_dist, query_dist), p=1, dim=-1)
logit