In [1]:
import os

import polars as pl

import utils



In [2]:
num_trains = (50, 100)
num_tests = (50, 100, 200, 500)
accuracies_home_dir: str = "accuracies_from_paper"

In [3]:
dfs = []
for num_train in num_trains:
    for num_test in num_tests:
        dfs.append(
            utils.load_all_accuracies(
                os.path.join(accuracies_home_dir, f"m{num_train}"), num_test
            ).with_columns(
                pl.lit(num_train).alias("num_train"), pl.lit(num_test).alias("num_test")
            )
        )

In [4]:
accuracies_df: pl.DataFrame = pl.concat(dfs)

group = ["num_train", "num_test", "lm_type", "dataset"]
accuracies_df = accuracies_df.sort(by=group)

In [5]:
accuracies_df = accuracies_df.with_columns(
    pretraining_boost=pl.col("extra") - pl.col("base"),
    evaluation_bias=pl.col("test") - pl.col("extra"),
)

In [6]:
def sem(col: pl.Expr) -> pl.Expr:
    # ty https://github.com/pola-rs/polars/issues/6175#issuecomment-1416846104
    return col.std() / (col.count().sqrt())


accuracies_grouped = (  # aggregate subsamples into mean and SE
    accuracies_df.group_by(*group)
    .agg(
        pl.col("pretraining_boost").mean().name.suffix("_mean"),
        sem(pl.col("pretraining_boost")).name.suffix("_se"),
        pl.col("evaluation_bias").mean().name.suffix("_mean"),
        sem(pl.col("evaluation_bias")).name.suffix("_se"),
    )
    .sort(group)
)

In [7]:
accuracies_grouped

num_train,num_test,lm_type,dataset,pretraining_boost_mean,pretraining_boost_se,evaluation_bias_mean,evaluation_bias_se
i32,i32,str,str,f64,f64,f64,f64
50,50,"""bert""","""FRENK-hate-en""",-0.0264,0.009228,-0.0008,0.009551
50,50,"""bert""","""ag_news""",0.1006,0.0098,-0.0162,0.011223
50,50,"""bert""","""amazon_counter…",0.0198,0.018692,0.0232,0.020512
50,50,"""bert""","""app_reviews""",0.1186,0.012834,0.007,0.015615
50,50,"""bert""","""blog_authorshi…",-0.006,0.007001,0.0064,0.007638
…,…,…,…,…,…,…,…
100,500,"""gpt2""","""silicone""",0.0352,0.009397,-0.0095,0.005943
100,500,"""gpt2""","""trec""",0.0214,0.009417,-0.0009,0.00403
100,500,"""gpt2""","""tweets_hate_sp…",0.0233,0.027941,-0.0056,0.019901
100,500,"""gpt2""","""yahoo_answers_…",-0.0109,0.004161,-0.0015,0.002384


In [8]:
with pl.Config(tbl_rows=-1):
    print(
        accuracies_grouped.group_by("dataset")
        .agg(  # across m, n, lm_type
            pl.col("pretraining_boost_mean").mean(),
            (pl.col("pretraining_boost_mean") < 0).mean().alias("frac_boost_lt_0"),
        )
        .filter(pl.col("pretraining_boost_mean") < 0)
        .sort("pretraining_boost_mean")
    )

shape: (2, 3)
┌────────────────────────┬────────────────────────┬─────────────────┐
│ dataset                ┆ pretraining_boost_mean ┆ frac_boost_lt_0 │
│ ---                    ┆ ---                    ┆ ---             │
│ str                    ┆ f64                    ┆ f64             │
╞════════════════════════╪════════════════════════╪═════════════════╡
│ blog_authorship_corpus ┆ -0.009588              ┆ 0.625           │
│ movie_rationales       ┆ -0.004675              ┆ 0.6875          │
└────────────────────────┴────────────────────────┴─────────────────┘


In [9]:
with pl.Config(tbl_rows=-1):
    print(
        accuracies_grouped.group_by("dataset")
        .agg(  # across m, n, lm_type
            pl.col("evaluation_bias_mean").mean(),
            (pl.col("evaluation_bias_mean") > 0).mean().alias("frac_bias_gt_0"),
        )
        .filter(pl.col("evaluation_bias_mean") > 0)
        .sort("evaluation_bias_mean", descending=True)
    )

shape: (12, 3)
┌────────────────────────────────┬──────────────────────┬────────────────┐
│ dataset                        ┆ evaluation_bias_mean ┆ frac_bias_gt_0 │
│ ---                            ┆ ---                  ┆ ---            │
│ str                            ┆ f64                  ┆ f64            │
╞════════════════════════════════╪══════════════════════╪════════════════╡
│ amazon_counterfactual_en       ┆ 0.009306             ┆ 0.5625         │
│ financial_phrasebank           ┆ 0.007406             ┆ 0.6875         │
│ trec                           ┆ 0.002819             ┆ 0.5            │
│ app_reviews                    ┆ 0.002575             ┆ 0.6875         │
│ clickbait_notclickbait_dataset ┆ 0.0023625            ┆ 0.625          │
│ craigslist_bargains            ┆ 0.0014               ┆ 0.6875         │
│ emotion                        ┆ 0.001088             ┆ 0.6875         │
│ climate_fever                  ┆ 0.001088             ┆ 0.5625         │
│ blog_aut