This notbeook performs the inference and analysis for a chosen experiment.

**Instructions**
* In the second cell you may change the path to the results folders. The default values should work, if when creating the worlds and running the experiment left the default results paths values.
* In the section choose experiment for analysis, you can opt which experiment to run analysis for by uncommeting the related info with comments. To rerun analysis for another experiment, restart notebook, uncomment only the experiment to run analysis for.
* with_spatial_flip variable indicates wether the SpatialFlip method shouls be included for the analysis. It is set to true only for statistical parity experiments and for experiments where ran withiing reasonable time. You can set it to false to run analysis without it
* PROMIS methods with apply_fit_flips=true, and SpatialFlip method in inference just apply the precomputed flips.

**Analysis**
1. Reads related experiment info data.
2. Reads pretrained models for SpatialFlip, PROMIS methods, performs predictions for test set and reads precomputed predictions for FairWhere method.
3. Computes MLR (for statistical parity or equal opportunity depending on the experiment).
4. Accuracy/F1 score, except for LAR (LAR includes only predictions), unfair by design (is synthetic) experiment.
5. Disparity (FairWhere unfairness score definition) only for the DNN experiment.
6. Computes normalized statistics (LR) (by dividing with the maximum statistic of the initial world).
7. Show the above computed metrics plus fit times, budgets where PROMIS Opt reached limit, final budget metrics

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import matplotlib
sys.path.append(os.path.abspath(os.path.join("..")))
from analysis.analyse_results_func import *
from sklearn import metrics
from utils.plot_utils import *
from utils.data_utils import (
    read_scanned_regs,
    get_y,
    get_pos_info_regions,
    read_all_models,
)
from utils.scores import get_mlr
from utils.results_names_utils import get_train_val_test_paths, combine_world_info
from sklearn import metrics
import ast
import random
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [2]:
base_path = "../../data/"
results_base_path = "../../results/" # path to the base results folder
save_plots_base_path = "../../plots/" # path to the base plots folder to save the plots or "" to not save them 

with_spatial_flip = False
apply_fit_flips = False
only_methods = []  # indicate all methods to be used for the analysis

dnn_exp_dir = "dnn_exp/" # directory name for the DNN experiments
xgb_eq_opp_dir = "xgb_eq_opp_exp/" # directory name for the XGB experiments with equal opportunity fairness notion
lar_exp_dir = "lar_exp/" # directory name for the LAR experiments
semi_synth_dir = "crime_semi_synth_exp/" # directory name for the semi synthetic experiments
dataset_name = "crime" # default dataset name. For the LAR dataset, it is set to "lar"
lar_dataset_name = "lar"

figsize = (20, 8) 
display_title = True

seed = 42
np.random.seed(seed)  
random.seed(seed)

## Choose Experiment for Analysis

### DNN Experiment (Equal Opportunity)

In [3]:
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "dnn",
#     "5_x_5",
#     True,
#     "equal_opportunity",
#     dnn_exp_dir,
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "dnn",
#     "non_overlap_k_8",
#     True,
#     "equal_opportunity",
#     dnn_exp_dir,
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "dnn",
#     "overlap_k_10_radii_4",
#     True,
#     "equal_opportunity",
#     dnn_exp_dir,
# )

### DNN Experiment (Statistical Parity)

