In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.spatial.distance import cdist, pdist, squareform
from pykeen.triples import TriplesFactory

In [2]:
output_dir = Path("3_outputs")

# graph_names = ['Hetionet',]
# triples_files = [Path('./hetionet/data/Hetionet_training_large.csv'), ]
# ir_labels = ['Disease::DOID:99999', ]

graph_names = ["OBL", "enrichedOBL"]
ds_names_list = [["small_DS"], ["large"]]
ts_names_list = [["SmallTS"], ["LargeTS"]]
triples_files = [
    Path("./data/OpenBioLink_training_large.csv"),
    Path("./data/OpenBioLink_training_large.csv"),
]
ir_labels = ["HP:0000855", "HP:0000855"]

model_names = ["TransE", "CompGCN", "RotatE"]

subset_names = ["WholeSet", "TestSet"]
subset_queries = [None, "test"]

In [3]:
def get_ir_id(triples_file, ir_label):
    ds = TriplesFactory.from_labeled_triples(
        np.loadtxt(triples_file, dtype=str, delimiter=","),
    )
    ir_id = ds.entity_to_id[ir_label]

    return ir_id, ds


def calc_ir_distances(
    model_name, out_dir, ir_id, metric="euclidean", subset_query=None
):
    out_dir = Path(out_dir)

    # distances = {}
    distances_norm = {}
    for i in range(10):
        name = f"{model_name}_X_{i}"

        ref_df = pd.read_csv(out_dir / "ref_df.csv", index_col=0)
        y_all = np.load(out_dir / "y.npy")

        X_all = np.load(out_dir / f"{name}.npy")

        if subset_query is None:
            ref_df = ref_df
        else:
            ref_df = ref_df.query(subset_query)

        X = X_all[ref_df.index]
        y = y_all[ref_df.index]

        x_irr = X_all[ir_id,]

        dists_all = cdist(X_all, [x_irr]).ravel()

        dists = cdist(X, [x_irr], metric=metric).ravel()
        distances_norm[name] = (dists - dists_all.mean()) / dists_all.std()

    dist_df = pd.DataFrame(dict(irr=y, **distances_norm), index=ref_df.index)

    return dist_df


def plot_distance_boxplot(dist_df, ax=None):
    plot = sns.boxplot(
        dist_df.groupby("irr").mean().T.melt(value_name="distance"),
        x="irr",
        y="distance",
        ax=ax,
    )
    plot.set_ylabel("Normalized Euclidean Distance")
    plot.set_xlabel("")
    plot.set_xticklabels(["Unknown", "IRR"])

    return plot


def aggregate_distances(dist_df, add_mean_row=True):
    dist_df_agg = pd.concat(
        [dist_df.groupby("irr").mean().T, dist_df.groupby("irr").std().T], axis=1
    ).reset_index()

    dist_df_agg.columns = [
        "model",
        "distance_mean_unknown",
        "distance_mean_irr",
        "distance_std_unknown",
        "distance_std_irr",
    ]

    prefix = "_".join(dist_df_agg["model"][0].split("_")[:-1])
    if add_mean_row:
        dist_df_agg = pd.concat(
            [
                dist_df_agg,
                pd.DataFrame(
                    {
                        c: [v]
                        for c, v in zip(
                            dist_df_agg.columns,
                            [
                                f"{prefix}_all",
                                *dist_df_agg.mean(axis=0, numeric_only=True),
                            ],
                        )
                    }
                ),
            ]
        )

    return dist_df_agg


def get_gene_df(gene_ids, node_ids, email):
    from Bio import Entrez
    from tqdm import tqdm

    Entrez.email = email

    def get_gene_info(gene_ids):
        handle = Entrez.esummary(db="gene", id=",".join(gene_ids))
        result = Entrez.read(handle)
        handle.close()

        return result["DocumentSummarySet"]["DocumentSummary"]

    batch_size = 1000
    gene_infos = []
    for i in tqdm(range(0, len(gene_ids), batch_size)):
        gene_infos.extend(get_gene_info(gene_ids[i : i + batch_size]))

    gene_names = [x["Name"] for x in gene_infos]
    gene_df = pd.DataFrame(
        {
            "gene_name": gene_names,
            "gene_label": gene_ids,
            "gene_id": node_ids,
        }
    )
    return gene_df

