In [None]:
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np

from hate_target import keys, utils
from mpl_lego.labels import bold_text, apply_subplot_labels
from pyprojroot import here

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
results_path = here('experiments/figure1_results.pkl')

In [None]:
# Analyze experiment results
analysis = utils.analyze_experiment(results_path, soft=True, verbose=True)

In [None]:
# Determine number of identity groups
n_groups = analysis['roc_aucs'].shape[1]
# Calculate incidence rates
incidence_rates = analysis['incidence_rate']
# Sort incidence rates by magnitude for figure
sorted_idx = np.flip(np.argsort(incidence_rates))
# Generate labels for plot
labels = bold_text(sorted(keys.target_labels))

In [None]:
# Calculate summary metrics for Figure 1
precision_mean = analysis['precision'].mean(axis=0)[sorted_idx]
precision_std = np.std(analysis['precision'], axis=0)[sorted_idx]
recall_mean = analysis['recall'].mean(axis=0)[sorted_idx]
recall_std = np.std(analysis['recall'], axis=0)[sorted_idx]
f1_mean = analysis['f1_scores'].mean(axis=0)[sorted_idx]
f1_std = np.std(analysis['f1_scores'], axis=0)[sorted_idx]
roc_auc_mean = analysis['roc_aucs'].mean(axis=0)[sorted_idx]
roc_auc_std = np.std(analysis['roc_aucs'], axis=0)[sorted_idx]
pr_auc_mean = analysis['pr_aucs'].mean(axis=0)[sorted_idx]
pr_auc_std = np.std(analysis['pr_aucs'], axis=0)[sorted_idx]

In [None]:
"""
Figure 1:
Transformer models are predictive of target identity groups
"""
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
plt.subplots_adjust(wspace=0.1)

# Colors
precision_color = 'C0'
recall_color = 'C1'
f1_color = 'C2'
roc_auc_color = 'C4'
pr_auc_color = 'lightgrey'
# Bar plot settings
center_points = np.arange(n_groups)
fig_1a_bar_width = 0.28
fig_1b_bar_width = 0.40
bar_edge_color = 'black'
fig1a_error_capsize = 2
fig1b_error_capsize = 3
# Size settings
legend_size = 13
ax_label_size = 17
ax_tick_label_size = 14
subplot_label_size = 18

"""
Figure 1a:
Precision, recall, and F1 score across identity groups
"""
# Precision bar plot
axes[0].bar(
    x=center_points - fig_1a_bar_width,
    height=precision_mean,
    width=fig_1a_bar_width,
    yerr=precision_std,
    color=precision_color,
    edgecolor=bar_edge_color,
    error_kw={'capsize': fig1a_error_capsize},
    label='Precision')
# Recall bar plot
axes[0].bar(
    x=center_points,
    height=recall_mean,
    width=fig_1a_bar_width,
    yerr=recall_std,
    color=recall_color,
    edgecolor=bar_edge_color,
    error_kw={'capsize': fig1a_error_capsize},
    label='Recall')
# F1 score bar plot
axes[0].bar(
    x=center_points + fig_1a_bar_width,
    height=f1_mean,
    width=fig_1a_bar_width,
    yerr=f1_std,
    color=f1_color,
    edgecolor=bar_edge_color,
    error_kw={'capsize': fig1a_error_capsize},
    label='F1 Score')

axes[0].set_ylim([0, 1.03])
axes[0].set_ylabel(bold_text('Metric'), fontsize=ax_label_size)
axes[0].grid(axis='y')
axes[0].set_axisbelow(True)
axes[0].legend(
    bbox_to_anchor=(0.5, 1.06),
    loc='center',
    ncol=3,
    prop={'size': legend_size})

"""
Figure 1b:
ROC / PR AUC scores across identity groups
"""
# ROC AUC bar plot
axes[1].bar(
    x=center_points - fig_1b_bar_width / 2,
    height=roc_auc_mean,
    width=fig_1b_bar_width,
    yerr=roc_auc_std,
    color=roc_auc_color,
    edgecolor=bar_edge_color,
    error_kw={'capsize': fig1b_error_capsize},
    label='ROC AUC')
# PR AUC bar plot
axes[1].bar(
    x=center_points + fig_1b_bar_width / 2,
    height=pr_auc_mean,
    width=fig_1b_bar_width,
    yerr=pr_auc_std,
    color=pr_auc_color,
    edgecolor=bar_edge_color,
    error_kw={'capsize': fig1b_error_capsize},
    label='PR AUC')

# Plot incidence rates for each PR AUC
for idx, rate in enumerate(incidence_rates[sorted_idx]):
    axes[1].plot(
        [idx + width, idx],
        [rate, rate],
        color='black',
        lw=2.5)

axes[1].grid(axis='y')
axes[1].set_axisbelow(True)
axes[1].set_ylim([0, 1.03])

# Set axes ticks and labels
for ax in axes:
    ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    ax.set_yticklabels([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    ax.set_xticks(np.arange(n_groups))
    ax.set_xticklabels(
        np.array(bold_text(sorted(keys.target_labels)))[sorted_idx],
        ha='right',
        rotation=20)
    ax.tick_params(labelsize=ax_tick_label_size)

# Create legend
axes[1].legend(
    bbox_to_anchor=(0.5, 1.06),
    loc='center',
    ncol=2,
    prop={'size': legend_size})

# Apply subplot labels
apply_subplot_labels(
    axes,
    bold=True,
    x=-0.04,
    y=1.07,
    size=subplot_label_size)

plt.show()
# plt.savefig('figure1.pdf', bbox_inches='tight')