In [None]:
import datasets 
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np

from mpl_lego.colorbar import append_colorbar_to_axis
from mpl_lego.labels import bold_text
from hate_measure import keys

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
dataset = datasets.load_dataset('ucberkeley-dlab/measuring-hate-speech', 'binary')   
df = dataset['train'].to_pandas().sort_values('annotator_id').drop_duplicates('annotator_id')
n_annotators = len(df)

In [None]:
df['annotator_education_high_school'] = df[keys.annotator_education_cols[:2]].any(axis=1)
df['annotator_education_college'] = df[keys.annotator_education_cols[2:4]].any(axis=1)
df['annotator_education_graduate_school'] = df[keys.annotator_education_cols[4:]].any(axis=1)

In [None]:
df['annotator_income_0-50k'] = df[keys.annotator_income_cols[:2]].any(axis=1)
df['annotator_income_50-100k'] = df[keys.annotator_income_cols[2:3]].any(axis=1)
df['annotator_income_more_than_100k'] = df[keys.annotator_income_cols[3:]].any(axis=1)

In [None]:
groups = [
    keys.annotator_race_cols,
    keys.annotator_gender_cols[:3] + keys.annotator_trans_cols[:2],
    keys.annotator_sexuality_cols,
    keys.annotator_religion_cols,
    ['annotator_education_high_school', 'annotator_education_college', 'annotator_education_graduate_school'],
    ['annotator_income_0-50k', 'annotator_income_50-100k', 'annotator_income_more_than_100k']
]

In [None]:
labels = [item for group in groups for item in group]
labels = [' '.join([chunk.capitalize() for chunk in label.split('_')[2:]])
          for label in labels]
labels[3] = 'Mid. East.'
labels[4] = 'Native Amer.'
labels[5] = 'Pac. Isl.'
labels[10] = 'Non-Binary'
labels[11] = 'Transgender'
labels[12] = 'Cisgender'
labels[27] = 'Grad School'
labels[-1] = '$>$100k'

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 9))

height = 0.15
base = 0
gap = 0.23
ticks = []
middles = []
counter = 0

for idx, group in enumerate(reversed(groups)):
    n_entries = len(group)
    y = np.linspace(base, base + height * (n_entries - 1), n_entries)
    ticks.extend(y)

    base = y[-1] + height + gap
    counts = df[reversed(group)].mean()
    middles.append(np.mean(y))
    ax.barh(
        y=y,
        width=counts,
        height=height,
        color='lightgray',
        edgecolor='black')

ax.set_yticks(ticks)
ax.set_yticklabels(reversed(bold_text(labels)))
ax.grid(axis='x')
ax.set_axisbelow(True)
ax.set_xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
ax.set_xlim([0, 1])
ax.set_ylim([-0.20, ticks[-1] + 0.20])
ax.tick_params(labelsize=12.5)
ax.set_xlabel(bold_text('Proportion of Annotators'), fontsize=15)

ax.text(x=-0.32, y=middles[0], s=bold_text('Income'), va='center', fontsize=14, rotation=90, transform=ax.transData)
ax.text(x=-0.32, y=middles[1], s=bold_text('Educ.'), va='center', fontsize=14, rotation=90, transform=ax.transData)
ax.text(x=-0.32, y=middles[2], s=bold_text('Religion'), va='center', fontsize=14, rotation=90, transform=ax.transData)
ax.text(x=-0.32, y=middles[3], s=bold_text('Sexuality'), va='center', fontsize=14, rotation=90, transform=ax.transData)
ax.text(x=-0.32, y=middles[4], s=bold_text('Gender'), va='center', fontsize=14, rotation=90, transform=ax.transData)
ax.text(x=-0.32, y=middles[5], s=bold_text('Race'), va='center', fontsize=14, rotation=90, transform=ax.transData)

plt.savefig('figure4.pdf', bbox_inches='tight')