# Experiments in Calibrating Zero-Shot Image Classifiers

This notebooks contains code for experimenting with different calibration methods for binary classifiers, applied to the use case of zero-shot image classification.

## Install Dependencies

In [None]:
! sudo apt install dvipng texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra cm-super
! pip install --upgrade datasets open_clip_torch relplot

## Import Libraries

In [None]:
import pathlib

import datasets
import matplotlib.pyplot as plt
import numpy as np
import open_clip
import pandas as pd
import PIL.Image
import PIL.PngImagePlugin
import relplot
import scipy.special
import sklearn.calibration
import sklearn.isotonic
import sklearn.metrics
import sklearn.model_selection
import torch
import tqdm.auto as tqdm

# fix image loading errors for some large files
PIL.PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024**2)

# have relplot make pretty LaTeX axis labels
relplot.config.use_tex_fonts = True

## Fetch Datasets

In [None]:
DATASET_URLS = {
    # Aerial Image Dataset: https://captain-whu.github.io/AID/
    "aid": "https://drive.google.com/uc?id=1CTdCFoo88_ygMb2PNGnK3QTi8xk_naoA",
    # MIT Places 365: http://places2.csail.mit.edu/
    "places365": "https://drive.google.com/uc?id=1w-0LncVMfBsdtqX7jT-jCTdAnLBZuFtU",
}

In [None]:
for dataset_name, dataset_gdrive_url in DATASET_URLS.items():
    ! gdown {dataset_gdrive_url}
    ! mkdir {dataset_name}
    ! unzip /content/{dataset_name}.zip -d {dataset_name}
    ! rm /content/{dataset_name}.zip

## Implement Zero-Shot Classifier and Calibrators

