# FAIM

Functions from https://github.com/nliulab/FAIM

## functions and utils

In [None]:
def rgb01_hex(col):
    col_hex = [round(i * 255) for i in col]
    col_hex = "#%02x%02x%02x" % tuple(col_hex)
    return col_hex


def compute_area(fairness_metrics):
    n_metric = len(fairness_metrics)
    tmp = fairness_metrics.values.flatten().tolist()
    tmp_1 = tmp[1:] + tmp[:1]

    if n_metric > 2:
        theta_c = 2 * np.pi / n_metric
        area = np.sum(np.array(tmp) * np.array(tmp_1) * np.sin(theta_c))
    elif n_metric == 2:
        area = np.sum(np.array(tmp) * np.array(tmp_1))
    else:
        area = np.abs(tmp[0])

    return area


def plot_perf_metric(
    perf_metric, eligible, x_range, select=None, plot_selected=False, x_breaks=None
):
    """ Plot performance metrics of sampled models

        Parameters
        ----------
        perf_metric : numpy.array or pandas.Series
            Numeric vector of performance metrics for all sampled models
        eligible : numpy.array or pandas.Series
            Boolean vector of the same length of 'perf_metric', indicating \
                whether each sample is eligible.
        x_range : list
            Numeric vector indicating the range of eligible values for \
                performance metrics.
            Will be indicated by dotted vertical lines in plots.
        select : list or numpy.array, optional (default: None)
            Numeric vector of indexes of 'perf_metric' to be selected
        plot_selected : bool, optional (default: False)
            Whether performance metrics of selected models should be plotted in \
                a secondary figure.
        x_breaks : list, optional (default: None)
            If selected models are to be plotted, the breaks to use in the \
                histogram

        Returns
        -------
        plot : plotnine.ggplot
            Histogram(s) of model performance made using ggplot
    """
    m = len(perf_metric)
    perf_df = pd.DataFrame(perf_metric, columns=["perf_metric"], index=None)
    plot = (
        pn.ggplot(perf_df, pn.aes(x="perf_metric"))
        + pn.geoms.geom_histogram(
            breaks=np.linspace(np.min(perf_metric), np.max(perf_metric), 40)
        )
        + pn.geoms.geom_vline(xintercept=x_range, linetype="dashed", size=0.7)
        + pn.labels.labs(
            x="Ratio of loss to minimum loss",
            title="""Loss of {m:d} sampled models
                \n{n_elg:d} ({per_elg:.1f}%) sampled models are eligible""".format(
                m=m, n_elg=np.sum(eligible), per_elg=np.sum(eligible) * 100 / m
            ),
        )
        + pn.themes.theme_bw()
        + pn.themes.theme(
            title=pn.themes.element_text(ha="left"),
            axis_title_x=pn.themes.element_text(ha="center"),
            axis_title_y=pn.themes.element_text(ha="center"),
        )
    )
    if plot_selected:
        if select is None:
            print("'select' vector is not specified!\nUsing all models instead")
            select = [i for i in range(len(perf_df))]
        try:
            perf_select = perf_df.iloc[select]
        except:
            print(
                "Invalid indexes detected in 'select' vector!\nUsing all models instead"
            )
            select = [i for i in range(len(perf_df))]
            perf_select = perf_df.iloc[select]
        plot2 = (
            pn.ggplot(perf_select, pn.aes(x="perf_metric"))
            + pn.geoms.geom_histogram(breaks=x_breaks)
            + pn.labels.labs(
                x="Ratio of loss to minimum loss",
                title="{n_select:d} selected models".format(n_select=len(select)),
            )
            + pn.themes.theme_bw()
            + pn.themes.theme(
                title=pn.themes.element_text(ha="left"),
                axis_title_x=pn.themes.element_text(ha="center"),
                axis_title_y=pn.themes.element_text(ha="center"),
            )
        )
        return (plot, plot2)
    else:
        return plot


def plot_distribution(df, s=4):
    num_metrics = df.shape[1] - 2
    labels = df.sen_var_exclusion.unique()
    for i in range(len(labels)):
        if labels[i] == "":
            labels[i] = "No exclusion"
        elif len(labels[i].split("_")) == 2:
            labels[i] = f"Exclusion of {' and '.join(labels[i].split('_'))}"
        elif len(labels[i].split("_")) > 2:
            sens = labels[i].split("_")
            labels[i] = f"Exclusion of {', '.join(sens[:-1])} and {sens[-1]}"
        else:
            labels[i] = f"Exclusion of {labels[i]}"

    fig, axes = plt.subplots(nrows=1, ncols=num_metrics, figsize=(s * num_metrics, s))
    for i, x in enumerate(df.columns[:-2]):
        ax = axes[i]
        # sns.jointplot(data=df, x=x, y="auc", hue="sen_var_exclusion",  ax=ax, legend=False)
        sns.histplot(
            data=df, x=x, hue="sen_var_exclusion", bins=50, ax=ax, legend=False
        )  # layout=(1, num_metrics), figsize=(4, 4), color="#595959",
        ax.set_title(x)
        ax.set_xlabel("")
        ax.set_ylabel("Count" if i == 0 else "")

    plt.legend(
        loc="center left",
        title="",
        labels=labels[::-1],
        ncol=1,
        bbox_to_anchor=(1.04, 0.5),
        borderaxespad=0,
    )
    # plt.tight_layout() bbox_transform=fig.transFigure,
    plt.show()

    return fig


def plot_scatter(df, perf, sen_var_exclusion, title, c1=20, c2=0.15, **kwargs):
    ### basic settings ###
    np.random.seed(0)
    if "figsize" not in kwargs.keys():
        fig_h = 400
        figsize = [fig_h * df.shape[1] * 2.45 / 3, fig_h]
    else:
        figsize = kwargs["figsize"]
    caption_size = figsize[1] / c1  # control font size / figure size
    fig_caption_ratio = 0.8
    fig_font_size = caption_size * fig_caption_ratio

    font_family = "Arial"
    highlight_color = "#D4AF37"
    fig_font_unit = c2  # control the relative position of elements
    caption_font_unit = fig_font_unit * fig_caption_ratio
    d = fig_font_unit / 8
    legend_pos_y = 1 + fig_font_unit
    subtitle_pos = [legend_pos_y + d, legend_pos_y + d + caption_font_unit]
    xlab_pos_y = -fig_font_unit * 2

    area_list = []
    for i, id in enumerate(df.index):
        values = df.loc[id, :]
        area_list.append(1 / compute_area(values))
    ranking = np.argsort(np.argsort(area_list)[::-1])

    # jittering for display
    jitter_control = np.zeros(len(ranking))
    for idx in range(len(ranking)):
        if ranking[idx] == 0:
            jitter_control[idx] = 0
        elif ranking[idx] <= 10 and ranking[idx] != 0:
            jitter_control[idx] = 0.01 * np.random.uniform(0, 1)
        elif ranking[idx] <= 10**2 and ranking[idx] > 10:
            jitter_control[idx] = 0.015 * np.random.uniform(0, 1)
        elif ranking[idx] <= 10**3 and ranking[idx] > 10**2:
            jitter_control[idx] = 0.015 * np.random.uniform(0, 1)
        else:
            jitter_control[idx] = 0.02 * np.random.uniform(-1, 1)

    ### plot ###
    best_id = df.index[np.where(ranking == 0)][0]
    worst_id = df.index[np.argmin(area_list)]
    meduim_id = df.index[np.argsort(area_list)[int(len(area_list) / 2)]]

    num_metrics = df.shape[1]
    num_models = df.shape[0]

    fig = make_subplots(cols=num_metrics, rows=1, horizontal_spacing=0.13)
    cmap = sns.light_palette("steelblue", as_cmap=False, n_colors=df.shape[0])
    cmap = cmap[::-1]
    colors = [rgb01_hex(cmap[x]) if x != 0 else highlight_color for x in ranking]
    sizes = [10 if x != 0 else 20 for x in ranking]

    shapes = sen_var_exclusion.copy().tolist()
    cases = sen_var_exclusion.unique()
    shapes_candidates = ["square", "circle", "triangle-up", "star"][: len(cases)]
    for i, case in enumerate(cases):
        for j, v in enumerate(sen_var_exclusion):
            if v == case:
                shapes[j] = shapes_candidates[i]

        if cases[i] == "":
            cases[i] = "No exclusion"
        elif len(cases[i].split("_")) == 2:
            cases[i] = f"Exclusion of {' and '.join(cases[i].split('_'))}"
        elif len(cases[i].split("_")) > 2:
            sens = cases[i].split("_")
            cases[i] = f"Exclusion of {', '.join(sens[:-1])} and {sens[-1]}"
        else:
            cases[i] = f"Exclusion of {cases[i]}"

    fair_index_df = pd.DataFrame(
        {
            "model id": df.index,
            "fair_index": area_list,
            "ranking": ranking,
            "eod": df["Equalized Odds"],
            "colors": colors,
            "shapes": shapes,
            "sizes": sizes,
            "cases": sen_var_exclusion,
            "jitter": jitter_control,
        }
    )

    # Add scatter plots to the subplots
    for k, s in enumerate(shapes_candidates):
        for i in range(num_metrics):
            # index of sen_var_exclusion(shape) == s
            s_idx = [idx for idx, x in enumerate(shapes) if x == s]
            x = df.iloc[s_idx, i].values
            js = fair_index_df.loc[fair_index_df.shapes == s, "jitter"].values
            jittered_x = x + js

            col = fair_index_df.loc[fair_index_df.shapes == s, "colors"]
            size = fair_index_df.loc[fair_index_df.shapes == s, "sizes"]
            fair_index = fair_index_df.loc[fair_index_df.shapes == s, "fair_index"]
            ids = fair_index_df.loc[fair_index_df.shapes == s, "model id"]
            rank_text = fair_index_df.loc[fair_index_df.shapes == s, "ranking"]
            r = (
                fair_index_df.loc[fair_index_df.shapes == s, "ranking"]
                .apply(lambda x: math.log10(x + 1))
                .values
            )
            sen_case = fair_index_df.loc[fair_index_df.shapes == s, "cases"]

            hovertext = [
                f"Fairness index: {f:.3f}. Ranking: {x}. Model id: {i}"
                for f, x, i in zip(fair_index, rank_text, ids)
            ]
            fig.add_trace(
                go.Scatter(
                    x=r,
                    y=jittered_x,
                    customdata=hovertext,
                    mode="markers",
                    marker=dict(
                        color=col,
                        symbol=s,
                        size=size,
                        line=dict(color=col, width=1),
                        opacity=0.8,
                    ),
                    hovertemplate="%{customdata}.",
                    hoverlabel=None,
                    hoverinfo="name+z",
                    name=cases[k],
                ),
                col=i + 1,
                row=1,
            )

            if i == int((df.shape[1] + 0.5) / 2):
                fig.update_xaxes(
                    title_text=None,
                    tickvals=[0, 1, 2, 3],
                    ticktext=[1, 10, 100, 1000],
                    col=i + 1,
                    row=1,
                    tickangle=0,
                )
            else:
                fig.update_xaxes(
                    title_text=None,
                    tickvals=[0, 1, 2, 3],
                    ticktext=[1, 10, 100, 1000],
                    col=i + 1,
                    row=1,
                    tickangle=0,
                )
            fig.update_yaxes(
                title_text=df.columns[i],
                col=i + 1,
                row=1,
                showticksuffix="none",
                titlefont={"size": caption_size},
            )

            fig.add_vline(
                x=0,
                line_width=2,
                line_dash="dot",
                line_color=highlight_color,
                col=i + 1,
                row=1,
            )

            min_metric = df.loc[ranking == 0, df.columns[i]].values[0]
            max_metric = df.loc[ranking == num_models - 1, df.columns[i]].values[0]
            meduim_metric = df.loc[
                ranking == int(num_models / 2), df.columns[i]
            ].values[0]

            # add annotations
            anno_size = caption_size * 0.7
            if k == 0:
                fig.add_hline(
                    y=min_metric,
                    line_width=2,
                    line_dash="dot",
                    line_color=highlight_color,
                    col=i + 1,
                    row=1,
                )

                # position_y = np.mean(df.iloc[:, i])
                min_annotation = {
                    "x": 0,
                    "y": min_metric,
                    "text": f"Model ID {best_id}<br> Rank No.1",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                }
                fig.add_annotation(min_annotation, col=i + 1, row=1)
            if meduim_id in ids:
                medium_annotation = {
                    "x": math.log10(int(num_models / 2) + 1),
                    "y": meduim_metric + jitter_control[meduim_id],
                    "text": f"Model ID {meduim_id}<br> Rank No.{int(num_models/2)}",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                }
                fig.add_annotation(medium_annotation, col=i + 1, row=1)
            if worst_id in ids:
                max_annotation = {
                    "x": math.log10(num_models + 1),
                    "y": max_metric + jitter_control[worst_id],
                    "text": f"Model ID {worst_id}<br> Rank No.{num_models}",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                    "align": "left",
                }
                fig.add_annotation(max_annotation, col=i + 1, row=1)

    colorbar_trace = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        hoverinfo="none",
        marker=dict(
            colorscale=[
                rgb01_hex(np.array((243, 244, 245)) / 255),
                "steelblue",
            ],  # "magma",
            showscale=True,
            cmin=0,
            cmax=2,
            colorbar=dict(
                title=None,
                thickness=10,
                tickvals=[0, 2],
                ticktext=["Low", "High"],
                outlinewidth=0,
                orientation="v",
                x=1,
                y=0.5,
            ),
        ),
    )
    fig.add_trace(colorbar_trace)

    fig.update_layout(
        title=title,
        font=dict(family="Arial", size=fig_font_size),
        hovermode="closest",
        width=figsize[0],
        height=figsize[1],
        showlegend=True,
        template="simple_white",
        legend=dict(x=0, y=legend_pos_y, orientation="h"),
    )

    rectangle = {
        "type": "rect",
        "x0": -0.1,
        "y0": subtitle_pos[0],
        "x1": 1.1,
        "y1": subtitle_pos[1],
        "xref": "paper",
        "yref": "paper",
        "fillcolor": "steelblue",
        "opacity": 0.1,
    }  # 'line': {'color': 'red', 'width': 2},
    fig.add_shape(rectangle)
    subtitle_annotation = {
        "x": -0.1,
        "y": subtitle_pos[1],
        "text": f"<i> The FAIM model (i.e., fairness-aware model) is with model ID {best_id}, out of {num_models} nearly-optimal models.</i>",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": caption_size * 1.1},
        "align": "left",
    }
    xaxis_annotation = {
        "x": 0.5,
        "y": xlab_pos_y,
        "text": "Model Rank",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": caption_size},
    }
    colorbar_title = {
        "x": 1.05,
        "y": 0.5,
        "text": "Fairness Ranking Index (FRI)",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": anno_size * 0.9},
        "textangle": 90,
    }
    fig.add_annotation(subtitle_annotation)
    fig.add_annotation(xaxis_annotation)
    fig.add_annotation(colorbar_title)

    for i, trace in enumerate(fig.data):
        if i % num_metrics == 1:
            trace.update(showlegend=True)
        else:
            trace.update(showlegend=False)
    # fig.show()

    return fig, fair_index_df


