In [59]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
import json
import numpy as np
import pickle
import pandas as pd

import statsmodels.stats.api as sms

In [61]:
with open("mldb_2021-07-16.pickle", 'rb') as f:
    evals = pickle.load(f)

In [62]:
df = pd.DataFrame(evals)

In [63]:
# Ignore checkpoints for now
df = df[df.epoch == -1]

In [64]:
# Add accurcy metric and ci
df['accuracy'] = df.stats.apply(lambda x: x['num_correct_and_size'][0] / x['num_correct_and_size'][1])

num_correct_vals = df.stats.apply(lambda x: x['num_correct_and_size'][0]).values
size_vals = df.stats.apply(lambda x: x['num_correct_and_size'][1]).values
cis = sms.proportion_confint(num_correct_vals, size_vals, alpha=0.05, method='beta')
df['accuracy_ci'] = list(zip(*cis))

In [65]:
# Add macro_f1 metric and ci
def confint(acc, n, alpha=0.05, method="beta"):
    return sms.proportion_confint(acc * n, n, alpha=alpha, method=method)

def worst_region_acc_ci(ev):
        """Compute Clopper-Pearson CI for the worst-region subgroup."""
        # Find the number of points in the worst-region
        if 'wilds_metrics' not in ev or 'acc_worst_region' not in ev['wilds_metrics']:
            return (0., 0.)
        regions = ["Asia", "Europe", "Africa", "Americas", "Oceania", "Other"]
        worst_acc = ev['wilds_metrics']["acc_worst_region"]
        worst_region_size = None
        for region in regions:
            if np.isclose(worst_acc, ev['wilds_metrics'][f"acc_region:{region}"]):
                worst_region_size = ev['wilds_metrics'][f"count_region:{region}"]
                break
        assert worst_region_size is not None
        # Note: This confidence interval isn't exactly correct because we took
        # a max over the worst-region first...
        num_correct = int(worst_region_size * worst_acc)
        return sms.proportion_confint(
            num_correct, worst_region_size, alpha=0.05, method="beta"
        )

df['macro_f1'] = df.stats.apply(lambda x: x.get('wilds_metrics', {}).get('F1-macro_all', None))
df['macro_f1_ci'] = df.stats.apply(lambda x: x.get('iwc_f1_approx_ci_95', (0., 0.))) # TEMPORARY

df['worst_region_accuracy'] = df.stats.apply(lambda x: x.get('wilds_metrics', {}).get('acc_worst_region', None))
df['worst_region_accuracy_ci'] = df.stats.apply(worst_region_acc_ci)



In [66]:
# Pairs of id-train, id-test
ID_PAIRS = [
    ("cifar10-train", "cifar10-test"),
    ("cifar10-train", "cifar10-STL10classes"),
    ("FMoW-train", "FMoW-id_test"),
    ("FMoW-train", "FMoW-id_val"),
    ("Camelyon17-train", "Camelyon17-id_val"),
    ("Camelyon17-train", "Camelyon17-id_test"),
    ("IWildCamOfficialV2-train", "IWildCamOfficialV2-id_val"),
    ("IWildCamOfficialV2-train", "IWildCamOfficialV2-id_test"),
]

In [67]:
def reformat(_df, train, test):
    test_eval = _df[_df.test_set == test]
    if len(test_eval) == 0:
        return pd.DataFrame()
    test_eval = test_eval.iloc[0]
    shift_evals = _df[~_df.test_set.isin([train, test])]
    newdf = shift_evals[["model_family", "model_id", "epoch", "rule_params"]]
    newdf = newdf.rename(columns={"rule_params": "hyperparameters"})
    newdf["train_set"] = train
    newdf["test_set"] = test
    newdf["shift_set"] = shift_evals["test_set"]
    for metric in ["accuracy", "macro_f1", "worst_region_accuracy"]:
        newdf[f"test_{metric}"] = test_eval[metric]
        newdf[f"test_{metric}_ci"] = [test_eval[f"{metric}_ci"] for _ in range(len(newdf))]
        newdf[f"shift_{metric}"] = shift_evals[metric]
        newdf[f"shift_{metric}_ci"] = shift_evals[f"{metric}_ci"]
    return newdf

In [68]:
new_df = []
for train, test in ID_PAIRS:
    df_train = df[df.train_set == train]
    shift_sets = set(df_train.test_set) - set([train, test])
    new_df.extend([reformat(modeldf, train, test) for _, modeldf in df_train.groupby("model_id")])

In [69]:
new_df = pd.concat(new_df)

In [70]:
def rename_model_family(model_family):
    if "RandFeatures" in model_family:
        return "RandomFeatures"
    elif "K_nearest_neighbors" in model_family:
        return "KNN"
    return model_family

new_df["model_family"] = new_df.model_family.apply(rename_model_family)

In [71]:
def rename_test_set(ts):
    if ts == "cifar10-STL10classes":
        return "cifar10-test-STL10classes"
    return ts
