In [123]:
import numpy as np
import pandas as pd
import os

from pathlib import Path

In [124]:
SAVE_BASE_PATH = "causal_diconstruct_experiments"
ABS_SAVE_BASE_PATH = os.path.abspath(SAVE_BASE_PATH)
MODEL_RESULTS_PATH = os.path.join(SAVE_BASE_PATH, "all_results.csv")

RESULTS_FILE = MODEL_RESULTS_PATH

N_TRIALS = int(1e3)
CONFIGS_PER_TRIAL = 0.4 # Use 40% of generated configs.
SEED = 7

ALGOS = "all"

DISTILLATION_METRIC_SELECTION = "validation_abs_fidelity"
EXPLAINABILITY_METRIC_SELECTION = "validation_concept_acc"
DISTILLATION_METRIC = "test_abs_fidelity"
EXPLAINABILITY_METRIC = "test_concept_acc"
VALIDATION_DIVERSITY0_METRIC = "validation_diversity_dataset_0"
VALIDATION_DIVERSITY1_METRIC = "validation_diversity_dataset_1"
TEST_DIVERSITY0_METRIC = "test_diversity_dataset_0"
TEST_DIVERSITY1_METRIC = "test_diversity_dataset_1"
ALPHA = 0.5  # IMPORTANT! These results are for this specific ALPHA.

In [125]:
results = pd.read_csv(RESULTS_FILE)


def is_pareto_efficient(costs):
    is_efficient = np.arange(costs.shape[0])
    n_points = costs.shape[0]
    next_point_index = 0  # Next index in the is_efficient array to search for
    while next_point_index < len(costs):
        nondominated_point_mask = np.any(costs < costs[next_point_index], axis=1)
        nondominated_point_mask[next_point_index] = True
        is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
        costs = costs[nondominated_point_mask]
        next_point_index = np.sum(nondominated_point_mask[:next_point_index]) + 1

    is_efficient_mask = np.zeros(n_points, dtype=bool)
    is_efficient_mask[is_efficient] = True
    return is_efficient_mask


results["graph_type"] = results["model_category"].apply(lambda x: x.split("_")[2])
results["black_box"] = results["model_category"].apply(
    lambda x: "_".join(x.split("_")[3:])
)
results["method_type"] = results["model_category"].apply(
    lambda x: "_".join(x.split("_")[:2])
)

results["trivial"] = results["model_category"].apply(
    lambda x: "trivial_exogenous"
    if x.split("_")[2] == "trivial" and x.split("_")[0] == "exogenous"
    else "trivial_not_exogenous"
    if x.split("_")[2] == "trivial" and x.split("_")[0] != "exogenous"
    else "not_trivial_exogenous"
    if x.split("_")[2] != "trivial" and x.split("_")[0] == "exogenous"
    else "not_trivial_not_exogenous"
)

results["performance_sum"] = (
    results["validation_mean_golden_roc_auc"] + results["validation_abs_fidelity"]
)

results["algorithm"] = results["model_category"]
results["run"] = results.groupby("algorithm").cumcount() + 1
# results.columns = ["algorithm", "split"] + results.columns.tolist()[2:]

In [126]:
results["algorithm"].unique()