def plot_radar(df, thresh_show, title, **kwargs):
    fig = go.Figure()
    # fig = sp.make_subplots(rows=1, cols=2)
    cmap = sns.diverging_palette(200, 20, sep=10, s=50, as_cmap=False, n=df.shape[0])
    theta = df.columns.tolist()
    theta += theta[:1]
    area_list = []

    for i, id in enumerate(df.index):
        values = df.loc[id, :]
        area_list.append(compute_area(values))
        values = values.values.flatten().tolist()
        values += values[:1]
        info = [
            f"{theta[j]}: {v:.3f}" for j, v in enumerate(values) if j != len(values) - 1
        ]
        fig.add_trace(
            go.Scatterpolar(
                r=values,
                theta=theta,
                fill="toself" if id == "FAIReg" else "none",
                text="\n".join(info),
                name=f"{id}",
                line=dict(color=rgb01_hex(cmap[i]), dash="dot"),
            )
        )

    ranking = np.argsort(np.argsort(area_list))
    best_id = df.index[np.where(ranking == 0)][0]
    print(
        f"The best model is No.{best_id} with metrics on validation set:\n {df.loc[best_id, :]}"
    )
    values = df.loc[best_id, :].values.flatten().tolist()
    values += values[:1]
    info = [
        f"{theta[j]}: {v:.3f}" for j, v in enumerate(values) if j != len(values) - 1
    ]
    fig.add_trace(
        go.Scatterpolar(
            r=values,
            theta=theta,
            fill="toself",
            text="\n".join(info),
            name=f"model {best_id}",
            line=dict(color="royalblue", dash="solid"),
        )
    )

    fig.update_layout(
        # title = title,
        font=dict(family="Arial", size=16),
        polar=dict(
            # bgcolor = "#1e2130",
            radialaxis=dict(
                showgrid=True,
                gridwidth=1,
                gridcolor="lightgray",
                visible=True,
                range=[0, thresh_show],
            )
        ),
        legend=dict(x=0.25, y=-0.1, orientation="h"),
        showlegend=False,
        **kwargs,
    )
    return fig


def plot_bar(
    shap_values, feature_names, original_feature_names, coef=None, title=None, **kwargs
):
    """Plot the bar chart of feature importance"""
    if "color" not in kwargs.keys():
        color = "steelblue"
    else:
        color = kwargs["color"]

    def get_prefix(v):
        if "_" in v and (v not in original_feature_names):
            tmp = ["_".join(v.split("_")[:i]) for i in range(len(v.split("_")))]
            return [s for s in tmp if s in original_feature_names][0]
        else:
            return v

    if shap_values is not None:

        grouped_df = pd.DataFrame({"values": shap_values}, index=feature_names).groupby(
            by=get_prefix, axis=0
        )
        df = {k: np.mean(np.abs(g.values)) for k, g in grouped_df}
        df = pd.DataFrame.from_dict(df, orient="index").reset_index()
        df.columns = ["Var", "Value"]

        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    {
                        "color": ["grey" if i < 0 else "steelblue" for i in df.Value],
                        "order": np.abs(df.Value),
                    }
                ),
            ],
            axis=1,
        )

    elif coef is not None:
        df = pd.DataFrame(
            {
                "Var": coef.index,
                "Value": coef.values,
                "color": ["grey" if i < 0 else "steelblue" for i in coef.values],
                "order": np.abs(coef.values),
            }
        )
    else:
        raise ValueError("Either shap_value or coef should be provided")

    df = df.loc[df["Var"] != "const", :]
    df = df.sort_values(by="order", ascending=True)
    df["Var"] = pd.Categorical(df["Var"], categories=df["Var"].tolist(), ordered=True)

    common_theme = theme(
        text=element_text(size=24),
        panel_grid_major_y=element_line(colour="lightgrey"),
        panel_grid_minor=element_blank(),
        panel_background=element_blank(),
        axis_line_x=element_line(colour="black"),
        axis_ticks_major_y=element_blank(),
    )

    x_lab = "Feature importance"

    p = (
        ggplot(data=df, mapping=aes(x="Var", y="Value", fill="color"))
        + geom_hline(yintercept=0, color="grey")
        + geom_bar(stat="identity")
        + common_theme
        + coord_flip()
        + labs(x="", y=x_lab, title=title)
        + theme(legend_position="none")
        + scale_fill_manual(values=[color])
    )
    return p



seed = 1234
np.random.seed(seed)
rng = np.random.RandomState(seed)


# metrics
def get_ci_auc(y_true, y_pred, alpha=0.05, type="auc"):
    """Calculate the confidence interval for the AUC (Area Under the Curve) score
    or PR (Precision-Recall) score using bootstrapping.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Predicted scores or probabilities.
        alpha (float, optional): Significance level for the confidence interval. Default is 0.05.
        type (str, optional): Type of score to calculate: 'auc' (default) or 'pr' (precision-recall).

    Returns:
        tuple: Tuple containing the lower and upper bounds of the confidence interval.
    """

    n_bootstraps = 1000
    bootstrapped_scores = []

    for i in range(n_bootstraps):
        # bootstrap by sampling with replacement on the prediction indices
        indices = rng.randint(0, len(y_pred) - 1, len(y_pred))

        if len(np.unique(y_true[indices])) < 2:
            continue

        if type == "pr":
            precision, recall, thresholds = precision_recall_curve(
                y_true[indices], y_pred[indices]
            )
            score = auc(recall, precision)
        else:
            score = roc_auc_score(y_true[indices], y_pred[indices])
        bootstrapped_scores.append(score)

    sorted_scores = np.array(bootstrapped_scores)
    sorted_scores.sort()

    # 95% c.i.
    confidence_lower = sorted_scores[int(alpha / 2 * len(sorted_scores))]
    confidence_upper = sorted_scores[int(1 - alpha / 2 * len(sorted_scores))]

    return confidence_lower, np.median(sorted_scores), confidence_upper


def find_optimal_cutoff(target, predicted, method="auc"):
    """Find the optimal probability cutoff point for a classification model related to event rate.

    Args:
        target (array-like): True labels.
        predicted (array-like): Predicted scores or probabilities.
        method (str, optional): Method for finding the optimal cutoff. Default is 'auc'.

    Returns:
        list: List of optimal cutoff values.
    """
    if method == "auc":
        fpr, tpr, threshold = roc_curve(target, predicted)
        i = np.arange(len(tpr))
        roc = pd.DataFrame(
            {
                "tf": pd.Series(tpr + (1 - fpr), index=i),
                "threshold": pd.Series(threshold, index=i),
            }
        )
        roc_t = roc.iloc[(roc.tf - 0).abs().argsort()[::-1][:1]]
    elif method == "pr-auc":
        precision, recall, threshold = precision_recall_curve(target, predicted)
        i = np.arange(len(precision))
        prc = pd.DataFrame(
            {
                "tf": pd.Series(tpr - (1 - fpr), index=i),
                "threshold": pd.Series(threshold, index=i),
            }
        )
        prc_t = prc.iloc[(prc.tf - 0).abs().argsort()[:1]]

    return list(roc_t["threshold"])


def get_cal_fairness(df):
    def absolute_difference(x):
        return np.abs(spline1(x) - spline0(x))

    df.groupby("group").apply(lambda x: np.max(x["p_obs"]))
    # for g in df_calib.group.unique():

    gs = df.group.unique()
    pairs = list(combinations(gs, 2))

    x_min_thresh = np.min(df.groupby("group").apply(lambda x: np.min(x["p_pred"])))
    x_max_thresh = np.max(df.groupby("group").apply(lambda x: np.max(x["p_pred"])))
    num_points = 100
    diff_cal = []

    for p in pairs:
        p0 = df.loc[df.group == p[0], ["p_obs", "p_pred"]].sort_values(
            by="p_pred", ascending=True
        )
        p1 = df.loc[df.group == p[1], ["p_obs", "p_pred"]].sort_values(
            by="p_pred", ascending=True
        )

        x0 = p0["p_pred"]
        y0 = p0["p_obs"]
        spline0 = CubicSpline(x0, y0)

        x1 = p1["p_pred"]
        y1 = p1["p_obs"]
        spline1 = CubicSpline(x1, y1)

        x_sample = np.linspace(x_min_thresh, x_max_thresh, num_points)
        y_sample = absolute_difference(x_sample)
        area = integrate.simpson(y_sample, x_sample)
        diff_cal.append(area)

    cal_metric = np.mean(diff_cal)
    # print(f"Calibration metric: {cal_metric:.2f}")
    return cal_metric


## small functions
def col_gap(col_train, col_test, x_with_constant):
    if len(col_train) != len(col_test):
        col_gap = [i not in col_test for i in col_train]
        x_with_constant[col_train[col_gap]] = 0
        x_with_constant = x_with_constant.loc[:, col_train]

    return x_with_constant


def generate_subsets(input_list):
    subsets = []
    n = len(input_list)
    for subset_size in range(n + 1):
        for subset in itertools.combinations(input_list, subset_size):
            subsets.append(list(subset))
    return subsets

##fairness_base

In [None]:
seed = 1234
np.random.seed(seed)
rng = np.random.RandomState(seed)


