In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

In [None]:
import sys

sys.path.append("../../")

from collections import defaultdict
import warnings

import pandas as pd
from rich import print
import numpy as np
from IPython.display import display
from tqdm.auto import tqdm

from mgi.data.sampled_datasets import load_sampled_datasets_metadata
from mgi.data.datasets.wn18rrdecoded import WN18RRDecoded
from mgi.data.datasets.fb15k237decoded import FB15K237Decoded
from mgi.data.datasets.wd50k import WD50K
from mgi.data.datasets.conceptnet import ConceptNet
from mgi.data.datasets.yago import YAGO310
from mgi.data.datasets.dataset_utils import get_ds_dataset

In [None]:
warnings.filterwarnings("ignore")
logging.disable(logging.WARNING)

In [None]:
ds_dataset_metadatas = load_sampled_datasets_metadata()

In [None]:
def get_num_of_entities(dataset, subset):
    return len(np.unique(getattr(dataset, subset).triples[:, [0, 2]].flatten()))


def get_num_of_relations(dataset, subset):
    return len(np.unique(getattr(dataset, subset).triples[:, 1].flatten()))

# Original

In [None]:
data = []

datasets = {
    "WN18RR": WN18RRDecoded,
    "FB15K237": FB15K237Decoded,
    "WD50K": WD50K,
    "ConceptNet": ConceptNet,
    "YAGO310": YAGO310,
}

for name, dataset_cls in tqdm(list(datasets.items())):
    dataset = dataset_cls.from_path()
    data += [
        {
            "dataset": name,
            "train_triples": dataset.dataset.training.num_triples,
            "val_triples": dataset.dataset.validation.num_triples,
            "test_triples": dataset.dataset.testing.num_triples,
            "train_entities": get_num_of_entities(dataset, "training"),
            "val_entities": get_num_of_entities(dataset, "validation"),
            "test_entities": get_num_of_entities(dataset, "testing"),
            "train_relations": get_num_of_relations(dataset, "training"),
            "val_relations": get_num_of_relations(dataset, "validation"),
            "test_relations": get_num_of_relations(dataset, "testing"),
        }
    ]

df = pd.DataFrame(data)
df.sort_values("dataset")

df["dataset"] = pd.Categorical(
    df["dataset"], categories=["WN18RR", "FB15K237", "WD50K", "ConceptNet", "YAGO310"], ordered=True
)

display(df)

print(
    df.style.format_index(axis=1, formatter="${}$".format)
    .hide(axis=0)
    .to_latex(convert_css=True)
    .replace("%", "\%")
    .replace("±", "\pm")
)

# Sampled

In [None]:
seeds = [121371, 59211, 44185]
datasets = defaultdict(list)

for name in ds_dataset_metadatas:
    for seed in seeds:
        datasets[name].append(get_ds_dataset(name, seed))

In [None]:
data = []

for name, list_datasets in tqdm(datasets.items()):
    for ds_dataset in list_datasets:
        data += [
            {
                **ds_dataset_metadatas[name].sampling_config,
                "train_triples": ds_dataset.dataset.training.num_triples,
                "val_triples": ds_dataset.dataset.validation.num_triples,
                "test_triples": ds_dataset.dataset.testing.num_triples,
                "train_entities": get_num_of_entities(ds_dataset, "training"),
                "val_entities": get_num_of_entities(ds_dataset, "validation"),
                "test_entities": get_num_of_entities(ds_dataset, "testing"),
                "train_relations": get_num_of_relations(ds_dataset, "training"),
                "val_relations": get_num_of_relations(ds_dataset, "validation"),
                "test_relations": get_num_of_relations(ds_dataset, "testing"),
            }
        ]

df = pd.DataFrame(data)

In [None]:
def add_empty_rows_on_dataset_change(df):
    empty_row = pd.DataFrame(columns=df.columns)
    df_list = []
    prev_dataset = None
    for index, row in df.iterrows():
        current_dataset = row["dataset"]
        if current_dataset.item() != prev_dataset:
            df_list.append(pd.Series([current_dataset.item()] * len(df.columns), index=df.columns))
        df_list.append(row)
        prev_dataset = current_dataset.item()

    result_df = pd.DataFrame(df_list)
    return result_df

In [None]:
with pd.option_context("display.float_format", "{:.1f}".format):
    grouped = df.groupby(["dataset", "sampling", "p"]).agg(["mean", "std"])

to_display = grouped.reset_index()
to_display = to_display[["dataset", "sampling", "p"]].copy()

for col in grouped.columns.levels[0]:
    decimal_places = 0
    to_display[col] = (
        (grouped[[(col, "mean"), (col, "std")]])
        .apply(lambda x: f"{x[0]:.{decimal_places}f}({x[1]:.{decimal_places}f})", axis=1)
        .tolist()
    )
to_display["dataset"] = pd.Categorical(
    to_display["dataset"], categories=["WN18RR", "FB15K237", "WD50K"], ordered=True
)
to_display["sampling"] = pd.Categorical(
    to_display["sampling"], categories=["triple", "node", "relation"], ordered=True
)
to_display = to_display.sort_values(["dataset", "sampling"])
to_display = add_empty_rows_on_dataset_change(to_display)
to_display = to_display.drop(columns=["dataset"])
display(to_display)
print(
    to_display.style.format(precision=1)
    .format_index(axis=1, formatter="${}$".format)
    .hide(axis=0)
    .to_latex(convert_css=True)
    .replace("%", "\%")
    .replace("±", "\pm")
)