In [107]:
import sys

sys.path.append("../")

import numpy as np
import pandas as pd
from dynalign.experiments.paths import LP_EVALUATION_RESULTS, PREV_EXPERIMENTS_PATH, DATA_PATH
from pathlib import Path
from typing import List, Dict, Any, Union, Tuple
from collections import defaultdict


def get_dirs_from_path(path: Path, only_files_with_extension: str = "") -> List[Path]:
    if only_files_with_extension:
        return list(path.glob(f"*{only_files_with_extension}"))
    else:
        return [it for it in path.iterdir() if ".gitignore" not in str(it)]


DF_COLUMNS_TO_AGGREGATION = ["run", "embeddings_aggregation"]


def aggregate_aligner_results_last_snapshot(
    df: pd.DataFrame, metric_name: str, precision: float = 3
) -> pd.DataFrame:
    df_columns_to_aggregation = [*DF_COLUMNS_TO_AGGREGATION, metric_name]
    df_columns_to_remove = set(df.columns).difference(set(df_columns_to_aggregation))
    df = df[df.prediction_snapshot == df.prediction_snapshot.max()].copy()
    df = df.drop(df_columns_to_remove, axis=1)
    df = (
        df.groupby(by=["embeddings_aggregation"])
        .agg(("mean", "std"))
        .drop("run", axis=1)
    )

    df = df.apply(
        lambda x: (
            round(x[metric_name]["mean"], precision),
            round(x[metric_name]["std"], precision),
        ),
        axis=1,
    )
    return df.to_dict()


def aggreagte_all_results_last_snapshot(
    paths: str, metric_name: str, precision: float = 3
):
    results = defaultdict(dict)
    for method_results_path in paths:
        method_name = method_results_path.name
        method_ds_results_paths = get_dirs_from_path(
            method_results_path, only_files_with_extension=".pkl"
        )

        for method_ds_results_path in method_ds_results_paths:
            ds_name = method_ds_results_path.name.replace(".pkl", "")
            results[ds_name][method_name] = aggregate_aligner_results_last_snapshot(
                df=pd.read_pickle(method_ds_results_path),
                metric_name=metric_name,
                precision=3,
            )

    return results


def merge_results_with_prev_results(
    results: Dict[str, Any], prev_results: Dict[str, Any]
) -> Dict[str, Any]:
    agg_results = defaultdict(dict)
    for ds in prev_results.keys():
        methods = set(prev_results[ds].keys()).union(set(prev_results[ds]))
        for method in methods:
            result = None
            if ds in results.keys():
                if method in results[ds].keys():
                    result = results[ds][method]
            prev_result = None
            if method in prev_results[ds].keys():
                prev_result = prev_results[ds][method]
            elif f"{method}_prev" in prev_results[ds].keys():
                prev_result = prev_results[ds][f"{method}_prev"]

            agg_results[ds].update(
                {(method, "zero"): result, (method, "prev"): prev_result}
            )

    return agg_results


def convert_float_to_str(x: float) -> str:
    return f"{x:.2f}"


def percentage_style(
    x: Union[float, Tuple[float, float]]
) -> Union[float, Tuple[float, float]]:
    """Percantage style fn."""
    if isinstance(x, float):
        return round(x * 100, 2)
    elif isinstance(x, tuple):
        return round(x[0] * 100, 2), round(x[1] * 100, 2)
    raise ValueError("X parsing error")


def highlight_max(x: pd.Series) -> List[str]:
    values = [it[0] if it else 0 for it in x.values]
    max_id = np.argmax(values)

    output = []
    for i in range(len(x)):
        if i == max_id:
            output.append("color:red")
        else:
            output.append("")
    return output

In [108]:
paths

[PosixPath('/Users/kamiltagowski/PycharmProjects/reg-alignment/notebooks/../data/evaluation/lp/Node2Vec'),
 PosixPath('/Users/kamiltagowski/PycharmProjects/reg-alignment/notebooks/../data/evaluation/lp/GNN_AE')]

In [109]:
paths

