In [12]:
from functools import partial
import random
import wandb
import sys

# MyTorch imports
from mytorch.utils.goodies import *

# Local imports
from parse_wd15k import Quint
from load import DataManager
from utils import *
from evaluation import EvaluationBench, acc, mrr, mr, hits_at, evaluate_pointwise
from models import TransE, BaseModule
from corruption import Corruption
from sampler import SimpleSampler
from loops import training_loop

"""
    CONFIG Things
"""

# Clamp the randomness
np.random.seed(42)
random.seed(42)

"""
    Explanation:
        *ENT_POS_FILTERED* 
            a flag which if False, implies that while making negatives, 
                we should exclude entities that appear ONLY in non-corrupting positions.
            Do not turn it off if the experiment is about predicting qualifiers, of course.

        *POSITIONS*
            the positions on which we should inflect the negatives.
"""
DEFAULT_CONFIG = {
    'EMBEDDING_DIM': 50,
    'NORM_FOR_NORMALIZATION_OF_ENTITIES': 2,
    'NORM_FOR_NORMALIZATION_OF_RELATIONS': 2,
    'SCORING_FUNCTION_NORM': 1,
    'MARGIN_LOSS': 1,
    'LEARNING_RATE': 0.001,
    'NEGATIVE_SAMPLING_PROBS': [0.3, 0.0, 0.2, 0.5],
    'NEGATIVE_SAMPLING_TIMES': 10,
    'BATCH_SIZE': 64,
    'EPOCHS': 1000,
    'STATEMENT_LEN': -1,
    'EVAL_EVERY': 10,
    'WANDB': False,
    'RUN_TESTBENCH_ON_TRAIN': True,
    'DATASET': 'wd15k',
    'CORRUPTION_POSITIONS': [0, 2],
    'DEVICE': 'cpu',
    'ENT_POS_FILTERED': True,
    'USE_TEST': False,
    'MAX_QPAIRS': 43,
    'NUM_FILTERS': 10,
    'PROJECT_QUALIFIERS': False,
    'SELF_ATTENTION'
}

In [13]:
# Custom Sanity Checks
if DEFAULT_CONFIG['DATASET'] == 'wd15k':
    assert DEFAULT_CONFIG['STATEMENT_LEN'] is not None, \
        "You use WD15k dataset and don't specify whether to treat them as quints or not. Nicht cool'"
if max(DEFAULT_CONFIG['CORRUPTION_POSITIONS']) > 2:     # If we're corrupting something apart from S and O
    assert DEFAULT_CONFIG['ENT_POS_FILTERED'] is False, \
        f"Since we're corrupting objects at pos. {DEFAULT_CONFIG['CORRUPTION_POSITIONS']}, " \
        f"You must allow including entities which appear exclusively in qualifiers, too!"

"""
    Load data based on the args/config
"""
data = DataManager.load(config=DEFAULT_CONFIG)()
try:
    training_triples, valid_triples, test_triples, num_entities, num_relations, e2id, r2id = data.values()
except ValueError:
    raise ValueError(f"Honey I broke the loader for {DEFAULT_CONFIG['DATASET']}")

if DEFAULT_CONFIG['ENT_POS_FILTERED']:
    ent_excluded_from_corr = DataManager.gather_missing_entities(data=training_triples + valid_triples + test_triples,
                                                                 positions=DEFAULT_CONFIG['CORRUPTION_POSITIONS'],
                                                                 n_ents=num_entities)
    DEFAULT_CONFIG['NUM_ENTITIES_FILTERED'] = len(ent_excluded_from_corr)
else:
    ent_excluded_from_corr = []
    DEFAULT_CONFIG['NUM_ENTITIES_FILTERED'] = len(ent_excluded_from_corr)

print(num_entities-DEFAULT_CONFIG['NUM_ENTITIES_FILTERED'])
DEFAULT_CONFIG['NUM_ENTITIES'] = num_entities
DEFAULT_CONFIG['NUM_RELATIONS'] = num_relations

43051


In [14]:
config = DEFAULT_CONFIG.copy()
config['DEVICE'] = torch.device(config['DEVICE'])


# YOUR MODEL COMES HERE

