In [None]:
import os
import pickle
import argparse
import numpy as np
import pandas as pd


from model_zoo import get_image_classifier, split_test
from data_zoo import get_image_dataset
from atypicality import GMMAtypicalityEstimator
%load_ext autoreload
%autoreload 2
def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="imagenet")
    parser.add_argument("--model_name", type=str, default="resnet50")
    parser.add_argument("--output-dir", type=str, default="./outputs")
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("-f")
    return parser.parse_args()

args = config()
os.makedirs(args.output_dir, exist_ok=True)

In [None]:
model, preprocess = get_image_classifier("resnet50", device=args.device)
train_dataset, test_dataset = get_image_dataset("imagenet", preprocess=preprocess)


In [None]:
train_features, train_logits, train_labels = model.run_and_cache_outputs(train_dataset, args.batch_size, args.output_dir)
test_features, test_logits, test_labels = model.run_and_cache_outputs(test_dataset, args.batch_size, args.output_dir)

In [None]:
# Splitting the test set into 2 sets: Calibration and Evaluation sets.
test_labels, test_features, test_logits, _, calib_labels, calib_features, calib_logits, _ = split_test(test_labels, test_features, test_logits, split=0.5, seed=1)

In [None]:
# Atypicality Estimation
atypicality_estimator = GMMAtypicalityEstimator()

# Fit the atypicality estaimtor
atypicality_estimator.fit(train_features, train_labels)

# Predict the atypicality
test_atypicality = atypicality_estimator.predict(test_features).reshape((-1, 1))
calib_atypicality = atypicality_estimator.predict(calib_features).reshape((-1, 1))


In [None]:
from calibration import TemperatureScaler, AtypicalityAwareCalibrator
from scipy.special import softmax
from utils.calibration import compute_calibration 
from utils.plots import get_fig_records

metadata = {"model": args.model_name, "dataset": args.dataset_name}

# Vanilla Model
probs = softmax(test_logits, axis=1)
accuracy = (np.argmax(test_logits, axis=1) == test_labels).mean()
print(f"Accuracy: {accuracy}")

# Temperature Scaling
ts = TemperatureScaler()
ts.fit(calib_logits, calib_labels)

# AAR
aar_calib = AtypicalityAwareCalibrator()
aar_calib.fit(calib_logits, calib_atypicality, calib_labels, max_iters=1500)

prob_info = {
    "probs": {"Uncalibrated": probs,
              "Temp. Scaling": ts.predict_proba(test_logits),
              "Atypicality-Aware": aar_calib.predict_proba(test_logits, test_atypicality)},
    "input_atypicality": test_atypicality,
    "labels": test_labels
}

all_records = get_fig_records(prob_info, N_groups=5, **metadata)

In [None]:
from matplotlib.ticker import FormatStrFormatter
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

data = pd.DataFrame(all_records)

In [None]:
sns.set_context("paper", font_scale=2)  # Adjust the font_scale as needed
fig, axs = plt.subplots(2, 1, figsize=(5, 4.2))
maps = {"Atypicality-Aware": "AAR(Ours)", "Temp. Scaling": "TS"}
data["Recalibration"] = data["Recalibration"].apply(lambda x: maps[x] if x in maps else x)

# Plot ECE vs quantile
sns.lineplot(x='quantile', y='ECE', hue='Recalibration', linewidth=2.5, errorbar=('ci', 95), data=data, ax=axs[0], legend=True)
barplot = sns.barplot(x='Recalibration', y='Accuracy', hue='Recalibration', dodge=False, errorbar=None, linewidth=2.5, data=data, ax=axs[1])

for p in barplot.patches:
    barplot.annotate(format(p.get_height(), '.2f'), (p.get_x() + p.get_width() / 2., p.get_height()), ha='center', va='baseline', color="white", xytext=(0, -20), textcoords='offset points')

# Format legends and axes
handles, labels = axs[0].get_legend_handles_labels()
axs[0].get_legend().remove()
axs[0].set_xlabel("Input Atypicality Quantile")
axs[1].set_xticklabels([])
axs[1].set_xlabel("")
for handle in handles:
    handle.set_linewidth(6)  # Set the desired line width

fig.legend(handles=handles, labels=labels, fontsize=15, loc="upper center", bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=len(labels))
fig.tight_layout()
fig.savefig(os.path.join(args.output_dir, f"{args.dataset_name}_{args.model_name}_llmfigure.pdf"))