In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import itertools
import math
import random
import statistics

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels.api as sm
from scipy.stats import norm, pearsonr
import numpy as np
import math

from utils import model_init
from score import cs_score, csk_score, f_score, ss_score
from dataset import logger, get_dataset_by_name

logger.setLevel(logging.WARNING)

In [None]:
def binary_ci(success: int, total: int, alpha: float = 0.95):
    """
    Using Agresti-Coull interval
    
    Return mean and confidence interval (lower and upper bound)
    """
    z = statistics.NormalDist().inv_cdf((1 + alpha) / 2)
    total = total + z**2
    loc = (success + (z**2) / 2) / total
    diameter = z * math.sqrt(loc * (1 - loc) / total)
    return loc, diameter


def bootstrap_ci(scores, alpha=0.95):
    """
    Bootstrapping based estimate.
    
    Return mean and confidence interval (lower and upper bound)
    """
    loc, scale = norm.fit(scores)    
    bootstrap = [sum(random.choices(scores, k=len(scores))) / len(scores) for _ in range(1000)]
    lower, upper = norm.interval(alpha, *norm.fit(bootstrap))
        
    return loc, loc - lower

# Figures

In [None]:
def get_ci_pred(x, y):
    """
    Get CI intervals for scatter points.
    """
    
    # sort by x
    x, y = zip(*sorted((xi, yi) for xi, yi in zip(x, y)))

    # calculate_predictions
    d = sm.add_constant(x)
    ols_model = sm.OLS(y, d)
    est = ols_model.fit()
    out = est.conf_int(alpha=0.05, cols=None)
    pred = est.get_prediction(d).summary_frame()
    
    return pred

## `ss` score

In [None]:
model, tokenizer = model_init('roberta-base')

dt = get_dataset_by_name('stereoset_genderswap', tokenizer)
ss_gender = ss_score(dt, tokenizer, model), ss_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_genderswap_filtered', tokenizer)
ss_gender_filter = ss_score(dt, tokenizer, model), ss_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_race_control', tokenizer)
ss_race = ss_score(dt, tokenizer, model), ss_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_profession_control', tokenizer)
ss_profession = ss_score(dt, tokenizer, model), ss_score(dt, tokenizer, model, swap=True)

model, tokenizer = model_init('gerulata/slovakbert')
dt = get_dataset_by_name('slovak_gender', tokenizer)
ss_slovak_gender = ss_score(dt, tokenizer, model), ss_score(dt, tokenizer, model, swap=True)

In [None]:
plt.rcParams["figure.figsize"] = (6,6)

fig, ax = plt.subplots(2,2,sharex=True,sharey=True)
ax = list(itertools.chain.from_iterable(ax))

