In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

sys.path.append("../../")
from dataclasses import asdict
import warnings
import logging

import srsly
from rich import print
import pandas as pd
import seaborn as sns
import torch
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from pykeen.datasets import Dataset
from pykeen.predict import predict_triples
from IPython.display import display
from mgi.data.sampled_datasets import load_sampled_datasets_metadata

from mgi.defaults import ROOT_PATH

os.chdir(ROOT_PATH)

from mgi.utils.config import load_training_config
from mgi.defaults import ROOT_PATH, PLOTS_PATH

In [None]:
warnings.filterwarnings("ignore")
logging.disable(logging.WARNING)

In [None]:
plots_path = PLOTS_PATH / "loss_analysis"
plots_path.mkdir(exist_ok=True, parents=True)

In [None]:
sns.set_theme(context="paper", style="whitegrid", font_scale=1.1)

In [None]:
configs = []
for x in os.scandir(ROOT_PATH / "experiments/configs/training/training_items"):
    if "debug" in x.name:
        continue
    [config_training_items] = srsly.read_yaml(x).values()
    configs += [load_training_config(**item) for item in config_training_items]

In [None]:
configs = [OmegaConf.to_container(c, resolve=True) for c in configs]
configs_df = pd.DataFrame(configs)

In [None]:
sampling = pd.json_normalize(
    configs_df["ds_dataset"]
    .map(load_sampled_datasets_metadata())
    .apply(lambda x: asdict(x) if pd.notna(x) else x)
)
sampling = sampling.drop(columns="name")
configs_df = pd.concat([configs_df, sampling], axis="columns")

In [None]:
configs_df = configs_df[configs_df["sampling_config.sampling"] == "triple"]

In [None]:
preds_all = []

for gk_dataset in tqdm(["WN18RR", "FB15K237", "WD50K"]):
    for entry in tqdm(
        list(
            configs_df[
                (configs_df.ds_dataset.str.startswith(gk_dataset))
                & (configs_df.gk_dataset == gk_dataset)
            ].itertuples()
        )
    ):
        if (entry.experiment_dir / "runs/0/dataset").exists():
            dataset = Dataset.from_directory_binary(entry.experiment_dir / "runs/0/dataset")
            model = torch.load(
                entry.experiment_dir / "runs" / "0" / "model.pt", map_location=torch.device("cpu")
            )

            preds = (
                predict_triples(model=model, triples=dataset.training)
                .process(factory=dataset.training)
                .df
            )
            preds = preds[preds.relation_label == "same-as"]
            preds["correct"] = preds.head_label.str.replace(
                "left:", ""
            ) == preds.tail_label.str.replace("right:", "")
            preds["experiment"] = entry.experiment_name
            preds["loss"] = entry.loss
            preds["ds_dataset"] = entry.ds_dataset
            preds["gk_dataset"] = entry.gk_dataset
            preds_all.append(preds)

In [None]:
df = pd.concat(preds_all)

to_plot = df.sort_values(by=["loss", "correct"])
to_plot.loss = to_plot.loss.map({"nssa": "standard", "weighted_nssa": "weighted"})

to_plot["linking triple"] = to_plot.correct.apply(lambda x: "correct" if x else "incorrect")

grouped = to_plot.groupby(["loss", "linking triple"]).agg({"score": ["mean", "median", "std"]})
unstacked = grouped.unstack()
unstacked["score mean diff"] = (
    unstacked[("score", "mean", "correct")] - unstacked[("score", "mean", "incorrect")]
)
unstacked["score median diff"] = (
    unstacked[("score", "median", "correct")] - unstacked[("score", "median", "incorrect")]
)
display(unstacked)
g = sns.FacetGrid(to_plot.reset_index(), col="ds_dataset", col_wrap=3, sharey=False, height=2.4)
g.map(sns.boxplot, "linking triple", "score", "loss", showfliers=False)
g.add_legend()
sns.move_legend(g, "lower center", bbox_to_anchor=(0.5, 1), ncol=3, title=None, frameon=True)
for ax in g.axes:
    ds_dataset = ax.get_title().split(" = ")[1].split("_")
    ds_dataset[2] = f"p={ds_dataset[2][0]}.{ds_dataset[2][1]}"
    ax.set_title(", ".join(ds_dataset))
fig = g.fig
fig.tight_layout()
g.savefig(plots_path / "loss_comparison.png", dpi=600)

In [None]:
grouped = to_plot.groupby(["ds_dataset", "loss", "linking triple"]).agg(
    {"score": ["mean", "median", "std"]}
)
unstacked = grouped.unstack()
unstacked["score mean diff"] = (
    unstacked[("score", "mean", "correct")] - unstacked[("score", "mean", "incorrect")]
)
unstacked["score median diff"] = (
    unstacked[("score", "median", "correct")] - unstacked[("score", "median", "incorrect")]
)

In [None]:
for x in ["correct", "incorrect"]:
    d = unstacked[("score", "mean", x)].reset_index()
    d = d.pivot(
        index=[("ds_dataset", "", "")], columns=[("loss", "", "")], values=[("score", "mean", x)]
    )
    diff = (
        (d[(("score", "mean", x), "standard")] - d[(("score", "mean", x), "weighted")])
        / d[(("score", "mean", x), "standard")]
    ).mean()
    print(f"Diff in mean score for {x} relations {diff:.2%}")