**Description**: Tries to answer [my answer on
stats.stackexchange.com](https://stats.stackexchange.com/q/611877/337906) by training
BERT on real classification datasets. This experiment suggests that training on test set
features (no labels) can be okay.

**Estimated runtime**: ~30 minutes on a Google Colab GPU or TPU. Eternity on CPU.

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import permutation_test
import seaborn as sns
from statsmodels.stats.multitest import fdrcorrection

# Individual analysis

In [None]:
# dataset = "ag_news"
# dataset = "enron_spam"
# dataset = "amazon_counterfactual_en"
# dataset = "yelp_review_full"
# dataset = "craigslist_bargains"
# dataset = "emotion"
# dataset = "ethos"
# dataset = "yahoo_answers_topics"
# dataset = "trec"
# dataset = "mtop_domain"
# dataset = "clickbait_notclickbait_dataset"
# dataset = "financial_phrasebank"
# dataset = "app_reviews"
dataset = "rotten_tomatoes"

df = pd.read_csv(os.path.join("bert_accuracies", f"{dataset}.csv"))
print(df.describe().round(3))

In [None]:
permutation_test(
    data=(df["test"], df["extra"]),
    statistic=lambda x, y: np.mean(x - y),
    alternative="greater",  # acc_test (unfair) > acc_extra (fair)
    permutation_type="samples",  # paired observations
    n_resamples=10_000,
).pvalue

In [None]:
bins = 10

fig, axes = plt.subplots(
    nrows=len(df.columns), ncols=1, figsize=(6, len(df.columns) * 1.5)
)

axes: list[plt.Axes]

x_common_min = df.min().min()
x_common_max = df.max().max()
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

# Plot histograms for each column
for i, column in enumerate(df.columns):
    column: str
    df[column].hist(
        bins=bins, ax=axes[i], range=(x_common_min, x_common_max), color=colors[i]
    )
    axes[i].set_ylim((0, 20))
    axes[i].set_xlabel(f"{column} accuracy")
    axes[i].set_ylabel("frequency")

fig.suptitle(f"Dataset: {dataset}", fontsize=12)
plt.tight_layout()
plt.subplots_adjust(hspace=1.2)

In [None]:
pvals = [
    0.65,
    0.63,
    0.58,
    0.09,
    0.84,
    0.78,
    0.18,
    0.9988,
    0.28,
    0.86,
    0.86,
    0.29,
    0.09,
    0.71,
]
fdrcorrection(pvals)

# Meta-analysis

In [None]:
_dfs = []
for accuracy_csv in sorted(os.listdir("bert_accuracies")):
    _df = pd.read_csv(os.path.join("bert_accuracies", accuracy_csv))
    _df["dataset"] = accuracy_csv.removesuffix(".csv")
    _dfs.append(_df)
accuracy_df = pd.concat(_dfs)
accuracy_df = accuracy_df[["dataset", "base", "extra", "test"]]

In [None]:
num_test = 200  # taken from bert/run.ipynb
num_correct_df = (accuracy_df.copy()[["base", "extra", "test"]] * num_test).astype(int)
num_correct_df["dataset"] = accuracy_df["dataset"].copy()
num_correct_df = num_correct_df[["dataset", "base", "extra", "test"]]
num_correct_df

In [None]:
accuracy_df["diff"] = accuracy_df["test"] - accuracy_df["extra"]
accuracy_df["control"] = accuracy_df["extra"] - accuracy_df["base"]

In [None]:
sns.set_theme(style="darkgrid")

### Does pretraining help?

In [None]:
(accuracy_df
 .groupby("dataset")
 ["control"]
 .describe()
 [["mean", "std"]]
 .round(3)
)

In [None]:
fig, axes = plt.subplots(figsize=(16, 2))
axes: plt.Axes
sns.violinplot(data=accuracy_df, x="dataset", y="control", ax=axes)
axes.set_title("BERT for text classification")

axes.yaxis.grid(True)
axes.set_xlabel("Dataset")
axes.set_ylabel(
    "Accuracy boost\n"
    "($\\text{acc}_\\text{extra}$ - $\\text{acc}_\\text{base}$)",
    rotation="horizontal",
    ha="right",
    va="top",
)
plt.xticks(rotation=45, ha="right")

plt.show()

### Does pretraining on test cause bias?

In [None]:
(accuracy_df
 .groupby("dataset")
 ["diff"]
 .describe()
 [["mean", "std"]]
 .round(3)
)

In [None]:
fig, axes = plt.subplots(figsize=(16, 2))
axes: plt.Axes
sns.violinplot(data=accuracy_df, x="dataset", y="diff", ax=axes, color="darkorange")
axes.set_title("BERT for text classification")

axes.yaxis.grid(True)
axes.set_xlabel("Dataset")
axes.set_ylabel(
    "Accuracy overestimation\n"
    "($\\text{acc}_\\text{test}$ - $\\text{acc}_\\text{extra}$)",
    rotation="horizontal",
    ha="right",
    va="center",
)
plt.xticks(rotation=45, ha="right")

plt.show()