In [None]:
import os
import pickle
import argparse
import numpy as np
import pandas as pd

from data_zoo import get_llm_classification_data
from model_zoo import get_llm
from utils.llm_utils import get_experiment_records, get_recalib_predata

def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="imdb")
    parser.add_argument("--model_name", type=str, default="alpaca7b")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the Alpaca7b model.")
    parser.add_argument("--output-dir", type=str, default="./outputs")
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("-f")
    return parser.parse_args()

args = config()

In [None]:
model = get_llm(args.model_name, model_path=args.model_path)
train_sentences, train_labels, test_sentences, test_labels, ds_params = get_llm_classification_data(args.dataset_name, filter_tokenizer=model.tokenizer)
label_choices = list(ds_params["inv_label_dict"].keys())

# Offset by 1 to avoid the start token. This holds for Llama family models.
label_tokens = [model.tokenizer.encode(c)[1] for c in label_choices]
ds_params["label_tokens"] = label_tokens

In [None]:
calib_result_path = os.path.join(args.output_dir, f"calib_llm_{args.model_name}_{args.dataset_name}.pkl")
test_result_path = os.path.join(args.output_dir, f"test_llm_{args.model_name}_{args.dataset_name}.pkl")
recalib_info_path = os.path.join(args.output_dir, f"recalib_llm_{args.model_name}_{args.dataset_name}.pkl")

# get_experiment_records returns a set of stats for the experiment, including the data for completions, predicted probabilities, and atypicality values

# If not already computed, compute the records
if not os.path.exists(test_result_path):
    experiment_records = get_experiment_records(model, test_sentences, test_labels, ds_params, batch_size=args.batch_size)
    with open(test_result_path, "wb") as f:
        pickle.dump(experiment_records, f)

if (not os.path.exists(calib_result_path)) and (args.dataset_name != "imdb"):
    experiment_records = get_experiment_records(model, train_sentences, train_labels, ds_params, batch_size=args.batch_size)
    with open(calib_result_path, "wb") as f:
        pickle.dump(experiment_records, f)

    
if not os.path.exists(recalib_info_path):
    recalib_info = get_recalib_predata(model, ds_params)
    pickle.dump(recalib_info, open(recalib_info_path, "wb"))

# Load the stats.
# If IMDB, we do not use the training set due to the contamination (Llama is very likely trained on the IMDB train set).
if args.dataset_name != "imdb":
    calib_records = pickle.load(open(calib_result_path, "rb"))
test_records = pickle.load(open(test_result_path, "rb"))
recalib_predata = pickle.load(open(recalib_info_path, "rb"))

In [None]:
# Gather the logits, labels, and atypicality to use later

test_logits = test_records["logits"]
test_labels = test_records["labels"]
test_atypicality = -np.expand_dims(test_records["atypicality_total_logprob"], axis=1)

if args.dataset_name == "imdb":
    from sklearn.model_selection import train_test_split
    test_logits, calib_logits, test_labels, calib_labels, test_atypicality, calib_atypicality = train_test_split(test_logits, test_labels, test_atypicality, test_size=0.5, 
                                                        stratify=test_labels, random_state=1)

else: 
    calib_logits = calib_records["logits"]
    calib_labels = calib_records["labels"]
    calib_atypicality = -np.expand_dims(calib_records["atypicality_total_logprob"], axis=1)

In [None]:
from calibration import TemperatureScaler, AtypicalityAwareCalibrator
from scipy.special import softmax
from utils.calibration import compute_calibration 
from utils.plots import get_fig_records


metadata = {"model": args.model_name, "dataset": args.dataset_name}

# Vanilla Model
probs = softmax(test_logits, axis=1)
accuracy = (np.argmax(test_logits, axis=1) == test_labels).mean()
print(f"Accuracy: {accuracy}")

# Temperature Scaling
ts = TemperatureScaler()
ts.fit(calib_logits, calib_labels)

# AAR
aar_calib = AtypicalityAwareCalibrator()
aar_calib.fit(calib_logits, calib_atypicality, calib_labels, max_iters=1500)

prob_info = {
    "probs": {"Uncalibrated": probs,
              "Temp. Scaling": ts.predict_proba(test_logits),
              "Atypicality-Aware": aar_calib.predict_proba(test_logits, test_atypicality)},
    "input_atypicality": test_atypicality,
    "labels": test_labels
}

all_records = get_fig_records(prob_info, N_groups=5, **metadata)

In [None]:
from matplotlib.ticker import FormatStrFormatter
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

data = pd.DataFrame(all_records)

In [None]:
sns.set_context("paper", font_scale=2)  # Adjust the font_scale as needed
fig, axs = plt.subplots(2, 1, figsize=(5, 4.2))
maps = {"Atypicality-Aware": "AAR(Ours)", "Temp. Scaling": "TS"}
data["Recalibration"] = data["Recalibration"].apply(lambda x: maps[x] if x in maps else x)

# Plot ECE vs quantile
sns.lineplot(x='quantile', y='ECE', hue='Recalibration', linewidth=2.5, errorbar=('ci', 95), data=data, ax=axs[0], legend=True)
barplot = sns.barplot(x='Recalibration', y='Accuracy', hue='Recalibration', dodge=False, errorbar=None, linewidth=2.5, data=data, ax=axs[1])

for p in barplot.patches:
    barplot.annotate(format(p.get_height(), '.2f'), (p.get_x() + p.get_width() / 2., p.get_height()), ha='center', va='baseline', color="white", xytext=(0, -20), textcoords='offset points')

# Format legends and axes
handles, labels = axs[0].get_legend_handles_labels()
axs[0].get_legend().remove()
axs[0].set_xlabel("Input Atypicality Quantile")
axs[1].set_xticklabels([])
axs[1].set_xlabel("")
for handle in handles:
    handle.set_linewidth(6)  # Set the desired line width

fig.legend(handles=handles, labels=labels, fontsize=15, loc="upper center", bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=len(labels))
fig.tight_layout()
fig.savefig(os.path.join(args.output_dir, f"{args.dataset_name}_{args.model_name}_llmfigure.pdf"))
