# TransE

**Data**: s, p, o, qp, qe

$s,p,o,qe,qp \in R^{d}$

## Challenge
1. Model
2. Sampling Code


**Model** from [pyKeen](https://github.com/SmartDataAnalytics/PyKEEN/blob/master/src/pykeen/kge_models/trans_e.py)

In [26]:
import logging
# from dataclasses import dataclass
from typing import Dict

import numpy as np
import torch
import torch.autograd
from torch import nn

import pandas as  pd
import pickle

from typing import Optional, Union, List

In [15]:
# Local imports
from raw_parser import *

## Basic Stats
n_entities, n_relations etc

In [16]:
# Load data from disk
with open('./parsed_raw_data.pkl', 'rb') as f:
    raw_data = pickle.load(f)

In [17]:
entities, predicates = [], []

for quint in raw_data:
    entities += [quint[0], quint[2]]
    if quint[4]:
        entities.append(quint[4])
        
    predicates.append(quint[1])
    if quint[3]: 
        predicates.append(quint[3])
    
entities = list(set(entities))
predicates = list(set(predicates))
        
print(len(entities), len(predicates))

174951 659


In [18]:
TRANS_E_CONFIG = {
    'NUM_ENTITIES': len(entities),
    'NUM_RELATIONS': len(predicates),
    'EMBEDDING_DIM': 200,
    'NORM_FOR_NORMALIZATION_OF_ENTITIES': 2,
    'NORM_FOR_NORMALIZATION_OF_RELATIONS': 2,
    'SCORING_FUNCTION_NORM': 1,
    'MARGIN_LOSS': 4,
    'LEARNING_RATE': 0.0001
}

In [29]:
def slice_triples(triples: torch.Tensor) -> List[torch.Tensor]:
    """ Slice in 3 or 5 as needed """
    return triples[:,0], triples[:,1], triples[:,2], triples[:,3], triples[:, 4]
    

In [33]:
# Test slice_triples
t = torch.randint(0, 100, (10, 5))
t, slice_triples(t)

(tensor([[40, 71, 47, 60, 82],
         [14, 66, 86, 91, 81],
         [14, 77,  8, 89, 72],
         [31, 23,  0, 79, 61],
         [14, 61, 66, 62, 17],
         [30, 86, 61, 83, 40],
         [ 5, 12, 16, 58, 29],
         [62, 51, 53, 90, 52],
         [99,  1,  5, 85, 43],
         [49, 75, 36, 73, 37]]),
 (tensor([40, 14, 14, 31, 14, 30,  5, 62, 99, 49]),
  tensor([71, 66, 77, 23, 61, 86, 12, 51,  1, 75]),
  tensor([47, 86,  8,  0, 66, 61, 16, 53,  5, 36]),
  tensor([60, 91, 89, 79, 62, 83, 58, 90, 85, 73]),
  tensor([82, 81, 72, 61, 17, 40, 29, 52, 43, 37])))

In [19]:
class BaseModule(nn.Module):
    """A base class for all of the models."""

    margin_ranking_loss_size_average: bool = None
    entity_embedding_max_norm: Optional[int] = None
    entity_embedding_norm_type: int = 2
    hyper_params = [TRANS_E_CONFIG['EMBEDDING_DIM'], 
                    TRANS_E_CONFIG['MARGIN_LOSS'], 
                    TRANS_E_CONFIG['LEARNING_RATE']]

    def __init__(self, config: Dict) -> None:
        super().__init__()

        # Device selection
        self.device = config['DEVICE']

        # Loss
        self.margin_loss = config['MARGIN_LOSS']
        self.criterion = nn.MarginRankingLoss(
            margin=self.margin_loss,
            reduction='mean' if self.margin_ranking_loss_size_average else 'sum'
        )

        # Entity dimensions
        #: The number of entities in the knowledge graph
        self.num_entities = config['NUM_ENTITIES']
        #: The number of unique relation types in the knowledge graph
        self.num_relations = config['NUM_RELATIONS']
        #: The dimension of the embeddings to generate
        self.embedding_dim = config['EMBEDDING_DIM']

        self.entity_embeddings = nn.Embedding(
            self.num_entities,
            self.embedding_dim,
            norm_type=self.entity_embedding_norm_type,
            max_norm=self.entity_embedding_max_norm,
        )

    def __init_subclass__(cls, **kwargs):  # noqa: D105
        if not getattr(cls, 'model_name', None):
            raise TypeError('missing model_name class attribute')

    def _get_entity_embeddings(self, entities):
        return self.entity_embeddings(entities).view(-1, self.embedding_dim)

    def _compute_loss(self, positive_scores: torch.Tensor, negative_scores: torch.Tensor) -> torch.Tensor:
        y = np.repeat([-1], repeats=positive_scores.shape[0])
        y = torch.tensor(y, dtype=torch.float, device=self.device)

        loss = self.criterion(positive_scores, negative_scores, y)
        return loss

In [20]:
class TransE(nn.Module):
    """An implementation of TransE [borders2013]_.
     This model considers a relation as a translation from the head to the tail entity.
    .. [borders2013] Bordes, A., *et al.* (2013). `Translating embeddings for modeling multi-relational data
                     <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_
                     . NIPS.
    .. seealso::
       - Alternative implementation in OpenKE: https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/models/TransE.py
    """

    model_name = 'TransE MM'
    margin_ranking_loss_size_average: bool = True
    entity_embedding_max_norm: Optional[int] = None
    entity_embedding_norm_type: int = 2
    
    def __init__(self, config) -> None:
        super().__init__()

        # Embeddings
        self.l_p_norm_entities = config['NORM_FOR_NORMALIZATION_OF_ENTITIES']
        self.scoring_fct_norm = config['SCORING_FUNCTION_NORM']
        self.relation_embeddings = nn.Embedding(config['NUM_RELATIONS'], config['EMBEDDING_DIM'])
        
        self.config = config

        self._initialize()

    def _initialize(self):
        embeddings_init_bound = 6 / np.sqrt(self.config['EMBEDDING_DIM'])
        nn.init.uniform_(
            self.entity_embeddings.weight.data,
            a=-embeddings_init_bound,
            b=+embeddings_init_bound,
        )
        nn.init.uniform_(
            self.relation_embeddings.weight.data,
            a=-embeddings_init_bound,
            b=+embeddings_init_bound,
        )

        norms = torch.norm(self.relation_embeddings.weight, 
                           p=self.config['NORM_FOR_NORMALIZATION_OF_RELATIONS'], dim=1).data
        self.relation_embeddings.weight.data = self.relation_embeddings.weight.data.div(
            norms.view(self.num_relations, 1).expand_as(self.relation_embeddings.weight))

    def predict(self, triples):
        scores = self._score_triples(triples)
        return scores.detach().cpu().numpy()

    def forward(self, batch_positives, batch_negatives):
        # Normalize embeddings of entities
        norms = torch.norm(self.entity_embeddings.weight, p=self.l_p_norm_entities, dim=1).data
        self.entity_embeddings.weight.data = self.entity_embeddings.weight.data.div(
            norms.view(self.num_entities, 1).expand_as(self.entity_embeddings.weight))

        positive_scores = self._score_triples(batch_positives)
        negative_scores = self._score_triples(batch_negatives)
        loss = self._compute_loss(positive_scores=positive_scores, negative_scores=negative_scores)
        return loss

    def _score_triples(self, triples):
        # TODO: fix/change
        head_embeddings, relation_embeddings, tail_embeddings = self._get_triple_embeddings(triples)
        scores = self._compute_scores(head_embeddings, relation_embeddings, tail_embeddings)
        return scores

    def _compute_scores(self, head_embeddings, relation_embeddings, tail_embeddings):
        """Compute the scores based on the head, relation, and tail embeddings.
        :param head_embeddings: embeddings of head entities of dimension batchsize x embedding_dim
        :param relation_embeddings: emebddings of relation embeddings of dimension batchsize x embedding_dim
        :param tail_embeddings: embeddings of tail entities of dimension batchsize x embedding_dim
        :return: Tensor of dimension batch_size containing the scores for each batch element
        """
        # Add the vector element wise
        sum_res = head_embeddings + relation_embeddings - tail_embeddings
        distances = torch.norm(sum_res, dim=1, p=self.scoring_fct_norm).view(size=(-1,))
        return distances

    def _get_triple_embeddings(self, triples):
        heads, relations, tails = slice_triples(triples)
        return (
            self._get_entity_embeddings(heads),
            self._get_relation_embeddings(relations),
            self._get_entity_embeddings(tails),
        )

    def _get_relation_embeddings(self, relations):
        return self.relation_embeddings(relations).view(-1, self.embedding_dim)

In [22]:
list(range(10))[:2]

[0, 1]