new_df["test_set"] = new_df.test_set.apply(rename_test_set)

In [72]:
new_df.head()

Unnamed: 0,model_family,model_id,epoch,hyperparameters,train_set,test_set,shift_set,test_accuracy,test_accuracy_ci,shift_accuracy,shift_accuracy_ci,test_macro_f1,test_macro_f1_ci,shift_macro_f1,shift_macro_f1_ci,test_worst_region_accuracy,test_worst_region_accuracy_ci,shift_worst_region_accuracy,shift_worst_region_accuracy_ci
14976,RandomFeatures,00082624-3fbe-46de-8c7c-b45f38a88e87,-1,"{'num_filters': 64, 'patch_size': 6, 'pool_siz...",cifar10-train,cifar10-test,cifar10c_impulse_noise_2,0.6551,"(0.6456894668728717, 0.6644201270832805)",0.4854,"(0.47555981191619107, 0.4952486953322158)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)"
14977,RandomFeatures,00082624-3fbe-46de-8c7c-b45f38a88e87,-1,"{'num_filters': 64, 'patch_size': 6, 'pool_siz...",cifar10-train,cifar10-test,cifar10c_fog_2,0.6551,"(0.6456894668728717, 0.6644201270832805)",0.6334,"(0.6238676403350546, 0.6428546097260688)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)"
14978,RandomFeatures,00082624-3fbe-46de-8c7c-b45f38a88e87,-1,"{'num_filters': 64, 'patch_size': 6, 'pool_siz...",cifar10-train,cifar10-test,cifar10c_frost_5,0.6551,"(0.6456894668728717, 0.6644201270832805)",0.5104,"(0.5005504702099839, 0.5202434698407776)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)"
14979,RandomFeatures,00082624-3fbe-46de-8c7c-b45f38a88e87,-1,"{'num_filters': 64, 'patch_size': 6, 'pool_siz...",cifar10-train,cifar10-test,cifar10c_snow_all,0.6551,"(0.6456894668728717, 0.6644201270832805)",0.5562,"(0.5518320019958354, 0.560561471986289)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)"
14980,RandomFeatures,00082624-3fbe-46de-8c7c-b45f38a88e87,-1,"{'num_filters': 64, 'patch_size': 6, 'pool_siz...",cifar10-train,cifar10-test,cifar10c_saturate_2,0.6551,"(0.6456894668728717, 0.6644201270832805)",0.5694,"(0.5596259942066042, 0.5791335647014177)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)",,"(0.0, 0.0)"


## Add YCB objects results


In [102]:
with open("ycb_50ktrain_evals.json") as handle:
    ycb50 = json.load(handle)
    for e in ycb50:
        del e["hparams"]["train:outdir"]
        e["test_ci"] = tuple(e["test_ci"])
        e["shift_ci"] = tuple(e["shift_ci"])

with open("ycb_100ktrain_evals.json") as handle:
    ycb100 = json.load(handle)
    for e in ycb100:
        del e["hparams"]["train:outdir"]
        e["test_ci"] = tuple(e["test_ci"])
        e["shift_ci"] = tuple(e["shift_ci"])

In [103]:
def reformat_ycb(ycb_data, trainset):
    ycb_df = pd.DataFrame(ycb_data)
    ycb_df = ycb_df.rename(columns={
        "model": "model_family", 
        "hparams": "hyperparameters", 
        "test_score": "test_accuracy", 
        "shift_score": "shift_accuracy",
        "test_ci": "test_accuracy_ci",
        "shift_ci": "shift_accuracy_ci"
    })
    ycb_df["train_set"] = trainset
    ycb_df["test_set"] = "YCB ID Test"
    ycb_df["shift_set"] = "YCB OOD Test"
    
    null_ci = [(0., 0.) for _ in range(len(ycb_df))]
    
    ycb_df["test_macro_f1_ci"] = null_ci
    ycb_df["shift_macro_f1_ci"] = null_ci
    ycb_df["test_worst_region_accuracy_ci"] = null_ci
    ycb_df["shift_worst_region_accuracy_ci"] = null_ci
    return ycb_df

ycb50 = reformat_ycb(ycb50, "YCB Train 50k examples")
ycb100 = reformat_ycb(ycb100, "YCB Train 100k examples")

In [104]:
combined_df = pd.concat([new_df, ycb50, ycb100])

In [105]:
combined_df.to_csv("results.csv", index=False)

In [106]:
df = pd.read_csv("results.csv",
    converters={
        "hyperparameters": ast.literal_eval,
        "test_accuracy_ci": ast.literal_eval,
        "shift_accuracy_ci": ast.literal_eval,
        "test_macro_f1_ci": ast.literal_eval,
        "shift_macro_f1_ci": ast.literal_eval,
        "test_worst_region_accuracy_ci": ast.literal_eval,
        "shift_worst_region_accuracy_ci": ast.literal_eval,
    })