[PosixPath('/Users/kamiltagowski/PycharmProjects/reg-alignment/notebooks/../data/evaluation/lp/Node2Vec'),
 PosixPath('/Users/kamiltagowski/PycharmProjects/reg-alignment/notebooks/../data/evaluation/lp/GNN_AE')]

In [110]:
paths = get_dirs_from_path(LP_EVALUATION_RESULTS)
full_paths = get_dirs_from_path(DATA_PATH / "evaluation_full" / "lp")

results = aggreagte_all_results_last_snapshot(paths, metric_name="auc", precision=3)
full_results = aggreagte_all_results_last_snapshot(full_paths, metric_name="auc", precision=3)
baselines = merge_results_with_prev_results(results=full_results, prev_results=results)
prev_results = aggreagte_all_results_last_snapshot(
    prev_paths, metric_name="auc", precision=3
)
agg_results = merge_results_with_prev_results(results={}, prev_results=prev_results)

In [111]:
prev_results = aggreagte_all_results_last_snapshot(
    prev_paths, metric_name="auc", precision=3
)
agg_results = merge_results_with_prev_results(results={}, prev_results=prev_results)

In [112]:
order = [
    "Node2Vec",
    "PosthocALL",
    "PosthocEJ",
    "PosthocTB",
    "Node2VecAligned_L2_ALL",
    "Node2VecAligned_L2_EJ",
    "Node2VecAligned_L2_EJ_Weighted",
    "Node2VecAligned_L2_TB",
    "Node2VecAligned_L2_TB_Weighted",
    "GNN_AE"
]
order_prev = [
    "PosthocALL",
    "PosthocEJ",
    "PosthocTB",
    "Node2VecAligned_L2_ALL",
    "Node2VecAligned_L2_EJ",
    "Node2VecAligned_L2_TB",
    "Node2VecAligned_L2_TB_Weighted",
        "GNN_AE"
]

In [113]:
for ds in agg_results.keys():
    if ds != "fb-messages":
        continue
    print(ds)
    # ds_df = pd.DataFrame.from_dict(agg_results[ds])
    ds_df = pd.concat(
        [
            pd.DataFrame.from_dict(agg_results[ds]).T.dropna(),
            pd.DataFrame.from_dict(baselines[ds]).T.applymap(
                lambda x: [0.0, 0.0] if not isinstance(x, tuple) else x
            ),
        ],
        axis=0,
    )

    display(ds_df.loc[order].style.apply(highlight_max))

fb-messages


Unnamed: 0,Unnamed: 1,FILDNE,k-FILDNE,last,mean
Node2Vec,zero,"[0.0, 0.0]","[0.0, 0.0]","(0.656, 0.096)","[0.0, 0.0]"
Node2Vec,prev,"(0.742, 0.041)","(0.65, 0.062)","(0.757, 0.034)","(0.634, 0.063)"
PosthocALL,prev,"(0.729, 0.047)","(0.655, 0.077)","(0.752, 0.051)","(0.653, 0.068)"
PosthocEJ,prev,"(0.73, 0.049)","(0.652, 0.086)","(0.754, 0.057)","(0.647, 0.071)"
PosthocTB,prev,"(0.735, 0.04)","(0.649, 0.077)","(0.761, 0.03)","(0.637, 0.079)"
Node2VecAligned_L2_ALL,prev,"(0.774, 0.031)","(0.665, 0.08)","(0.812, 0.028)","(0.637, 0.081)"
Node2VecAligned_L2_EJ,prev,"(0.78, 0.03)","(0.685, 0.084)","(0.809, 0.029)","(0.656, 0.073)"
Node2VecAligned_L2_EJ_Weighted,prev,"(0.771, 0.041)","(0.709, 0.052)","(0.791, 0.041)","(0.683, 0.061)"
Node2VecAligned_L2_TB,prev,"(0.741, 0.052)","(0.647, 0.095)","(0.778, 0.031)","(0.618, 0.099)"
Node2VecAligned_L2_TB_Weighted,prev,"(0.769, 0.036)","(0.685, 0.086)","(0.806, 0.03)","(0.655, 0.08)"