In [None]:
class ZeroShotClassifier:
    """Wrapper for loading and running zero-shot image classifiers."""

    TEMPLATES = [
        "itap of a {}.",
        "a bad photo of the {}.",
        "a origami {}.",
        "a photo of the large {}.",
        "a {} in a video game.",
        "art of the {}.",
        "a photo of the small {}.",
    ]

    def __init__(self, model_name, pretrained_source, device):
        """Create a new instance of the zero-shot classifier."""
        self.model_name = model_name
        self.pretrained_source = pretrained_source
        self.device = device
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            self.model_name,
            pretrained=self.pretrained_source,
        )
        self.model.eval().to(device)
        self.tokenizer = open_clip.get_tokenizer(self.model_name)
        self.text = None
        self.text_features = None

    def set_text(self, text):
        """Set the text prompt for the classifier."""
        self.text = text
        tokens = self.tokenizer(
            [t.format(self.text) for t in ZeroShotClassifier.TEMPLATES]
        )
        with torch.no_grad(), torch.amp.autocast("cuda"):
            text_features = self.model.encode_text(tokens.to(self.device))
            text_features /= text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.mean(dim=0)
            text_features /= text_features.norm()
        self.text_features = text_features.to("cpu").numpy()

    def get_image_features(self, image: PIL.Image.Image):
        """Get CLIP features for a single image."""
        input_features = self.preprocess(image).unsqueeze(0)
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features.to(self.device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.to("cpu").numpy()

    def get_image_features_batch(self, images):
        """Get CLIP features for a batch of images."""
        input_features = torch.stack([self.preprocess(im) for im in images])
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features.to(self.device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy()

    def score_image(self, image: PIL.Image.Image, with_features: bool = False):
        """Score an image based on the text prompt."""
        image_features = self.get_image_features(image)
        score = (image_features @ self.text_features).item()
        results = {"score": score.astype(np.float32)}
        if with_features:
            results["features"] = image_features
        return results

    def score_image_batch(self, images, with_features=False):
        """Score a batch of images based on the text prompt."""
        image_features = self.get_image_features_batch(images)
        scores = image_features @ self.text_features
        results = {"score": scores.astype(np.float32)}
        if with_features:
            results["features"] = image_features
        return results

    def score_features(self, image_features: np.ndarray):
        """Score image features based on the text prompt."""
        score = (image_features @ self.text_features).item()
        return score.astype(np.float32)

    def score_features_batch(self, image_features: np.ndarray):
        """Score a batch of image features based on the text prompt."""
        scores = image_features @ self.text_features
        return scores.astype(np.float32)


class BaseCalibrator:
    """Base class for score calibration methods."""

    def fit(self, scores: np.ndarray, labels: np.ndarray):
        """Train the calibrator."""
        raise NotImplementedError

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Calibrate scores into probabilities."""
        raise NotImplementedError


class IsotonicCalibrator(BaseCalibrator):
    """Calibrate scores using isotonic regression."""

    def __init__(self) -> None:
        """Create a new instance of the isotonic calibrator."""
        self.calibrator = sklearn.isotonic.IsotonicRegression(out_of_bounds="clip")

    def fit(self, scores: np.ndarray, labels: np.ndarray) -> "IsotonicCalibrator":
        """Fit the isotonic calibrator."""
        self.calibrator.fit(scores, labels)
        return self

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Apply isotonic calibration to scores."""
        return self.calibrator.predict(scores)


class SigmoidCalibrator(BaseCalibrator):
    """Calibrates scores using sigmoid/Platt scaling."""

    def __init__(self) -> None:
        """Create a new instance of the sigmoid calibrator."""
        self.calibrator = sklearn.calibration._SigmoidCalibration()  # brittle

    def fit(self, scores: np.ndarray, labels: np.ndarray) -> "SigmoidCalibrator":
        """Fit the sigmoid calibrator"""
        self.calibrator.fit(scores, labels)
        return self

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Apply sigmoid calibration to scores."""
        return self.calibrator.predict(scores)


class SimilarityBinningAveragingCalibrator(BaseCalibrator):
    """Calibrate scores using Similarity-Binning Averaging calibration."""

    def __init__(
        self,
        k: int = 10,
        alpha: float = 0.95,
        inner_calibration_class: BaseCalibrator = SigmoidCalibrator,
    ) -> None:
        """Create a new instance of the SBA calibrator.

        k: number of nearest neighbors in the bin to average over
        alpha: weighting factor between CLIP features and class probs for similarity
        inner_calibration_class: method for getting calibrated scores from calibration data
        """
        self.k = k
        self.alpha = alpha
        self.inner_calibrator = inner_calibration_class()
        self._train_features = None
        self._train_probs = None

    def fit(
        self, scores: np.ndarray, labels: np.ndarray, features: np.ndarray
    ) -> "SimilarityBinningAveragingCalibrator":
        """Fit the SBA calibrator"""
        self.inner_calibrator.fit(scores, labels)
        self._train_features = features
        self._train_probs = self.inner_calibrator.predict_proba(scores)
        return self

    def predict_proba(self, scores: np.ndarray, features: np.ndarray) -> np.ndarray:
        """Apply SBA calibration to scores."""
        feature_similarities = features @ self._train_features.T
        fs_max = feature_similarities.max()
        fs_min = feature_similarities.min()
        feature_similarities = (feature_similarities - fs_min) / fs_max
        probs = self.inner_calibrator.predict_proba(scores)
        probs_similarities = 1 - np.abs(probs[:, np.newaxis] - self._train_probs)
        similarities = (
            self.alpha * feature_similarities + (1 - self.alpha) * probs_similarities
        )
        top_k_indices = np.argsort(-similarities, axis=1)[:, : self.k]
        top_k_probs = self._train_probs[top_k_indices]
        return top_k_probs.mean(axis=1)


def build_features_dataset(zero_shot_classifier, image_dataset, batch_size=32):
    """Compute and add extracted CLIP features as a column in an image dataset."""
    features_dataset = image_dataset.map(
        lambda images: dict(
            features=zero_shot_classifier.get_image_features_batch(images)
        ),
        input_columns=["image"],
        batched=True,
        batch_size=batch_size,
    )
    return features_dataset


def build_scores_dataset(zero_shot_classifier, features_dataset, batch_size=1024):
    """Compute and add zero-shot scores as a column in a features dataset."""
    def compute_scores(features):
        return {
            "score": zero_shot_classifier.score_features_batch(features),
        }
    scores_dataset = features_dataset.map(
        compute_scores,
        input_columns=["features"],
        batched=True,
        batch_size=batch_size,
    )
    return scores_dataset


def places365_clean_label(label):
    """Clean up Places365 labels for use in zero-shot prompting."""
    if "-" in label:
        noun, adj = label.rsplit("-", 1)
        label = f"{adj} {noun}"
    return label.replace("_", " ")


def create_binary_dataset(
    multiclass_dataset: datasets.Dataset,
    target_class_id: int,
    non_target_ratio: float = 2.0,
):
    """Create a binary dataset from a multiclass dataset.

    The binary dataset will have a target class and a non-target class, with the
    non-target class sampled at a specified ratio to the target class.
    """
    num_target_samples = sum(
        label == target_class_id for label in multiclass_dataset["label"]
    )
    num_non_target_samples = len(multiclass_dataset) - num_target_samples
    non_target_prob = non_target_ratio * num_target_samples / num_non_target_samples

    def filter_dataset(labels_batch):
        keep = []
        for label in labels_batch:
            if label == target_class_id:
                keep.append(True)
            else:
                if np.random.random() < non_target_prob:
                    keep.append(True)
                else:
                    keep.append(False)
        return keep

    binary_dataset = multiclass_dataset.filter(
        filter_dataset,
        input_columns=["label"],
        batched=True,
        batch_size=1024,
    )
    binary_dataset = binary_dataset.map(
        lambda labels: dict(
            label=np.array([int(label == target_class_id) for label in labels]),
        ),
        input_columns=["label"],
        batched=True,
        batch_size=1024,
    )
    binary_dataset = binary_dataset.cast_column(
        "label", datasets.ClassLabel(names=["non-target", "target"])
    )
    return binary_dataset

## Create Zero-Shot Classifier and Extract Image Features

In [None]:
! mkdir -p /content/results
! rm -rf /content/results/*

In [None]:
print("Creating zero-shot classifier")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name, provider = "ViT-B-32-quickgelu", "openai"
#model_name, provider = "ViT-B-16-SigLIP", "webli"
#model_name, provider = "ViT-B-32", "datacomp_m_s128m_b4k"
#model_name, provider = "RN50-quickgelu", "yfcc15m"
zsc = ZeroShotClassifier(model_name, provider, device)

print("Extracting features for AID dataset")
aid_dataset = datasets.load_from_disk("/content/aid", keep_in_memory=True)
aid_dataset = build_features_dataset(zsc, aid_dataset)
aid_dataset.set_format(type="numpy")

print("Extracting features for Places365 dataset (takes several minutes)")
places_dataset = datasets.load_from_disk("/content/places365", keep_in_memory=True)
places_dataset = build_features_dataset(zsc, places_dataset)
places_dataset.set_format(type="numpy")

## Run Zero-Shot Classification and Calibration and Collect Metrics

In [None]:
results_dir = pathlib.Path("/content/results")
all_results = []
calibrators = {}
for target_class_name in tqdm.tqdm(places_dataset.features["label"].names):
    target_class_id = places_dataset.features["label"].str2int(target_class_name)
    target_class_name_clean = places365_clean_label(target_class_name)
    print(f"Processing {target_class_name_clean} (class {target_class_id}) from Places365")
    zsc.set_text(target_class_name_clean)

    # make one-vs-rest version of dataset
    # note, we filter out some non-targets to make the dataset more balanced
    calibration_dataset = create_binary_dataset(places_dataset, target_class_id)
    calibration_dataset = build_scores_dataset(zsc, calibration_dataset)
    calibration_dataset = calibration_dataset.train_test_split(test_size=0.5, seed=42)

    # use AID as the out-of-domain test set
    ood_calibration_dataset = create_binary_dataset(aid_dataset, target_class_id)
    ood_calibration_dataset = build_scores_dataset(zsc, ood_calibration_dataset)

    for calibration_method in ["sigmoid", "isotonic", "sba"]:
        calibrator_args = [
            calibration_dataset["train"]["score"],
            calibration_dataset["train"]["label"],
        ]
        predict_proba_args = [calibration_dataset["test"]["score"]]
        ood_predict_proba_args = [ood_calibration_dataset["score"]]

        if calibration_method == "sigmoid":
            calibrator = SigmoidCalibrator()
        elif calibration_method == "isotonic":
            calibrator = IsotonicCalibrator()
        elif calibration_method == "sba":
            calibrator = SimilarityBinningAveragingCalibrator()
            calibrator_args.append(calibration_dataset["train"]["features"])
            predict_proba_args.append(calibration_dataset["test"]["features"])
            ood_predict_proba_args.append(ood_calibration_dataset["features"])
        else:
            raise ValueError(f"Unknown calibration method {calibration_method}")

        calibrator.fit(*calibrator_args)
        labels = calibration_dataset["test"]["label"]
        probs = calibrator.predict_proba(*predict_proba_args)
        preds = probs >= 0.5
        ood_labels = ood_calibration_dataset["label"]
        ood_probs = calibrator.predict_proba(*ood_predict_proba_args)
        ood_preds = ood_probs >= 0.5

        slug = f"{provider}--{model_name}--{target_class_id:>02}--{target_class_name}--{calibration_method}"
        calibrators[slug] = calibrator
        print(f"Computing plots and metrics for {slug}")
        results = {
            # classifier/calibrator metadata
            "model_name": model_name,
            "provider": provider,
            "target_class_id": target_class_id,
            "target_class_name": target_class_name,
            "target_class_name_clean": target_class_name_clean,
            "calibration_method": calibration_method,
            "slug": slug,
            # in-domain classifier metrics
            "accuracy": sklearn.metrics.accuracy_score(labels, preds),
            "precision": sklearn.metrics.precision_score(labels, preds),
            "recall": sklearn.metrics.recall_score(labels, preds),
            "f1": sklearn.metrics.f1_score(labels, preds),
            # in-domain calibration metrics
            "brier_score": sklearn.metrics.brier_score_loss(labels, probs),
            "sm_ece": relplot.smECE(probs, labels),
            "binned_ece": relplot.metrics.binnedECE(probs, labels, nbins=15),
            # out-of-domain classifier metrics
            "ood_accuracy": sklearn.metrics.accuracy_score(ood_labels, ood_preds),
            "ood_precision": sklearn.metrics.precision_score(ood_labels, ood_preds),
            "ood_recall": sklearn.metrics.recall_score(ood_labels, ood_preds),
            "ood_f1": sklearn.metrics.f1_score(ood_labels, ood_preds),
            # out-of-domain calibration metrics
            "ood_brier_score": sklearn.metrics.brier_score_loss(ood_labels, ood_probs),
            "ood_sm_ece": relplot.smECE(ood_probs, ood_labels),
            "ood_binned_ece": relplot.metrics.binnedECE(ood_probs, ood_labels, nbins=15),
        }
        all_results.append(results)

        show = True
        sklearn.metrics.ConfusionMatrixDisplay.from_predictions(labels, preds)
        plt.savefig(results_dir / f"{slug}-confusion_matrix.png", bbox_inches="tight")
        plt.show() if show else plt.close()
        relplot.rel_diagram(probs, labels)
        plt.savefig(results_dir / f"{slug}-reliability_diagram.png", bbox_inches="tight")
        plt.show() if show else plt.close()
        relplot.rel_diagram_binned(probs, labels)
        plt.savefig(results_dir / f"{slug}-reliability_diagram_binned.png", bbox_inches="tight")
        plt.show() if show else plt.close()

        sklearn.metrics.ConfusionMatrixDisplay.from_predictions(ood_labels, ood_preds)
        plt.savefig(results_dir / f"{slug}-ood-confusion_matrix.png", bbox_inches="tight")
        plt.show() if show else plt.close()
        relplot.rel_diagram(ood_probs, ood_labels)
        plt.savefig(results_dir / f"{slug}-ood-reliability_diagram.png", bbox_inches="tight")
        plt.show() if show else plt.close()
        relplot.rel_diagram_binned(ood_probs, ood_labels)
        plt.savefig(results_dir / f"{slug}-ood-reliability_diagram_binned.png", bbox_inches="tight")
        plt.show() if show else plt.close()

## Analyze Results

In [None]:
results_df = pd.DataFrame(all_results)
results_df.to_csv(results_dir / f"metrics-{provider}-{model_name}.csv")

In [None]:
results_df["calibration_method"] = results_df["slug"].str.split("--").str[-1]
overall_stats = (
    results_df.groupby("calibration_method")
    .agg({"sm_ece": ["mean", "std"], "ood_sm_ece": ["mean", "std"]})
    .round(4)
)
overall_stats["degradation"] = (
    overall_stats[("ood_sm_ece", "mean")] / overall_stats[("sm_ece", "mean")]
).round(2)

per_class_stats = (
    results_df.groupby(["target_class_name_clean", "calibration_method"])
    .agg({"sm_ece": "mean", "ood_sm_ece": "mean"})
    .round(4)
)
per_class_stats["degradation"] = (
    per_class_stats["ood_sm_ece"] / per_class_stats["sm_ece"]
).round(2)

In [None]:
overall_stats

In [None]:
per_class_stats

In [None]:
import seaborn as sns

plot_data = pd.melt(
    results_df,
    id_vars=["calibration_method"],
    value_vars=["sm_ece", "ood_sm_ece"],
    var_name="metric",
    value_name="value"
)

sns.boxplot(
    data=plot_data,
    x="calibration_method",
    y="value",
    hue="metric",
)
plt.title("smECE Distribution by Calibration Method")
plt.ylabel("smECE Value")

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
colors = ["#1f77b4", "#2ca02c", "#ff7f0e"]
markers = ["o", "s", "^"]

for method, color, marker in zip(["sigmoid", "isotonic", "sba"], colors, markers):
    method_data = results_df[results_df["calibration_method"] == method]
    plt.scatter(
        method_data["sm_ece"],
        method_data["accuracy"],
        label=method,
        alpha=0.7,
        c=color,
        marker=marker,
        s=100
    )

    z = np.polyfit(method_data["sm_ece"], method_data["accuracy"], 1)
    p = np.poly1d(z)
    plt.plot(
        method_data["sm_ece"],
        p(method_data["sm_ece"]),
        c=color,
        linestyle="--",
        alpha=0.5
    )

plt.xlabel("smECE")
plt.ylabel("Accuracy")
plt.title("Accuracy vs smECE by Calibration Method")
plt.legend(title="Calibration Method")
plt.grid(True, alpha=0.3)

# Add correlation information as text
correlations = []
for method in ["sigmoid", "isotonic", "sba"]:
    method_data = results_df[results_df["calibration_method"] == method]
    corr = method_data["accuracy"].corr(method_data["sm_ece"]).round(3)
    correlations.append(f"{method}: r = {corr}")

plt.text(
    0.02,
    0.02,
    "Correlations:\n" + "\n".join(correlations),
    transform=plt.gca().transAxes,
    bbox=dict(facecolor="white", alpha=0.8)
)

plt.tight_layout()
plt.show()

In [None]:
! zip -r results results