class FairBase:
    def __init__(
        self,
        dat_train,
        selected_vars,
        selected_vars_cat,
        y_name,
        sen_name,
        sen_var_ref,
        without_sen=False,
        weighted=True,
        weights={"tnr": 0.5, "tpr": 0.5},
        class_weight="balanced",
    ):
        """Initialize the fairness base class

        Args:
            dat_train (data frame): training data
            selected_vars (list): selected variables including sensitive variables
            selected_vars_cat (list): selected categorical variables
            y_name (str): the name of the label
            sen_name (list): the name of the sensitive variable
            sen_var_ref (dict): the reference level of the sensitive variables
            without_sen (bool, optional): directly exclude the sensitive variables. Defaults to False.
            weighted (bool, optional): compute the weighted version of metrics "tnr" and "tpr". Defaults to True.
            weights (dict, optional): the weightage for "tnr" and "tpr", summing up to 1. Defaults to {"tnr": 0.5, "tpr": 0.5}.
        """

        self.dat_train = dat_train
        self.vars = selected_vars
        self.vars_cat = selected_vars_cat
        self.y_name = y_name

        if not isinstance(sen_name, list):
            self.sen_name = [self.sen_name]
        for s in sen_name:
            if sen_var_ref[s] not in dat_train[s].unique():
                raise ValueError(
                    f"Please provide the right reference level of sensitive variables {s}!"
                )
            if s not in self.vars:
                self.vars.append(s)
        else:
            self.sen_name = sen_name
            self.sen_var_ref = sen_var_ref

        self.without_sen = without_sen
        self.weighted = weighted
        self.weights = weights
        self.class_weight = class_weight

    def compute_class_weights(self, y):
        """Compute class weights for balanced training

        Args:
            y: target labels

        Returns:
            array of weights for each sample
        """
        if self.class_weight == "balanced":
            classes, counts = np.unique(y, return_counts=True)
            n_samples = len(y)
            n_classes = len(classes)
            class_weights = n_samples / (n_classes * counts)
            weights = np.array([class_weights[c] for c in y])
            return weights
        else:
            return np.ones(len(y))

    def data_process(
        self, dat, selected_vars=None, selected_vars_cat=None, without_sen=None
    ):
        """Data preprocess

        Args:
            dat (data frame): data
            selected_vars (list, optional): selected variables (can include sensitive variables). This needs to be provided if the case considered is beyond completely inclusion or exclusion of sensitive variables. Defaults to None.
            selected_vars_cat (list, optional): selected categorical variables, subset of selected variables. Defaults to None.
            without_sen (bool, optional): directly exclude the sensitive variables. Defaults to None.

        Returns:
            x_1: predictors with one-coding and with constant
            sen_var: the vector of sensitive variable. combined by "_", if there are several sensitive variables
            y: the vector of the label

        """

        def combine_sen(dat, sen):
            new_sen = ["_".join(v) for v in zip(*[dat[s].astype("str") for s in sen])]
            return new_sen

        if selected_vars is None:
            selected_vars = self.vars
            selected_vars_cat = self.vars_cat

        x = dat.drop(
            columns=[
                c for c in dat.columns if c == self.y_name or c not in selected_vars
            ]
        )

        if (
            self.without_sen != "auto"
            and (without_sen is None and self.without_sen)
            or without_sen
        ):
            x = x.drop(columns=self.sen_name)

        if len(self.sen_name) > 1:
            sen_var = combine_sen(dat, self.sen_name)
        else:
            sen_var = dat[self.sen_name[0]]

        y = dat[self.y_name]

        x_dm, x_groups = _util.model_matrix(x=x, x_names_cat=selected_vars_cat)
        x_1 = sm.add_constant(x_dm).astype("float")
        return x_1, sen_var, y

    def data_prepare(self, dat_expl=None):
        """Shape the data to AIF360 format

        Args:
            dat_expl (_type_, optional): validation data needed for post-processing methods. Defaults to None.

        """
        if self.method == "Unawareness":
            x_with_constant, _, y_train = self.data_process(
                self.dat_train if dat_expl is None else dat_expl, without_sen=True
            )
            return x_with_constant

        if self.method_type == "pre":
            x_with_constant, _, y_train = self.data_process(self.dat_train)
            x_with_constant_sen_bin = copy.deepcopy(x_with_constant)
            for s in self.sen_name:
                x_with_constant_sen_bin[s] = [
                    0 if i == self.sen_var_ref[s] else 1 for i in self.dat_train[s]
                ]
            # print(x_with_constant_expl_sen_bin.columns.head(), flush=True)

            pre_train_df = pd.concat([x_with_constant_sen_bin, y_train], axis=1)
            pre_train = BinaryLabelDataset(
                favorable_label=1,
                df=pre_train_df,
                label_names=[self.y_name],
                protected_attribute_names=self.sen_name,
            )
            return pre_train

        elif self.method_type == "post":
            if dat_expl is None:
                raise ValueError("Please provide validation data.")
            else:
                x_with_constant_expl, sen_var, y_expl = self.data_process(dat_expl)
                x_with_constant_expl_sen_bin = copy.deepcopy(x_with_constant_expl)

                for s in self.sen_name:
                    x_with_constant_expl_sen_bin[s] = [
                        0 if i == self.sen_var_ref[s] else 1 for i in dat_expl[s]
                    ]

                prob_expl_ori = self.lr_results.predict(x_with_constant_expl)
                ori_thresh = find_optimal_cutoff(y_expl, prob_expl_ori)[0]
                pred_expl_ori = prob_expl_ori > ori_thresh

                post_expl_df = pd.concat([x_with_constant_expl_sen_bin, y_expl], axis=1)
                post_expl = BinaryLabelDataset(
                    favorable_label=1,
                    df=post_expl_df,
                    label_names=[self.y_name],
                    protected_attribute_names=self.sen_name,
                )
                post_expl_pred = post_expl.copy(deepcopy=True)
                post_expl_pred.scores = prob_expl_ori.values.reshape(-1, 1)
                post_expl_pred.labels = pred_expl_ori.values.reshape(-1, 1)
                return post_expl, post_expl_pred

    def model(self, method_type=None, method=None, dat_expl=None, **kwargs):
        """Fit the model

        Args:
            method_type (str, optional): the type of bias mitigation method (pre/in/post). Defaults to None.
            method (str, optional): the name of bias mitigation method. Defaults to None.
            dat_expl (_type_, optional): validation data needed for post-processing methods. Defaults to None.

            Methods:
            +------------------+--------------------------------+
            | Method type      | Specific methods               |
            +==================+================================+
            | None             | "OriginalLR", "Unawareness"   |
            +------------------+--------------------------------+
            | "pre"            | "Reweigh"                      |
            +------------------+--------------------------------+
            | "in"             | "Reductions"                   |
            +------------------+--------------------------------+
            | "post"           | "EqOdds", "CalEqOdds", "ROC"  |
            +------------------+--------------------------------+

        Returns:
            model results that can be used for prediction
        """
        self.method = method
        self.method_type = method_type

        if isinstance(self.sen_name, list):
            privileged_groups = [{s: 0 for s in self.sen_name}]
            unprivileged_groups = [{s: 1 for s in self.sen_name}]
        else:
            privileged_groups = [{self.sen_name: 0}]
            unprivileged_groups = [{self.sen_name: 1}]

        x_with_constant_nosen, _, y_train = self.data_process(
            self.dat_train, without_sen=True
        )
        x_with_constant, _, y_train = self.data_process(self.dat_train)

        # original LR
        if self.class_weight == "balanced":
            sample_weights = self.compute_class_weights(self.dat_train[self.y_name])
            lr_model = sm.GLM(
                self.dat_train[self.y_name], x_with_constant, family=sm.families.Binomial(),
                freq_weights=sample_weights
            )
        else:
            lr_model = sm.GLM(
                self.dat_train[self.y_name], x_with_constant, family=sm.families.Binomial()
            )
        self.lr_results = lr_model.fit()

        if method_type == None:
            if method == "OriginalLR":
                return self.lr_results

            elif method == "Unawareness":
                un_model = sm.GLM(
                    self.dat_train[self.y_name],
                    x_with_constant_nosen,
                    family=sm.families.Binomial(),
                )
                un_results = un_model.fit()
                return un_results
            else:
                raise ValueError(
                    "Please confirm the method: 'OriginalLR' if no bias mitigation is needed; 'Unawareness' if simply excluding the sensitive variabl is enough."
                )

        elif method_type == "pre":
            pre_train = self.data_prepare()

            if method == "Reweigh":
                reweigh_model = Reweighing(
                    privileged_groups=privileged_groups,
                    unprivileged_groups=unprivileged_groups,
                )
                rw_train = reweigh_model.fit_transform(pre_train)
                rw_model = sm.GLM(
                    self.dat_train[self.y_name],
                    x_with_constant,
                    family=sm.families.Binomial(),
                    freq_weights=rw_train.instance_weights,
                )
                plt.hist(rw_train.instance_weights)
                rw_results = rw_model.fit()
                return rw_model, rw_results, rw_train.instance_weights
            else:
                raise ValueError(
                    "Please specify the type of pre-process bias mitigation method among ['Reweigh']!"
                )

        elif method_type == "in":
            if method == "Reductions":

                constraint = EqualizedOdds(difference_bound=0.01)
                np.random.seed(
                    0
                )  # set seed for consistent results with ExponentiatedGradient
                lr_model_sk = LogisticRegression(max_iter=5000, penalty=None)
                mitigator = ExponentiatedGradient(lr_model_sk, constraint)

                mitigator.fit(
                    x_with_constant,
                    y_train,
                    sensitive_features=self.dat_train[self.sen_name],
                )
                return mitigator
            else:
                raise ValueError(
                    "Please specify the type of in-process bias mitigation method among ['Reductions']!"
                )

        elif method_type == "post":
            post_expl, post_expl_pred = self.data_prepare(dat_expl=dat_expl)

            if method == "EqOdds":
                eq_model = EqOddsPostprocessing(
                    privileged_groups=privileged_groups,
                    unprivileged_groups=unprivileged_groups,
                    seed=seed,
                )
                eq_results = eq_model.fit(post_expl, post_expl_pred)
                return eq_results

            elif method == "CalEqOdds":
                if "cost_constraint" in kwargs:
                    cost_constraint = kwargs["cost_constraint"]
                else:
                    cost_constraint = "weighted"  # "fnr", "fpr", "weighted"
                cal_eq_model = CalibratedEqOddsPostprocessing(
                    privileged_groups=privileged_groups,
                    unprivileged_groups=unprivileged_groups,
                    cost_constraint=cost_constraint,
                    seed=seed,
                )
                cal_eq_results = cal_eq_model.fit(post_expl, post_expl_pred)
                return cal_eq_results

            elif method == "ROC":
                ub = 0.05 if "ub" not in kwargs else kwargs["ub"]
                lb = -0.05 if "lb" not in kwargs else kwargs["lb"]
                metric_name = (
                    "Equal opportunity difference"
                    if "metric_name" not in kwargs
                    else kwargs["metric_name"]
                )
                # allowed_metrics = ["Statistical parity difference", "Average odds difference", "Equal opportunity difference"]
                ROC_model = RejectOptionClassification(
                    privileged_groups=privileged_groups,
                    unprivileged_groups=unprivileged_groups,
                    low_class_thresh=0.01,
                    high_class_thresh=0.99,
                    num_class_thresh=100,
                    num_ROC_margin=50,
                    metric_name=metric_name,
                    metric_ub=ub,
                    metric_lb=lb,
                )
                ROC_results = ROC_model.fit(post_expl, post_expl_pred)
                return ROC_results

            else:
                raise ValueError(
                    "Please specify the type of post-process bias mitigation method among ['EqOdds', 'CalEqOdds', 'ROC']!"
                )

        else:
            raise ValueError(
                "Please specify the type of bias mitigation method (pre/in/post)!"
            )

    def test(self, dat_test, model=None, params=None, thresh=None, **kwargs):
        """Test the fairness of the model

        Args:
            dat_test (data frame): test data
            model (_type_, optional): the fitted model. Defaults to None.
            params (_type_, optional): coefficients for the model. Defaults to None.
            thresh (_type_, optional): threshold for the predictions. Defaults to None.

        Returns:
            pred_test / prob_test (array): predicted labels / predicted probabilities
            fairmetrics (data frame): fairness metrics
            fairsummary (data frame): fairness summary for each subgroup
        """
        if "without_sen" in kwargs.keys():
            without_sen = kwargs["without_sen"]
        else:
            without_sen = self.without_sen
        x_with_constant_test, sen_var, y_test = self.data_process(dat_test)
        prob_test = None
        thresh = None

        if model is not None:
            if self.method_type == "post":
                _, post_test_pred = self.data_prepare(dat_expl=dat_test)
                pred_test = model.predict(post_test_pred).labels.reshape(-1)
            else:
                if self.method_type == None and self.method == "Unawareness":
                    x_with_constant_test = self.data_prepare(dat_expl=dat_test)

                if self.method == "Reductions":
                    pred_test = model.predict(x_with_constant_test)
                else:
                    prob_test = model.predict(x_with_constant_test)
                    thresh = find_optimal_cutoff(y_test, prob_test)[0]
                    pred_test = prob_test > thresh
                    # print(prob_test)
        else:
            raise ValueError("Please provide the right model!")

        fe = FAIMEvaluator(
            y_true=y_test,
            y_pred=prob_test,
            y_pred_bin=pred_test,
            sen_var=sen_var,
            weighted=self.weighted,
            weights=self.weights,
        )
        fairmetrics = fe.fairmetrics
        fairsummary = fe.fairsummary
        clametrics = fe.clametrics

        return pred_test if prob_test is None else prob_test, fairmetrics, clametrics


