In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import getpass
import os
import sys
import time

sys.path.append("../src")
import matplotlib.pyplot as plt
import pandas as pd
import pykeen
import torch
from pykeen.pipeline import pipeline
from pykeen.datasets import get_dataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# ds = get_dataset(dataset="FB15k237")
ds = get_dataset(dataset="Nations")
ds.num_entities, ds.num_relations
training, testing, validation = ds.training, ds.testing, ds.validation
len(training.mapped_triples)
# all_triples = torch.cat((training.mapped_triples, testing.mapped_triples, validation.mapped_triples))

1592

In [9]:
import pandas as pd
def test_model(dataset="Nations", model="TransE"):
    # with perturbation
    result_perturb = pipeline(
        dataset=dataset,
        model=model,
        random_seed=1235,
        device='cpu',
        training_kwargs=dict(num_epochs=100),  # Shouldn't take more than a minute or two on a nice computer
        negative_sampler="basic",
        use_perturbation=True, ###
        perturbator = "random",
    )

    #without perturbation
    result_normal = pipeline(
        dataset=dataset,
        model=model,
        random_seed=1235,
        device='cpu',
        training_kwargs=dict(num_epochs=100),  # Shouldn't take more than a minute or two on a nice computer
        negative_sampler="basic",
        use_perturbation=False, ###
        perturbator = "random",
    )
    
    all_df = result_normal.metric_results.to_df().drop(columns=["Value"])
    all_df["Normal"] = result_normal.metric_results.to_df()["Value"].values
    all_df["Perturbed"] = result_perturb.metric_results.to_df()["Value"].values
    return all_df

In [11]:
%%capture tqdmm
results_transe = test_model("Nations", "TransE")
results_rotate = test_model("Nations", "RotatE")
results_transd = test_model("Nations", "TransD")
results_conve = test_model("Nations", "ConvE")
results_complex = test_model("Nations", "ComplEx")
print("Done!")

In [12]:
# results_complex

Unnamed: 0,Side,Type,Metric,Normal,Perturbed
0,head,avg,adjusted_mean_rank,0.987182,1.000077
1,tail,avg,adjusted_mean_rank,0.857059,0.821176
2,both,avg,adjusted_mean_rank,0.922121,0.910626
3,head,avg,mean_rank,4.557214,4.562189
4,head,avg,mean_reciprocal_rank,0.301029,0.294364
5,head,avg,hits_at_1,0.004975,0.0
6,head,avg,hits_at_3,0.497512,0.477612
7,head,avg,hits_at_5,0.706468,0.696517
8,head,avg,hits_at_10,0.955224,0.950249
9,head,best,mean_rank,4.557214,4.562189


In [9]:
def display_all(results: list, models: list):
    metrics = results[0].drop(columns=["Normal", "Perturbed"])
    for r, name in zip(results, models):
        metrics[f"{name}_normal"] = r["Normal"].values
        metrics[f"{name}_perturbed"] = r["Perturbed"].values
        metrics[f"{name}_diff"] = r["Normal"].values - r["Perturbed"].values
    return metrics
results = [
    results_transe,
    results_rotate,
    results_transd,
    results_conve,
    results_complex,
]
names = [
    "TransE",
    "RotatE",
    "TransD",
    "ConvE",
    "ComplEx",
]

all_results = display_all(results, names)        

In [11]:
all_results

Unnamed: 0,Side,Type,Metric,TransE_normal,TransE_perturbed,TransE_diff,RotatE_normal,RotatE_perturbed,RotatE_diff,TransD_normal,TransD_perturbed,TransD_diff,ConvE_normal,ConvE_perturbed,ConvE_diff,ComplEx_normal,ComplEx_perturbed,ComplEx_diff
0,tail,avg,adjusted_mean_rank,0.857059,0.821176,0.035884,0.744914,0.718223,0.026691,0.81862,0.805658,0.012963,0.493267,0.531049,-0.037781,0.840149,0.924635,-0.084486
1,head,avg,adjusted_mean_rank,0.987182,1.000077,-0.012895,0.961463,0.958712,0.002751,0.964128,0.954529,0.009599,0.753187,0.80712,-0.053933,0.870323,0.912059,-0.041735
2,both,avg,adjusted_mean_rank,0.922121,0.910626,0.011495,0.853189,0.838468,0.014721,0.891374,0.880093,0.011281,0.623227,0.669084,-0.045857,0.855236,0.918347,-0.06311
3,tail,worst,mean_rank,3.945274,3.766169,0.179104,3.537313,3.40796,0.129353,3.791045,3.726368,0.064677,2.263682,2.492537,-0.228856,4.084577,4.557214,-0.472637
4,tail,worst,mean_reciprocal_rank,0.327227,0.337933,-0.010707,0.49711,0.51676,-0.01965,0.344585,0.359672,-0.015086,0.695319,0.647886,0.047433,0.441699,0.346454,0.095245
5,tail,worst,hits_at_1,0.004975,0.0,0.004975,0.288557,0.318408,-0.029851,0.029851,0.029851,0.0,0.542289,0.482587,0.059701,0.243781,0.124378,0.119403
6,tail,worst,hits_at_3,0.557214,0.577114,-0.0199,0.636816,0.616915,0.0199,0.547264,0.59204,-0.044776,0.79602,0.771144,0.024876,0.487562,0.427861,0.059701
7,tail,worst,hits_at_5,0.79602,0.800995,-0.004975,0.791045,0.80597,-0.014925,0.81592,0.781095,0.034826,0.905473,0.910448,-0.004975,0.731343,0.666667,0.064677
8,tail,worst,hits_at_10,0.9801,0.99005,-0.00995,0.975124,0.9801,-0.004975,0.975124,0.985075,-0.00995,1.0,0.99005,0.00995,0.955224,0.970149,-0.014925
9,tail,best,mean_rank,3.945274,3.766169,0.179104,3.537313,3.40796,0.129353,3.791045,3.726368,0.064677,2.263682,2.492537,-0.228856,4.084577,4.557214,-0.472637


In [17]:
# Hetio: degree preserving random permutations - such that each entity still has the same degree
# how much noise before we see degradation
# preserve node types