In [1]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
def TPR_gap(df, occupation, causal=False):
    if causal:
        temp = df.loc[df["label"] == occupation]
        female = temp.loc[temp.gender == "F"].copy()
        male = temp.loc[temp.gender == "M"].copy()
        female["factual"] = (female.title_scrubbed_pred == female["label"]).astype(int)
        female["counterfactual"] = (female.title_scrubbed_gender_swapped_pred == female["label"]).astype(int)
        male["factual"] = (male.title_scrubbed_pred == male["label"]).astype(int)
        male["counterfactual"] = (male.title_scrubbed_gender_swapped_pred == male["label"]).astype(int)
        female["gap"] =  female.factual - female.counterfactual
        male["gap"] = male.counterfactual - male.factual
        return (female["gap"].sum() + male["gap"].sum())/(female.shape[0]+male.shape[0])
    else:
        female_tpr = df[(df.title_scrubbed_pred == occupation) & (df["label"] == occupation) & (df.gender == "F")].shape[0]
        female_tpr /= df[(df.gender == "F") & (df["label"] == occupation)].shape[0]
        male_tpr = df[(df.title_scrubbed_pred == occupation) & (df["label"] == occupation) & (df.gender == "M")].shape[0]
        male_tpr /= df[(df.gender == "M") & (df["label"] == occupation)].shape[0]
        return female_tpr - male_tpr

def get_tpr_results(preds):
    df = preds

    occupation_list = preds["label"].unique().tolist()
    statistical_parity = []

    for occupation in occupation_list:
        statistical_parity.append(TPR_gap(df, occupation))

    causal_parity = []

    for occupation in occupation_list:
        causal_parity.append(TPR_gap(df, occupation, causal=True))

    df = pd.DataFrame({"occupation": occupation_list, "statistical_parity": statistical_parity, "causal_parity": causal_parity})
    return df.sort_values("occupation")

def get_avg_tpr(root_dir, runs=3):
    temp = None
    for i in range(runs):
        if i == 0:
            run_dir = root_dir
        else:
            run_dir = f"{root_dir}-{i+1}"

        if os.path.exists(os.path.join(run_dir, "test_tpr.csv")):
            tpr_results = pd.read_csv(os.path.join(run_dir, "test_tpr.csv"))
        else:
            preds = pd.read_csv(os.path.join(run_dir, "test_preds.csv"))
            tpr_results = get_tpr_results(preds)
            tpr_results.to_csv(os.path.join(run_dir, "test_tpr.csv"), index=False)
        if temp is None:
            temp = tpr_results
        else:
            temp = temp.set_index('occupation').add(tpr_results.set_index('occupation'), fill_value=0).reset_index()

    temp.statistical_parity = temp.statistical_parity / runs
    temp.causal_parity = temp.causal_parity / runs
    return temp

In [3]:
def FPR_gap(df, occupation, causal=False):
    if causal:
        temp = df.loc[df["label"] != occupation]
        female = temp.loc[temp.gender == "F"].copy()
        male = temp.loc[temp.gender == "M"].copy()
        female["factual"] = (female.title_scrubbed_pred == female["label"]).astype(int)
        female["counterfactual"] = (female.title_scrubbed_gender_swapped_pred == female["label"]).astype(int)
        male["factual"] = (male.title_scrubbed_pred == male["label"]).astype(int)
        male["counterfactual"] = (male.title_scrubbed_gender_swapped_pred == male["label"]).astype(int)
        female["gap"] =  female.factual - female.counterfactual
        male["gap"] = male.counterfactual - male.factual
        return (female["gap"].sum() + male["gap"].sum())/(female.shape[0]+male.shape[0])
    else:
        female_fpr = df[(df.title_scrubbed_pred == occupation) & (df["label"] != occupation) & (df.gender == "F")].shape[0]
        female_fpr /= df[(df.gender == "F") & (df["label"] != occupation)].shape[0]
        male_fpr = df[(df.title_scrubbed_pred == occupation) & (df["label"] != occupation) & (df.gender == "M")].shape[0]
        male_fpr /= df[(df.gender == "M") & (df["label"] != occupation)].shape[0]
        return female_fpr - male_fpr

def get_fpr_results(preds):
    df = preds

    occupation_list = preds["label"].unique().tolist()
    statistical_parity = []

    for occupation in occupation_list:
        statistical_parity.append(FPR_gap(df, occupation))

    causal_parity = []

    for occupation in occupation_list:
        causal_parity.append(FPR_gap(df, occupation, causal=True))

    df = pd.DataFrame({"occupation": occupation_list, "statistical_parity": statistical_parity, "causal_parity": causal_parity})
    return df.sort_values("occupation")

