In [35]:
from pykeen.models import ERModel, TransE, DistMult, RotatE
from pykeen.nn import Embedding
from pykeen.nn.modules import Interaction, NormBasedInteraction
from torch import FloatTensor
from pykeen.pipeline import pipeline
from class_resolver import Hint, HintOrType, OptionalKwargs
from torch.nn import functional
from pykeen.nn.init import xavier_uniform_, xavier_uniform_norm_, xavier_normal_norm_
from pykeen.typing import Constrainer, Initializer
from pykeen.regularizers import Regularizer, LpRegularizer
from typing import Union, Any, ClassVar, Mapping
from pykeen.utils import negative_norm_of_sum, tensor_product
from pykeen.constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE

In [18]:
def kgcmodel_interaction(
    h: FloatTensor,
    r: FloatTensor,
    t: FloatTensor,
    p: Union[int, str] = 2,
    power_norm: bool = False,
) -> FloatTensor:
    return (tensor_product(h, r, t).sum(dim=-1) * negative_norm_of_sum(h, r, -t, p=p, power_norm=power_norm))

In [19]:
class KGCModelInteraction(NormBasedInteraction[FloatTensor, FloatTensor, FloatTensor]):

    func = kgcmodel_interaction

In [28]:
class KGCModel(ERModel):

    hpo_default: ClassVar[Mapping[str, Any]] = dict(
        embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE,
        scoring_fct_norm=dict(type=int, low=1, high=2),
    )

    def __init__(
        self,
        *,
        embedding_dim: int = 50,
        scoring_fct_norm: int = 1,
        entity_initializer: Hint[Initializer] = xavier_uniform_,
        entity_constrainer: Hint[Constrainer] = functional.normalize,
        relation_initializer: Hint[Initializer] = xavier_uniform_norm_,
        relation_constrainer: Hint[Constrainer] = None,
        regularizer: HintOrType[Regularizer] = LpRegularizer,
        regularizer_kwargs: OptionalKwargs = None,
        **kwargs,
    ) -> None:

        if regularizer is LpRegularizer and regularizer_kwargs is None:
            regularizer_kwargs = DistMult.regularizer_default_kwargs

        super().__init__(
            interaction=KGCModelInteraction,
            interaction_kwargs=dict(p=scoring_fct_norm),
            entity_representations=Embedding,
            entity_representations_kwargs=dict(
                embedding_dim=embedding_dim,
                initializer=entity_initializer,
                constrainer=entity_constrainer,
            ),
            relation_representations=Embedding,
            relation_representations_kwargs=dict(
                embedding_dim=embedding_dim,
                initializer=relation_initializer,
                constrainer=relation_constrainer,
                regularizer=regularizer,
                regularizer_kwargs=regularizer_kwargs,
            ),
            **kwargs,
        )

## Сравнение значений метрик моделей TransE, DistMult, KGCModel на датасете Nations.

In [None]:
result_KGCModel = pipeline(
    model=KGCModel,
    dataset='nations',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [30]:
print(f"MRR: {result_KGCModel.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_KGCModel.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.6104003190994263
Hits@1 : 0.43781094527363185
Hits@3 : 0.7189054726368159
Hits@5 : 0.8159203980099502
Hits@10 : 0.9701492537313433


In [None]:
result_DistMult = pipeline(
    model=DistMult,
    dataset='nations',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [24]:
print(f"MRR: {result_DistMult.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_DistMult.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.59869384765625
Hits@1 : 0.43034825870646765
Hits@3 : 0.6691542288557214
Hits@5 : 0.8034825870646766
Hits@10 : 0.9776119402985075


In [None]:
result_TransE = pipeline(
    model=TransE,
    dataset='nations',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [34]:
print(f"MRR: {result_TransE.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_TransE.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.33192434906959534
Hits@1 : 0.0
Hits@3 : 0.5621890547263682
Hits@5 : 0.7985074626865671
Hits@10 : 0.9751243781094527


## Сравнение значений метрик моделей TransE, DistMult, KGCModel на датасете FB15k-237.

In [None]:
result_DistMultModel_FB15k_237 = pipeline(
    model=DistMult,
    dataset='FB15k-237',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [39]:
print(f"MRR: {result_DistMultModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_DistMultModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.17628440260887146
Hits@1 : 0.1092572658772874
Hits@3 : 0.1839465701144926
Hits@5 : 0.23326646442900478
Hits@10 : 0.3159311087190527


In [None]:
result_TransEModel_FB15k_237 = pipeline(
    model=TransE,
    dataset='FB15k-237',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [41]:
print(f"MRR: {result_TransEModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_TransEModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.18099358677864075
Hits@1 : 0.11050494177512477
Hits@3 : 0.19214208826695373
Hits@5 : 0.24466679714257755
Hits@10 : 0.32495841080340543


In [None]:
result_KGCModel_FB15k_237 = pipeline(
    model=KGCModel,
    dataset='FB15k-237',
    training_kwargs={'num_epochs':100},
    random_seed=1603073093
)

In [38]:
print(f"MRR: {result_KGCModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.inverse_harmonic_mean_rank']}")
for k in [1,3,5,10]:
  print(f"Hits@{k} : {result_KGCModel_FB15k_237.metric_results.to_flat_dict()['both.realistic.hits_at_'+str(k)]}")

MRR: 0.19223463535308838
Hits@1 : 0.11830903219493101
Hits@3 : 0.20669830707505626
Hits@5 : 0.26071533418142673
Hits@10 : 0.34482336823563947
