In [None]:
import os
import sys

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
from sklearn.metrics import classification_report

from jax import vmap
import jax.numpy as jnp
from jax.nn import sigmoid

CWD = os.path.abspath("")
sys.path.append(CWD)
sys.path.append(os.path.join(CWD, ".."))
from plt_settings import plt_settings

full_width = 5.5
ratio = 1 / 1.618

In [None]:
# Load data
from data.adult import get_data

(
    X,
    y,
    var_names,
    gender_idx,
    male_idx,
    female_idx,
    X_test,
    y_test,
    test_male_idx,
    test_female_idx,
) = get_data()

X_test_cf = X_test.at[:, gender_idx].set(1 - X_test[:, gender_idx])


def ineqconst(beta):
    return jnp.array(
        [
            100 * (sigmoid(jnp.dot(X, beta)).mean() - sigmoid(jnp.dot(X[male_idx, :], beta)).mean())
            - 1,
            100
            * (sigmoid(jnp.dot(X, beta)).mean() - sigmoid(jnp.dot(X[female_idx, :], beta)).mean())
            - 1,
        ]
    )


def report_acc_overall(beta):
    y_pred = jnp.where(jnp.dot(X_test, beta) < 0.0, 0.0, 1.0).astype(jnp.int32)
    print(classification_report(y_test, y_pred))
    return (y_pred == y_test).mean()


def prevalence(X, beta, male_idx, female_idx):
    y_pred = jnp.where(jnp.dot(X, beta) < 0.0, 0.0, 1.0)

    return jnp.array(
        [
            y_pred.mean(),
            y_pred[male_idx].mean(),
            y_pred[female_idx].mean(),
        ]
    )


def disparity(X, beta, male_idx, female_idx):
    prev = prevalence(X, beta, male_idx, female_idx)

    return jnp.array(
        [
            prev[1] - prev[0],
            prev[2] - prev[0],
        ]
    )


def prevalence_test(beta):
    return prevalence(X_test, beta, test_male_idx, test_female_idx)


def prevalence_train(beta):
    return prevalence(X, beta, male_idx, female_idx)


def disparity_test(beta):
    return disparity(X_test, beta, test_male_idx, test_female_idx)


def disparity_train(beta):
    return disparity(X, beta, male_idx, female_idx)

In [None]:
ITERATIONS = int(2e4)
BURN_IN = int(1e4)

sampling_data = np.load(os.path.join(CWD, "average_vs_N.npz"))

pdlmc_samples = {}
pdlmc_samples[1] = sampling_data["pdlmc_x_1"][BURN_IN:, :]
pdlmc_samples[10] = sampling_data["pdlmc_x_10"][BURN_IN:, :]
pdlmc_samples[100] = sampling_data["pdlmc_x_100"][BURN_IN:, :]
pdlmc_samples[1000] = sampling_data["pdlmc_x_1000"][BURN_IN:, :]

pdlmc_lambda = {}
pdlmc_lambda[1] = sampling_data["pdlmc_lambda_1"]
pdlmc_lambda[10] = sampling_data["pdlmc_lambda_10"]
pdlmc_lambda[100] = sampling_data["pdlmc_lambda_100"]
pdlmc_lambda[1000] = sampling_data["pdlmc_lambda_1000"]

pdlmc_time = {}
pdlmc_time[1] = sampling_data["pdlmc_t_1"]
pdlmc_time[10] = sampling_data["pdlmc_t_10"]
pdlmc_time[100] = sampling_data["pdlmc_t_100"]
pdlmc_time[1000] = sampling_data["pdlmc_t_1000"]

slacks = {}
slacks[1] = vmap(ineqconst)(sampling_data["pdlmc_x_1"])
slacks[10] = vmap(ineqconst)(sampling_data["pdlmc_x_10"])
slacks[100] = vmap(ineqconst)(sampling_data["pdlmc_x_100"])
slacks[1000] = vmap(ineqconst)(sampling_data["pdlmc_x_1000"])

pdlmc_prev = {}
pdlmc_disp = {}
pdlmc_disp_train = {}
for N, samples in pdlmc_samples.items():
    pdlmc_prev[N] = vmap(prevalence_test)(samples)
    pdlmc_disp[N] = vmap(disparity_test)(samples)
    pdlmc_disp_train[N] = vmap(disparity_train)(samples)

linestyle = {1: "solid", 10: "dashed", 100: "dotted", 1000: "dashdot"}