def get_avg_fpr(root_dir, runs=3):
    temp = None
    for i in range(runs):
        if i == 0:
            run_dir = root_dir
        else:
            run_dir = f"{root_dir}-{i+1}"

        if os.path.exists(os.path.join(run_dir, "test_fpr.csv")):
            fpr_results = pd.read_csv(os.path.join(run_dir, "test_fpr.csv"))
        else:
            preds = pd.read_csv(os.path.join(run_dir, "test_preds.csv"))
            fpr_results = get_fpr_results(preds)
            fpr_results.to_csv(os.path.join(run_dir, "test_fpr.csv"), index=False)
        if temp is None:
            temp = fpr_results
        else:
            temp = temp.set_index('occupation').add(fpr_results.set_index('occupation'), fill_value=0).reset_index()

    temp.statistical_parity = temp.statistical_parity / runs
    temp.causal_parity = temp.causal_parity / runs
    return temp

In [4]:
def get_avg_acc(root_dir, runs=3):
    all_acc = 0
    for i in range(runs):
        if i == 0:
            run_dir = root_dir
        else:
            run_dir = f"{root_dir}-{i+1}"

        preds = pd.read_csv(os.path.join(run_dir, "test_preds.csv"))
        acc = preds[preds.title_scrubbed_pred == preds["label"]].shape[0]/preds.shape[0]
        all_acc += acc

    return all_acc / runs

In [5]:
all_paths = {
    "Baseline": "models/albert-large-biasbios", 
    "Zari": "models/zari-albert-biasbios", 
    "CDA": "models/albert-large-biasbios-cda", "Zari w/ CDA": "models/zari-albert-biasbios-cda", "Reweight": "models/albert-large-biasbios-reweight",
    "Oversampling": "models/albert-large-biasbios-oversampling", "Undersampling": "models/albert-large-biasbios-subsampling",
    "Oversampling w/ CDA": "models/albert-large-biasbios-oversampling-cda",
    "Undersampling w/ CDA": "models/albert-large-biasbios-subsampling-cda", "Reweight w/ CDA": "models/albert-large-biasbios-reweight-cda"
}

In [10]:
tpr_parity_df = None
for method, path in all_paths.items():
    avg_tpr = get_avg_tpr(path)
    avg_tpr["Method"] = method
    if tpr_parity_df is None:
        tpr_parity_df = avg_tpr
    else:
        tpr_parity_df = pd.concat([tpr_parity_df, avg_tpr])

fpr_parity_df = None
for method, path in all_paths.items():
    avg_fpr = get_avg_fpr(path)
    avg_fpr["Method"] = method
    if fpr_parity_df is None:
        fpr_parity_df = avg_fpr
    else:
        fpr_parity_df = pd.concat([fpr_parity_df, avg_fpr])

In [8]:
debias_type_map = {
    "Baseline": "Baseline", "CDA": "Causal", "Zari": "Causal", "Zari w/ CDA": "Causal",
    "Oversampling": "Statistical", "Undersampling": "Statistical", "Reweight": "Statistical"
}
tpr_parity_df["Debias Type"] = tpr_parity_df["Method"].map(debias_type_map)
fpr_parity_df["Debias Type"] = fpr_parity_df["Method"].map(debias_type_map)

In [13]:
methods = []
all_acc = []

for method, path in all_paths.items():
    avg_acc = get_avg_acc(path)
    methods.append(method)
    all_acc.append(avg_acc)

acc_df = pd.DataFrame({"Method": methods, "Acc": all_acc})
acc_df

Unnamed: 0,Method,Acc
0,Baseline,0.954928
1,Zari,0.952318
2,CDA,0.954684
3,Zari w/ CDA,0.951969
4,Reweight,0.952552
5,Oversampling,0.95502
6,Undersampling,0.947928
7,Oversampling w/ CDA,0.953938
8,Undersampling w/ CDA,0.947267
9,Reweight w/ CDA,0.954342


## TPR Gap

In [14]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.03
)

for t in px.box(tpr_parity_df, x="statistical_parity", y="Method", color_discrete_sequence=px.colors.qualitative.D3, hover_data=["occupation"]).data:
    fig.add_trace(t, row=1, col=1)
for t in px.box(tpr_parity_df, x="causal_parity", y="Method", color_discrete_sequence=px.colors.qualitative.D3, hover_data=["occupation"]).data:
    t["showlegend"] = False
    fig.add_trace(t, row=1, col=2)