pred = get_ci_pred(*ss_gender)
ax[0].set_title('StereoSet gender')
ax[0].set_ylabel('Control group')
ax[0].scatter(*ss_gender, s=4)
ax[0].fill_between(sorted(ss_gender[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)

pred = get_ci_pred(*ss_gender_filter)
ax[0].scatter(*ss_gender_filter, s=4)
ax[0].fill_between(sorted(ss_gender_filter[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='oldlace', alpha=0.5)
ax[0].axline((0, 0), slope=1, color="black", linestyle=':')
ax[0].legend(handles=[
    mpatches.Patch(color='tab:blue', label='Original'),
    mpatches.Patch(color='tab:orange', label='Filtered')
])

pred = get_ci_pred(*ss_race)
ax[1].set_title('StereoSet race')
ids = random.sample(range(len(ss_race[0])), k=500)
ax[1].scatter([ss_race[0][i] for i in ids], [ss_race[1][i] for i in ids], s=4)
ax[1].fill_between(sorted(ss_race[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[1].axline((0, 0), slope=1, color="black", linestyle=':')


pred = get_ci_pred(*ss_profession)
ax[2].set_xlabel('Original group')
ax[2].set_ylabel('Control group')
ax[2].set_title('StereoSet profession')
ids = random.sample(range(len(ss_profession[0])), k=500)
ax[2].scatter([ss_profession[0][i] for i in ids], [ss_profession[1][i] for i in ids], s=4)
ax[2].fill_between(sorted(ss_profession[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[2].axline((0, 0), slope=1, color="black", linestyle=':')


pred = get_ci_pred(*our_gender)
ax[3].set_xlabel('Original group')
ax[3].set_title('Slovak gender')
ax[3].scatter(*our_gender, s=4)
ax[3].fill_between(sorted(our_gender[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[3].axline((0, 0), slope=1, color="black", linestyle=':')
fig.tight_layout()

plt.savefig('1.pdf')

## `cs` score

In [None]:
model, tokenizer = model_init('roberta-base')

dt = get_dataset_by_name('stereoset-genderswap', tokenizer)
cs_gender = cs_score(dt, tokenizer, model), cs_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset-genderswap-filtered', tokenizer)
cs_gender_filter = cs_score(dt, tokenizer, model), cs_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('crows-negation', tokenizer)
cs_neg = cs_score(dt, tokenizer, model), cs_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('crows-antistereotypes', tokenizer)
cs_anti = cs_score(dt, tokenizer, model), cs_score(dt, tokenizer, model, swap=True)

model, tokenizer = model_init('gerulata/slovakbert')
dt = get_dataset_by_name('slovak_gender', tokenizer)
cs_slovak_gender = cs_score(dt, tokenizer, model), cs_score(dt, tokenizer, model, swap=True)

In [None]:
plt.rcParams["figure.figsize"] = (6,6)

fig, ax = plt.subplots(2,2,sharex=True,sharey=True)
ax = list(itertools.chain.from_iterable(ax))

pred = get_ci_pred(*cs_gender)
ax[0].set_title('StereoSet gender')
ax[0].set_ylabel('Control pair')
ax[0].scatter(*cs_gender, s=4)
ax[0].fill_between(sorted(cs_gender[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)

pred = get_ci_pred(*cs_gender_filter)
ax[0].scatter(*cs_gender_filter, s=4)
ax[0].fill_between(sorted(cs_gender_filter[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='oldlace', alpha=0.5)
ax[0].axline((0, 0), slope=1, color="black", linestyle=':')
ax[0].legend(handles=[
    mpatches.Patch(color='tab:blue', label='Original'),
    mpatches.Patch(color='tab:orange', label='Filtered')
])


pred = get_ci_pred(*cs_neg)
ax[1].set_title('CrowS Negation')
ax[1].set_ylabel('Control pair')
ax[1].scatter(cs_neg[0][:100], cs_neg[1][:100], s=4)
ax[1].fill_between(sorted(cs_neg[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[1].axline((0, 0), slope=1, color="black", linestyle=':')


pred = get_ci_pred(*cs_anti)
ax[2].set_xlabel('Original pair')
ax[2].set_ylabel('Control pair')
ax[2].set_title('CrowS Antistereotype')
ax[2].scatter(cs_anti[0][:100], cs_anti[1][:100], s=4)
ax[2].fill_between(sorted(cs_anti[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[2].axline((0, 0), slope=1, color="black", linestyle=':')


pred = get_ci_pred(*cs_our_gender)
ax[3].set_xlabel('Original pair')
ax[3].set_title('Slovak gender')
ax[3].scatter(cs_our_gender[0][:100], cs_our_gender[1][:100], s=4)
ax[3].fill_between(sorted(cs_our_gender[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[3].axline((0, 0), slope=1, color="black", linestyle=':')
fig.tight_layout()
plt.savefig('2.pdf')

## `csk` score

In [None]:
model, tokenizer = model_init('roberta-base')

dt = get_dataset_by_name('stereoset_genderswap', tokenizer)
csk_gender = csk_score(dt, tokenizer, model), csk_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_genderswap_filtered', tokenizer)
csk_gender_filter = csk_score(dt, tokenizer, model), csk_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_race_control', tokenizer)
csk_race = csk_score(dt, tokenizer, model), csk_score(dt, tokenizer, model, swap=True)

dt = get_dataset_by_name('stereoset_profession_control', tokenizer)
csk_profession = csk_score(dt, tokenizer, model), csk_score(dt, tokenizer, model, swap=True)

model, tokenizer = model_init('gerulata/slovakbert')
dt = get_dataset_by_name('slovak_gender', tokenizer)
csk_slovak_gender = csk_score(dt, tokenizer, model), csk_score(dt, tokenizer, model, swap=True)

In [None]:
plt.rcParams["figure.figsize"] = (20,5)

fig, ax = plt.subplots(1, 4)

pred = get_ci_pred(*css_gender)
ax[0].set_title('StereoSet gender')
ax[0].set_xlabel('Original group')
ax[0].set_ylabel('Control group')
ax[0].scatter(*csk_gender, s=4)
ax[0].fill_between(sorted(csk_gender[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)

pred = get_ci_pred(*csk_gender_filter)
ax[0].scatter(*csk_gender_filter, s=4)
ax[0].fill_between(sorted(csk_gender_filter[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='oldlace', alpha=0.5)
ax[0].plot([min(csk_gender_filter[0]), max(csk_gender_filter[0])], [min(csk_gender_filter[0]), max(csk_gender_filter[0])], linestyle=':', c='black')


pred = get_ci_pred(*csk_race)
ax[1].set_title('CrowS Negated')
ax[1].set_xlabel('Original group')
ax[1].set_ylabel('Control group')
ax[1].scatter(csk_race[0][:100], csk_race[1][:100], s=4)
ax[1].fill_between(sorted(csk_race[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[1].plot([min(csk_race[0]), max(csk_race[0])], [min(csk_race[0]), max(csk_race[0])], linestyle=':', c='black')


pred = get_ci_pred(*csk_profession)
ax[2].set_xlabel('Original group')
ax[2].set_ylabel('Control group')
ax[2].set_title('CrowS Antistereotype')
ax[2].scatter(csk_profession[0][:100], csk_profession[1][:100], s=4)
# ax[0].plot(sorted(ss_gender[0]), pred['mean'], color='blue')
ax[2].fill_between(sorted(csk_profession[0]), pred['mean_ci_lower'], pred['mean_ci_upper'], color='lightblue', alpha=0.5)
ax[2].plot([min(csk_profession[0]), max(csk_profession[0])], [min(csk_profession[0]), max(csk_profession[0])], linestyle=':', c='black')


plt.savefig('3.pdf')

# Tables

In [1]:
english_models = [
    'roberta-base',
    'bert-base-uncased',
    'distilbert-base-uncased',
    'xlm-roberta-base',
    'albert-base-v2',
    'albert-xxlarge-v2',
    'bert-base-multilingual-cased',
]

slovak_models = [
    'gerulata/slovakbert',
    'xlm-roberta-base',
    'bert-base-multilingual-cased',
]

## `ss` score

In [None]:
def ss_table_results(model, tokenizer, dt):
    
    ss_results = stereo_score(dt, tokenizer, model)
    ss_swap_results = stereo_score(dt, tokenizer, model, swap=True)

    # ssmu original
    yield bootstrap_ci(ss_results)
    
    # ssmu control
    yield bootstrap_ci(ss_swap_results)
    
    # ss+ original
    yield binary_ci(sum([ss > 0 for ss in ss_results]), len(ss_results))
    
    # ss+ control
    yield binary_ci(sum([ss > 0 for ss in ss_swap_results]), len(ss_swap_results))
    
    # ss pearson
    yield pearsonr(ss_results, ss_swap_results)[0]
    
    # false positive rate
    yield sum(ss > 0 and ss_swap > ss for ss, ss_swap in zip(ss_results, ss_swap_results)) / sum(ss > 0 for ss in ss_results)
    
    # false negative rate
    yield sum(ss < 0 and ss_swap < ss for ss, ss_swap in zip(ss_results, ss_swap_results)) / sum(ss < 0 for ss in ss_results)


In [None]:
row_names = [
    '$ss\\mu$ Original',
    '$ss\\mu$ Control',
    '$ss+$ Original',
    '$ss+$ Control',
    '$ss\\ \\rho$',
    'False Positive Rate',
    'False Negative Rate',
]

for model_name in english_models:
    model, tokenizer = model_init(model_name)
    print(model_name)
    dataset_values = [
        list(ss_table_results(model, tokenizer, dt))
        for dt in [
            get_dataset_by_name('stereoset_genderswap', tokenizer),
            get_dataset_by_name('stereoset_genderswap_filtered', tokenizer),
            get_dataset_by_name('stereoset_race_control', tokenizer),
            get_dataset_by_name('stereoset_profession_control', tokenizer),
        ]
    ]
    for values, row_name in zip(zip(*dataset_values), row_names):
        print(row_name, end='')
        for value in values:
            if isinstance(value, float):
                print(f' & ${value:.2}$', end='')
            else:
                print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
        if row_name in ['$ss\\mu$ Control', '$ss+$ Control']:
            print('\\\\ \\midrule')    
        else:
            print('\\\\')
        

In [None]:
model_values = []
for model_name in slovak_models:
    model, tokenizer = model_init(model_name)
    dt = get_dataset_by_name('slovak_gender', tokenizer)
    model_values.append(ss_table_results(model, tokenizer, dt))
    
for values, row_name in zip(zip(*model_values), row_names):
    print(row_name, end='')
    for value in values:
        if isinstance(value, float):
            print(f' & ${value:.2}$', end='')
        else:
            print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
    if row_name in ['$ss\\mu$ Control', '$ss+$ Control']:
        print('\\\\ \\midrule')    
    else:
        print('\\\\')

## `cs` score

In [None]:
def cs_table_results(model, tokenizer, dt, csk=True):
    
    cs_results = cs_score(dt, tokenizer, model)
    cs_swap_results = cs_score(dt, tokenizer, model, swap=True)
    if csk:
        csk_results = csk_score(dt, tokenizer, model)
        csk_swap_results = csk_score(dt, tokenizer, model, swap=True)
    
    # csmu original
    yield bootstrap_ci(cs_results)
    
    # csmu control
    yield bootstrap_ci(cs_swap_results)
    
    
    if csk:
        # cskmu original
        yield bootstrap_ci(csk_results)

        # cskmu control
        yield bootstrap_ci(csk_swap_results)
    else:
        yield None
        yield None
        
    # cs+ original
    yield binary_ci(sum([cs > 0 for cs in cs_results]), len(cs_results))
    
    # cs+ control
    yield binary_ci(sum([cs > 0 for cs in cs_swap_results]), len(cs_swap_results))
    
    yield pearsonr(cs_results, cs_swap_results)[0]
    
    if csk:
        yield pearsonr(csk_results, csk_swap_results)[0]
        yield pearsonr(cs_results, csk_results)[0]
    else:
        yield None
        yield None


In [None]:
row_names = [
    '$cs\\mu$ Original',
    '$cs\\mu$ Control',
    '$csk\\mu$ Original',
    '$csk\\mu$ Control',
    '$cs+$ Original',
    '$cs+$ Control',
    '$cs\\ \\rho$',
    '$csk\\ \\rho$',
    '$cs{-}csk\\ \\rho$',
]

for model_name in english_models:
    model, tokenizer = model_init(model_name)
    print(model_name)
    
    dataset_names = [
        'stereoset_genderswap',
        'stereoset_genderswap_filtered',
        'crows_negation',
        'crows_antistereotypes',
    ]
    dataset_values = [
        list(
            cs_table_results(
                model,
                tokenizer,
                get_dataset_by_name(dt_name, tokenizer),
                csk=dt_name.startswith('stereoset')
            )
        )
        for dt_name in dataset_names
    ]
    
    for values, row_name in zip(zip(*dataset_values), row_names):
        print(row_name, end='')
        for value in values:
            if value is None:
                print(' & -', end='')
            if isinstance(value, float):
                print(f' & ${value:.2}$', end='')
            else:
                print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
        if row_name in ['$cs\\mu$ Control', '$csk\\mu$ Control', '$cs+$ Control']:
            print('\\\\ \\midrule')    
        else:
            print('\\\\')
        

In [None]:
model_values = []
for model_name in slovak_models:
    model, tokenizer = model_init(model_name)
    dt = get_dataset_by_name('slovak_gender', tokenizer)
    model_values.append(cs_table_results(model, tokenizer, dt))
    
for values, row_name in zip(zip(*model_values), row_names):
    print(row_name, end='')
    for value in values:
        if value is None:
            print(' & -', end='')
        if isinstance(value, float):
            print(f' & ${value:.2}$', end='')
        else:
            print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
    if row_name in ['$cs\\mu$ Control', '$csk\\mu$ Control', '$cs+$ Control']:
        print('\\\\ \\midrule')    
    else:
        print('\\\\')

## `f` score

In [None]:
def f_table_results(model, tokenizer, dt, cs=True):
    
    ss_results = ss_score(dt, tokenizer, model)
    cs_results = cs_score(dt, tokenizer, model)
    f_results = f_score(dt, tokenizer, model)
    
    # fmu
    yield bootstrap_ci(f_results)
    
    # f+
    yield binary_ci(sum([f > 0 for f in f_results]), len(f_results))
    
    # Correlation with SS
    yield pearsonr(f_results, ss_results)[0]

    # Agreement with SS
    yield sum((ss > 0) == (f > 0) for ss, f in zip(ss_results, f_results)) / len(ss_results)

    # Correlation with CS
    yield pearsonr(f_results, cs_results)[0]

    # Agreement with CS
    yield sum((cs > 0) == (f > 0) for cs, f in zip(cs_results, f_results)) / len(cs_results)

In [None]:
row_names = [
    '$f\mu$',
    '$f+$',
    '$f{-}ss\\ \\rho$',
    '$f{-}ss$ agreement',
    '$f{-}cs\\ \\rho$',
    '$f{-}cs$ agreement',
]
for model_name in english_models:
    model, tokenizer = model_init(model_name)
    print(model_name)
    dataset_values = [
        f_table_results(model, tokenizer, dt)
        for dt in [
            get_dataset_by_name('stereoset_genderswap', tokenizer),
            get_dataset_by_name('stereoset_genderswap_filtered', tokenizer),
            get_dataset_by_name('stereoset_race_control', tokenizer),
            get_dataset_by_name('stereoset_profession_control', tokenizer),
        ]
    ]
    for values, row_name in zip(zip(*dataset_values), row_names):
        print(row_name, end='')
        for value in values:
            if isinstance(value, float):
                print(f' & ${value:.2}$', end='')
            else:
                print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
        if row_name in ['$f+$', '$f{-}ss$ agreement']:
            print('\\\\ \\midrule')    
        else:
            print('\\\\')


In [None]:
model_values = []
for model_name in slovak_models:
    model, tokenizer = model_init(model_name)
    dt = get_dataset_by_name('slovak_gender', tokenizer)
    model_values.append(f_table_results(model, tokenizer, dt))
    
for values, row_name in zip(zip(*model_values), row_names):
    print(row_name, end='')
    for value in values:
        if isinstance(value, float):
            print(f' & ${value:.2}$', end='')
        else:
            print(f' & ${value[0]:.2} \\pm {value[1]:.2}$', end='')
    if row_name in ['$f+$', '$f{-}ss$ agreement']:
        print('\\\\ \\midrule')    
    else:
        print('\\\\')