In [None]:
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np
import pandas as pd

from hate_target import keys, datasets, utils
from mpl_lego.labels import bold_text, apply_subplot_labels
from pyprojroot import here
from scipy.stats import iqr

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
analysis_race = utils.analyze_experiment(
    here('experiments/race_exp01e.pkl'),
    soft=True,
    verbose=True,
    thresholds=[0.5, 0.5, 0.5, 0.5, 0.5, 0.25, 0.25, 0.5])
incidence_race = analysis_race['incidence_rate']
sorted_race = np.flip(np.argsort(incidence_race))
n_race_groups = analysis_race['roc_aucs'].shape[1]
race_labels = bold_text(['Asian', 'Black', 'Latinx', 'Middle Eastern', 'Native American', 'Other', 'Pacific Islander', 'White'])    

In [None]:
analysis_gender = utils.analyze_experiment(
    here('experiments/subgroups/gender_exp01e.pkl'),
    soft=True,
    verbose=True,
    thresholds=[0.5, 0.25, 0.5, 0.5])
incidence_gender = analysis_gender['incidence_rate']
sorted_gender = np.flip(np.argsort(incidence_gender))
n_gender_groups = analysis_gender['roc_aucs'].shape[1]
gender_labels = bold_text(['Men', 'Non-Binary', 'Transgender', 'Women'])

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 7))

plt.subplots_adjust(wspace=0.15, hspace=0.8)

"""
Figure 3a, 3b
"""
width = 0.30
axes[0, 0].bar(
    x=np.arange(n_race_groups) - width,
    height=analysis_race['precision'].mean(axis=0)[sorted_race],
    width=width,
    yerr=np.std(analysis_race['precision'], axis=0)[sorted_race],
    color='C0',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='Precision')
axes[0, 0].bar(
    x=np.arange(n_race_groups),
    height=analysis_race['recall'].mean(axis=0)[sorted_race],
    width=width,
    yerr=np.std(analysis_race['recall'], axis=0)[sorted_race],
    color='C1',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='Recall')
axes[0, 0].bar(
    x=np.arange(n_race_groups) + width,
    height=analysis_race['f1_scores'].mean(axis=0)[sorted_race],
    width=width,
    yerr=np.std(analysis_race['f1_scores'], axis=0)[sorted_race],
    color='C2',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='F1 Score')

axes[0, 0].set_ylim([0, 1])
axes[0, 0].grid(axis='y')
axes[0, 0].set_axisbelow(True)

width = 0.40
axes[0, 1].bar(
    x=np.arange(n_race_groups) - width / 2,
    height=analysis_race['roc_aucs'].mean(axis=0)[sorted_race],
    width=width,
    yerr=np.std(analysis_race['roc_aucs'], axis=0)[sorted_race],
    color='C4',
    edgecolor='black',
    error_kw={'capsize': 3},
    label='ROC AUC')
axes[0, 1].bar(
    x=np.arange(n_race_groups) + width / 2,
    height=analysis_race['pr_aucs'].mean(axis=0)[sorted_race],
    width=width,
    yerr=np.std(analysis_race['pr_aucs'], axis=0)[sorted_race],
    color='lightgrey',
    edgecolor='black',
    error_kw={'capsize': 3},
    label='PR AUC')

for idx, rate in enumerate(analysis_race['incidence_rate'][sorted_race]):
    axes[0, 1].plot([idx + width, idx], [rate, rate], color='black', lw=2.5)

"""
Figure 3c, 3d
"""
width = 0.30
axes[1, 0].bar(
    x=np.arange(n_gender_groups) - width,
    height=analysis_gender['precision'].mean(axis=0)[sorted_gender],
    width=width,
    yerr=np.std(analysis_gender['precision'], axis=0)[sorted_gender],
    color='C0',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='Precision')
axes[1, 0].bar(
    x=np.arange(n_gender_groups),
    height=analysis_gender['recall'].mean(axis=0)[sorted_gender],
    width=width,
    yerr=np.std(analysis_gender['recall'], axis=0)[sorted_gender],
    color='C1',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='Recall')
axes[1, 0].bar(
    x=np.arange(n_gender_groups) + width,
    height=analysis_gender['f1_scores'].mean(axis=0)[sorted_gender],
    width=width,
    yerr=np.std(analysis_gender['f1_scores'], axis=0)[sorted_gender],
    color='C2',
    edgecolor='black',
    error_kw={'capsize': 2},
    label='F1 Score')

axes[1, 0].set_ylim([0, 1])
axes[1, 0].grid(axis='y')
axes[1, 0].set_axisbelow(True)

width = 0.40
axes[1, 1].bar(
    x=np.arange(n_gender_groups) - width / 2,
    height=analysis_gender['roc_aucs'].mean(axis=0)[sorted_gender],
    width=width,
    yerr=np.std(analysis_gender['roc_aucs'], axis=0)[sorted_gender],
    color='C4',
    edgecolor='black',
    error_kw={'capsize': 3},
    label='ROC AUC')
axes[1, 1].bar(
    x=np.arange(n_gender_groups) + width / 2,
    height=analysis_gender['pr_aucs'].mean(axis=0)[sorted_gender],
    width=width,
    yerr=np.std(analysis_gender['pr_aucs'], axis=0)[sorted_gender],
    color='lightgrey',
    edgecolor='black',
    error_kw={'capsize': 3},
    label='PR AUC')

for idx, rate in enumerate(analysis_gender['incidence_rate'][sorted_gender]):
    axes[1, 1].plot([idx + width, idx], [rate, rate], color='black', lw=2.5)

for ax in axes[0]:
    ax.set_xticks(np.arange(n_race_groups))
    ax.set_xticklabels(
        bold_text(np.array(race_labels)[sorted_race]),
        ha='right',
        rotation=20,
        fontsize=13)
    ax.set_xlabel(bold_text('Race Sub-Groups'), fontsize=16)
    
for ax in axes[1]:
    ax.set_xticks(np.arange(n_gender_groups))
    ax.set_xticklabels(
        bold_text(np.array(gender_labels)[sorted_gender]),
        ha='right',
        rotation=20,
        fontsize=13)
    ax.set_xlabel(bold_text('Gender Sub-Groups'), fontsize=16)




axes[0, 0].legend(bbox_to_anchor=(0.5, 1.08), loc='center', ncol=3, prop={'size': 12})
axes[0, 1].legend(bbox_to_anchor=(0.5, 1.08), loc='center', ncol=2, prop={'size': 12})
axes[1, 0].legend(bbox_to_anchor=(0.5, 1.08), loc='center', ncol=3, prop={'size': 12})
axes[1, 1].legend(bbox_to_anchor=(0.5, 1.08), loc='center', ncol=2, prop={'size': 12})

for ax in axes.ravel():
    ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    ax.set_ylim([0, 1.03])
    ax.grid(axis='y')
    ax.set_axisbelow(True)
    
for ax in axes[:, 0]:
    ax.grid(axis='y')
    ax.set_axisbelow(True)
    ax.set_ylabel(bold_text('Metric'), fontsize=15)
    
for ax in axes.ravel():
    ax.tick_params(labelsize=14)

apply_subplot_labels(axes, bold=True, x=-0.04, y=1.09, size=20)

plt.savefig('figure3.pdf', bbox_inches='tight')