fig.update_layout(
    autosize=False,
    width=1000,
    height=535,
    plot_bgcolor='white',
    font=dict(size=14),
    legend_title="Debias Type",
    margin=dict(l=15, r=15, t=20, b=20),
    boxgap=0.5,
    yaxis = dict(
        tickmode = 'array',
        tickvals = ["Baseline", "Oversampling", "Undersampling", "Reweight", "CDA", "Zari", "Zari w/ CDA", "Oversampling w/ CDA", "Undersampling w/ CDA", "Reweight w/ CDA"],
        ticktext = [
            f"<b>Normal</b><br>Acc={acc_df[acc_df.Method == 'Baseline'].Acc.item()*100: .2f}%<br>", 
            f"<b>Statistical</b><br>Oversampling<br>Acc={acc_df[acc_df.Method == 'Oversampling'].Acc.item()*100: .2f}%", 
            f"Undersampling<br>Acc={acc_df[acc_df.Method == 'Undersampling'].Acc.item()*100: .2f}%",
            f"Reweighting<br>Acc={acc_df[acc_df.Method == 'Reweight'].Acc.item()*100: .2f}%<br>",
            f"<b>Causal</b><br>CDA<br>Acc={acc_df[acc_df.Method == 'CDA'].Acc.item()*100: .2f}%<br>", 
            f"Zari<br>Acc={acc_df[acc_df.Method == 'Zari'].Acc.item()*100: .2f}%<br>",
            f"Zari w/ CDA<br>Acc={acc_df[acc_df.Method == 'Zari w/ CDA'].Acc.item()*100: .2f}%<br>",
            f"<b>Combination</b><br>OS-CDA<br>Acc={acc_df[acc_df.Method == 'Oversampling w/ CDA'].Acc.item()*100: .2f}%",
            f"US-CDA<br>Acc={acc_df[acc_df.Method == 'Undersampling w/ CDA'].Acc.item()*100: .2f}%",
            f"RW-CDA<br>Acc={acc_df[acc_df.Method == 'Reweight w/ CDA'].Acc.item()*100: .2f}%",
        ]
    ),
    xaxis2 = dict(range=[-0.052, 0.105], tickvals=[-0.05, 0, 0.05, 0.1])
)
fig.update_xaxes(title_text='Statistical TPR Gap',row=1, col=1)
fig.update_xaxes(title_text='Causal TPR Gap',row=1, col=2)
fig.update_traces(boxpoints="all", jitter=0.3, pointpos=-2)
fig.update_xaxes(
    mirror=True,
    showgrid=True,
    gridcolor='darkgrey',
    zeroline = True,
    zerolinecolor='darkgrey',
)
fig.update_yaxes(
    mirror=True,
    showgrid=True,
    gridcolor='lightgrey',
    categoryorder='array',
    categoryarray= ["Reweight w/ CDA", "Undersampling w/ CDA", "Oversampling w/ CDA", "Zari w/ CDA", "Zari", "CDA", "Reweight", "Undersampling", "Oversampling", 'Baseline']
)

baseline_x = tpr_parity_df[(tpr_parity_df.occupation == "paralegal")&(tpr_parity_df.Method == "Baseline")].statistical_parity.item()
fig.add_annotation(text="paralegal", x=baseline_x, yref="paper", y=0.97, showarrow=False)

offset = 0.0993
methods = ["Baseline", "Oversampling", "Undersampling", "Reweight", "CDA", "Zari", "Zari w/ CDA", "Oversampling w/ CDA", "Undersampling w/ CDA", "Reweight w/ CDA"]
for i, method in enumerate(methods):
    x_pos = tpr_parity_df[(tpr_parity_df.occupation == "paralegal")&(tpr_parity_df.Method == method)].statistical_parity.item()
    fig.add_shape(type="circle", xref="x1", x0=x_pos-0.0023, x1=x_pos+0.0023, yref="paper", y0=0.91-(offset*i), y1=0.92-(offset*i), line_color="red", fillcolor="red")

baseline_x = tpr_parity_df[(tpr_parity_df.occupation == "interior_designer")&(tpr_parity_df.Method == "Baseline")].causal_parity.item()
fig.add_annotation(text="interior_designer", xref="x2", x=baseline_x-0.005, yref="paper", y=0.97, showarrow=False)
offset = 0.0993
for i, method in enumerate(methods):
    x_pos = tpr_parity_df[(tpr_parity_df.occupation == "interior_designer")&(tpr_parity_df.Method == method)].causal_parity.item()
    fig.add_shape(type="circle", xref="x2", x0=x_pos-0.0008, x1=x_pos+0.0008, yref="paper", y0=0.91-(offset*i), y1=0.92-(offset*i), line_color="red", fillcolor="red")