In [15]:
import torch.autograd
import torch.nn.functional as F


In [16]:
class ConvKB(BaseModule):
    """
    An implementation of ConvKB.
    
    A Novel Embedding Model for Knowledge Base CompletionBased on Convolutional Neural Network. 
    """

    model_name = 'ConvKB'

    def __init__(self, config) -> None:

        self.margin_ranking_loss_size_average: bool = True
        self.entity_embedding_max_norm: Optional[int] = None
        self.entity_embedding_norm_type: int = 2
        self.model_name = 'ConvKB'
        super().__init__(config)
        self.statement_len = config['STATEMENT_LEN']

        # Embeddings
        self.l_p_norm_entities = config['NORM_FOR_NORMALIZATION_OF_ENTITIES']
        self.scoring_fct_norm = config['SCORING_FUNCTION_NORM']
        self.relation_embeddings = nn.Embedding(config['NUM_RELATIONS'], config['EMBEDDING_DIM'], padding_idx=0)

        self.config = config

        self.criterion = nn.SoftMarginLoss(
            reduction='sum'
        )
        
        self.conv = nn.Conv2d(in_channels=1, 
                              out_channels=config['NUM_FILTER'], kernel_size= (config['MAX_QPAIRS'],1), 
                             bias=True)
        
        self.fc = nn.Linear(config['NUM_FILTER']*self.embedding_dim,1, bias=False)
        
        self._initialize()

        # Make pad index zero. # TODO: Should pad index be configurable? Probably not, right? Cool? Cool.
        # self.entity_embeddings.weight.data[0] = torch.zeros_like(self.entity_embeddings.weight[0], requires_grad=True)
        # self.relation_embeddings.weight.data[0] = torch.zeros_like(self.relation_embeddings.weight[0], requires_grad=True)

    def _initialize(self):
        embeddings_init_bound = 6 / np.sqrt(self.config['EMBEDDING_DIM'])
        nn.init.uniform_(
            self.entity_embeddings.weight.data,
            a=-embeddings_init_bound,
            b=+embeddings_init_bound,
        )
        nn.init.uniform_(
            self.relation_embeddings.weight.data,
            a=-embeddings_init_bound,
            b=+embeddings_init_bound,
        )

        norms = torch.norm(self.relation_embeddings.weight,
                           p=self.config['NORM_FOR_NORMALIZATION_OF_RELATIONS'], dim=1).data
        self.relation_embeddings.weight.data = self.relation_embeddings.weight.data.div(
            norms.view(self.num_relations, 1).expand_as(self.relation_embeddings.weight))

        self.relation_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)
        self.entity_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)  # zeroing the padding index
        
        

    def predict(self, triples):
        scores = self._score_triples(triples)
        return scores
    
    def _compute_loss(self, positive_scores: torch.Tensor, negative_scores: torch.Tensor) -> torch.Tensor:
        # Let n items in pos score.
        y = np.repeat([-1], repeats=positive_scores.shape[0])          # n item here (all -1)
        y = torch.tensor(y, dtype=torch.float, device=self.device)

        pos_loss = self.criterion(positive_scores, -1*y)
        neg_loss = self.criterion(negative_scores, y)
        return pos_loss + neg_loss

    def forward(self, batch_positives, batch_negatives) \
            -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

        # Normalize embeddings of entities
        norms = torch.norm(self.entity_embeddings.weight, p=self.l_p_norm_entities, dim=1).data
        
        self.entity_embeddings.weight.data = self.entity_embeddings.weight.data.div(
            norms.view(self.num_entities, 1).expand_as(self.entity_embeddings.weight))
        
        self.entity_embeddings.weight.data[0] = torch.zeros(1, self.embedding_dim)  # zeroing the padding index

        positive_scores = self._score_triples(batch_positives)
        negative_scores = self._score_triples(batch_negatives)
        loss = self._compute_loss(positive_scores=positive_scores, negative_scores=negative_scores)
        return (positive_scores, negative_scores), loss

    def _score_triples(self, triples) -> torch.Tensor:
        """ Get triple/quint embeddings, and compute scores """
        scores = self._compute_scores(*self._get_triple_embeddings(triples))
        return scores


    def _compute_scores(self, head_embeddings, relation_embeddings, tail_embeddings,
                        qual_relation_embeddings=None, qual_entity_embeddings=None):
        """
            Compute the scores based on the head, relation, and tail embeddings.

        :param head_embeddings: embeddings of head entities of dimension batchsize x embedding_dim
        :param relation_embeddings: embeddings of relation embeddings of dimension batchsize x embedding_dim
        :param tail_embeddings: embeddings of tail entities of dimension batchsize x embedding_dim
        :param qual_entity_embeddings: embeddings of qualifier relations of dimensinos batchsize x embeddig_dim
        :param qual_relation_embeddings: embeddings of qualifier entities of dimension batchsize x embedding_dim
        :return: Tensor of dimension batch_size containing the scores for each batch element
        """
        
        statement_emb = torch.zeros(head_embeddings.shape[0],
                                    relation_embeddings.shape[1]*2+1,
                                     head_embeddings.shape[1], 
                                 device=self.config['DEVICE'],
                                   dtype=head_embeddings.dtype) # 1 for head embedding
        
        # Assignment
        statement_emb[:,0] = head_embeddings
        statement_emb[:,1::2] = relation_embeddings
        statement_emb[:,2::2] = tail_embeddings
        
        
        # Convolutional operation
        statement_emb = F.relu(self.conv(statement_emb.unsqueeze(1))).squeeze(-1) # bs*number_of_filter*emb_dim            
        statement_emb = statement_emb.view(statement_emb.shape[0], -1)
        score = self.fc(statement_emb)
        
        return score.squeeze()

    def _get_triple_embeddings(self, triples):
        
        head, statement_entities, statement_relations = slice_triples(triples, -1)
        return (
            self._get_entity_embeddings(head),
            self.relation_embeddings(statement_relations),
            self.entity_embeddings(statement_entities)
        )

    def _get_relation_embeddings(self, relations):
        return self.relation_embeddings(relations).view(-1, self.embedding_dim)


