In [6]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns

from scipy.spatial.distance import cdist
from pykeen.triples import TriplesFactory

In [7]:
output_dir = Path("3_outputs")

graph_names = ["Hetionet", "enrichedHetionet"]
ds_names_list = [["small_DS"], ["large"]]
ts_names_list = [["SmallTS"], ["LargeTS"]]
triples_files = [
    Path("../../data/Hetionet_training_large.csv"),
    Path("../../data/Hetionet_training_large.csv"),
]
ir_labels = ["Disease::DOID:99999", "Disease::DOID:99999"]

model_names = ["TransE", "CompGCN", "RotatE"]

subset_names = ["WholeSet", "TestSet"]
subset_queries = [None, "test"]

In [8]:
def get_ir_id(triples_file, ir_label):
    ds = TriplesFactory.from_labeled_triples(
        np.loadtxt(triples_file, dtype=str, delimiter=","),
    )
    ir_id = ds.entity_to_id[ir_label]

    return ir_id


def calc_ir_distances(
    model_name, out_dir, ir_id, metric="euclidean", subset_query=None
):
    out_dir = Path(out_dir)

    # distances = {}
    distances_norm = {}
    for i in range(10):
        name = f"{model_name}_X_{i}"

        ref_df = pd.read_csv(out_dir / "ref_df.csv", index_col=0)
        y_all = np.load(out_dir / "y.npy")

        X_all = np.load(out_dir / f"{name}.npy")

        if subset_query is None:
            ref_df = ref_df
        else:
            ref_df = ref_df.query(subset_query)

        X = X_all[ref_df.index]
        y = y_all[ref_df.index]

        x_irr = X_all[ir_id,]

        dists_all = cdist(X_all, [x_irr]).ravel()

        dists = cdist(X, [x_irr], metric=metric).ravel()
        distances_norm[name] = (dists - dists_all.mean()) / dists_all.std()

    dist_df = pd.DataFrame(dict(irr=y, **distances_norm), index=ref_df.index)

    return dist_df


def plot_distance_boxplot(dist_df, ax=None):
    plot = sns.boxplot(
        dist_df.groupby("irr").mean().T.melt(value_name="distance"),
        x="irr",
        y="distance",
        ax=ax,
    )
    plot.set_ylabel("Normalized Euclidean Distance")
    plot.set_xlabel("")
    plot.set_xticklabels(["Unknown", "IRR"])

    return plot


def aggregate_distances(dist_df, add_mean_row=True):
    dist_df_agg = pd.concat(
        [dist_df.groupby("irr").mean().T, dist_df.groupby("irr").std().T], axis=1
    ).reset_index()

    dist_df_agg.columns = [
        "model",
        "distance_mean_unknown",
        "distance_mean_irr",
        "distance_std_unknown",
        "distance_std_irr",
    ]

    prefix = "_".join(dist_df_agg["model"][0].split("_")[:-1])
    if add_mean_row:
        dist_df_agg = pd.concat(
            [
                dist_df_agg,
                pd.DataFrame(
                    {
                        c: [v]
                        for c, v in zip(
                            dist_df_agg.columns,
                            [
                                f"{prefix}_all",
                                *dist_df_agg.mean(axis=0, numeric_only=True),
                            ],
                        )
                    }
                ),
            ]
        )

    return dist_df_agg

In [None]:
for graph_name, triples_file, ir_label, ds_names, ts_names in zip(
    graph_names, triples_files, ir_labels, ds_names_list, ts_names_list
):
    ir_id = get_ir_id(triples_file=triples_file, ir_label=ir_label)
    for ds_name, ts_name in zip(ds_names, ts_names):
        for model_name in model_names:
            for subset_query, subset_name in zip(subset_queries, subset_names):
                dist_df = calc_ir_distances(
                    model_name=model_name,
                    out_dir=Path(f"./1_outputs/{ds_name}/"),
                    ir_id=ir_id,
                    metric="euclidean",
                )
                agg_dist_df = aggregate_distances(dist_df)

                result_dir = (
                    output_dir / graph_name / ts_name / model_name / subset_name
                )
                result_dir.mkdir(exist_ok=True, parents=True)

                distances_file = result_dir / "distances.csv"
                distances_aggregated_file = result_dir / "distances_aggregated.csv"

                dist_df.to_csv(distances_file, index=True)
                agg_dist_df.to_csv(distances_aggregated_file, index=False)