fig.show()

## FPR Gap

In [17]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.03
)

for t in px.box(fpr_parity_df, x="statistical_parity", y="Method", color_discrete_sequence=px.colors.qualitative.D3, hover_data=["occupation"]).data:
    fig.add_trace(t, row=1, col=1)
for t in px.box(fpr_parity_df, x="causal_parity", y="Method", color_discrete_sequence=px.colors.qualitative.D3, hover_data=["occupation"]).data:
    t["showlegend"] = False
    fig.add_trace(t, row=1, col=2)

fig.update_layout(
    autosize=False,
    width=1000,
    height=535,
    plot_bgcolor='white',
    font=dict(size=14),
    legend_title="Debias Type",
    margin=dict(l=15, r=15, t=20, b=20),
    boxgap=0.5,
    yaxis = dict(
        tickmode = 'array',
        tickvals = ["Baseline", "Oversampling", "Undersampling", "Reweight", "CDA", "Zari", "Zari w/ CDA", "Oversampling w/ CDA", "Undersampling w/ CDA", "Reweight w/ CDA"],
        ticktext = [
            f"<b>Normal</b><br>Acc={acc_df[acc_df.Method == 'Baseline'].Acc.item()*100: .2f}%<br>", 
            f"<b>Statistical</b><br>Oversampling<br>Acc={acc_df[acc_df.Method == 'Oversampling'].Acc.item()*100: .2f}%", 
            f"Undersampling<br>Acc={acc_df[acc_df.Method == 'Undersampling'].Acc.item()*100: .2f}%",
            f"Reweighting<br>Acc={acc_df[acc_df.Method == 'Reweight'].Acc.item()*100: .2f}%<br>",
            f"<b>Causal</b><br>CDA<br>Acc={acc_df[acc_df.Method == 'CDA'].Acc.item()*100: .2f}%<br>", 
            f"Zari<br>Acc={acc_df[acc_df.Method == 'Zari'].Acc.item()*100: .2f}%<br>",
            f"Zari w/ CDA<br>Acc={acc_df[acc_df.Method == 'Zari w/ CDA'].Acc.item()*100: .2f}%<br>",
            f"<b>Combination</b><br>OS-CDA<br>Acc={acc_df[acc_df.Method == 'Oversampling w/ CDA'].Acc.item()*100: .2f}%",
            f"US-CDA<br>Acc={acc_df[acc_df.Method == 'Undersampling w/ CDA'].Acc.item()*100: .2f}%",
            f"RW-CDA<br>Acc={acc_df[acc_df.Method == 'Reweight w/ CDA'].Acc.item()*100: .2f}%",
        ]
    ),
    # xaxis2 = dict(range=[-0.052, 0.105], tickvals=[-0.05, 0, 0.05, 0.1])
)
fig.update_xaxes(title_text='Statistical FPR Gap',row=1, col=1)
fig.update_xaxes(title_text='Causal FPR Gap',row=1, col=2)
fig.update_traces(boxpoints="all", jitter=0.3, pointpos=-2)
fig.update_xaxes(
    mirror=True,
    showgrid=True,
    gridcolor='darkgrey',
    zeroline = True,
    zerolinecolor='darkgrey',
)
fig.update_yaxes(
    mirror=True,
    showgrid=True,
    gridcolor='lightgrey',
    categoryorder='array',
    categoryarray= ["Reweight w/ CDA", "Undersampling w/ CDA", "Oversampling w/ CDA", "Zari w/ CDA", "Zari", "CDA", "Reweight", "Undersampling", "Oversampling", 'Baseline']
)

baseline_x = fpr_parity_df[(fpr_parity_df.occupation == "physician")&(fpr_parity_df.Method == "Baseline")].statistical_parity.item()
fig.add_annotation(text="physician", x=baseline_x+0.0001, yref="paper", y=0.96, showarrow=False)

offset = 0.0993
methods = ["Baseline", "Oversampling", "Undersampling", "Reweight", "CDA", "Zari", "Zari w/ CDA", "Oversampling w/ CDA", "Undersampling w/ CDA", "Reweight w/ CDA"]
for i, method in enumerate(methods):
    x_pos = fpr_parity_df[(fpr_parity_df.occupation == "physician")&(fpr_parity_df.Method == method)].statistical_parity.item()
    fig.add_shape(type="circle", xref="x1", x0=x_pos-0.000045, x1=x_pos+0.000045, yref="paper", y0=0.91-(offset*i), y1=0.921-(offset*i), line_color="red", fillcolor="red")