In [17]:
model = TransE(config)
model.to(config['DEVICE'])
optimizer = torch.optim.SGD(model.parameters(), lr=config['LEARNING_RATE'])

In [18]:
data = {'index': np.array(training_triples + test_triples), 'eval': np.array(valid_triples)}
_data = {'index': np.array(valid_triples + test_triples), 'eval': np.array(training_triples)}
tr_data = {'train': np.array(training_triples), 'valid': data['eval']}

eval_metrics = [acc, mrr, mr, partial(hits_at, k=3), partial(hits_at, k=5), partial(hits_at, k=10)]
evaluation_valid = EvaluationBench(data, model, bs=8000,
                                   metrics=eval_metrics, filtered=True,
                                   n_ents=num_entities,
                                   excluding_entities=ent_excluded_from_corr,
                                   positions=config.get('CORRUPTION_POSITIONS', None))
evaluation_train = EvaluationBench(_data, model, bs=8000,
                                   metrics=eval_metrics, filtered=True,
                                   n_ents=num_entities,
                                   excluding_entities=ent_excluded_from_corr,
                                   positions=config.get('CORRUPTION_POSITIONS', None), trim=0.01)

In [19]:
args = {
        "epochs": config['EPOCHS'],
        "data": tr_data,
        "opt": optimizer,
        "train_fn": model,
        "neg_generator": Corruption(n=num_entities, excluding=ent_excluded_from_corr,
                                    position=list(range(0, config['MAX_QPAIRS'], 2))),
        "device": config['DEVICE'],
        "data_fn": partial(SimpleSampler, bs=config["BATCH_SIZE"]),
        "eval_fn_trn": evaluate_pointwise,
        "val_testbench": evaluation_valid.run,
        "trn_testbench": evaluation_train.run,
        "eval_every": config['EVAL_EVERY'],
        "log_wandb": config['WANDB'],
        "run_trn_testbench": config['RUN_TESTBENCH_ON_TRAIN']
    }

In [20]:
traces = training_loop(**args)

HBox(children=(IntProgress(value=0, max=2646), HTML(value='')))




KeyError: 'SELF_ATTENTION'