##fairness_evaluation

In [None]:
warnings.filterwarnings("ignore", category=RuntimeWarning)

my_fairness_bases = {
    "tpr": recall_score,
    "tnr": true_negative_rate,
    "sr": selection_rate,
    "acc": accuracy_score,
    "conf_mat": confusion_matrix,
}
# the situation for each group should not be bad; tnr -> fpr
my_bases_bound = {"tpr": 0.6, "tnr": 0.6, "sr": 0, "acc": 0.6, "conf_mat": pd.NA}


def fairarea(fairness_metrics):
    n_metric = len(fairness_metrics)
    tmp = fairness_metrics.values.flatten().tolist()
    tmp_1 = tmp[1:] + tmp[:1]

    if n_metric > 2:
        theta_c = 2 * np.pi / n_metric
        area = np.sum(np.array(tmp) * np.array(tmp_1) * np.sin(theta_c))
    elif n_metric == 2:
        area = np.sum(np.array(tmp) * np.array(tmp_1))
    else:
        area = np.abs(tmp[0])

    return area


class FAIMEvaluator:
    def __init__(
        self,
        y_true,
        y_pred,
        y_pred_bin,
        sen_var,
        fair_only=False,
        cla_metrics=["auc"],
        weighted=False,
        weights=None,
        bases=my_fairness_bases,
        bound=my_bases_bound,
    ):
        """Initialize the fairness evaluator.

        Args:
            y_true: true labels
            y_pred: predicted scores or probabilities
            y_pred_bin: predicted binary labels
            sen_var: the vector of sensitive variables
            fair_only (bool, optional): whether to compute fairness metrics only. Defaults to False.
            cla_metrics (list, optional): classification metrics. Defaults to ["auc"].
            weighted (bool, optional): whether to create a customized fairness metric based on weighted combining of 'tnr' and 'tpr'. Defaults to False.
            weights (_type_, optional): the weights for weighted combining of 'tnr' and 'tpr'. Required when `weighted` is True. Defaults to None.
            bases (_type_, optional): the bases for fairness metrics. Defaults to my_fairness_bases (see above).
            bound (_type_, optional): the bound for base metrics. Defaults to my_bases_bound.
        """
        # super().__init__(y_obs=y_true, y_pred=y_pred, y_pred_bin=y_pred_bin, sens_var=pd.Series(sen_var), y_pos=True)

        self.y_true = y_true
        self.y_pred = y_pred
        self.y_pred_bin = pd.Series(y_pred_bin)
        self.sen_var = pd.Series(sen_var)
        self.cla_metrics = cla_metrics

        self.my_fairness_bases = bases
        self.my_bases_bound = bound

        if weighted:
            if weights is None or not isinstance(weights, dict):
                raise TypeError(
                    "The weights need to be specified and the type should be dict!"
                )
            elif len(weights) != 2 or np.sum(list(weights.values())) != 1:
                raise ValueError(
                    "The weights should be a dict containing two values respectively for 'tpr' and 'tnr'. In addition, the sum of weights should be equal to 1!"
                )
            else:
                self.weighted = weighted
                self.weights = weights

        self._fairsummary_generation()
        self._fairmetrics_generation()
        if not fair_only:
            if cla_metrics is None:
                raise ValueError("The classification metrics should be specified!")
            self._clametric_generation()

    @staticmethod
    def _check_sen(y_obs, sen_var, sens_var_ref):
        # print("Checking the sensitive variable...")
        return {"sens_var": sen_var, "sens_var_ref": pd.unique(sen_var)[0]}

    def _fairsummary_generation(self):
        """Computation primary metrics (e.g., TPR, TNR, etc.) among subgroups

        Returns:
            _type_: _description_
        """
        self.fairsummary = MetricFrame(
            y_true=self.y_true,
            y_pred=self.y_pred_bin,
            metrics=self.my_fairness_bases,
            sensitive_features=self.sen_var,
        )

        # Create performance metrics table with additional metrics
        by_group = self.fairsummary.by_group

        tpr = by_group["tpr"]
        tnr = by_group["tnr"]
        sr = by_group["sr"]
        acc = by_group["acc"]

        # Calculate FPR, FNR, and BER
        fpr = 1 - tnr
        fnr = 1 - tpr
        ber = (fpr + fnr) / 2

        # Create comprehensive performance table
        self.performance_table = pd.DataFrame({
            "TPR": tpr,
            "FPR": fpr,
            "TNR": tnr,
            "FNR": fnr,
            "SR": sr,
            "Accuracy": acc,
            "BER": ber
        })

    def _fairmetrics_generation(self):
        """Generate fairness metrics and disparity tables."""

        # ----- machine learning performance-based fairness metrics -----
        bases = self.my_fairness_bases.keys()
        fairmetrics = {}
        qc = {}
        diff_ = self.fairsummary.difference()
        for b in list(bases)[:-1]:
            qc[b] = self.fairsummary.overall[b] > self.my_bases_bound[b]

        fairmetrics["Equal Opportunity"] = diff_["tpr"]
        fairmetrics["Equalized Odds"] = np.max([diff_["tpr"], diff_["tnr"]])
        fairmetrics["Statistical Parity"] = diff_["sr"]
        fairmetrics["Accuracy Equality"] = diff_["acc"]

        # MODIFIED: Calculate true BER Equality as Range(BER)
        by_group = self.fairsummary.by_group
        fpr_values = 1 - by_group["tnr"]
        fnr_values = 1 - by_group["tpr"]
        ber_values = 0.5 * (fpr_values + fnr_values)

        if self.weighted:
            # True BER Equality = Range of BER across groups
            fairmetrics["BER Equality"] = ber_values.max() - ber_values.min()
        else:
            fairmetrics["BER Equality"] = ber_values.max() - ber_values.min()

        self.fairmetrics = pd.DataFrame([fairmetrics])

        # Create fairness disparity summary table
        diff_ = self.fairsummary.difference()

        # Get TPR, FPR disparities
        tpr_diff = diff_["tpr"]
        fpr_diff = abs(diff_["tnr"])  # FPR diff = TNR diff

        # Equalized Odds = max of TPR and FPR differences
        equalized_odds = max(tpr_diff, fpr_diff)

        # Equal Opportunity = TPR difference
        equal_opportunity = tpr_diff

        # BER Equality (now correctly calculated)
        ber_equality = fairmetrics["BER Equality"]

        # Helper function to determine if we're looking at intersectional groups
        def is_intersectional(group_name):
            """Check if group contains multiple attributes (e.g., 'Male_White')"""
            return '_' in str(group_name)

        # MODIFIED: Helper function to check if reference group exists
        def get_reference_group(groups):
            """Find reference group based on sensitive variable type"""
            # For intersectional (Sex_Race)
            for group in groups:
                if '@Male_@White' in str(group):
                    return group

            # For race only
            for group in groups:
                if '@White' in str(group) and '@Male' not in str(group):
                    return group

            # For sex only
            for group in groups:
                if '@Male' in str(group) and '@White' not in str(group):
                    return group

            return None

        # Determine min type for Equalized Odds (which metric drives the disparity)
        if tpr_diff >= fpr_diff:
            eq_odds_min_type = "TPR"
            eq_odds_values = by_group["tpr"]
        else:
            eq_odds_min_type = "FPR"
            eq_odds_values = 1 - by_group["tnr"]

        # Build disparity table with 3 rows
        disparity_data = []

        # Determine the appropriate column name based on group structure
        groups = by_group.index.tolist()
        is_intersect = is_intersectional(groups[0]) if len(groups) > 0 else False
        group_col_name = "Intersection" if is_intersect else "Group"

        # Check if reference group exists
        reference_group = get_reference_group(groups)

        # Row 1: Equalized Odds (max of TPR/FPR)
        row1 = {
            "Metric": "Equalized Odds (max of TPR/FPR)",
            "Min Value": eq_odds_values.min(),
            f"Min {group_col_name}": eq_odds_values.idxmin(),
            "Min Type": eq_odds_min_type,
            "Max Value": eq_odds_values.max(),
            f"Max {group_col_name}": eq_odds_values.idxmax(),
            "Gap": equalized_odds
        }

        # Add reference group if it exists and is not already min/max
        if reference_group and reference_group not in [eq_odds_values.idxmin(), eq_odds_values.idxmax()]:
            row1[f"Reference {group_col_name}"] = reference_group
            row1["Reference Value"] = eq_odds_values[reference_group]
            row1["Reference Gap"] = abs(eq_odds_values[reference_group] - eq_odds_values.min())
        else:
            row1[f"Reference {group_col_name}"] = None
            row1["Reference Value"] = None
            row1["Reference Gap"] = None

        disparity_data.append(row1)

        # Row 2: Equal Opportunity (TPR)
        tpr_values = by_group["tpr"]
        row2 = {
            "Metric": "Equal Opportunity (TPR)",
            "Min Value": tpr_values.min(),
            f"Min {group_col_name}": tpr_values.idxmin(),
            "Min Type": "TPR",
            "Max Value": tpr_values.max(),
            f"Max {group_col_name}": tpr_values.idxmax(),
            "Gap": equal_opportunity
        }

        # Add reference group if it exists and is not already min/max
        if reference_group and reference_group not in [tpr_values.idxmin(), tpr_values.idxmax()]:
            row2[f"Reference {group_col_name}"] = reference_group
            row2["Reference Value"] = tpr_values[reference_group]
            row2["Reference Gap"] = abs(tpr_values[reference_group] - tpr_values.min())
        else:
            row2[f"Reference {group_col_name}"] = None
            row2["Reference Value"] = None
            row2["Reference Gap"] = None

        disparity_data.append(row2)

        # Row 3: BER Equality (CORRECTED)
        row3 = {
            "Metric": "BER Equality",
            "Min Value": ber_values.min(),
            f"Min {group_col_name}": ber_values.idxmin(),
            "Min Type": "BER",
            "Max Value": ber_values.max(),
            f"Max {group_col_name}": ber_values.idxmax(),
            "Gap": ber_equality
        }

        # Add reference group if it exists and is not already min/max
        if reference_group and reference_group not in [ber_values.idxmin(), ber_values.idxmax()]:
            row3[f"Reference {group_col_name}"] = reference_group
            row3["Reference Value"] = ber_values[reference_group]
            row3["Reference Gap"] = abs(ber_values[reference_group] - ber_values.min())
        else:
            row3[f"Reference {group_col_name}"] = None
            row3["Reference Value"] = None
            row3["Reference Gap"] = None

        disparity_data.append(row3)

        self.disparity_table = pd.DataFrame(disparity_data)

        # Round numeric columns to 4 decimals
        numeric_cols = ["Min Value", "Max Value", "Gap"]
        if "Reference Value" in self.disparity_table.columns:
            numeric_cols.append("Reference Value")
        if "Reference Gap" in self.disparity_table.columns:
            numeric_cols.append("Reference Gap")

        for col in numeric_cols:
            if col in self.disparity_table.columns:
                self.disparity_table[col] = self.disparity_table[col].round(4)

        self.qc = pd.DataFrame([qc])

    def _clametric_generation(self):
        clametrics = {}
        pred = self.y_pred if self.y_pred is not None else self.y_pred_bin

        # Add sensitivity (TPR) and specificity (TNR)
        clametrics["sensitivity"] = recall_score(self.y_true, self.y_pred_bin)
        clametrics["specificity"] = true_negative_rate(self.y_true, self.y_pred_bin)

        if "auc" in self.cla_metrics:
            # clametrics["auc"] = roc_auc_score(self.y_true, self.y_pred)
            clametrics["auc_low"], clametrics["auc"], clametrics["auc_high"] = (
                get_ci_auc(self.y_true, pred, alpha=0.05, type="auc")
            )

        self.clametrics = pd.DataFrame([clametrics])