baseline_x = fpr_parity_df[(fpr_parity_df.occupation == "architect")&(fpr_parity_df.Method == "Baseline")].causal_parity.item()
fig.add_annotation(text="architect", xref="x2", x=baseline_x, yref="paper", y=0.96, showarrow=False)
offset = 0.0993
for i, method in enumerate(methods):
    x_pos = fpr_parity_df[(fpr_parity_df.occupation == "architect")&(fpr_parity_df.Method == method)].causal_parity.item()
    fig.add_shape(type="circle", xref="x2", x0=x_pos-0.00002, x1=x_pos+0.00002, yref="paper", y0=0.91-(offset*i), y1=0.921-(offset*i), line_color="red", fillcolor="red")
fig.show()

## Baseline Model (Statistical vs Causal)

In [19]:
def load_preds(root_dir, exclude_idxs=None):
    preds = pd.read_csv(os.path.join(root_dir, "test_preds.csv"))
    if exclude_idxs:
        preds = preds[preds.idx.isin(exclude_idxs)]
    return preds

def get_results(preds, performance_metric="TPR"):
    df = preds

    occupation_list = preds["label"].unique().tolist()
    statistical_gap = []

    for occupation in occupation_list:
        if performance_metric == "TPR":
            statistical_gap.append(TPR_gap(df, occupation))
        else:
            statistical_gap.append(FPR_gap(df, occupation))

    causal_gap = []

    for occupation in occupation_list:
        if performance_metric == "TPR":
            causal_gap.append(TPR_gap(df, occupation, causal=True))
        else:
            causal_gap.append(FPR_gap(df, occupation, causal=True))

    df = pd.DataFrame({"occupation": occupation_list, "statistical_gap": statistical_gap, "causal_gap": causal_gap})
    return df.sort_values("occupation")

In [19]:
preds = load_preds(all_paths["Baseline"])
tpr_results = get_results(preds)

fig = px.scatter(tpr_results, x="statistical_parity", y="causal_parity", hover_data=["occupation"])
fig.update_layout(
    autosize=False,
    width=525,
    height=375,
    plot_bgcolor='white',
    font=dict(size=15),
    xaxis_title=r"$\textrm{Statistical TPR Gap } (\mathcal{SG}^{\mathsf{TPR}})$",
    yaxis_title=r"$\textrm{Causal TPR Gap } (\mathcal{CG}^{\mathsf{TPR}})$",
    margin=dict(l=15, r=15, t=20, b=20),
)
fig.update_xaxes(
    mirror=True,
    showgrid=True,
    zeroline = True,
    zerolinecolor='lightgrey',
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    showgrid=True,
    zeroline = True,
    zerolinecolor='lightgrey',
    gridcolor='lightgrey',
)
fig.add_shape(type="rect", x0=0, y0=0, x1=0.3, y1=-0.05, fillcolor="grey", opacity=0.2)
fig.add_shape(type="rect", x0=0, y0=0, x1=-0.15, y1=0.15, fillcolor="grey", opacity=0.2)
fig.update_traces(marker={'size': 8.5, "line": {"width": 1.5, "color": "DarkSlateGrey"}})
sampled_occupations = [
    "pastor", "poet", "paralegal", "interior_designer", 
    "model", "dj", "rapper", "comedian", "dietitian", 
    "personal_trainer", "yoga_teacher"]

for occupation in sampled_occupations:
    if occupation  == "poet":
        xshift = 23
        yshift = -3
    elif occupation == "pastor":
        xshift = 25
        yshift = -8
    elif occupation == "comedian":
        yshift = -12
        xshift = -5
    elif occupation == "personal_trainer":
        xshift = 48
        yshift = -8
    else:
        yshift = 15
        xshift = 0

    fig.add_annotation(
            x=tpr_results[tpr_results.occupation == occupation].statistical_parity.item(), 
            y=tpr_results[tpr_results.occupation == occupation].causal_parity.item(),
            text=occupation,
            showarrow=False,
            yshift=yshift,
            xshift=xshift
    )
fig.add_trace(go.Scatter(x=[-0.05, 0.15], y=[-0.05, 0.15], mode="lines", line=dict(dash="dash", color="red"), showlegend=False))
fig.show()