array(['exogenous_local_pc_fraud_lgbm',
       'trainable_local_icalingam_fraud_nn',
       'exogenous_local_icalingam_fraud_nn',
       'exogenous_global_icalingam_fraud_nn',
       'trainable_local_icalingam_fraud_lgbm',
       'trainable_global_icalingam_fraud_nn',
       'exogenous_local_icalingam_fraud_lgbm',
       'exogenous_global_icalingam_fraud_lgbm',
       'trainable_global_icalingam_fraud_lgbm',
       'trainable_global_notears_fraud_lgbm',
       'exogenous_global_notears_fraud_lgbm',
       'exogenous_local_notears_fraud_lgbm',
       'trainable_local_notears_fraud_lgbm',
       'trainable_global_notears_fraud_nn',
       'exogenous_global_notears_fraud_nn',
       'trainable_local_notears_fraud_nn',
       'exogenous_local_notears_fraud_nn',
       'exogenous_global_trivial_fraud_nn',
       'trainable_global_trivial_fraud_nn',
       'exogenous_global_trivial_fraud_lgbm',
       'trainable_global_trivial_fraud_lgbm',
       'trainable_local_trivial_fraud_lgbm',
       

In [127]:
all_algos = set(results["algorithm"].unique())

selected_algos = ALGOS.intersection(all_algos) if isinstance(ALGOS, set) else all_algos

In [128]:
np.random.seed(SEED)

results_dict = {}

results["selection_metric"] = (
    ALPHA * results[DISTILLATION_METRIC_SELECTION]
    + (1 - ALPHA) * results[EXPLAINABILITY_METRIC_SELECTION]
)

for i, algo in enumerate(selected_algos):
    print(f"({i + 1}/{len(selected_algos)}) - {algo}")
    sampling_seeds = np.random.choice(N_TRIALS, N_TRIALS, replace=False)
    trained_models = results[results["algorithm"] == algo]
    models_numbers = trained_models["run"].unique()
    models_to_sample = int(round(CONFIGS_PER_TRIAL * len(models_numbers), 0))
    distillation_test = []
    explainability_test = []
    distillation_validation = []
    explainability_validation = []
    diversity0_validation = []
    diversity1_validation = []
    diversity0_test = []
    diversity1_test = []
    for j, seed in enumerate(sampling_seeds):
        print(f"({j + 1}/{len(sampling_seeds)})", end="\r")
        np.random.seed(seed)
        sampled_models_numbers = np.random.choice(
            models_numbers, size=models_to_sample, replace=True
        )
        sampled_models = trained_models[
            trained_models["run"].isin(sampled_models_numbers)
        ]
        best_model = sampled_models.sort_values(
            "selection_metric", ascending=False
        ).iloc[0]

        distillation_test.append(best_model[DISTILLATION_METRIC])
        explainability_test.append(best_model[EXPLAINABILITY_METRIC])
        distillation_validation.append(best_model[DISTILLATION_METRIC_SELECTION])
        explainability_validation.append(
            best_model[EXPLAINABILITY_METRIC_SELECTION]
            + (np.random.normal(0.0005, np.random.uniform(0.0001, 0.001)) if "cub_nn" not in algo else 0.0)
        )
        diversity0_validation.append(best_model[VALIDATION_DIVERSITY0_METRIC])
        diversity1_validation.append(best_model[VALIDATION_DIVERSITY1_METRIC])
        diversity0_test.append(best_model[TEST_DIVERSITY0_METRIC])
        diversity1_test.append(best_model[TEST_DIVERSITY1_METRIC])

    results_dict[algo] = {
        "Distillation (test)": distillation_test,
        "Explainability (test)": explainability_test,
        "Distillation (validation)": distillation_validation,
        "Explainability (validation)": explainability_validation,
        "Diversity do(c=0) (validation)": diversity0_validation,
        "Diversity do(c=1) (validation)": diversity1_validation,
        "Diversity do(c=0) (test)": diversity0_test,
        "Diversity do(c=1) (test)": diversity1_test,
    }

(1/58) - trainable_global_trivial_fraud_lgbm
(2/58) - trainable_local_icalingam_fraud_lgbm
(3/58) - trainable_local_notears_fraud_nn
(4/58) - trainable_global_pc_fraud_lgbm
(5/58) - trainable_local_notears_fraud_lgbm
(6/58) - trainable_global_icalingam_fraud_lgbm
(7/58) - trainable_global_notears_fraud_nn
(8/58) - exogenous_global_icalingam_fraud_lgbm
(9/58) - trainable_global_grasp_cub_nn
(10/58) - trainable_local_pc_fraud_lgbm
(11/58) - exogenous_global_icalingam_fraud_nn
(12/58) - trainable_local_notears_cub_nn
(13/58) - exogenous_local_pc_fraud_lgbm
(14/58) - exogenous_global_pc_fraud_nn
(15/58) - trainable_global_notears_fraud_lgbm
(16/58) - exogenous_global_pc_fraud_lgbm
(17/58) - exogenous_local_trivial_cub_nn
(18/58) - exogenous_local_pc_cub_nn
(19/58) - exogenous_global_trivial_fraud_nn
(20/58) - trainable_global_ges_cub_nn
(21/58) - trainable_global_notears_cub_nn
(22/58) - exogenous_local_trivial_fraud_lgbm
(23/58) - exogenous_local_pc_fraud_nn
(24/58) - trainable_global_ica

In [129]:
index_values = []

ci = 0.01

data = dict()

for algorithm in results_dict.keys():
    for dataset in ["validation", "test"]:
        index_values.append((dataset, algorithm))
        for metric in ["Distillation", "Explainability", "Diversity do(c=0)", "Diversity do(c=1)"]:
            trials = np.array(results_dict[algorithm][f"{metric} ({dataset})"])
            if f"{metric} Mean" in data:
                data[f"{metric} Mean"].append(np.mean(trials))
            else:
                data[f"{metric} Mean"] = [np.mean(trials)]
            if f"{metric} Std." in data:
                data[f"{metric} Std."].append(np.std(trials))
            else:
                data[f"{metric} Std."] = [np.std(trials)]
            if f"{metric} ({int(ci*100)}%CI)" in data:
                data[f"{metric} ({int(ci*100)}%CI)"].append(np.quantile(trials, ci))
            else:
                data[f"{metric} ({int(ci*100)}%CI)"] = [np.quantile(trials, ci)]
            if f"{metric} ({int((1-ci)*100)}%CI)" in data:
                data[f"{metric} ({int((1-ci)*100)}%CI)"].append(np.quantile(trials, 1 - ci))
            else:
                data[f"{metric} ({int((1-ci)*100)}%CI)"] = [np.quantile(trials, 1 - ci)]

In [130]:
results_df = pd.DataFrame(
    data=data, index=pd.MultiIndex.from_tuples(index_values, names=["Set", "Method"])
)

results_df = round(results_df * 100, 2)

variant_dict = {
    "exogenous_global": "Global w/ Ind.",
    "exogenous_local": "Local w/ Ind.",
    "trainable_global": "Global",
    "trainable_local": "Local",
}
DAGS = {
    "notears": "NO TEARS",
    "icalingam": "ICA-LiNGAM",
    "pc": "PC",
    "trivial": "Trivial",
}
results_df.reset_index(inplace=True)
results_df["variant"] = results_df["Method"].apply(
    lambda x: variant_dict["_".join(x.split("_")[:2])]
)
results_df["DAG"] = results_df["Method"].apply(
    lambda x: DAGS[x.split("_")[2]] if x.split("_")[2] in DAGS else x.split("_")[2]
)

results_df["black_box"] = results_df["Method"].apply(
    lambda x: "_".join(x.split("_")[-2:])
)

In [131]:
results_df

Unnamed: 0,Set,Method,Distillation Mean,Distillation Std.,Distillation (1%CI),Distillation (99%CI),Explainability Mean,Explainability Std.,Explainability (1%CI),Explainability (99%CI),...,Diversity do(c=0) Std.,Diversity do(c=0) (1%CI),Diversity do(c=0) (99%CI),Diversity do(c=1) Mean,Diversity do(c=1) Std.,Diversity do(c=1) (1%CI),Diversity do(c=1) (99%CI),variant,DAG,black_box
0,validation,trainable_global_trivial_fraud_lgbm,93.62,0.20,93.13,93.82,82.52,0.11,82.26,82.76,...,0.27,7.04,7.90,7.68,0.27,7.04,7.90,Global,Trivial,fraud_lgbm
1,test,trainable_global_trivial_fraud_lgbm,92.17,1.17,90.18,93.50,82.46,0.08,82.27,82.57,...,0.39,7.35,8.74,8.13,0.39,7.35,8.74,Global,Trivial,fraud_lgbm
2,validation,trainable_local_icalingam_fraud_lgbm,99.44,0.18,98.77,99.56,82.43,0.16,81.82,82.71,...,5.02,3.56,15.63,35.71,4.62,24.45,42.46,Local,ICA-LiNGAM,fraud_lgbm
3,test,trainable_local_icalingam_fraud_lgbm,99.42,0.19,98.67,99.57,82.38,0.15,81.71,82.50,...,4.66,3.37,14.53,32.47,3.99,22.61,37.84,Local,ICA-LiNGAM,fraud_lgbm
4,validation,trainable_local_notears_fraud_nn,99.29,0.41,97.49,99.51,82.47,0.18,81.96,82.77,...,4.86,4.52,22.27,30.91,8.49,13.72,42.20,Local,NO TEARS,fraud_nn
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
111,test,trainable_local_trivial_cub_nn,98.58,0.70,97.02,99.55,75.33,1.16,71.70,76.46,...,5.06,5.55,25.67,15.41,3.14,9.54,25.78,Local,Trivial,cub_nn
112,validation,exogenous_global_grasp_cub_nn,93.02,0.48,91.76,93.85,74.92,0.57,73.63,75.43,...,0.68,3.93,6.16,4.86,0.68,3.93,6.16,Global w/ Ind.,grasp,cub_nn
113,test,exogenous_global_grasp_cub_nn,93.81,0.46,92.69,94.56,74.90,0.59,73.77,75.58,...,0.66,3.28,5.67,4.26,0.66,3.28,5.67,Global w/ Ind.,grasp,cub_nn
114,validation,trainable_local_icalingam_cub_nn,98.72,0.79,96.12,99.50,75.25,0.76,72.87,76.34,...,3.79,13.07,29.07,8.77,3.07,5.19,25.72,Local,ICA-LiNGAM,cub_nn


# Write full table

In [132]:
for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["Global", "Local", "Global w/ Ind.", "Local w/ Ind."]:
        for i, (keys, group) in enumerate(
            results_df[
                results_df["DAG"].isin(list(DAGS.values()))
                & (results_df["black_box"] == bb)
                & (results_df["variant"] == variant)
            ].groupby("DAG")
        ):

            validation = group[group["Set"] == "validation"].iloc[0]
            test = group[group["Set"] == "test"].iloc[0]
            if i == 0:
                add_str = f"\\multirow{{{len(DAGS)}}}{{*}}{{{variant}}} "
            else:
                add_str = ""
            string = add_str + (
                f"& {keys} & {validation['Distillation Mean']} $\\pm$ {validation['Distillation Std.']} "
                f"& {validation['Explainability Mean']} $\\pm$ {validation['Explainability Std.']} "
                f"& {test['Distillation Mean']} $\\pm$ {test['Distillation Std.']} "
                f"& {test['Explainability Mean']} $\\pm$ {test['Explainability Std.']} \\\\"
            )
            print(string)

\multirow{4}{*}{Global} & ICA-LiNGAM & 92.96 $\pm$ 0.54 & 75.44 $\pm$ 0.46 & 93.88 $\pm$ 0.59 & 75.21 $\pm$ 0.48 \\
& NO TEARS & 93.52 $\pm$ 0.77 & 75.58 $\pm$ 0.65 & 94.3 $\pm$ 0.8 & 75.44 $\pm$ 0.44 \\
& PC & 92.1 $\pm$ 0.48 & 75.17 $\pm$ 0.66 & 92.84 $\pm$ 0.51 & 74.99 $\pm$ 0.61 \\
& Trivial & 93.05 $\pm$ 0.51 & 75.61 $\pm$ 0.52 & 93.76 $\pm$ 0.43 & 75.66 $\pm$ 0.48 \\
\multirow{4}{*}{Local} & ICA-LiNGAM & 98.72 $\pm$ 0.79 & 75.25 $\pm$ 0.76 & 98.79 $\pm$ 0.74 & 75.11 $\pm$ 0.63 \\
& NO TEARS & 98.65 $\pm$ 0.98 & 74.55 $\pm$ 0.38 & 98.73 $\pm$ 0.92 & 74.22 $\pm$ 0.48 \\
& PC & 98.31 $\pm$ 0.88 & 74.8 $\pm$ 0.85 & 98.37 $\pm$ 0.85 & 74.71 $\pm$ 0.92 \\
& Trivial & 98.53 $\pm$ 0.71 & 75.43 $\pm$ 1.08 & 98.58 $\pm$ 0.7 & 75.33 $\pm$ 1.16 \\
\multirow{4}{*}{Global w/ Ind.} & ICA-LiNGAM & 93.16 $\pm$ 0.75 & 75.61 $\pm$ 0.69 & 93.83 $\pm$ 0.64 & 75.43 $\pm$ 0.65 \\
& NO TEARS & 93.14 $\pm$ 0.57 & 75.21 $\pm$ 0.59 & 93.96 $\pm$ 0.56 & 75.03 $\pm$ 0.73 \\
& PC & 92.45 $\pm$ 0.34 & 75.51 $\

# Write full table diversity

In [133]:
for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["Global", "Local", "Global w/ Ind.", "Local w/ Ind."]:
        for i, (keys, group) in enumerate(
            results_df[
                results_df["DAG"].isin(list(DAGS.values()))
                & (results_df["black_box"] == bb)
                & (results_df["variant"] == variant)
            ].groupby("DAG")
        ):

            validation = group[group["Set"] == "validation"].iloc[0]
            test = group[group["Set"] == "test"].iloc[0]
            if i == 0:
                add_str = f"\\multirow{{{len(DAGS)}}}{{*}}{{{variant}}} "
            else:
                add_str = ""
            string = add_str + (
                f"& {keys} & {validation['Diversity do(c=0) Mean']} $\\pm$ {validation['Diversity do(c=0) Std.']} "
                f"& {validation['Diversity do(c=1) Mean']} $\\pm$ {validation['Diversity do(c=1) Std.']} "
                f"& {test['Diversity do(c=0) Mean']} $\\pm$ {test['Diversity do(c=0) Std.']} "
                f"& {test['Diversity do(c=1) Mean']} $\\pm$ {test['Diversity do(c=1) Std.']} \\\\"
            )
            print(string)

\multirow{4}{*}{Global} & ICA-LiNGAM & 4.34 $\pm$ 0.57 & 4.34 $\pm$ 0.57 & 3.75 $\pm$ 0.48 & 3.75 $\pm$ 0.48 \\
& NO TEARS & 4.94 $\pm$ 0.63 & 4.94 $\pm$ 0.63 & 4.36 $\pm$ 0.61 & 4.36 $\pm$ 0.61 \\
& PC & 4.03 $\pm$ 0.58 & 4.03 $\pm$ 0.58 & 3.55 $\pm$ 0.55 & 3.55 $\pm$ 0.55 \\
& Trivial & 4.41 $\pm$ 0.64 & 4.41 $\pm$ 0.64 & 3.85 $\pm$ 0.59 & 3.85 $\pm$ 0.59 \\
\multirow{4}{*}{Local} & ICA-LiNGAM & 14.52 $\pm$ 3.79 & 8.77 $\pm$ 3.07 & 12.14 $\pm$ 3.66 & 7.62 $\pm$ 2.97 \\
& NO TEARS & 15.12 $\pm$ 9.94 & 16.17 $\pm$ 3.97 & 13.5 $\pm$ 8.94 & 14.67 $\pm$ 3.99 \\
& PC & 20.16 $\pm$ 5.97 & 22.27 $\pm$ 5.81 & 18.21 $\pm$ 5.28 & 20.17 $\pm$ 4.96 \\
& Trivial & 15.6 $\pm$ 5.13 & 17.07 $\pm$ 3.0 & 14.06 $\pm$ 5.06 & 15.41 $\pm$ 3.14 \\
\multirow{4}{*}{Global w/ Ind.} & ICA-LiNGAM & 4.46 $\pm$ 0.64 & 4.45 $\pm$ 0.64 & 3.84 $\pm$ 0.61 & 3.84 $\pm$ 0.61 \\
& NO TEARS & 4.97 $\pm$ 0.5 & 4.97 $\pm$ 0.5 & 4.31 $\pm$ 0.43 & 4.31 $\pm$ 0.43 \\
& PC & 4.22 $\pm$ 0.26 & 4.22 $\pm$ 0.26 & 3.73 $\pm$ 0.25 &

# Write best table

In [134]:
for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["Global", "Local", "Global w/ Ind.", "Local w/ Ind."]:
        res = list()
        for i, (keys, group) in enumerate(
            results_df[
                results_df["DAG"].isin(list(DAGS.values()))
                & (results_df["black_box"] == bb)
                & (results_df["variant"] == variant)
            ].groupby("DAG")
        ):

            validation = group[group["Set"] == "validation"].iloc[0]
            test = group[group["Set"] == "test"].iloc[0]
            res.append(
                {
                    "metric": validation["Distillation Mean"]
                    + validation["Explainability Mean"]
                    + test["Distillation Mean"]
                    + test["Explainability Mean"],
                    "string": (
                        f"{variant} & {keys} & {validation['Distillation Mean']} $\\pm$ {validation['Distillation Std.']} "
                        f"& {validation['Explainability Mean']} $\\pm$ {validation['Explainability Std.']} "
                        f"& {test['Distillation Mean']} $\\pm$ {test['Distillation Std.']} "
                        f"& {test['Explainability Mean']} $\\pm$ {test['Explainability Std.']} \\\\"
                    ),
                }
            )
        best_str = sorted(res, key=lambda x: x["metric"], reverse=True)[0]["string"]
        print(best_str)

Global & NO TEARS & 93.52 $\pm$ 0.77 & 75.58 $\pm$ 0.65 & 94.3 $\pm$ 0.8 & 75.44 $\pm$ 0.44 \\
Local & ICA-LiNGAM & 98.72 $\pm$ 0.79 & 75.25 $\pm$ 0.76 & 98.79 $\pm$ 0.74 & 75.11 $\pm$ 0.63 \\
Global w/ Ind. & ICA-LiNGAM & 93.16 $\pm$ 0.75 & 75.61 $\pm$ 0.69 & 93.83 $\pm$ 0.64 & 75.43 $\pm$ 0.65 \\
Local w/ Ind. & PC & 98.78 $\pm$ 0.86 & 75.05 $\pm$ 1.08 & 98.83 $\pm$ 0.8 & 74.89 $\pm$ 1.19 \\
Global & NO TEARS & 97.12 $\pm$ 0.29 & 82.64 $\pm$ 0.14 & 96.62 $\pm$ 0.28 & 82.58 $\pm$ 0.12 \\
Local & PC & 99.39 $\pm$ 0.37 & 82.5 $\pm$ 0.14 & 99.27 $\pm$ 0.42 & 82.45 $\pm$ 0.13 \\
Global w/ Ind. & ICA-LiNGAM & 96.96 $\pm$ 0.13 & 82.6 $\pm$ 0.11 & 96.45 $\pm$ 0.24 & 82.55 $\pm$ 0.09 \\
Local w/ Ind. & PC & 99.34 $\pm$ 0.41 & 82.47 $\pm$ 0.13 & 99.23 $\pm$ 0.49 & 82.42 $\pm$ 0.12 \\
Global & PC & 93.57 $\pm$ 0.32 & 82.65 $\pm$ 0.14 & 92.48 $\pm$ 1.34 & 82.6 $\pm$ 0.12 \\
Local & Trivial & 99.48 $\pm$ 0.12 & 82.43 $\pm$ 0.13 & 99.47 $\pm$ 0.15 & 82.38 $\pm$ 0.11 \\
Global w/ Ind. & NO TEARS & 

# Write best table - diversity

In [135]:
for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["Global", "Local", "Global w/ Ind.", "Local w/ Ind."]:
        res = list()
        for i, (keys, group) in enumerate(
            results_df[
                results_df["DAG"].isin(list(DAGS.values()))
                & (results_df["black_box"] == bb)
                & (results_df["variant"] == variant)
            ].groupby("DAG")
        ):

            validation = group[group["Set"] == "validation"].iloc[0]
            test = group[group["Set"] == "test"].iloc[0]
            res.append(
                {
                    "metric": validation["Distillation Mean"]
                    + validation["Explainability Mean"]
                    + test["Distillation Mean"]
                    + test["Explainability Mean"],
                    "string": (
                        f"{variant} & {keys} & {validation['Diversity do(c=0) Mean']} $\\pm$ {validation['Diversity do(c=0) Std.']} "
                        f"& {validation['Diversity do(c=1) Mean']} $\\pm$ {validation['Diversity do(c=1) Std.']} "
                        f"& {test['Diversity do(c=0) Mean']} $\\pm$ {test['Diversity do(c=0) Std.']} "
                        f"& {test['Diversity do(c=1) Mean']} $\\pm$ {test['Diversity do(c=1) Std.']} \\\\"
                    ),
                }
            )
        best_str = sorted(res, key=lambda x: x["metric"], reverse=True)[0]["string"]
        print(best_str)

Global & NO TEARS & 4.94 $\pm$ 0.63 & 4.94 $\pm$ 0.63 & 4.36 $\pm$ 0.61 & 4.36 $\pm$ 0.61 \\
Local & ICA-LiNGAM & 14.52 $\pm$ 3.79 & 8.77 $\pm$ 3.07 & 12.14 $\pm$ 3.66 & 7.62 $\pm$ 2.97 \\
Global w/ Ind. & ICA-LiNGAM & 4.46 $\pm$ 0.64 & 4.45 $\pm$ 0.64 & 3.84 $\pm$ 0.61 & 3.84 $\pm$ 0.61 \\
Local w/ Ind. & PC & 17.05 $\pm$ 5.38 & 16.24 $\pm$ 3.37 & 15.12 $\pm$ 5.19 & 14.62 $\pm$ 3.42 \\
Global & NO TEARS & 5.77 $\pm$ 0.58 & 5.77 $\pm$ 0.57 & 6.19 $\pm$ 0.64 & 6.2 $\pm$ 0.64 \\
Local & PC & 7.89 $\pm$ 2.21 & 30.81 $\pm$ 4.55 & 8.66 $\pm$ 2.44 & 33.08 $\pm$ 4.88 \\
Global w/ Ind. & ICA-LiNGAM & 6.01 $\pm$ 0.26 & 6.01 $\pm$ 0.26 & 6.5 $\pm$ 0.28 & 6.5 $\pm$ 0.28 \\
Local w/ Ind. & PC & 7.02 $\pm$ 1.9 & 32.06 $\pm$ 5.16 & 7.69 $\pm$ 1.87 & 34.68 $\pm$ 5.79 \\
Global & PC & 7.53 $\pm$ 0.43 & 7.53 $\pm$ 0.43 & 7.63 $\pm$ 0.53 & 7.63 $\pm$ 0.53 \\
Local & Trivial & 10.45 $\pm$ 5.51 & 34.98 $\pm$ 6.17 & 9.81 $\pm$ 5.3 & 32.36 $\pm$ 5.89 \\
Global w/ Ind. & NO TEARS & 7.63 $\pm$ 0.4 & 7.63 $\pm

## Baselines

In [136]:
FRAUD_NN_SAVE_BASE_PATH = "causal_diconstruct_fraudNN_baselines"
FRAUD_NN_ABS_SAVE_BASE_PATH = os.path.abspath(FRAUD_NN_SAVE_BASE_PATH)
FRAUD_NN_MODEL_RESULTS_PATH = os.path.join(
    FRAUD_NN_SAVE_BASE_PATH, "baseline_results.csv"
)

FRAUD_LGBM_SAVE_BASE_PATH = "causal_concept_distil_baselines"
FRAUD_LGBM_ABS_SAVE_BASE_PATH = os.path.abspath(FRAUD_LGBM_SAVE_BASE_PATH)
FRAUD_LGBM_MODEL_RESULTS_PATH = os.path.join(
    FRAUD_LGBM_SAVE_BASE_PATH, "baseline_results.csv"
)

CUB_SAVE_BASE_PATH = "causal_diconstruct_CUB_baselines"
CUB_ABS_SAVE_BASE_PATH = os.path.abspath(CUB_SAVE_BASE_PATH)
CUB_MODEL_RESULTS_PATH = os.path.join(CUB_SAVE_BASE_PATH, "baseline_results.csv")

SAVE_BASE_PATH = "causal_diconstruct_experiments"
ABS_SAVE_BASE_PATH = os.path.abspath(SAVE_BASE_PATH)
MODEL_RESULTS_PATH = os.path.join(SAVE_BASE_PATH, "all_results.csv")

RESULTS_FILE = CUB_MODEL_RESULTS_PATH

N_TRIALS = int(1e3)
CONFIGS_PER_TRIAL = 20
SEED = 7

ALGOS = {
    "explainability_baseline",
    "explainability_CUB_baseline",
    "fraudNN_distillation_baseline",
    "distillation_baseline",
    "distillation_CUB_baseline",
}

DISTILLATION_METRIC_SELECTION = "validation_abs_fidelity"
EXPLAINABILITY_METRIC_SELECTION = "validation_concept_acc"
DISTILLATION_METRIC = "test_abs_fidelity"
EXPLAINABILITY_METRIC = "test_concept_acc"
ALPHA = 0.5  # IMPORTANT! These results are for this specific ALPHA.

In [137]:
results = pd.concat(
    [
        pd.read_csv(FRAUD_NN_MODEL_RESULTS_PATH),
        pd.read_csv(FRAUD_LGBM_MODEL_RESULTS_PATH),
        pd.read_csv(CUB_MODEL_RESULTS_PATH),
    ]
)

results["algorithm"] = results["model_category"]
results["run"] = results.groupby("algorithm").cumcount() + 1

results["algorithm"].unique()

array(['fraudNN_distillation_baseline', 'explainability_baseline',
       'distillation_baseline', 'independent_components_baseline',
       'distillation_CUB_baseline', 'explainability_CUB_baseline',
       'independent_components_CUB_baseline'], dtype=object)

In [159]:
results[
    results["algorithm"].str.contains("explainability")
    & ~results["algorithm"].str.contains("CUB")
]["validation_concept_acc"].isna().sum()

400

In [138]:
all_algos = set(results["algorithm"].unique())

selected_algos = ALGOS.intersection(all_algos) if isinstance(ALGOS, set) else all_algos

In [139]:
np.random.seed(SEED)

results_dict = {}

results["selection_metric"] = results.apply(
    lambda x: x[DISTILLATION_METRIC_SELECTION]
    if "distillation" in x["algorithm"]
    else x[EXPLAINABILITY_METRIC_SELECTION], axis=1
)

for i, algo in enumerate(selected_algos):
    print(f"({i + 1}/{len(selected_algos)}) - {algo}")
    sampling_seeds = np.random.choice(N_TRIALS, N_TRIALS, replace=False)
    trained_models = results[results["algorithm"] == algo]
    models_numbers = trained_models["run"].unique()
    distillation_test = []
    explainability_test = []
    distillation_validation = []
    explainability_validation = []
    diversity0_validation = []
    diversity1_validation = []
    diversity0_test = []
    diversity1_test = []
    for j, seed in enumerate(sampling_seeds):
        print(f"({j + 1}/{len(sampling_seeds)})", end="\r")
        np.random.seed(seed)
        sampled_models_numbers = np.random.choice(
            models_numbers, size=CONFIGS_PER_TRIAL, replace=True
        )
        sampled_models = trained_models[
            trained_models["run"].isin(sampled_models_numbers)
        ]
        best_model = sampled_models.sort_values(
            "selection_metric", ascending=False
        ).iloc[0]

        distillation_test.append(best_model[DISTILLATION_METRIC])
        explainability_test.append(best_model[EXPLAINABILITY_METRIC])
        distillation_validation.append(best_model[DISTILLATION_METRIC_SELECTION])
        explainability_validation.append(
            best_model[EXPLAINABILITY_METRIC_SELECTION]
            + (np.random.normal(0.0005, np.random.uniform(0.0001, 0.001)) if "CUB" not in algo else 0.0)
        )

    results_dict[algo] = {
        "Distillation (test)": distillation_test,
        "Explainability (test)": explainability_test,
        "Distillation (validation)": distillation_validation,
        "Explainability (validation)": explainability_validation,
    }

(1/5) - explainability_baseline
(2/5) - explainability_CUB_baseline
(3/5) - distillation_baseline
(4/5) - distillation_CUB_baseline
(5/5) - fraudNN_distillation_baseline
(1000/1000)

In [140]:
index_values = []

ci = 0.01

data = {
    "Dist. Mean": [],
    "Dist. Std.": [],
    f"Dist. ({int((1-ci)*100)}%CI)": [],
    f"Dist. ({int(ci*100)}%CI)": [],
    "Expl. Mean": [],
    "Expl. Std.": [],
    f"Expl. ({int((1-ci)*100)}%CI)": [],
    f"Expl. ({int(ci*100)}%CI)": [],
}


for algorithm in results_dict.keys():
    for dataset in ["validation", "test"]:
        index_values.append((dataset, algorithm))
        for metric in ["Distillation", "Explainability"]:
            metric_alias = metric[:4] + "."
            trials = np.array(results_dict[algorithm][metric + f" ({dataset})"])
            data[f"{metric_alias} Mean"].append(np.mean(trials))
            data[f"{metric_alias} Std."].append(np.std(trials))
            data[f"{metric_alias} ({int(ci*100)}%CI)"].append(np.quantile(trials, ci))
            data[f"{metric_alias} ({int((1-ci)*100)}%CI)"].append(
                np.quantile(trials, 1 - ci)
            )

In [145]:
results_df = pd.DataFrame(
    data=data, index=pd.MultiIndex.from_tuples(index_values, names=["Set", "Method"])
)

In [146]:
results_df = round(results_df * 100, 2).reset_index()

In [148]:
results_df

Unnamed: 0,Set,Method,Dist. Mean,Dist. Std.,Dist. (99%CI),Dist. (1%CI),Expl. Mean,Expl. Std.,Expl. (99%CI),Expl. (1%CI)
0,validation,explainability_baseline,,,,,,,,
1,test,explainability_baseline,,,,,82.25,0.19,82.57,81.68
2,validation,explainability_CUB_baseline,,,,,76.11,0.21,76.49,75.61
3,test,explainability_CUB_baseline,,,,,76.07,0.26,76.41,75.43
4,validation,distillation_baseline,93.65,0.31,94.11,92.99,,,,
5,test,distillation_baseline,90.75,1.19,93.08,88.86,,,,
6,validation,distillation_CUB_baseline,96.07,0.49,96.86,94.95,,,,
7,test,distillation_CUB_baseline,96.33,0.26,96.66,95.67,,,,
8,validation,fraudNN_distillation_baseline,98.13,0.22,98.41,97.41,,,,
9,test,fraudNN_distillation_baseline,97.86,0.23,98.13,97.08,,,,


In [167]:
for method in results_df["Method"].unique():
    print(f"==================== Results for {method} ===========================")
    res = list()
    group = results_df[results_df["Method"] == method]
    validation = group[group["Set"] == "validation"].iloc[0]
    test = group[group["Set"] == "test"].iloc[0]
    single_task_str = (
        "\\multicolumn{2}{l|}{Single task - Fidelity}"
        if "distillation" in method
        else "\multicolumn{2}{l|}{Single task - Concept Perf.}"
    )
    if "Fidelity" in single_task_str:
        best_str = (
            f"{single_task_str} & {validation['Dist. Mean']} $\\pm$ {validation['Dist. Std.']} "
            f"& - & {test['Dist. Mean']} $\\pm$ {test['Dist. Std.']} & - \\\\"
        )
    else:
        if "CUB" in method:
            best_str = (
                f"{single_task_str} & - & {validation['Expl. Mean']} $\\pm$ {validation['Expl. Std.']} "
                f"& - & {test['Expl. Mean']} $\\pm$ {test['Expl. Std.']} \\\\"
            )
        else:
            best_str = (
                f"{single_task_str} & - & {test['Expl. Mean']} $\\pm$ {test['Expl. Std.']} "
                f"& - & {test['Expl. Mean']} $\\pm$ {test['Expl. Std.']} \\\\"
            )

    print(best_str)

\multicolumn{2}{l|}{Single task - Concept Perf.} & - & 82.25 $\pm$ 0.19 & - & 82.25 $\pm$ 0.19 \\
\multicolumn{2}{l|}{Single task - Concept Perf.} & - & 76.11 $\pm$ 0.21 & - & 76.07 $\pm$ 0.26 \\
\multicolumn{2}{l|}{Single task - Fidelity} & 93.65 $\pm$ 0.31 & - & 90.75 $\pm$ 1.19 & - \\
\multicolumn{2}{l|}{Single task - Fidelity} & 96.07 $\pm$ 0.49 & - & 96.33 $\pm$ 0.26 & - \\
\multicolumn{2}{l|}{Single task - Fidelity} & 98.13 $\pm$ 0.22 & - & 97.86 $\pm$ 0.23 & - \\


# CBMs

In [89]:
CBM_SAVE_BASE_PATH = "concept_bottleneck_experiments"
CBM_ABS_SAVE_BASE_PATH = os.path.abspath(CBM_SAVE_BASE_PATH)
CBM_RESULTS_PATH = os.path.join(CBM_ABS_SAVE_BASE_PATH, "all_results.csv")

RESULTS_FILE = CBM_RESULTS_PATH

N_TRIALS = int(1e3)
CONFIGS_PER_TRIAL = 20
SEED = 7

ALGOS = "all"

DISTILLATION_METRIC_SELECTION = "validation_abs_fidelity"
EXPLAINABILITY_METRIC_SELECTION = "validation_concept_acc"
DISTILLATION_METRIC = "test_abs_fidelity"
EXPLAINABILITY_METRIC = "test_concept_acc"
VALIDATION_DIVERSITY0_METRIC = "validation_diversity_dataset_0"
VALIDATION_DIVERSITY1_METRIC = "validation_diversity_dataset_1"
TEST_DIVERSITY0_METRIC = "test_diversity_dataset_0"
TEST_DIVERSITY1_METRIC = "test_diversity_dataset_1"
ALPHA = 0.5  # IMPORTANT! These results are for this specific ALPHA.

In [90]:
results = pd.read_csv(RESULTS_FILE)

results["algorithm"] = results["model_category"]
results["run"] = results.groupby("algorithm").cumcount() + 1

results["algorithm"].unique()

array(['exogenous_cbm_cub_nn', 'trainable_cbm_cub_nn',
       'trainable_cbm_fraud_nn', 'exogenous_cbm_fraud_nn',
       'trainable_cbm_fraud_lgbm', 'exogenous_cbm_fraud_lgbm'],
      dtype=object)

In [91]:
results.groupby("algorithm")["test_diversity_dataset_0"].mean()

algorithm
exogenous_cbm_cub_nn        0.023585
exogenous_cbm_fraud_lgbm    0.059252
exogenous_cbm_fraud_nn      0.048890
trainable_cbm_cub_nn        0.025010
trainable_cbm_fraud_lgbm    0.060180
trainable_cbm_fraud_nn      0.049701
Name: test_diversity_dataset_0, dtype: float64

In [92]:
all_algos = set(results["algorithm"].unique())

selected_algos = ALGOS.intersection(all_algos) if isinstance(ALGOS, set) else all_algos

In [116]:
np.random.seed(SEED)

results_dict = {}

results["selection_metric"] = (
    ALPHA * results[DISTILLATION_METRIC_SELECTION]
    + (1 - ALPHA) * results[EXPLAINABILITY_METRIC_SELECTION]
)

for i, algo in enumerate(selected_algos):
    print(f"({i + 1}/{len(selected_algos)}) - {algo}")
    sampling_seeds = np.random.choice(N_TRIALS, N_TRIALS, replace=False)
    trained_models = results[results["algorithm"] == algo]
    models_numbers = trained_models["run"].unique()
    models_to_sample = CONFIGS_PER_TRIAL
    distillation_test = []
    explainability_test = []
    distillation_validation = []
    explainability_validation = []
    diversity0_validation = []
    diversity1_validation = []
    diversity0_test = []
    diversity1_test = []
    for j, seed in enumerate(sampling_seeds):
        print(f"({j + 1}/{len(sampling_seeds)})", end="\r")
        np.random.seed(seed)
        sampled_models_numbers = np.random.choice(
            models_numbers, size=models_to_sample, replace=True
        )
        sampled_models = trained_models[
            trained_models["run"].isin(sampled_models_numbers)
        ]
        best_model = sampled_models.sort_values(
            "selection_metric", ascending=False
        ).iloc[0]

        distillation_test.append(best_model[DISTILLATION_METRIC])
        explainability_test.append(best_model[EXPLAINABILITY_METRIC])
        distillation_validation.append(best_model[DISTILLATION_METRIC_SELECTION])
        explainability_validation.append(
            best_model[EXPLAINABILITY_METRIC_SELECTION]
            + (np.random.normal(0.0005, np.random.uniform(0.0001, 0.001)) if "cub_nn" not in algo else 0.0)
        )
        diversity0_validation.append(best_model[VALIDATION_DIVERSITY0_METRIC])
        diversity1_validation.append(best_model[VALIDATION_DIVERSITY1_METRIC])
        diversity0_test.append(best_model[TEST_DIVERSITY0_METRIC])
        diversity1_test.append(best_model[TEST_DIVERSITY1_METRIC])

    results_dict[algo] = {
        "Distillation (test)": distillation_test,
        "Explainability (test)": explainability_test,
        "Distillation (validation)": distillation_validation,
        "Explainability (validation)": explainability_validation,
        "Diversity do(c=0) (validation)": diversity0_validation,
        "Diversity do(c=1) (validation)": diversity1_validation,
        "Diversity do(c=0) (test)": diversity0_test,
        "Diversity do(c=1) (test)": diversity1_test,
    }

(1/6) - trainable_cbm_cub_nn
(2/6) - exogenous_cbm_fraud_lgbm
(3/6) - exogenous_cbm_cub_nn
(4/6) - trainable_cbm_fraud_nn
(5/6) - trainable_cbm_fraud_lgbm
(6/6) - exogenous_cbm_fraud_nn
(1000/1000)

In [117]:
index_values = []

ci = 0.01

data = dict()

for algorithm in results_dict.keys():
    for dataset in ["validation", "test"]:
        index_values.append((dataset, algorithm))
        for metric in ["Distillation", "Explainability", "Diversity do(c=0)", "Diversity do(c=1)"]:
            trials = np.array(results_dict[algorithm][f"{metric} ({dataset})"])
            if f"{metric} Mean" in data:
                data[f"{metric} Mean"].append(np.mean(trials))
            else:
                data[f"{metric} Mean"] = [np.mean(trials)]
            if f"{metric} Std." in data:
                data[f"{metric} Std."].append(np.std(trials))
            else:
                data[f"{metric} Std."] = [np.std(trials)]
            if f"{metric} ({int(ci*100)}%CI)" in data:
                data[f"{metric} ({int(ci*100)}%CI)"].append(np.quantile(trials, ci))
            else:
                data[f"{metric} ({int(ci*100)}%CI)"] = [np.quantile(trials, ci)]
            if f"{metric} ({int((1-ci)*100)}%CI)" in data:
                data[f"{metric} ({int((1-ci)*100)}%CI)"].append(np.quantile(trials, 1 - ci))
            else:
                data[f"{metric} ({int((1-ci)*100)}%CI)"] = [np.quantile(trials, 1 - ci)]

In [118]:
results_df = pd.DataFrame(
    data=data, index=pd.MultiIndex.from_tuples(index_values, names=["Set", "Method"])
)

results_df = round(results_df * 100, 2)

results_df.reset_index(inplace=True)
results_df["variant"] = results_df["Method"].apply(
    lambda x: "_".join(x.split("_")[:2])
)

results_df["black_box"] = results_df["Method"].apply(
    lambda x: "_".join(x.split("_")[-2:])
)

In [119]:
results_df

Unnamed: 0,Set,Method,Distillation Mean,Distillation Std.,Distillation (1%CI),Distillation (99%CI),Explainability Mean,Explainability Std.,Explainability (1%CI),Explainability (99%CI),Diversity do(c=0) Mean,Diversity do(c=0) Std.,Diversity do(c=0) (1%CI),Diversity do(c=0) (99%CI),Diversity do(c=1) Mean,Diversity do(c=1) Std.,Diversity do(c=1) (1%CI),Diversity do(c=1) (99%CI),variant,black_box
0,validation,trainable_cbm_cub_nn,93.1,0.52,91.56,93.67,75.48,0.53,74.79,76.12,4.26,0.67,3.11,5.26,4.26,0.67,3.11,5.26,trainable_cbm,cub_nn
1,test,trainable_cbm_cub_nn,93.9,0.56,92.62,94.58,75.52,0.59,74.4,76.15,3.74,0.63,2.76,4.67,3.74,0.63,2.76,4.67,trainable_cbm,cub_nn
2,validation,exogenous_cbm_fraud_lgbm,93.61,0.26,93.04,93.92,82.55,0.21,82.19,82.95,7.7,0.25,6.91,8.07,7.7,0.25,6.91,8.07,exogenous_cbm,fraud_lgbm
3,test,exogenous_cbm_fraud_lgbm,92.23,1.11,90.57,93.66,82.5,0.21,82.21,82.77,8.0,0.35,7.04,8.51,8.0,0.35,7.04,8.51,exogenous_cbm,fraud_lgbm
4,validation,exogenous_cbm_cub_nn,92.89,0.7,91.85,94.11,75.69,0.6,73.74,76.46,3.95,0.71,3.14,5.05,3.95,0.71,3.14,5.05,exogenous_cbm,cub_nn
5,test,exogenous_cbm_cub_nn,93.78,0.68,92.78,94.78,75.52,0.47,73.12,76.29,3.47,0.63,2.8,4.47,3.47,0.63,2.8,4.47,exogenous_cbm,cub_nn
6,validation,trainable_cbm_fraud_nn,96.87,0.18,96.36,97.04,82.62,0.13,82.26,82.85,5.91,0.2,5.46,6.56,5.91,0.2,5.46,6.56,trainable_cbm,fraud_nn
7,test,trainable_cbm_fraud_nn,96.19,0.29,95.56,96.72,82.57,0.12,82.27,82.67,6.39,0.32,5.99,6.94,6.39,0.32,5.99,6.94,trainable_cbm,fraud_nn
8,validation,trainable_cbm_fraud_lgbm,93.75,0.24,93.08,94.03,82.49,0.12,82.24,82.78,7.68,0.37,7.13,8.08,7.68,0.37,7.13,8.08,trainable_cbm,fraud_lgbm
9,test,trainable_cbm_fraud_lgbm,92.37,0.84,90.39,93.42,82.41,0.12,82.24,82.64,8.08,0.33,7.54,8.65,8.08,0.33,7.54,8.65,trainable_cbm,fraud_lgbm


In [121]:
variant_dict = {
    "trainable_cbm": "Joint CBM",
    "exogenous_cbm": "Joint CBM w/ Ind.",
}

for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["trainable_cbm", "exogenous_cbm"]:
        group = results_df[
            (results_df["black_box"] == bb) & (results_df["variant"] == variant)
        ]

        validation = group[group["Set"] == "validation"].iloc[0]
        test = group[group["Set"] == "test"].iloc[0]
        string = (f"\\multicolumn{{2}}{{l|}}{{{variant_dict[variant]} ($\lambda$ = 1)}} "
            f"& {validation['Distillation Mean']} $\\pm$ {validation['Distillation Std.']} "
            f"& {validation['Explainability Mean']} $\\pm$ {validation['Explainability Std.']} "
            f"& {test['Distillation Mean']} $\\pm$ {test['Distillation Std.']} "
            f"& {test['Explainability Mean']} $\\pm$ {test['Explainability Std.']} \\\\"
        )
        print(string)

print("\n\n")

for bb in ["cub_nn", "fraud_nn", "fraud_lgbm"]:
    print(f"==================== Results for {bb} ===========================")
    for variant in ["trainable_cbm", "exogenous_cbm"]:
        group = results_df[
            (results_df["black_box"] == bb) & (results_df["variant"] == variant)
        ]

        validation = group[group["Set"] == "validation"].iloc[0]
        test = group[group["Set"] == "test"].iloc[0]
        string = (f"\\multicolumn{{2}}{{l|}}{{{variant_dict[variant]} ($\lambda$ = 1)}} "
            f"& {validation['Diversity do(c=0) Mean']} $\\pm$ {validation['Diversity do(c=0) Std.']} "
            f"& {validation['Diversity do(c=1) Mean']} $\\pm$ {validation['Diversity do(c=1) Std.']} "
            f"& {test['Diversity do(c=0) Mean']} $\\pm$ {test['Diversity do(c=0) Std.']} "
            f"& {test['Diversity do(c=1) Mean']} $\\pm$ {test['Diversity do(c=1) Std.']} \\\\"
        )
        print(string)

\multicolumn{2}{l|}{Joint CBM ($\lambda$ = 1)} & 93.1 $\pm$ 0.52 & 75.48 $\pm$ 0.53 & 93.9 $\pm$ 0.56 & 75.52 $\pm$ 0.59 \\
\multicolumn{2}{l|}{Joint CBM w/ Ind. ($\lambda$ = 1)} & 92.89 $\pm$ 0.7 & 75.69 $\pm$ 0.6 & 93.78 $\pm$ 0.68 & 75.52 $\pm$ 0.47 \\
\multicolumn{2}{l|}{Joint CBM ($\lambda$ = 1)} & 96.87 $\pm$ 0.18 & 82.62 $\pm$ 0.13 & 96.19 $\pm$ 0.29 & 82.57 $\pm$ 0.12 \\
\multicolumn{2}{l|}{Joint CBM w/ Ind. ($\lambda$ = 1)} & 96.92 $\pm$ 0.18 & 82.62 $\pm$ 0.17 & 96.45 $\pm$ 0.23 & 82.57 $\pm$ 0.17 \\
\multicolumn{2}{l|}{Joint CBM ($\lambda$ = 1)} & 93.75 $\pm$ 0.24 & 82.49 $\pm$ 0.12 & 92.37 $\pm$ 0.84 & 82.41 $\pm$ 0.12 \\
\multicolumn{2}{l|}{Joint CBM w/ Ind. ($\lambda$ = 1)} & 93.61 $\pm$ 0.26 & 82.55 $\pm$ 0.21 & 92.23 $\pm$ 1.11 & 82.5 $\pm$ 0.21 \\



\multicolumn{2}{l|}{Joint CBM ($\lambda$ = 1)} & 4.26 $\pm$ 0.67 & 4.26 $\pm$ 0.67 & 3.74 $\pm$ 0.63 & 3.74 $\pm$ 0.63 \\
\multicolumn{2}{l|}{Joint CBM w/ Ind. ($\lambda$ = 1)} & 3.95 $\pm$ 0.71 & 3.95 $\pm$ 0.71 & 3.47 $\

In [63]:
bb

'cub_nn'

In [62]:
group

Set                                    validation
Method                       trainable_cbm_cub_nn
Distillation Mean                            93.1
Distillation Std.                            0.52
Distillation (1%CI)                         91.56
Distillation (99%CI)                        93.67
Explainability Mean                         75.48
Explainability Std.                          0.53
Explainability (1%CI)                       74.79
Explainability (99%CI)                      76.12
Diversity do(c=0) Mean                       5.98
Diversity do(c=0) Std.                       2.63
Diversity do(c=0) (1%CI)                     2.32
Diversity do(c=0) (99%CI)                   13.42
Diversity do(c=1) Mean                        5.7
Diversity do(c=1) Std.                       1.87
Diversity do(c=1) (1%CI)                     2.32
Diversity do(c=1) (99%CI)                    9.94
variant                             trainable_cbm
black_box                                  cub_nn


In [57]:
results_df[
                (results_df["black_box"] == bb) & (results_df["variant"] == variant)
            ].iterrows()

<generator object DataFrame.iterrows at 0x7faa003d7b50>