##fairness_modelling

In [None]:
class FAIMGenerator(FairBase):
    def __init__(
        self,
        dat_train,
        selected_vars,
        selected_vars_cat,
        y_name,
        sen_name,
        sen_var_ref,
        output_dir,
        criterion="loss",
        epsilon=0.05,
        m=800,
        n_final=350,
        without_sen=False,
        pre=False,
        pre_method="Reweigh",
        post=False,
        post_method="equalizedodds",
        class_weight="balanced",
    ):
        """Initialize the class of FAIM

        Args:
            dat_train (data frame): the training data
            selected_vars (list): the selected variables that include sensitive variables
            selected_vars_cat (list): the selected categorical variables that include sensitive variables
            y_name (str): the name of the label, e.g. "y", "label", etc.
            sen_name (list): the name of the sensitive variables
            sen_var_ref (dict): the reference values of the sensitive variables e.g. {"gender": "F"}
            output_dir: the output directory to store the nearly optimal models results
            criterion (str, optional): the criterion to generate nearly optimal models. Defaults to "loss".
            epsilon (float, optional): the control factor of "nearly optimality", i.e. the gap to the optimal model. Defaults to 0.05.
            without_sen (bool, optional): directly exclude the sensitive variables. Defaults to False.
            pre (bool, optional): whether to use pre-process bias mitigation methods before FAIM. Defaults to False.
            pre_method (str, optional): specific pre-process method. Defaults to "Reweigh".
            post (bool, optional): whether to use post-process bias mitigation methods after FAIM. Defaults to False.
            post_method (str, optional): specific post-process method. Defaults to "EqOdds".
        """

        super().__init__(
            dat_train,
            selected_vars,
            selected_vars_cat,
            y_name,
            sen_name,
            sen_var_ref,
            without_sen,
            class_weight=class_weight,
        )

        self.criterion = criterion
        self.output_dir = output_dir
        self.epsilon = epsilon
        self.m = m
        self.n_final = n_final

        self.pre = pre
        if pre:
            self.pre_method = pre_method
            self.rw_model, self.rw_results, self.rw_weights = self.pre_mitigate()
            plt.hist(self.rw_weights)
        if post:
            self.post = post
            self.post_method = post_method

        self.optim_obj = self.optimal_model(selected_vars, selected_vars_cat)
        self.optim_results = self.optim_obj.model_optim
        self.optim_model = self.optim_obj.model_optim.model

        self.dat_expl = None
        self.dat_test = None

    # def __reduce__(self):
    #     return (self.__class__, (self.coefs, self.best_coef, self.best_optim_base_obj, self.best_sen_exclusion, self.fairmetrics_df))

    def pre_mitigate(self):
        """Pre-process bias mitigation methods"""
        rw_model, rw_results, instance_weights = self.model(
            method_type="pre", method=self.pre_method
        )

        return rw_model, rw_results, instance_weights

    def optimal_model(self, selected_vars, selected_vars_cat):
        """Generate the optimal model"""
        x = self.dat_train.drop(
            columns=[
                c
                for c in self.dat_train.columns
                if c == self.y_name or c not in selected_vars
            ]
        )

        # ===== HANDLE THREE CASES =====

        # Case 1: Using pre-processing fairness intervention (Reweighing)
        if self.pre and self.pre_method == "Reweigh":
            sample_weights = self.rw_weights  # Use Reweighing weights
            print("✓ Using Reweighing pre-processing weights (fairness intervention)")

        # Case 2: Using balanced class weights (for imbalance, not fairness)
        elif self.class_weight == "balanced":
            y_train = self.dat_train[self.y_name]
            n_samples = len(y_train)
            classes, counts = np.unique(y_train, return_counts=True)
            class_weight_dict = {
                cls: n_samples / (len(classes) * cnt)
                for cls, cnt in zip(classes, counts)
            }
            sample_weights = np.array([class_weight_dict[cls] for cls in y_train])
            print(f"✓ Using balanced class weights (for imbalance): {class_weight_dict}")

        # Case 3: No weighting
        else:
            sample_weights = None
            print("✓ No sample weighting")

        model_object = model.models(
            x=x,
            y=self.dat_train[self.y_name],
            x_names_cat=selected_vars_cat,
            output_dir=self.output_dir,
            criterion=self.criterion,
            sample_w=sample_weights
        )

        return model_object

    def nearly_optimal_model(self, optim_base_obj, m=200, n_final=50, epsilon=None):
        """Generate the nearly optimal models

        Args:
            optim_base_obj (object): the object of the optimal model
            m (int, optional): the number of models to be generated. Defaults to 800.
            n_final (int, optional): the number of nearly optimal models to be generated. Defaults to 350.

        Returns:
            coefs (data frame): the coefficients of the nearly optimal models
            plots (plot): the plot of the status nearly optimal models
        """
        if epsilon is None:
            epsilon = self.epsilon

        u1, u2 = optim_base_obj.init_hyper_params(m=m)
        optim_base_obj.draw_models(
            u1=u1,
            u2=u2,
            m=self.m,
            n_final=self.n_final,
            random_state=1234
        )
        coefs = pd.read_csv(
            os.path.join(self.output_dir, "models_near_optim.csv"), index_col=0
        )
        return coefs, optim_base_obj.models_plot

    def fairness_compute(
        self,
        dat_expl,
        optim_base_obj,
        coefs,
        weighted=True,
        weights={"tnr": 0.5, "tpr": 0.5},
        **kwargs,
    ):
        """Compute the fairness metrics of the nearly optimal models

        Args:
            dat_expl (data frame): the data frame of the validation data
            optim_base_obj (object): the object of the optimal model
            coefs (data frame): the coefficients of the nearly optimal models
            weighted (bool, optional): whether to use weighted fairness metrics. Defaults to True.
            weights (dict, optional): the weights of the weighted fairness metrics. Defaults to {"tnr": 0.5, "tpr": 0.5}.
            **kwargs: the other parameters of the fairness computation

        Returns:
            fairmetrics_df (data frame): the fairness metrics of the nearly optimal models
            qc_df (data frame): the quality control results of the nearly optimal models
        """
        if weighted:
            self.weighted = weighted
            self.weights = weights
        self.dat_expl = dat_expl

        optim_base_results = optim_base_obj.model_optim
        optim_base_model = optim_base_obj.model_optim.model

        fairmetrics_df = []
        qc_df = []
        by_group_list = []

        for i in range(coefs.shape[0]):
            coef = coefs.drop(columns=["perf_metric"]).iloc[i, :]
            x_with_constant, sen_var, y_expl = self.data_process(dat_expl, **kwargs)

            # sen_var = dat_expl[self.sen_name]
            optim_base_results.params = coef

            col_train = optim_base_results.params.index
            col_test = x_with_constant.columns
            x_with_constant = col_gap(col_train, col_test, x_with_constant)

            prob_expl = optim_base_model.predict(params=coef, exog=x_with_constant)
            thresh = find_optimal_cutoff(y_expl, prob_expl)[0]
            pred_expl = prob_expl > thresh

            fe = FAIMEvaluator(
                y_true=y_expl,
                y_pred=prob_expl,
                y_pred_bin=pred_expl,
                sen_var=sen_var,
                fair_only=True,
                weighted=weighted,
                weights=weights,
            )
            fairmetrics = fe.fairmetrics
            qc = fe.qc

            fairmetrics_df.append(fairmetrics)
            qc_df.append(qc)
            by_group_list.append(fe.fairsummary)

        fairmetrics_df = pd.concat(fairmetrics_df).reset_index(drop=True)
        qc_df = pd.concat(qc_df)

        return fairmetrics_df, qc_df
        # self.thresh_list = thresh_list

    def compare(self, dat_expl, optim_base_results, selected_vars, selected_vars_cat):
        """Compare the cases of exclusion of sensitive variables with the original optimal model i.e. no exclusion of sensitive variables

        Args:
            dat_expl (data frame): the data frame of the validation data
            optim_base_results (object): the object of the optimal model regarding the specific case of exclusion of sensitive variables
            selected_vars (list): the selected variables that can include sensitive variables
            selected_vars_cat (list): the selected categorical variables that can include sensitive variables

        Returns:
            bool: whether the case of exclusion of sensitive variables will be expanded to the nearly optimal models

        """
        x_with_constant_ori, sen_var, y_expl = self.data_process(
            dat_expl, selected_vars=self.vars, selected_vars_cat=self.vars_cat
        )
        x_with_constant_base, sen_var, y_expl = self.data_process(
            dat_expl, selected_vars=selected_vars, selected_vars_cat=selected_vars_cat
        )
        pred_ori = self.optim_results.predict(x_with_constant_ori)
        pred_base = optim_base_results.predict(x_with_constant_base)

        if self.criterion == "auc":
            auc_ori = roc_auc_score(y_expl, pred_ori)
            auc_base = roc_auc_score(y_expl, pred_base)

            # return auc_base > auc_ori * (1-np.sqrt(self.epsilon))
            return auc_base > auc_ori * np.sqrt(1 - self.epsilon)
        if self.criterion == "loss":
            loss_ori = self.optim_model.loglike(self.optim_results.params)
            loss_base = optim_base_results.model.loglike(optim_base_results.params)
            ratio = loss_base / loss_ori
            print(f"loss_ori: {loss_ori}, loss_base: {loss_base}, ratio: {ratio}")
            return ratio < np.sqrt(1 + self.epsilon)

    def FAIM_model(self, dat_expl):
        """FAIM: Generate the nearly optimal models and compute the fairness metrics of the nearly optimal models

        Args:
            dat_expl (data frame): the data frame of the validation data
        """
        self.dat_expl = dat_expl

        self.coefs = {}
        self.plots = {}
        self.optim_base_obj_list = {}
        self.fairmetrics_df = pd.DataFrame()
        self.qc_df = pd.DataFrame()

        if self.without_sen == "auto":
            sen_senarios = generate_subsets(self.sen_name)
            pbar = tqdm(
                sen_senarios, desc="Generating nearly optimal models", postfix="*Start*"
            )
            for x in pbar:
                pbar.set_postfix(postfix=f"exclusion: {x}")

                selected_vars = [i for i in self.vars if i not in x]
                selected_vars_cat = [i for i in self.vars_cat if i not in x]
                optim_base_obj = self.optimal_model(selected_vars, selected_vars_cat)
                optim_base_results = optim_base_obj.model_optim

                if self.compare(
                    dat_expl, optim_base_results, selected_vars, selected_vars_cat
                ):
                    self.optim_base_obj_list["_".join(x)] = optim_base_obj

                    if self.criterion == "auc":
                        epsilon = 1 - np.sqrt(1 - self.epsilon)
                    elif self.criterion == "loss":
                        epsilon = np.sqrt(1 + self.epsilon) - 1

                    coefs, plots = self.nearly_optimal_model(
                        optim_base_obj, n_final=self.n_final, epsilon=epsilon
                    )
                    self.coefs["_".join(x)] = coefs
                    self.plots["_".join(x)] = plots
                    fairmetrics_df, qc_df = self.fairness_compute(
                        dat_expl,
                        optim_base_obj,
                        coefs,
                        selected_vars=selected_vars,
                        selected_vars_cat=selected_vars_cat,
                    )
                    fairmetrics_df["auc"] = coefs["perf_metric"]
                    qc_df["auc"] = coefs["perf_metric"]

                    fairmetrics_df["sen_var_exclusion"] = "_".join(x)
                    qc_df["sen_var_exclusion"] = "_".join(x)
                    self.fairmetrics_df = pd.concat(
                        [self.fairmetrics_df, fairmetrics_df]
                    )
                    self.qc_df = pd.concat([self.qc_df, qc_df])

                else:
                    print(f"Exclusion of {x} degrades the discrimination performance!")

            self.fairmetrics_df = self.fairmetrics_df.reset_index(drop=True)
            self.qc_df = self.qc_df.reset_index(drop=True)

            # return fairmetrics_df, qc_df

    def describe(self, selected_metrics=None):
        """Describe the distribution of fairness metrics for all nearly optimal models
        Args:
            selected_metrics (list, optional): the selected fairness metrics, e.g. ["Statistical Parity", "Equalized Odds", "Average Accuracy"]. Defaults to None.

        Returns:
            fig: the plot of the distribution of fairness metrics for all nearly optimal models
        """
        sen_var_exclusion = self.fairmetrics_df["sen_var_exclusion"]
        auc_var = self.fairmetrics_df["auc"]
        fairmetrics_df = self.fairmetrics_df.drop(columns=["sen_var_exclusion"])
        qc_df = self.qc_df.drop(columns=["sen_var_exclusion", "auc"])

        if selected_metrics is None:
            num_metrics = fairmetrics_df.shape[1]
        else:
            for m in selected_metrics:
                if m not in fairmetrics_df.columns:
                    raise ValueError(f"The metric {m} is not in the fairness metrics!")

            qc_df = qc_df[selected_metrics]
            fairmetrics_df = fairmetrics_df[selected_metrics]
            num_metrics = len(selected_metrics)

        ids_after_qc = np.arange(qc_df.shape[0])[
            (np.sum(qc_df, 1) == qc_df.shape[1]).tolist()
        ]
        print(f"{len(ids_after_qc)} are qualified after quality control")

        min_ones = fairmetrics_df.iloc[ids_after_qc, :].apply(axis=0, func=np.argmin)

        for m in min_ones.index:
            id = fairmetrics_df.iloc[ids_after_qc, :].index[min_ones[m]]
            print(
                f"the model with minimal {m}: No.{id} -- {fairmetrics_df.loc[id, m]:.3f}, with {sen_var_exclusion[id]} excluded from regression"
            )

        plot_df = self.fairmetrics_df.loc[ids_after_qc, :]
        fig = plot_distribution(plot_df)

        if len(ids_after_qc) < fairmetrics_df.shape[0]:
            return (
                fairmetrics_df.iloc[ids_after_qc, :],
                [sen_var_exclusion[i] for i in ids_after_qc],
                fig,
            )

        return fig

    def transmit(
        self,
        targeted_metrics=["Average Accuracy", "Statistical Parity", "Equalized Odds"],
        thresh_show=0.3,
        best_id=None,
        best_sen_exclusion=None,
        **kwargs,
    ):
        """Select the best model regarding fairness

        Args:
            targeted_metrics (list, optional): the targeted fairness metrics. Defaults to ["Average Accuracy", "Statistical Parity", "Equalized Odds"].
            thresh_show (float, optional): the threshold to filter the models. Defaults to 0.3.

        Returns:
            best_coef: the coefficients of the best model
            best_sen_exclusion: the sensitive variables excluded from the best model
            best_optim_base_obj: the object of the best model
            p: the radar plot of the distribution of fairness metrics for all nearly optimal models
        """
        FAIM_area_list = []
        ids = np.sum(
            self.fairmetrics_df[targeted_metrics] < thresh_show, axis=1
        ) == len(targeted_metrics)
        if len(ids) == 0:
            raise ValueError("The thresh is too low!")
        fairmetrics_df = self.fairmetrics_df.loc[ids, :]

        sen_var_exclusion = fairmetrics_df["sen_var_exclusion"]
        perf = fairmetrics_df["auc"]
        df = fairmetrics_df.drop(columns=["sen_var_exclusion"])
        df = df[targeted_metrics]

        print(f"There are {df.shape[0]} models for final fairness selection.")

        if "title" in kwargs.keys():
            title = kwargs["title"]
        else:
            title = None
        p_radar = plot_radar(df, thresh_show=thresh_show, title=title)
        # p_radar.show()
        p, fair_idx_df = plot_scatter(df, perf, sen_var_exclusion, title=title)
        self.p = p
        p.show()

        for i, id in enumerate(df.index):
            values = df.loc[id, :]
            FAIM_area_list.append(fairarea(values))
        ranking = np.argsort(np.argsort(FAIM_area_list))

        if best_id is not None:
            assert best_sen_exclusion is not None
            self.best_id = best_id
            self.best_sen_exclusion = best_sen_exclusion
        else:
            self.best_id = df.index[np.where(ranking == 0)][0]
            self.best_sen_exclusion = sen_var_exclusion.iloc[np.argmin(FAIM_area_list)]

        id_senario = [
            index
            for index, item in enumerate(list(self.coefs.keys()))
            if item == self.best_sen_exclusion
        ][0]

        self.best_coef = (
            self.coefs[self.best_sen_exclusion]
            .drop(columns=["perf_metric"])
            .loc[self.best_id - self.n_final * id_senario, :]
        )
        self.best_optim_base_obj = self.optim_base_obj_list[self.best_sen_exclusion]

        # confidence interval
        dat_uncertainty = self.dat_train.sample(
            n=np.min([50000, self.dat_train.shape[0]]), random_state=42
        )
        excluded_vars = self.best_sen_exclusion.split("_")
        selected_vars = [i for i in self.vars if i not in excluded_vars]
        selected_vars_cat = [i for i in self.vars_cat if i not in excluded_vars]
        x_with_constant = self.data_process(
            dat_uncertainty,
            selected_vars=selected_vars,
            selected_vars_cat=selected_vars_cat,
        )[0].values

        prob_train, _, _ = self.test(dat_uncertainty)
        best_se = None
        fisher_information = (
            x_with_constant.T @ np.diag(prob_train * (1 - prob_train)) @ x_with_constant
        )
        print("multiplication successed!")
        cov = np.linalg.pinv(fisher_information)
        best_se = [np.sqrt(cov[i, i]) for i in range(cov.shape[0])]

        # self.best_thresh = self.thresh_list[self.best_id]
        best_results = {
            "best_coef": self.best_coef,
            "best_sen_exclusion": self.best_sen_exclusion,
            "best_se": best_se,
            "best_optim_base_obj": self.best_optim_base_obj,
        }

        return best_results, fair_idx_df

    def post_mitigate(self):
        pass

    def test(self, dat_test, model=None, params=None, thresh=None):
        """Test the best model regarding fairness

        Args:
            dat_test (data frame): the data frame of the test data
            model (object, optional): the object of the model to be tested. Defaults to None.
            params (optional): the parameters of the model to be tested. Defaults to None.
            thresh (optional): the threshold of the predictions. Defaults to None.

        Methods:
        +--------------+------------+----------------------------------------+
        | Model        | Params     | Description                            |
        +--------------+------------+----------------------------------------+
        | None         | None       | Test the best model produced by FAIM.  |
        +--------------+------------+----------------------------------------+
        | model results| None       | Test the provided model with parameters|
        |              |            | embedded.                              |
        +--------------+------------+----------------------------------------+
        | model results| as required| Test the provided model with the       |
        |              |            | parameters additionally provided.      |
        +--------------+------------+----------------------------------------+

        Returns:
            prob_test: the predicted probabilities of the test data
            fairmetrics: the fairness metrics of the test data
            fairsummary: the fairness summary of the test data for each subgroup

        """
        self.dat_test = dat_test
        excluded_vars = self.best_sen_exclusion.split("_")
        selected_vars = [i for i in self.vars if i not in excluded_vars]
        selected_vars_cat = [i for i in self.vars_cat if i not in excluded_vars]
        x_with_constant, sen_var, y_test = self.data_process(
            dat_test, selected_vars=selected_vars, selected_vars_cat=selected_vars_cat
        )

        if model is None:
            prob_test = self.best_optim_base_obj.model_optim.model.predict(
                params=self.best_coef, exog=x_with_constant
            )
        else:
            if isinstance(model, type(self.optim_results)) and params is None:
                prob_test = model.predict(exog=x_with_constant)
            elif isinstance(model, type(self.optim_model)) and params is not None:
                prob_test = model.predict(params=params, exog=x_with_constant)
            else:
                raise ValueError("Please provide the right model!")

        thresh = find_optimal_cutoff(y_test, prob_test)[0]
        pred_test = prob_test > thresh
        fe = FAIMEvaluator(
            y_true=np.array(y_test),
            y_pred_bin=pred_test,
            y_pred=prob_test,
            sen_var=sen_var,
            weighted=self.weighted,
            weights=self.weights,
        )
        fairmetrics = fe.fairmetrics
        clametrics = fe.clametrics

        return prob_test, fairmetrics, clametrics


    def explain(self, method="best"):
        """Compute SHAP values for FAIM (best) or baseline (ori) model."""

        import shap
        import pandas as pd
        import numpy as np
        import os

        # Choose model + coefficients
        if method == "best":
            coef = self.best_coef
            excluded = set(self.best_sen_exclusion.split("_"))
            used_vars = [v for v in self.vars if v not in excluded]
            used_vars_cat = [v for v in self.vars_cat if v not in excluded]
            model = self.best_optim_base_obj.model_optim.model  # 18 features
        else:  # method == "ori"
            coef = self.optim_results.params
            used_vars = self.vars
            used_vars_cat = self.vars_cat
            model = self.optim_model  # <-- FIX: Use the ORIGINAL model (23 features)

        model_cols = model.exog_names  # Get the correct columns for THIS model

        # ---- Build background data ----
        bg_full = self.data_process(
            self.dat_train,
            selected_vars=used_vars,
            selected_vars_cat=used_vars_cat,
        )[0]

        # IMPORTANT: reduce to model columns BEFORE kmeans
        bg_full = bg_full[model_cols]

        # Now safe to summarize
        bg_data = shap.kmeans(bg_full, k=50)

        # ---- Build explain data ----
        ex_full = self.data_process(
            self.dat_expl,
            selected_vars=used_vars,
            selected_vars_cat=used_vars_cat,
        )[0]

        # Also reduce ex_data to model columns
        ex_data = ex_full[model_cols].sample(n=200, random_state=42)

        # ---- Model function that ALWAYS uses correct columns ----
        def f(X):
            if not isinstance(X, pd.DataFrame):
                X = pd.DataFrame(X, columns=model_cols)
            return model.predict(params=coef, exog=X)

        # ---- Run SHAP ----
        explainer = shap.KernelExplainer(f, bg_data)
        shap_vals = explainer.shap_values(ex_data)

        # KernelExplainer output may be list
        shap_vals = shap_vals[1] if isinstance(shap_vals, list) else shap_vals

        # ---- Save ----
        output_dir = os.path.join(self.output_dir, "explain")
        os.makedirs(output_dir, exist_ok=True)

        pd.DataFrame(shap_vals, columns=model_cols).to_csv(
            os.path.join(output_dir, f"{method}.csv")
        )

        return shap_vals

        def f(bg):
            model = self.best_optim_base_obj.model_optim.model
            if method == "best":
                return model.predict(params=self.best_coef, exog=bg)
            else:
                return model.predict(params=self.optim_results.params, exog=bg)

        output_dir = os.path.join(self.output_dir, "explain")
        if method == "best":
            excluded_vars = self.best_sen_exclusion.split("_")
            selected_vars = [i for i in self.vars if i not in excluded_vars]
            selected_vars_cat = [i for i in self.vars_cat if i not in excluded_vars]

        else:
            selected_vars = self.vars
            selected_vars_cat = self.vars_cat

        bg_data = self.data_process(
            self.dat_train,
            selected_vars=selected_vars,
            selected_vars_cat=selected_vars_cat,
        )[0].sample(n=1000, random_state=42)
        ex_data = self.data_process(
            self.dat_expl,
            selected_vars=selected_vars,
            selected_vars_cat=selected_vars_cat,
        )[0].sample(n=200, random_state=42)

        e = shap.KernelExplainer(f, bg_data)
        shap_values_train = e.shap_values(ex_data)
        shap_values_train_1 = shap_values_train[1].squeeze()
        pd.DataFrame(shap_values_train_1).to_csv(
            os.path.join(output_dir, f"{method}.csv")
        )

        return shap_values_train_1


    def compare_explain(self, overide=True, top_n=None):
        """Compare the SHAP values of the best model and original model"""

        import matplotlib.pyplot as plt
        import textwrap
        import os
        import pandas as pd
        import numpy as np

        def add_bar_labels(ax, values, fmt="{:.4f}", padding=3):
            """
            Add numeric labels to horizontal bar plots.
            """
            for i, v in enumerate(values):
                ax.text(
                    v,
                    i,
                    fmt.format(v),
                    va="center",
                    ha="left",
                    fontsize=16  # ← CHANGED from 10 to 16
                )
        def clean_spines(ax):
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
            ax.spines["left"].set_visible(True)
            ax.spines["bottom"].set_visible(True)

        output_dir = os.path.join(self.output_dir, "explain")
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # ---------- LOAD SHAP VALUES ----------
        if (not os.path.exists(os.path.join(output_dir, "best.csv"))) or overide:
            shap_best = self.explain(method="best")
        else:
            shap_best = pd.read_csv(
                os.path.join(output_dir, "best.csv"), index_col=0
            ).values  # Don't reshape - keep as 2D

        if (not os.path.exists(os.path.join(output_dir, "ori.csv"))) or overide:
            shap_ori = self.explain(method="ori")
        else:
            shap_ori = pd.read_csv(
                os.path.join(output_dir, "ori.csv"), index_col=0
            ).values  # Don't reshape - keep as 2D

        # ---------- AGGREGATE SHAP VALUES (mean absolute) ----------
        # Take mean absolute SHAP value across all samples for each feature
        shap_best_agg = np.mean(np.abs(shap_best), axis=0)
        shap_ori_agg = np.mean(np.abs(shap_ori), axis=0)

        # ---------- SORT + SELECT TOP FEATURES ----------
        features_best = self.best_coef.index
        features_ori = self.optim_results.params.index

        # If top_n is None, use all features
        n_best = len(features_best) if top_n is None else top_n
        n_ori = len(features_ori) if top_n is None else top_n

        shap_best_df = (
            pd.DataFrame({"feature": features_best, "shap": shap_best_agg})
            .sort_values("shap", ascending=False)
            .head(n_best)
        )

        shap_ori_df = (
            pd.DataFrame({"feature": features_ori, "shap": shap_ori_agg})
            .sort_values("shap", ascending=False)
            .head(n_ori)
        )

        # ---------- CLEAN LABELS ----------
        def clean_label(lbl, max_len=40):
            lbl = lbl.replace("_", " ")
            if len(lbl) > max_len:
                return lbl[:max_len] + "..."
            return lbl

        shap_best_df["feature"] = shap_best_df["feature"].apply(clean_label)
        shap_ori_df["feature"] = shap_ori_df["feature"].apply(clean_label)

        # ---------- MAKE PLOTS WITH BIGGER SIZE ----------
        fig, axes = plt.subplots(
            1,
            2,
            figsize=(22, 12),  # ← INCREASED from (18, 10) to (22, 12)
            sharex=False
        )

        # BEST MODEL
        axes[0].barh(
            shap_best_df["feature"],
            shap_best_df["shap"]
        )
        axes[0].invert_yaxis()
        axes[0].set_title(
            "Fairness-aware model (FAIM) - Top SHAP features",
            fontsize=20,  # ← INCREASED from 16 to 20
            weight="bold"
        )
        axes[0].set_xlabel("Mean |SHAP value|", fontsize=18)  # ← INCREASED from 16 to 18
        add_bar_labels(axes[0], shap_best_df["shap"].values)
        axes[0].tick_params(axis="y", labelsize=16)
        axes[0].tick_params(axis="x", labelsize=14)  # ← ADDED for x-axis
        clean_spines(axes[0])

        # ORIGINAL MODEL
        axes[1].barh(
            shap_ori_df["feature"],
            shap_ori_df["shap"],
            color='darkorange'
        )
        axes[1].invert_yaxis()
        axes[1].set_title(
            "Fairness-unaware model (Baseline) - Top SHAP features",
            fontsize=20,  # ← INCREASED from 16 to 20
            weight="bold"
        )
        axes[1].set_xlabel("Mean |SHAP value|", fontsize=18)  # ← INCREASED from 16 to 18
        add_bar_labels(axes[1], shap_ori_df["shap"].values)
        axes[1].tick_params(axis="y", labelsize=16)
        axes[1].tick_params(axis="x", labelsize=14)  # ← ADDED for x-axis
        clean_spines(axes[1])

        # CLEAN LAYOUT
        plt.tight_layout(pad=4)
        plt.subplots_adjust(wspace=0.4)

        # SAVE TOO
        plt.savefig(os.path.join(output_dir, "shap_compare.png"), dpi=300)

        return fig

