In [3]:
from typing import Union
from pathlib import Path

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

from sklearn import metrics

from src.utils import paths


DATASETS = ["robustness", "sensitivity", "monotonicity",]
DATASETS_CAT = CategoricalDtype([d.capitalize() for d in DATASETS], ordered=True)

EXPERIMENTS = ["0-baseline", "1-connectivity", "2-source-dest", "3-node-type", "4-edge-aware"]
EXPERIMENTS_HUMANIZE = ["Baseline", "+Connectivity", "+SourceDest", "+NodeType", "+EdgeAware"]
EXPERIMENTS2HUMAN = dict(zip(EXPERIMENTS, EXPERIMENTS_HUMANIZE))
EXPERIMENTS_CAT = CategoricalDtype(EXPERIMENTS_HUMANIZE, ordered=True)

FOLDS = [f"fold_{i}" for i in range(5)]


def fetch_scores(path: Union[str, Path]) -> dict:
    path = Path(path) / "logs" / "test_predictions.csv"
    df = pd.read_csv(path)

    probas = np.array(df.Prediction.values)
    preds = np.array(df.Prediction.round().values)
    targets = np.array(df.Target.values)

    return {
        "AUROC": metrics.roc_auc_score(targets, probas),
        "ACC": metrics.accuracy_score(targets, preds),
        "WF1": metrics.f1_score(targets, preds),
        "MCC": metrics.matthews_corrcoef(targets, preds),
        "SENS": metrics.recall_score(targets, preds, pos_label=1),
        "SPEC": metrics.recall_score(targets, preds, pos_label=0)
    }


def display_table_3():
    rows = []
    for dataset in DATASETS:
        for exp in EXPERIMENTS:
            for fold in FOLDS:
                path = paths.EXPS_DIR / dataset / exp / "eval" / fold
                rows.append({
                    "Dataset": dataset.capitalize(),
                    "Experiment": EXPERIMENTS2HUMAN[exp],
                    "Fold": int(fold.split("_")[-1]),
                } | fetch_scores(path))
    data = pd.DataFrame(rows)
    data.Dataset = data.Dataset.astype(DATASETS_CAT)
    data.Experiment = data.Experiment.astype(EXPERIMENTS_CAT)
    return data.groupby(["Dataset", "Experiment"]).agg(["mean", "std"]).drop("Fold", axis=1).round(3)


In [4]:
display_table_3()

Unnamed: 0_level_0,Unnamed: 1_level_0,AUROC,AUROC,ACC,ACC,WF1,WF1,MCC,MCC,SENS,SENS,SPEC,SPEC
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
Dataset,Experiment,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Robustness,Baseline,0.618,0.008,0.547,0.182,0.639,0.2,0.113,0.032,0.536,0.259,0.607,0.273
Robustness,+Connectivity,0.862,0.005,0.753,0.016,0.837,0.013,0.423,0.01,0.74,0.022,0.829,0.022
Robustness,+SourceDest,0.941,0.002,0.903,0.009,0.941,0.006,0.663,0.021,0.916,0.012,0.824,0.025
Robustness,+NodeType,0.952,0.002,0.917,0.007,0.95,0.005,0.705,0.008,0.93,0.013,0.841,0.032
Robustness,+EdgeAware,0.953,0.001,0.919,0.006,0.951,0.004,0.71,0.01,0.93,0.009,0.848,0.014
Sensitivity,Baseline,0.635,0.006,0.424,0.133,0.41,0.014,0.132,0.059,0.838,0.155,0.295,0.221
Sensitivity,+Connectivity,0.819,0.003,0.748,0.014,0.576,0.008,0.425,0.011,0.723,0.029,0.755,0.026
Sensitivity,+SourceDest,0.936,0.004,0.878,0.007,0.764,0.01,0.687,0.014,0.831,0.014,0.893,0.012
Sensitivity,+NodeType,0.947,0.003,0.893,0.006,0.791,0.009,0.723,0.012,0.849,0.017,0.907,0.01
Sensitivity,+EdgeAware,0.949,0.002,0.896,0.007,0.796,0.008,0.73,0.011,0.853,0.018,0.91,0.014