In [12]:
pd.read_csv("3_outputs/enrichedOBL/LargeTS/RotatE/TestSet/distances.csv")

Unnamed: 0.1,Unnamed: 0,gene_name,irr,RotatE_X_0,RotatE_X_1,RotatE_X_2,RotatE_X_3,RotatE_X_4,RotatE_X_5,RotatE_X_6,RotatE_X_7,RotatE_X_8,RotatE_X_9
0,69993,A1BG,False,0.882485,1.029728,-0.640524,0.803798,0.873876,1.727445,0.510214,1.804071,1.263317,1.129373
1,69994,NAT2,False,0.786405,1.375793,-0.119487,1.177939,0.824003,0.397911,0.512779,1.426479,0.404191,0.292733
2,69995,ADA,False,-0.222524,0.720210,-0.719634,0.680083,-0.320620,0.632203,0.091515,0.946580,0.970093,0.677298
3,69996,CDH2,False,0.667808,-0.338023,-0.545086,0.054583,0.864080,1.006927,0.225499,1.181191,0.495307,0.307320
4,69997,AKT3,True,1.195810,1.010652,-0.537919,1.595100,1.026286,0.736791,0.438266,0.753597,0.242062,0.741571
...,...,...,...,...,...,...,...,...,...,...,...,...,...
19570,89563,PTBP3,False,0.190042,1.336074,-0.046641,2.597270,2.019986,1.265438,0.004463,-0.078691,0.493429,0.294553
19571,89564,KCNE2,False,2.794968,1.120276,0.144781,0.266173,2.664760,0.883789,0.683059,0.582614,-0.204707,0.713300
19572,89565,DGCR2,False,1.268925,0.899767,0.286595,0.803269,1.611377,0.421352,1.853347,0.472057,1.265944,0.864980
19573,89566,CASP8AP2,False,0.882334,0.496093,-0.366668,1.226200,1.988349,0.815381,0.880678,0.921882,1.025551,0.055802


In [None]:
for graph_name, triples_file, ir_label, ds_names, ts_names in zip(
    graph_names, triples_files, ir_labels, ds_names_list, ts_names_list
):
    ir_id, ds = get_ir_id(triples_file=triples_file, ir_label=ir_label)
    for ds_name, ts_name in zip(ds_names, ts_names):
        for model_name in model_names:
            for subset_query, subset_name in zip(subset_queries, subset_names):
                dist_df = calc_ir_distances(
                    model_name=model_name,
                    out_dir=Path(f"./1_outputs/{ds_name}/"),
                    ir_id=ir_id,
                    metric="euclidean",
                )
                agg_dist_df = aggregate_distances(dist_df)

                result_dir = (
                    output_dir / graph_name / ts_name / model_name / subset_name
                )
                result_dir.mkdir(exist_ok=True, parents=True)

                distances_file = result_dir / "distances.csv"
                distances_aggregated_file = result_dir / "distances_aggregated.csv"

                gene_df = get_gene_df(
                    [ds.entity_id_to_label[_id].split(":")[1] for _id in dist_df.index],
                    dist_df.index,
                    email="thzo@novonordisk.com",
                )
                gene_df = gene_df.set_index("gene_id")
                dist_df = dist_df.join(gene_df[["gene_name"]])
                dist_df = dist_df.iloc[:, [-1, *np.arange(0, dist_df.shape[1] - 1)]]
                dist_df.to_csv(distances_file, index=True)
                agg_dist_df.to_csv(distances_aggregated_file, index=False)