##fairness_plotting

In [None]:
def rgb01_hex(col):
    col_hex = [round(i * 255) for i in col]
    col_hex = "#%02x%02x%02x" % tuple(col_hex)
    return col_hex


def compute_area(fairness_metrics):
    n_metric = len(fairness_metrics)
    tmp = fairness_metrics.values.flatten().tolist()
    tmp_1 = tmp[1:] + tmp[:1]

    if n_metric > 2:
        theta_c = 2 * np.pi / n_metric
        area = np.sum(np.array(tmp) * np.array(tmp_1) * np.sin(theta_c))
    elif n_metric == 2:
        area = np.sum(np.array(tmp) * np.array(tmp_1))
    else:
        area = np.abs(tmp[0])

    return area


def plot_perf_metric(
    perf_metric, eligible, x_range, select=None, plot_selected=False, x_breaks=None
):
    """ Plot performance metrics of sampled models

        Parameters
        ----------
        perf_metric : numpy.array or pandas.Series
            Numeric vector of performance metrics for all sampled models
        eligible : numpy.array or pandas.Series
            Boolean vector of the same length of 'perf_metric', indicating \
                whether each sample is eligible.
        x_range : list
            Numeric vector indicating the range of eligible values for \
                performance metrics.
            Will be indicated by dotted vertical lines in plots.
        select : list or numpy.array, optional (default: None)
            Numeric vector of indexes of 'perf_metric' to be selected
        plot_selected : bool, optional (default: False)
            Whether performance metrics of selected models should be plotted in \
                a secondary figure.
        x_breaks : list, optional (default: None)
            If selected models are to be plotted, the breaks to use in the \
                histogram

        Returns
        -------
        plot : plotnine.ggplot
            Histogram(s) of model performance made using ggplot
    """
    m = len(perf_metric)
    perf_df = pd.DataFrame(perf_metric, columns=["perf_metric"], index=None)
    plot = (
        pn.ggplot(perf_df, pn.aes(x="perf_metric"))
        + pn.geoms.geom_histogram(
            breaks=np.linspace(np.min(perf_metric), np.max(perf_metric), 40)
        )
        + pn.geoms.geom_vline(xintercept=x_range, linetype="dashed", size=0.7)
        + pn.labels.labs(
            x="Ratio of loss to minimum loss",
            title="""Loss of {m:d} sampled models
                \n{n_elg:d} ({per_elg:.1f}%) sampled models are eligible""".format(
                m=m, n_elg=np.sum(eligible), per_elg=np.sum(eligible) * 100 / m
            ),
        )
        + pn.themes.theme_bw()
        + pn.themes.theme(
            title=pn.themes.element_text(ha="left"),
            axis_title_x=pn.themes.element_text(ha="center"),
            axis_title_y=pn.themes.element_text(ha="center"),
        )
    )
    if plot_selected:
        if select is None:
            print("'select' vector is not specified!\nUsing all models instead")
            select = [i for i in range(len(perf_df))]
        try:
            perf_select = perf_df.iloc[select]
        except:
            print(
                "Invalid indexes detected in 'select' vector!\nUsing all models instead"
            )
            select = [i for i in range(len(perf_df))]
            perf_select = perf_df.iloc[select]
        plot2 = (
            pn.ggplot(perf_select, pn.aes(x="perf_metric"))
            + pn.geoms.geom_histogram(breaks=x_breaks)
            + pn.labels.labs(
                x="Ratio of loss to minimum loss",
                title="{n_select:d} selected models".format(n_select=len(select)),
            )
            + pn.themes.theme_bw()
            + pn.themes.theme(
                title=pn.themes.element_text(ha="left"),
                axis_title_x=pn.themes.element_text(ha="center"),
                axis_title_y=pn.themes.element_text(ha="center"),
            )
        )
        return (plot, plot2)
    else:
        return plot