In [4]:
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, with_spatial_flip = (
#     "dnn",
#     "5_x_5",
#     True,
#     "statistical_parity",
#     dnn_exp_dir,
#     True
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, with_spatial_flip = (
#     "dnn",
#     "non_overlap_k_8",
#     True,
#     "statistical_parity",
#     dnn_exp_dir,
#     True
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, with_spatial_flip = (
#     "dnn",
#     "overlap_k_10_radii_4",
#     True,
#     "statistical_parity",
#     dnn_exp_dir,
#     True
# )

### LAR Experiment

In [5]:
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, dataset_name, with_spatial_flip, apply_fit_flips = (
#     "",
#     "non_overlap_k_100",
#     False,
#     "statistical_parity",
#     lar_exp_dir,
#     lar_dataset_name,
#     True,
#     True
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, dataset_name, apply_fit_flips = (
#     "",
#     "overlap_k_100_radii_30",
#     False,
#     "statistical_parity",
#     lar_exp_dir,
#     lar_dataset_name,
#     True
# )

# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, dataset_name, with_spatial_flip, apply_fit_flips, with_spatial_flip = (
#     "",
#     "5_x_5",
#     False,
#     "statistical_parity",
#     lar_exp_dir,
#     lar_dataset_name,
#     False, 
#     True,
#     True
# )

### Semi Synthetic Experiment

In [6]:
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, apply_fit_flips = (
#     "semi_synthetic_regions_non_overlap_k_8",
#     "non_overlap_k_8",
#     False,
#     "statistical_parity",
#     semi_synth_dir,
#     True,
# )
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, apply_fit_flips = (
#     "semi_synthetic_regions_5_x_5",
#     "5_x_5",
#     False,
#     "statistical_parity",
#     semi_synth_dir,
#     True
# )
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name, apply_fit_flips = (
#     "semi_synthetic_regions_overlap_k_10_radii_4",
#     "overlap_k_10_radii_4",
#     True,
#     "statistical_parity",
#     semi_synth_dir,
#     True,
# )

### XGB Experiment (Equal Opportunity)

In [7]:
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "xgb",
#     "overlap_k_10_radii_4",
#     True,
#     "equal_opportunity",
#     xgb_eq_opp_dir,
# )
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "xgb",
#     "non_overlap_k_8",
#     False,
#     "equal_opportunity",
#     xgb_eq_opp_dir,
# )
# clf_name, partioning_type_name, overlap, fairness_notion, dir_name = (
#     "xgb",
#     "5_x_5",
#     True,
#     "equal_opportunity",
#     xgb_eq_opp_dir,
# )

In [8]:
results_base_path = os.path.join(results_base_path, dir_name)

## Set Display Settings

In [9]:
method_to_display_name = {
    "iter": "Spatial Flip",
    "promis_app": "PROMIS App",
    "promis_opt": "PROMIS Opt",
    "promis_opt_wlimit_300": "PROMIS Opt",
    "promis_opt_wlimit_1800": "PROMIS Opt (wlimit=1800)",
    "promis_opt_wlimit_3600": "PROMIS Opt (wlimit=3600)",
    "init": "Initial World",
}

colors_list = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "black",
    "darkred",
    "darkgreen",
    "darkblue",
    "darkmagenta",
    "darkcyan",
    "darkorange",
    "darkviolet",
    "darkturquoise",
    "darkslategray",
    "darkgoldenrod",
    "darkolivegreen",
    "darkseagreen",
    "darkslateblue",
    "darkkhaki",
]


method_to_plot_info = {
    "promis_app": {
        "linewidth": 6,
        "color": "darkgreen",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "promis_opt": {"linewidth": 6, "color": "black", "linestyle": "-", "scatter_marker": "o"},
    "promis_opt_wlimit_300": {
        "linewidth": 6,
        "color": "black",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "promis_opt_wlimit_3600": {
        "linewidth": 6,
        "color": "purple",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "promis_opt_wlimit_1800": {
        "linewidth": 6,
        "color": "purple",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "iter": {
        "linewidth": 6,
        "color": "saddlebrown",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "where": {
        "linewidth": 6,
        "color": "blue",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    },
    "init": {
        "linewidth": 6,
        "color": "darkorange",
        "linestyle": "-",
        "scatter_marker": "o",
        "marker_size": 100,
    }
}

## Read Trained Models 

In [None]:
res_desc_label, partioning_name, prediction_name = combine_world_info(
    dataset_name, partioning_type_name, clf_name
)
train_path_info, val_path_info, test_path_info = get_train_val_test_paths(
    base_path, partioning_name, prediction_name, dataset_name
)

(
    val_regions_df,
    val_pred_df,
    val_labels_df,
    y_pred_val,
    y_pred_probs_val,
    y_true_val,
    val_points_per_region,
    pos_y_true_indices_val,
    pos_points_per_region_val,
) = (
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
)
if dataset_name == "lar":
    test_regions_df = read_scanned_regs(train_path_info["regions"])
    test_pred_df = pd.read_csv(f"{base_path}preprocess/lar.csv")
    y_pred_test = get_y(test_pred_df, "label")
    y_true_test = None
    y_pred_probs_test=None

    test_points_per_region = test_regions_df["points"].tolist()
else:
    test_regions_df = read_scanned_regs(test_path_info["regions"])
    test_pred_df = pd.read_csv(test_path_info["predictions"])
    test_labels_df = pd.read_csv(test_path_info["labels"])
    y_pred_test = get_y(test_pred_df, "pred")
    y_pred_probs_test = get_y(test_pred_df, "prob") if not clf_name.startswith("semi_synthetic") else None
    y_true_test = get_y(test_labels_df, "label")

    test_points_per_region = test_regions_df["points"].tolist()


if dataset_name != "lar":
    pos_y_true_indices_test, pos_points_per_region_test = get_pos_info_regions(
        y_true_test, test_points_per_region
    )
else:
    pos_y_true_indices_test, pos_points_per_region_test = None, None

results_path = f"{results_base_path}{res_desc_label}/"

if save_plots_base_path:
    save_plots_path = os.path.join(save_plots_base_path, dir_name, res_desc_label, f"{fairness_notion}/")
    print(f"Save plots path: {save_plots_path}")
    os.makedirs(save_plots_path, exist_ok=True)

sp_flip_meths_2_pretrained_models = {}
if with_spatial_flip:
    sp_flip_meths_2_pretrained_models = read_all_models(
        f"{results_path}spatial_flip_models/{fairness_notion}/",
        False,
        methods=only_methods,
    )

sp_opt_meths_2_pretrained_models = read_all_models(
    f"{results_path}spatial_optim_models/{fairness_notion}/", True, methods=only_methods
)
all_meths_2_pretrained_models = {
    **sp_flip_meths_2_pretrained_models,
    **sp_opt_meths_2_pretrained_models,
}

splitted_labels = get_all_methods_modes_labels(
    list(all_meths_2_pretrained_models.keys())
)

opt_methods_display_labels = splitted_labels["opt_labels"]
for label in all_meths_2_pretrained_models.keys():
    if label not in method_to_display_name:
        method_to_display_name[label] = label

## Perform Predictions - Compute Results Information

In [None]:
all_methods_to_results_info, budget_range = compute_all_results_info(
    all_meths_2_pretrained_models=all_meths_2_pretrained_models,
    test_points_per_region=test_points_per_region,
    y_pred_test_probs=y_pred_probs_test,
    y_true_test=y_true_test,
    y_pred_test_orig=y_pred_test,
    apply_fit_flips=apply_fit_flips,
)

sp_flip_methods_2_results_info = {
    k: v
    for k, v in all_methods_to_results_info.items()
    if k in splitted_labels["heu_labels"]
}
sp_opt_methods_2_results_info = {
    k: v
    for k, v in all_methods_to_results_info.items()
    if k in splitted_labels["opt_labels"]
}

## Compute Disparity, Metrics on FairWhere Predictions 

In [12]:
(
    P_where_test,
    N_where_test,
    RHO_where_test,
    TP_where_test,
    TPR_where_test,
    mlr_where_test_st_par,
    mlr_where_test_eq_opp,
    acc_where_test,
    f1_where_test,
    where_fit_time,
    where_fair_score_test,
    init_fair_score_test,
    where_fairness_loss_sum_test,
    init_fairness_loss_sum_test,
    where_fairness_loss_weighted_sum_test,
    init_fairness_loss_sum_weighted_test,
) = (
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
    None,
)


def get_pr(y_pred):
    if len(y_pred) == 0:
        return 0
    return np.sum(y_pred) / len(y_pred)


if clf_name.startswith("dnn"):
    if fairness_notion == "equal_opportunity":
        fair_score_func = metrics.recall_score
    else:
        fair_score_func = get_pr

    with open(f"{results_path}{dataset_name}_{fairness_notion}_where_fit_time.txt", "r") as file:
        where_fit_time = float(file.read())


    where_pred_test_df = pd.read_csv(
        f"{results_path}{dataset_name}_{fairness_notion}_where_model_test_pred.csv"
    )
    test_partitioning_id_df = pd.read_csv(
        f"{base_path}partitionings/test_{partioning_name}_partitioning_ids.csv"
    )

    y_pred_where_test = get_y(where_pred_test_df, "pred")

    test_partitioning_id_df["id"] = test_partitioning_id_df["id"].apply(
        ast.literal_eval
    )
    test_partitioning_id_df["partitioning"] = test_partitioning_id_df[
        "partitioning"
    ].apply(ast.literal_eval)

    test_ids = test_partitioning_id_df["id"].tolist()
    test_partitionings = test_partitioning_id_df["partitioning"].tolist()

    P_where_test = np.sum(y_pred_where_test)
    N_where_test = len(y_pred_where_test)
    RHO_where_test = P_where_test / N_where_test
    TP_where_test = np.sum(y_pred_where_test[pos_y_true_indices_test])
    TPR_where_test = TP_where_test / len(pos_y_true_indices_test)

    # MLR
    mlr_where_test_st_par = get_mlr(y_pred_where_test, test_points_per_region)
    mlr_where_test_eq_opp = get_mlr(
        y_pred_where_test[pos_y_true_indices_test], pos_points_per_region_test
    )

    # Accuracy
    acc_where_test = metrics.accuracy_score(y_true_test, y_pred_where_test)
    f1_where_test = metrics.f1_score(y_true_test, y_pred_where_test)

    where_fairness_loss_list_test = get_partionings_fairness_loss_all(
        y_pred_where_test,
        test_ids,
        test_partitionings,
        y_true_test,
        weighted=False,
        score_func=fair_score_func,
    )

    where_fairness_loss_sum_test = np.sum(where_fairness_loss_list_test)

    init_fairness_loss_list_test = get_partionings_fairness_loss_all(
        y_pred_test,
        test_ids,
        test_partitionings,
        y_true_test,
        weighted=False,
        score_func=fair_score_func,
    )

    init_fairness_loss_sum_test = np.sum(init_fairness_loss_list_test)

    # fair loss score weighed per partitioning
    where_fairness_loss_list_weighted_test = get_partionings_fairness_loss_all(
        y_pred_where_test,
        test_ids,
        test_partitionings,
        y_true_test,
        weighted=True,
        score_func=fair_score_func,
    )

    where_fairness_loss_weighted_sum_test = np.sum(
        where_fairness_loss_list_weighted_test
    )

    init_fairness_loss_list_weighted_test = get_partionings_fairness_loss_all(
        y_pred_test,
        test_ids,
        test_partitionings,
        y_true_test,
        weighted=True,
        score_func=fair_score_func,
    )

    init_fairness_loss_sum_weighted_test = np.sum(init_fairness_loss_list_weighted_test)

    for method, res_df in all_methods_to_results_info.items():
        res_df["fair_loss_list_test"] = res_df["y_pred_test"].apply(
            lambda x: get_partionings_fairness_loss_all(
                x,
                test_ids,
                test_partitionings,
                y_true_test,
                weighted=False,
                score_func=fair_score_func,
            )
        )

        res_df["fair_loss_sum_test"] = res_df["fair_loss_list_test"].apply(
            lambda x: np.sum(x)
        )

        res_df["fair_loss_list_weighted_test"] = res_df["y_pred_test"].apply(
            lambda x: get_partionings_fairness_loss_all(
                x,
                test_ids,
                test_partitionings,
                y_true_test,
                weighted=True,
                score_func=fair_score_func,
            )
        )

        res_df["fair_loss_sum_weighted_test"] = res_df[
            "fair_loss_list_weighted_test"
        ].apply(lambda x: np.sum(x))

In [13]:
if clf_name.startswith("dnn"):
    fig, axes = plt.subplots(2, 1, figsize=(16, 6))  

    test_ids_str = [str(i) for i in test_ids]

    # Non-weighted Fairness Loss

    axes[0].plot(test_ids_str, init_fairness_loss_list_test, label="base")
    axes[0].scatter(test_ids_str, init_fairness_loss_list_test)
    axes[0].plot(
        test_ids_str, where_fairness_loss_list_test, label="where", linestyle="dashed"
    )
    axes[0].scatter(test_ids_str, where_fairness_loss_list_test)
    axes[0].set_xlabel("Partitioning Id")
    axes[0].set_ylabel("Fairness Loss")
    if display_title:
        axes[0].set_title("Fairness Loss per Partitioning (Test)")
    axes[0].legend()

    # Weighted Fairness Loss
    axes[1].plot(test_ids_str, init_fairness_loss_list_weighted_test, label="base")
    axes[1].scatter(test_ids_str, init_fairness_loss_list_weighted_test)
    axes[1].plot(
        test_ids_str,
        where_fairness_loss_list_weighted_test,
        label="where",
        linestyle="dashed",
    )
    axes[1].scatter(test_ids_str, where_fairness_loss_list_weighted_test)
    axes[1].set_xlabel("Partitioning Id")
    axes[1].set_ylabel("Weighted Fairness Loss")
    if display_title:
        axes[1].set_title("Weighted Fairness Loss per Partitioning (Test)")
    axes[1].legend()

    plt.tight_layout()  

    if save_plots_base_path:
        plt.savefig(
            f"{save_plots_path}fairness_loss_per_partitioning.pdf", format="pdf"
        )
    plt.show()

## Experiments Statistics

In [None]:
fair_labels = []
fair_base_labels = ["mlr", "fair_mlr_ratio"]
if fairness_notion == "equal_opportunity":
    for fair_base_label in fair_base_labels:
        fair_labels.append(f"{fair_base_label}_st_par")

if fairness_notion == "equal_opportunity":
    for fair_base_label in fair_base_labels:
        fair_labels.append(f"{fair_base_label}_eq_opp")

P_test = np.sum(y_pred_test)
N_test = len(y_pred_test)
RHO_test = P_test / N_test


print()
print(f"N_test: {N_test}")
print(f"P_test: {P_test}")
print(f"RHO_test: {RHO_test:.3f}")

if dataset_name != "lar":
    TP_test = np.sum(y_pred_test[pos_y_true_indices_test])
    TPR_test = TP_test / len(pos_y_true_indices_test)
    print(f"TPR_test: {TPR_test:.3f}")

In [15]:
signif_level = 0.005
n_alt_worlds = 200


def to_region_dict(pts_per_region):
    return [{"points": reg} for reg in pts_per_region]


init_test_mlr_st_par, init_test_stats_st_par = get_mlr(
    y_pred_test, test_points_per_region, with_stats=True
)

init_test_mlr_eq_opp, init_test_stats_eq_opp = get_mlr(
    y_pred_test[pos_y_true_indices_test], pos_points_per_region_test, with_stats=True
) if dataset_name != "lar" else (None, None)

init_acc_test = metrics.accuracy_score(y_true_test, y_pred_test) if y_true_test is not None else None
init_f1_test = metrics.f1_score(y_true_test, y_pred_test) if y_true_test is not None else None

## Plot Budgets where Optimization Reach Limit (i.e. PROMIS Opt)

In [16]:
for method in splitted_labels["opt_labels"]:
    if "status" in all_methods_to_results_info[method].columns:
        unique_status = list(all_methods_to_results_info[method]["status"].unique())
        for status in unique_status:
            if status not in [1, 3]:
                other_status_exp_idxs = all_methods_to_results_info[method][
                    all_methods_to_results_info[method]["status"] == status
                ]["exp_idx"].unique()
                print(
                    f"Found status {status} for method {method} for exp indexes: {other_status_exp_idxs}"
                )

In [None]:
method_tlimit_cnt = {
    method: 0
    for method in splitted_labels["opt_labels"]
    if "status" in all_methods_to_results_info[method].columns
}

labels = []
status_lists = []
budget_lists = []
for method in splitted_labels["opt_labels"]:
    if "status" in all_methods_to_results_info[method].columns:
        labels.append(method)
        status_list = all_methods_to_results_info[method]["status"].to_list()
        budget_list = all_methods_to_results_info[method]["budget"].to_list()
        tlimit_cnt = len(np.where(np.array(status_list) == 3)[0])
        method_tlimit_cnt[method] += tlimit_cnt

        status_lists.append(status_list)
        budget_lists.append(budget_list)

plot_opt_methods_status(
    labels=labels,
    budget_lists=budget_lists,
    status_lists=status_lists,
    save_path=save_plots_path,
    figsize=figsize,
)

In [None]:
meths_min_C_reach_limit = {}
for method in splitted_labels["opt_labels"]:
    res_df = all_methods_to_results_info[method]
    if 3 in res_df.status.tolist():
        meths_min_C_reach_limit[method] = res_df[res_df["status"] == 3]["budget"].min()


plot_min_C_reach_limit(
    meths_min_C_reach_limit,
    method_to_plot_info,
    method_to_display_name,
    opt_methods_display_labels,
    figsize=figsize,
    save_path=save_plots_path,
    display_title=display_title,

)

## Plot MLR

In [None]:
flips_limit = None
if fairness_notion == "statistical_parity":
    plot_scores(
        all_methods_to_results_info,
        init_test_mlr_st_par,
        method_to_plot_info,
        method_to_display_name,
        opt_methods_display_labels,
        save_plots_path,
        figsize=figsize,
        flips_limit=flips_limit,
        append_to_title=" (Statistical Parity - Test Set)",
        append_to_save="_st_par_test",
        score_label="mlr_st_par_test",
        display_title=display_title,
        axhline_mlr=mlr_where_test_st_par,
        axhline_mlr_label="Where MLR",
    )

if fairness_notion == "equal_opportunity":
    plot_scores(
        all_methods_to_results_info,
        init_test_mlr_eq_opp,
        method_to_plot_info,
        method_to_display_name,
        opt_methods_display_labels,
        save_plots_path,
        figsize=figsize,
        flips_limit=flips_limit,
        append_to_title=" (Equal Opportunity - Test Set)",
        append_to_save="_eq_opp_test",
        score_label="mlr_eq_opp_test",
        display_title=display_title,
        axhline_mlr=mlr_where_test_eq_opp,
        axhline_mlr_label="Where",
    )

<!-- params[['mlr', 'pos_mlr', 'test_mlr', 'pos_test_mlr']]
all_methods_to_results_info['cont_in_ov_over_eq_opp'][['sol_mlr', 'pos_mlr', 'new_val_mlr', 'val_new_pos_mlr', 'new_test_mlr', 'test_new_pos_mlr', 'budget']] -->

## Plot Times

In [None]:
if fairness_notion == "statistical_parity":
    plot_flips_time(
        all_methods_to_results_info,
        method_to_display_name,
        method_to_plot_info,
        opt_methods_display_labels,
        title_append=" (Statistical Parity)",
        save_append="_st_par",
        save_plots_path=save_plots_path,
        log_time=False,
        figsize=figsize,
        display_title=display_title,
        axhline_time=where_fit_time,
        axhline_time_label="Where",
    )
    plot_flips_time(
        all_methods_to_results_info,
        method_to_display_name,
        method_to_plot_info,
        opt_methods_display_labels,
        title_append=" (Statistical Parity)",
        save_append="_st_par",
        save_plots_path=save_plots_path,
        log_time=True,
        figsize=figsize,
        display_title=display_title,
        axhline_time=where_fit_time,
        axhline_time_label="Where",
    )

if fairness_notion == "equal_opportunity":
    plot_flips_time(
        all_methods_to_results_info,
        method_to_display_name,
        method_to_plot_info,
        opt_methods_display_labels,
        title_append="(Equal Opportunity)",
        save_append="_eq_opp",
        save_plots_path=save_plots_path,
        log_time=False,
        figsize=figsize,
        display_title=display_title,
        axhline_time=where_fit_time,
        axhline_time_label="Where",
    )
    plot_flips_time(
        all_methods_to_results_info,
        method_to_display_name,
        method_to_plot_info,
        opt_methods_display_labels,
        title_append="(Equal Opportunity)",
        save_append="_eq_opp",
        save_plots_path=save_plots_path,
        log_time=True,
        figsize=figsize,
        display_title=display_title,
        axhline_time=where_fit_time,
        axhline_time_label="Where",
    )

## Computes Maximum Budget Info

In [33]:
methods = ["init"]
final_mlrs_st_par_test = [init_test_mlr_st_par]
final_mlrs_eq_opp_test = [init_test_mlr_eq_opp]
final_times = [None]
n_flips_ = budget_range[-1]
budget_list = [0] + [n_flips_] * len(all_methods_to_results_info)
final_stats_st_par_test_list = [init_test_stats_st_par]
final_stats_eq_opp_test_list = [init_test_stats_eq_opp]
final_performance_label = "f1" if clf_name.startswith("dnn") else "accuracy"
final_performance_test_list = [init_f1_test] if clf_name.startswith("dnn") else [init_acc_test]
final_fair_score_test_list = [init_fairness_loss_sum_test]
for method, exp_res_df in all_methods_to_results_info.items():
    if fairness_notion == "statistical_parity":
        mlr_st_par_test = exp_res_df[exp_res_df["budget"] == n_flips_][
            "mlr_st_par_test"
        ].tolist()[0]
        final_mlrs_st_par_test.append(mlr_st_par_test)
        y_test_pred = exp_res_df[exp_res_df["budget"] == n_flips_][
            "y_pred_test"
        ].tolist()[0]
        _, final_stats_st_par_test = get_mlr(
            y_test_pred, test_points_per_region, with_stats=True
        )
        final_stats_st_par_test_list.append(final_stats_st_par_test)
        if clf_name.startswith("dnn"):
            final_fair_score_test_list.append(
                exp_res_df[exp_res_df["budget"] == n_flips_]["fair_loss_sum_test"].tolist()[0]
            )
    else:
        mlr_eq_opp_test = exp_res_df[exp_res_df["budget"] == n_flips_][
            "mlr_eq_opp_test"
        ].tolist()[0]
        y_test_pred = exp_res_df[exp_res_df["budget"] == n_flips_][
            "y_pred_test"
        ].tolist()[0]
        y_test_pred_pos = y_test_pred[pos_y_true_indices_test]
        _, final_stats_eq_opp_test = get_mlr(
            y_test_pred_pos, pos_points_per_region_test, with_stats=True
        )
        final_stats_eq_opp_test_list.append(final_stats_eq_opp_test)

        if clf_name.startswith("dnn"):
            final_fair_score_test_list.append(
                exp_res_df[exp_res_df["budget"] == n_flips_]["fair_loss_sum_test"].tolist()[0]
            )

    final_flip_time = exp_res_df[exp_res_df["budget"] == n_flips_]["time"].tolist()[0]
    final_times.append(final_flip_time)
    methods.append(method)
    if y_true_test is not None:
        final_performance_test_list.append(
            exp_res_df[exp_res_df["budget"] == n_flips_][f"{final_performance_label}_test"].tolist()[0]
        )
    else:
        final_performance_test_list.append(None)
        
    if fairness_notion == "equal_opportunity":
        final_mlrs_eq_opp_test.append(mlr_eq_opp_test)

final_results = {
    "budget": budget_list,
    "Method": methods,
    "time": final_times,
    f"{final_performance_label}_test": final_performance_test_list,
}
if fairness_notion == "statistical_parity":
    final_results["mlr_st_par_test"] = final_mlrs_st_par_test
    final_results["final_stats_st_par_test"] = final_stats_st_par_test_list
else:
    final_results["mlr_eq_opp_test"] = final_mlrs_eq_opp_test
    final_results["final_stats_eq_opp_test"] = final_stats_eq_opp_test_list

if clf_name.startswith("dnn"):
    final_results["Method"].append("Where")
    final_results["budget"].append(n_flips_)
    final_results["time"].append(where_fit_time)
    final_results["f1_test"].append(f1_where_test)
    if fairness_notion == "statistical_parity":
        final_results["mlr_st_par_test"].append(mlr_where_test_st_par)
        _, final_stats_st_par_test = get_mlr(
            y_pred_where_test, test_points_per_region, with_stats=True
        )
        final_results["final_stats_st_par_test"].append(final_stats_st_par_test)
        final_fair_score_test_list.append(where_fairness_loss_sum_test)
        final_results['Fair Loss Sum (Test)'] = final_fair_score_test_list

    else:
        final_results["mlr_eq_opp_test"].append(mlr_where_test_eq_opp)
        _, final_stats_eq_opp_test = get_mlr(
            y_pred_where_test[pos_y_true_indices_test],
            pos_points_per_region_test,
            with_stats=True,
        )
        final_results["final_stats_eq_opp_test"].append(final_stats_eq_opp_test)
        final_fair_score_test_list.append(where_fairness_loss_sum_test)
        final_results['Fair Loss Sum (Test)'] = final_fair_score_test_list


final_results_df = pd.DataFrame(final_results)
if save_plots_base_path:
    final_results_df.to_csv(f"{save_plots_path}final_res.csv", index=False)

## Plot MLR/Mean Disparity vs Accuracy/F1 score

In [34]:
init_scores = {
    "mlr_st_par": {
        "test": init_test_mlr_st_par,
    },
    "mlr_eq_opp": {
        "test": init_test_mlr_eq_opp,
    },
    "accuracy": {
        "test": init_acc_test,
    },
    "f1": {
        "test": init_f1_test,
    },
    "fair_loss_sum":
    {
        "test": init_fairness_loss_sum_test,
    }
}

where_scores = {
    "mlr_st_par": {
        "test": mlr_where_test_st_par,
    },
    "mlr_eq_opp": {
        "test": mlr_where_test_eq_opp,
    },
    "accuracy": {
        "test": acc_where_test,
    },
    "f1": {
        "test": f1_where_test,
    },
    "fair_loss_sum":
    {
        "test": where_fairness_loss_sum_test,
    }
}
fair_scores_display_labels = {
    "mlr_st_par": "MLR",
    "mlr_eq_opp": "MLR",
    "fair_loss_sum": "Mean Disparity"
}
performance_scores_display_labels = {
    "accuracy": "Accuracy",
    "f1": "F1 Score",
}
sets_display_labels = {
    "sol": "Solution",
    "val": "Validation Set",
    "test": "Test Set"
}

In [None]:
score1_vs_score2_figsize = (14, 8) if display_title else figsize

if y_true_test is not None:
    sets = ["test"]
    if fairness_notion == "statistical_parity":
        fair_scores = ["mlr_st_par"]
    else:
        fair_scores = ["mlr_eq_opp"]

    if clf_name.startswith("dnn"):
        fair_scores.append("fair_loss_sum")
        performance_scores = ["f1"]
    else:
        performance_scores = ["accuracy"]

    for set_ in sets:
        for fair_score in fair_scores:
            for performance_score in performance_scores:
                plot_score1_vs_score2(
                    methods_to_res_info=all_methods_to_results_info,
                    score_label1=f"{fair_score}_{set_}",
                    score_label2=f"{performance_score}_{set_}",
                    score_display_label1=fair_scores_display_labels[fair_score],
                    score_display_label2=performance_scores_display_labels[performance_score],
                    init_score1=init_scores[fair_score][set_],
                    init_score2=init_scores[performance_score][set_],
                    method_to_plot_info=method_to_plot_info,
                    method_to_display_name=method_to_display_name,
                    opt_methods_display_labels=opt_methods_display_labels,
                    save_plots_path=save_plots_path,
                    figsize=score1_vs_score2_figsize,
                    append_to_title=f" ({sets_display_labels[set_]})",
                    append_to_save=f"_{fair_score}_{set_}",
                    display_title=display_title,
                    other_score1=where_scores[fair_score][set_],
                    other_score2=where_scores[performance_score][set_],
                    other_method_label="Where",
                )



## Plot P, RHO, Actual Flips 

In [None]:
plot_compare_methods_info(
    all_methods_to_results_info,
    P_test,
    RHO_test,
    p_label="P_test",
    rho_label="RHO_test",
    actual_flips_label="actual_flips_test",
    method_to_plot_info=method_to_plot_info,
    method_to_display_name=method_to_display_name,
    opt_methods_display_labels=opt_methods_display_labels,
    save_path=save_plots_path,
    figsize=figsize,
    append_to_title=f" ({fairness_notion} - Test Set)",
    display_title=display_title,
    axhline_P=P_where_test,
    axhline_RHO=RHO_where_test,
    axhline_label="Where",
)

## Plot Normalized Statistics (LR)

In [None]:
display(final_results_df)

In [None]:
xlabel = "Regions"
ylabel = "Normalized LR"

methods_labels = final_results_df["Method"].unique().tolist()
if fairness_notion == "statistical_parity":
    stats_per_method = (
        final_results_df.groupby("Method", sort=False)["final_stats_st_par_test"]
        .apply(list)
        .tolist()
    )
    stats_per_method = [np.concatenate(stats).tolist() for stats in stats_per_method]
    max_init_stat = max(init_test_stats_st_par)
else:
    stats_per_method = (
        final_results_df.groupby("Method", sort=False)["final_stats_eq_opp_test"]
        .apply(list)
        .tolist()
    )
    stats_per_method = [np.concatenate(stats).tolist() for stats in stats_per_method]
    max_init_stat = max(init_test_stats_eq_opp)

# Plotting the combined barplot
plot_regions_norm_stats(
    methods_stats=stats_per_method,
    methods_labels=methods_labels,
    xlabel=xlabel,
    ylabel=ylabel,
    max_stat=max_init_stat,
    save_path=save_plots_path,
    append_to_title="(Test Set)",
    display_title=display_title,
    method_to_display_name=method_to_display_name,
    method_to_plot_info=method_to_plot_info,
    figsize=figsize,
)