In [None]:
pdlmc_acc = {}
for N, samples in pdlmc_samples.items():
    print("")
    print(f"LMC iterations = {N} ({pdlmc_time[N]} s)")

    pdlmc_acc[N] = report_acc_overall(samples.mean(axis=0))

    print("PDLMC disparity [male, female] (train):", pdlmc_disp_train[N].mean(axis=0))
    print("PDLMC disparity [male, female] (test):", pdlmc_disp[N].mean(axis=0))

In [None]:
plt_settings["figure.figsize"] = (full_width / 2, ratio * full_width / 2)

with plt.rc_context(plt_settings):
    _, axs = plt.subplots(1, 1, dpi=300)
    for N, lmbda in pdlmc_lambda.items():
        axs.plot(lmbda[:, 0], c="C0", linestyle=linestyle[N])
        axs.plot(lmbda[:, 1], c="C3", linestyle=linestyle[N])
    axs.grid()
    axs.set_xlabel("Iterations")
    axs.set_ylabel(r"Dual variable ($\lambda$)")
    axs.legend(
        [
            Rectangle((0, 0), 1, 1, color=f"C3", alpha=1.0),
            Rectangle((0, 0), 1, 1, color=f"C0", alpha=1.0),
        ],
        ["Female", "Male"],
        handlelength=0.7,
    )

    plt.show()

In [None]:
plt_settings["figure.figsize"] = (full_width, ratio * full_width / 2)

with plt.rc_context(plt_settings):
    _, axs = plt.subplots(1, 2, dpi=300)

    for idx, (N, slack) in enumerate(slacks.items()):
        cum_mean = np.cumsum(slack, axis=0) / np.expand_dims(
            np.arange(1, slack.shape[0] + 1), axis=1
        )
        axs[0].plot(cum_mean[:, 0], linestyle="--", color=f"C{idx}")
        axs[1].plot(cum_mean[:, 1], linestyle="--", color=f"C{idx}")

    for ax in axs:
        ax.set_xlim((-1e2, 8e2))
        ax.grid()
        ax.set_xlabel("Iterations")
        ax.set_ylabel(r"Constraints slacks")

    plt.show()

In [None]:
def prob_test(beta):
    return jnp.where(jnp.dot(X_test, beta) < 0, 0.0, 1.0)


pdlmc_p_d = {}
for N, samples in pdlmc_samples.items():
    pdlmc_p_d[N] = vmap(prob_test)(samples)

In [None]:
plt_settings["figure.figsize"] = (full_width / 2, ratio * full_width / 2)
width = 0.25


def violin_prev(ax, y, samples_f, samples_m):
    parts = ax.violinplot(
        dataset=samples_f,
        positions=[y],
        widths=0.8,
        showextrema=False,
        vert=False,
        side="high",
    )
    for pc in parts["bodies"]:
        pc.set_facecolor("C3")
        pc.set_edgecolor("C3")
        pc.set_alpha(0.8)

    parts = ax.violinplot(
        dataset=samples_m,
        positions=[y],
        widths=0.8,
        showextrema=False,
        vert=False,
        side="low",
    )
    for pc in parts["bodies"]:
        pc.set_facecolor("C0")
        pc.set_edgecolor("C0")
        pc.set_alpha(0.8)


with plt.rc_context(plt_settings):
    fig, ax = plt.subplots(1, 1, dpi=300)

    for idx, (N, samples) in enumerate(pdlmc_p_d.items()):
        violin_prev(
            ax,
            -idx,
            samples[:, test_female_idx].mean(axis=1),
            samples[:, test_male_idx].mean(axis=1),
        )
        ax.vlines(samples.flatten().mean(), -idx - 1.1 * width, -idx + 1.1 * width, color="black")

    ax.set_xlabel(r"Prevalence of $>$ \$50k (\%)")
    ax.set_xticks(np.arange(0, 0.4, 0.1))
    ax.set_xticklabels(np.arange(0, 40, 10))
    ax.set_xlim(0.01, 0.35)
    ax.set_yticks(np.arange(0, -4, -1))
    ax.set_yticklabels([f"{N}\n({pdlmc_time[N]:.0f} s)" for N in pdlmc_p_d.keys()])
    ax.set_ylabel(r"LMC iterations ($N_b$)")
    ax.grid()
    ax.legend(
        [
            Rectangle((0, 0), 1, 1, color=f"C3", alpha=1.0),
            Rectangle((0, 0), 1, 1, color=f"C0", alpha=1.0),
        ],
        ["Female", "Male"],
        handlelength=0.7,
        loc="lower right",
        bbox_to_anchor=(1.01, 0.54),
    )

    plt.savefig("../../paper/figures/fairness_average_vs_N_prevalence.pdf")
    plt.show()