def plot_distribution(df, s=4):
    num_metrics = df.shape[1] - 2
    labels = df.sen_var_exclusion.unique()
    for i in range(len(labels)):
        if labels[i] == "":
            labels[i] = "No exclusion"
        elif len(labels[i].split("_")) == 2:
            labels[i] = f"Exclusion of {' and '.join(labels[i].split('_'))}"
        elif len(labels[i].split("_")) > 2:
            sens = labels[i].split("_")
            labels[i] = f"Exclusion of {', '.join(sens[:-1])} and {sens[-1]}"
        else:
            labels[i] = f"Exclusion of {labels[i]}"

    fig, axes = plt.subplots(nrows=1, ncols=num_metrics, figsize=(s * num_metrics, s))
    for i, x in enumerate(df.columns[:-2]):
        ax = axes[i]
        # sns.jointplot(data=df, x=x, y="auc", hue="sen_var_exclusion",  ax=ax, legend=False)
        sns.histplot(
            data=df, x=x, hue="sen_var_exclusion", bins=50, ax=ax, legend=False
        )  # layout=(1, num_metrics), figsize=(4, 4), color="#595959",
        ax.set_title(x)
        ax.set_xlabel("")
        ax.set_ylabel("Count" if i == 0 else "")

    plt.legend(
        loc="center left",
        title="",
        labels=labels[::-1],
        ncol=1,
        bbox_to_anchor=(1.04, 0.5),
        borderaxespad=0,
    )
    # plt.tight_layout() bbox_transform=fig.transFigure,
    plt.show()

    return fig


def plot_scatter(df, perf, sen_var_exclusion, title, c1=20, c2=0.15, **kwargs):
    ### basic settings ###
    np.random.seed(0)
    if "figsize" not in kwargs.keys():
        fig_h = 400
        figsize = [fig_h * df.shape[1] * 2.45 / 3, fig_h]
    else:
        figsize = kwargs["figsize"]
    caption_size = figsize[1] / c1  # control font size / figure size
    fig_caption_ratio = 0.8
    fig_font_size = caption_size * fig_caption_ratio

    font_family = "Arial"
    highlight_color = "#D4AF37"
    fig_font_unit = c2  # control the relative position of elements
    caption_font_unit = fig_font_unit * fig_caption_ratio
    d = fig_font_unit / 8
    legend_pos_y = 1 + fig_font_unit
    subtitle_pos = [legend_pos_y + d, legend_pos_y + d + caption_font_unit]
    xlab_pos_y = -fig_font_unit * 2

    area_list = []
    for i, id in enumerate(df.index):
        values = df.loc[id, :]
        area_list.append(1 / compute_area(values))
    ranking = np.argsort(np.argsort(area_list)[::-1])

    # map model id to index position (SAFE indexing)
    id_to_pos = {df.index[i]: i for i in range(len(df))}

    # jittering for display
    jitter_control = np.zeros(len(ranking))
    for idx in range(len(ranking)):
        if ranking[idx] == 0:
            jitter_control[idx] = 0
        elif ranking[idx] <= 10 and ranking[idx] != 0:
            jitter_control[idx] = 0.01 * np.random.uniform(0, 1)
        elif ranking[idx] <= 10**2 and ranking[idx] > 10:
            jitter_control[idx] = 0.015 * np.random.uniform(0, 1)
        elif ranking[idx] <= 10**3 and ranking[idx] > 10**2:
            jitter_control[idx] = 0.015 * np.random.uniform(0, 1)
        else:
            jitter_control[idx] = 0.02 * np.random.uniform(-1, 1)

    ### plot ###
    best_id = df.index[np.where(ranking == 0)][0]
    worst_id = df.index[np.argmin(area_list)]
    meduim_id = df.index[np.argsort(area_list)[int(len(area_list) / 2)]]

    num_metrics = df.shape[1]
    num_models = df.shape[0]

    fig = make_subplots(cols=num_metrics, rows=1, horizontal_spacing=0.13)
    cmap = sns.light_palette("steelblue", as_cmap=False, n_colors=df.shape[0])
    cmap = cmap[::-1]
    colors = [rgb01_hex(cmap[x]) if x != 0 else highlight_color for x in ranking]
    sizes = [10 if x != 0 else 20 for x in ranking]

    shapes = sen_var_exclusion.copy().tolist()
    cases = sen_var_exclusion.unique()
    shapes_candidates = ["square", "circle", "triangle-up", "star"][: len(cases)]
    for i, case in enumerate(cases):
        for j, v in enumerate(sen_var_exclusion):
            if v == case:
                shapes[j] = shapes_candidates[i]

        if cases[i] == "":
            cases[i] = "No exclusion"
        elif len(cases[i].split("_")) == 2:
            cases[i] = f"Exclusion of {' and '.join(cases[i].split('_'))}"
        elif len(cases[i].split("_")) > 2:
            sens = cases[i].split("_")
            cases[i] = f"Exclusion of {', '.join(sens[:-1])} and {sens[-1]}"
        else:
            cases[i] = f"Exclusion of {cases[i]}"

    fair_index_df = pd.DataFrame(
        {
            "model id": df.index,
            "fair_index": area_list,
            "ranking": ranking,
            "eod": df["Equalized Odds"],
            "colors": colors,
            "shapes": shapes,
            "sizes": sizes,
            "cases": sen_var_exclusion,
            "jitter": jitter_control,
        }
    )

    # Add scatter plots to the subplots
    for k, s in enumerate(shapes_candidates):
        for i in range(num_metrics):
            # index of sen_var_exclusion(shape) == s
            s_idx = [idx for idx, x in enumerate(shapes) if x == s]
            x = df.iloc[s_idx, i].values
            js = fair_index_df.loc[fair_index_df.shapes == s, "jitter"].values
            jittered_x = x + js

            col = fair_index_df.loc[fair_index_df.shapes == s, "colors"]
            size = fair_index_df.loc[fair_index_df.shapes == s, "sizes"]
            fair_index = fair_index_df.loc[fair_index_df.shapes == s, "fair_index"]
            ids = fair_index_df.loc[fair_index_df.shapes == s, "model id"]
            rank_text = fair_index_df.loc[fair_index_df.shapes == s, "ranking"]
            r = (
                fair_index_df.loc[fair_index_df.shapes == s, "ranking"]
                .apply(lambda x: math.log10(x + 1))
                .values
            )
            sen_case = fair_index_df.loc[fair_index_df.shapes == s, "cases"]

            hovertext = [
                f"Fairness index: {f:.3f}. Ranking: {x}. Model id: {i}"
                for f, x, i in zip(fair_index, rank_text, ids)
            ]
            fig.add_trace(
                go.Scatter(
                    x=r,
                    y=jittered_x,
                    customdata=hovertext,
                    mode="markers",
                    marker=dict(
                        color=col,
                        symbol=s,
                        size=size,
                        line=dict(color=col, width=1),
                        opacity=0.8,
                    ),
                    hovertemplate="%{customdata}.",
                    hoverlabel=None,
                    hoverinfo="name+z",
                    name=cases[k],
                ),
                col=i + 1,
                row=1,
            )

            if i == int((df.shape[1] + 0.5) / 2):
                fig.update_xaxes(
                    title_text=None,
                    tickvals=[0, 1, 2, 3],
                    ticktext=[1, 10, 100, 1000],
                    col=i + 1,
                    row=1,
                    tickangle=0,
                )
            else:
                fig.update_xaxes(
                    title_text=None,
                    tickvals=[0, 1, 2, 3],
                    ticktext=[1, 10, 100, 1000],
                    col=i + 1,
                    row=1,
                    tickangle=0,
                )
            fig.update_yaxes(
                title_text=df.columns[i],
                col=i + 1,
                row=1,
                showticksuffix="none",
                titlefont={"size": caption_size},
            )

            fig.add_vline(
                x=0,
                line_width=2,
                line_dash="dot",
                line_color=highlight_color,
                col=i + 1,
                row=1,
            )

            min_metric = df.loc[ranking == 0, df.columns[i]].values[0]
            max_metric = df.loc[ranking == num_models - 1, df.columns[i]].values[0]
            meduim_metric = df.loc[
                ranking == int(num_models / 2), df.columns[i]
            ].values[0]

            # add annotations
            anno_size = caption_size * 0.7
            if k == 0:
                fig.add_hline(
                    y=min_metric,
                    line_width=2,
                    line_dash="dot",
                    line_color=highlight_color,
                    col=i + 1,
                    row=1,
                )

                # position_y = np.mean(df.iloc[:, i])
                min_annotation = {
                    "x": 0,
                    "y": min_metric,
                    "text": f"Model ID {best_id}<br> Rank No.1",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                }
                fig.add_annotation(min_annotation, col=i + 1, row=1)
            if meduim_id in ids:
                medium_annotation = {
                    "x": math.log10(int(num_models / 2) + 1),
                    "y": meduim_metric + jitter_control[id_to_pos[meduim_id]],
                    "text": f"Model ID {meduim_id}<br> Rank No.{int(num_models/2)}",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                }
                fig.add_annotation(medium_annotation, col=i + 1, row=1)
            if worst_id in ids:
                max_annotation = {
                    "x": math.log10(num_models + 1),
                    "y": max_metric + jitter_control[id_to_pos[worst_id]],
                    "text": f"Model ID {worst_id}<br> Rank No.{num_models}",
                    "showarrow": True,
                    "arrowhead": 6,
                    "xanchor": "left",
                    "yanchor": "bottom",
                    "xref": "x",
                    "yref": "y",
                    "font": {"size": anno_size},
                    "ax": -10,
                    "ay": -10,
                    "xshift": 0,
                    "yshift": 0,
                    "align": "left",
                }
                fig.add_annotation(max_annotation, col=i + 1, row=1)

    colorbar_trace = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        hoverinfo="none",
        marker=dict(
            colorscale=[
                rgb01_hex(np.array((243, 244, 245)) / 255),
                "steelblue",
            ],  # "magma",
            showscale=True,
            cmin=0,
            cmax=2,
            colorbar=dict(
                title=None,
                thickness=10,
                tickvals=[0, 2],
                ticktext=["Low", "High"],
                outlinewidth=0,
                orientation="v",
                x=1,
                y=0.5,
            ),
        ),
    )
    fig.add_trace(colorbar_trace)

    fig.update_layout(
        title=title,
        font=dict(family="Arial", size=fig_font_size),
        hovermode="closest",
        width=figsize[0],
        height=figsize[1],
        showlegend=True,
        template="simple_white",
        legend=dict(x=0, y=legend_pos_y, orientation="h"),
    )

    rectangle = {
        "type": "rect",
        "x0": -0.1,
        "y0": subtitle_pos[0],
        "x1": 1.1,
        "y1": subtitle_pos[1],
        "xref": "paper",
        "yref": "paper",
        "fillcolor": "steelblue",
        "opacity": 0.1,
    }  # 'line': {'color': 'red', 'width': 2},
    fig.add_shape(rectangle)
    subtitle_annotation = {
        "x": -0.1,
        "y": subtitle_pos[1],
        "text": f"<i> The FAIM model (i.e., fairness-aware model) is with model ID {best_id}, out of {num_models} nearly-optimal models.</i>",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": caption_size * 1.1},
        "align": "left",
    }
    xaxis_annotation = {
        "x": 0.5,
        "y": xlab_pos_y,
        "text": "Model Rank",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": caption_size},
    }
    colorbar_title = {
        "x": 1.05,
        "y": 0.5,
        "text": "Fairness Ranking Index (FRI)",
        "showarrow": False,
        "xref": "paper",
        "yref": "paper",
        "font": {"size": anno_size * 0.9},
        "textangle": 90,
    }
    fig.add_annotation(subtitle_annotation)
    fig.add_annotation(xaxis_annotation)
    fig.add_annotation(colorbar_title)

    for i, trace in enumerate(fig.data):
        if i % num_metrics == 1:
            trace.update(showlegend=True)
        else:
            trace.update(showlegend=False)
    # fig.show()

    return fig, fair_index_df


def plot_radar(df, thresh_show, title, **kwargs):
    fig = go.Figure()
    # fig = sp.make_subplots(rows=1, cols=2)
    cmap = sns.diverging_palette(200, 20, sep=10, s=50, as_cmap=False, n=df.shape[0])
    theta = df.columns.tolist()
    theta += theta[:1]
    area_list = []

    for i, id in enumerate(df.index):
        values = df.loc[id, :]
        area_list.append(compute_area(values))
        values = values.values.flatten().tolist()
        values += values[:1]
        info = [
            f"{theta[j]}: {v:.3f}" for j, v in enumerate(values) if j != len(values) - 1
        ]
        fig.add_trace(
            go.Scatterpolar(
                r=values,
                theta=theta,
                fill="toself" if id == "FAIReg" else "none",
                text="\n".join(info),
                name=f"{id}",
                line=dict(color=rgb01_hex(cmap[i]), dash="dot"),
            )
        )

    ranking = np.argsort(np.argsort(area_list))
    best_id = df.index[np.where(ranking == 0)][0]
    print(
        f"The best model is No.{best_id} with metrics on validation set:\n {df.loc[best_id, :]}"
    )
    values = df.loc[best_id, :].values.flatten().tolist()
    values += values[:1]
    info = [
        f"{theta[j]}: {v:.3f}" for j, v in enumerate(values) if j != len(values) - 1
    ]
    fig.add_trace(
        go.Scatterpolar(
            r=values,
            theta=theta,
            fill="toself",
            text="\n".join(info),
            name=f"model {best_id}",
            line=dict(color="royalblue", dash="solid"),
        )
    )

    fig.update_layout(
        # title = title,
        font=dict(family="Arial", size=16),
        polar=dict(
            # bgcolor = "#1e2130",
            radialaxis=dict(
                showgrid=True,
                gridwidth=1,
                gridcolor="lightgray",
                visible=True,
                range=[0, thresh_show],
            )
        ),
        legend=dict(x=0.25, y=-0.1, orientation="h"),
        showlegend=False,
        **kwargs,
    )
    return fig


def plot_bar(
    shap_values, feature_names, original_feature_names, coef=None, title=None, **kwargs
):
    """Plot the bar chart of feature importance"""
    if "color" not in kwargs.keys():
        color = "steelblue"
    else:
        color = kwargs["color"]

    def get_prefix(v):
        if "_" in v and (v not in original_feature_names):
            tmp = ["_".join(v.split("_")[:i]) for i in range(len(v.split("_")))]
            return [s for s in tmp if s in original_feature_names][0]
        else:
            return v

    if shap_values is not None:

        grouped_df = pd.DataFrame({"values": shap_values}, index=feature_names).groupby(
            by=get_prefix, axis=0
        )
        df = {k: np.mean(np.abs(g.values)) for k, g in grouped_df}
        df = pd.DataFrame.from_dict(df, orient="index").reset_index()
        df.columns = ["Var", "Value"]

        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    {
                        "color": ["grey" if i < 0 else "steelblue" for i in df.Value],
                        "order": np.abs(df.Value),
                    }
                ),
            ],
            axis=1,
        )

    elif coef is not None:
        df = pd.DataFrame(
            {
                "Var": coef.index,
                "Value": coef.values,
                "color": ["grey" if i < 0 else "steelblue" for i in coef.values],
                "order": np.abs(coef.values),
            }
        )
    else:
        raise ValueError("Either shap_value or coef should be provided")

    df = df.loc[df["Var"] != "const", :]
    df = df.sort_values(by="order", ascending=True)
    df["Var"] = pd.Categorical(df["Var"], categories=df["Var"].tolist(), ordered=True)

    common_theme = theme(
        text=element_text(size=24),
        panel_grid_major_y=element_line(colour="lightgrey"),
        panel_grid_minor=element_blank(),
        panel_background=element_blank(),
        axis_line_x=element_line(colour="black"),
        axis_ticks_major_y=element_blank(),
    )

    x_lab = "Feature importance"

    p = (
        ggplot(data=df, mapping=aes(x="Var", y="Value", fill="color"))
        + geom_hline(yintercept=0, color="grey")
        + geom_bar(stat="identity")
        + common_theme
        + coord_flip()
        + labs(x="", y=x_lab, title=title)
        + theme(legend_position="none")
        + scale_fill_manual(values=[color])
    )
    return p