<a href="https://colab.research.google.com/github/cameronbc/CS475-project-calibrating-zero-shot-image-classifiers/blob/main/zero-shot-calibration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Experiments in Calibrating Zero-Shot Image Classifiers

This notebooks contains code for experimenting with different calibration methods for binary classifiers, applied to the use case of zero-shot image classification.

## Install Dependencies

In [None]:
! sudo apt install dvipng texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra cm-super
! pip install --upgrade datasets open_clip_torch relplot

## Import Libraries

In [None]:
import glob
import pathlib

import datasets
import matplotlib.pyplot as plt
import numpy as np
import open_clip
import pandas as pd
import PIL.Image
import PIL.PngImagePlugin
import relplot
import scipy.special
import seaborn as sns
import sklearn.calibration
import sklearn.isotonic
import sklearn.metrics
import sklearn.model_selection
import torch
import tqdm.auto as tqdm

# fix image loading errors for some large files
PIL.PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024**2)

# have relplot make pretty LaTeX axis labels
relplot.config.use_tex_fonts = True

## Fetch Datasets

In [None]:
DATASET_URLS = {
    # Aerial Image Dataset: https://captain-whu.github.io/AID/
    "aid": "https://drive.google.com/uc?id=1CTdCFoo88_ygMb2PNGnK3QTi8xk_naoA",
    # MIT Places 365: http://places2.csail.mit.edu/
    "places365": "https://drive.google.com/uc?id=1w-0LncVMfBsdtqX7jT-jCTdAnLBZuFtU",
}

In [None]:
for dataset_name, dataset_gdrive_url in DATASET_URLS.items():
    ! gdown {dataset_gdrive_url}
    ! mkdir {dataset_name}
    ! unzip /content/{dataset_name}.zip -d {dataset_name}
    ! rm /content/{dataset_name}.zip

## Implement Zero-Shot Classifier and Calibrators

In [None]:
class ZeroShotClassifier:
    """Wrapper for loading and running zero-shot image classifiers."""

    TEMPLATES = [
        "itap of a {}.",
        "a bad photo of the {}.",
        "a origami {}.",
        "a photo of the large {}.",
        "a {} in a video game.",
        "art of the {}.",
        "a photo of the small {}.",
    ]

    def __init__(self, model_name, pretrained_source, device):
        """Create a new instance of the zero-shot classifier."""
        self.model_name = model_name
        self.pretrained_source = pretrained_source
        self.device = device
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            self.model_name,
            pretrained=self.pretrained_source,
        )
        self.model.eval().to(device)
        self.tokenizer = open_clip.get_tokenizer(self.model_name)
        self.text = None
        self.text_features = None

    def set_text(self, text):
        """Set the text prompt for the classifier."""
        self.text = text
        tokens = self.tokenizer(
            [t.format(self.text) for t in ZeroShotClassifier.TEMPLATES]
        )
        with torch.no_grad(), torch.amp.autocast("cuda"):
            text_features = self.model.encode_text(tokens.to(self.device))
            text_features /= text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.mean(dim=0)
            text_features /= text_features.norm()
        self.text_features = text_features.to("cpu").numpy()

    def get_image_features(self, image: PIL.Image.Image):
        """Get CLIP features for a single image."""
        input_features = self.preprocess(image).unsqueeze(0)
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features.to(self.device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.to("cpu").numpy()

    def get_image_features_batch(self, images):
        """Get CLIP features for a batch of images."""
        input_features = torch.stack([self.preprocess(im) for im in images])
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features.to(self.device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy()

    def score_image(self, image: PIL.Image.Image, with_features: bool = False):
        """Score an image based on the text prompt."""
        image_features = self.get_image_features(image)
        score = (image_features @ self.text_features).item()
        results = {"score": score.astype(np.float32)}
        if with_features:
            results["features"] = image_features
        return results

    def score_image_batch(self, images, with_features=False):
        """Score a batch of images based on the text prompt."""
        image_features = self.get_image_features_batch(images)
        scores = image_features @ self.text_features
        results = {"score": scores.astype(np.float32)}
        if with_features:
            results["features"] = image_features
        return results

    def score_features(self, image_features: np.ndarray):
        """Score image features based on the text prompt."""
        score = (image_features @ self.text_features).item()
        return score.astype(np.float32)

    def score_features_batch(self, image_features: np.ndarray):
        """Score a batch of image features based on the text prompt."""
        scores = image_features @ self.text_features
        return scores.astype(np.float32)


class BaseCalibrator:
    """Base class for score calibration methods."""

    def fit(self, scores: np.ndarray, labels: np.ndarray):
        """Train the calibrator."""
        raise NotImplementedError

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Calibrate scores into probabilities."""
        raise NotImplementedError


class IsotonicCalibrator(BaseCalibrator):
    """Calibrate scores using isotonic regression."""

    def __init__(self) -> None:
        """Create a new instance of the isotonic calibrator."""
        self.calibrator = sklearn.isotonic.IsotonicRegression(out_of_bounds="clip")

    def fit(self, scores: np.ndarray, labels: np.ndarray) -> "IsotonicCalibrator":
        """Fit the isotonic calibrator."""
        self.calibrator.fit(scores, labels)
        return self

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Apply isotonic calibration to scores."""
        return self.calibrator.predict(scores)


class SigmoidCalibrator(BaseCalibrator):
    """Calibrates scores using sigmoid/Platt scaling."""

    def __init__(self) -> None:
        """Create a new instance of the sigmoid calibrator."""
        self.calibrator = sklearn.calibration._SigmoidCalibration()  # brittle

    def fit(self, scores: np.ndarray, labels: np.ndarray) -> "SigmoidCalibrator":
        """Fit the sigmoid calibrator"""
        self.calibrator.fit(scores, labels)
        return self

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Apply sigmoid calibration to scores."""
        return self.calibrator.predict(scores)


class SigLIPCalibrator(BaseCalibrator):
    """Calibrates scores using sigmoid/Platt scaling."""

    def __init__(self, zsc: ZeroShotClassifier, target_prior=None) -> None:
        """Create a new instance of the SigLIP sigmoid-style calibrator."""
        if "siglip" not in zsc.model_name.lower():
            raise ValueError("Zero-shot classifier not a sigLIP instance")
        with torch.no_grad():
            self.scale = zsc.model.logit_scale.exp().cpu().numpy()  # AKA slope
            self.shift = zsc.model.logit_bias.cpu().numpy()  # AKA intercept
        self.target_prior = target_prior
        if self.target_prior is not None:
            self.shift = np.log(self.target_prior / (1 - self.target_prior))

    def fit(self, scores: np.ndarray, labels: np.ndarray) -> "SigLIPCalibrator":
        """Fit the sigLIP calibrator (does nothing, comes trained from sigLIP)"""
        return self

    def predict_proba(self, scores: np.ndarray) -> np.ndarray:
        """Apply sigmoid calibration to scores."""
        return scipy.special.expit(self.scale * scores + self.shift)


class SimilarityBinningAveragingCalibrator(BaseCalibrator):
    """Calibrate scores using Similarity-Binning Averaging calibration."""

    def __init__(
        self,
        k: int = 10,
        alpha: float = 0.95,
        inner_calibration_class: BaseCalibrator = SigmoidCalibrator,
    ) -> None:
        """Create a new instance of the SBA calibrator.

        k: number of nearest neighbors in the bin to average over
        alpha: weighting factor between CLIP features and class probs for similarity
        inner_calibration_class: method for getting calibrated scores from calibration data
        """
        self.k = k
        self.alpha = alpha
        self.inner_calibrator = inner_calibration_class()
        self._train_features = None
        self._train_probs = None

    def fit(
        self, scores: np.ndarray, labels: np.ndarray, features: np.ndarray
    ) -> "SimilarityBinningAveragingCalibrator":
        """Fit the SBA calibrator"""
        self.inner_calibrator.fit(scores, labels)
        self._train_features = features
        self._train_probs = self.inner_calibrator.predict_proba(scores)
        return self

    def predict_proba(self, scores: np.ndarray, features: np.ndarray) -> np.ndarray:
        """Apply SBA calibration to scores."""
        feature_similarities = features @ self._train_features.T
        fs_max = feature_similarities.max()
        fs_min = feature_similarities.min()
        feature_similarities = (feature_similarities - fs_min) / fs_max
        probs = self.inner_calibrator.predict_proba(scores)
        probs_similarities = 1 - np.abs(probs[:, np.newaxis] - self._train_probs)
        similarities = (
            self.alpha * feature_similarities + (1 - self.alpha) * probs_similarities
        )
        top_k_indices = np.argsort(-similarities, axis=1)[:, : self.k]
        top_k_probs = self._train_probs[top_k_indices]
        return top_k_probs.mean(axis=1)


def build_features_dataset(
    zero_shot_classifier: ZeroShotClassifier,
    image_dataset: datasets.Dataset,
    batch_size: int = 32,
    keep_in_memory: bool = False,
):
    """Compute and add extracted CLIP features as a column in an image dataset."""
    features_dataset = image_dataset.map(
        lambda images: dict(
            features=zero_shot_classifier.get_image_features_batch(images)
        ),
        input_columns=["image"],
        batched=True,
        batch_size=batch_size,
        keep_in_memory=keep_in_memory,
        desc="Extracting features",
    )
    return features_dataset.remove_columns("image")


def build_scores_dataset(
    zero_shot_classifier: ZeroShotClassifier,
    features_dataset: pd.DataFrame,
    batch_size: int = 1024,
    keep_in_memory: bool = False,
):
    """Compute and add zero-shot scores as a column in a features dataset."""
    features = np.stack(features_dataset["features"].values)
    scores = zero_shot_classifier.score_features_batch(features)
    scores_dataset = features_dataset.copy()
    scores_dataset["score"] = scores
    return scores_dataset


def places365_clean_label(label: str):
    """Clean up Places365 labels for use in zero-shot prompting."""
    if "-" in label:
        noun, adj = label.rsplit("-", 1)
        label = f"{adj} {noun}"
    return label.replace("_", " ")


def create_binary_dataset(
    multiclass_dataset: pd.DataFrame,
    target_class_id: int,
    target_prior: float = 0.5,
    keep_in_memory: bool = False,
):
    """Create a binary dataset from a multiclass dataset.

    The binary dataset will have a target class and a non-target class, with the
    non-target class sampled at a specified ratio to the target class.
    """
    num_target_samples = sum(
        label == target_class_id for label in multiclass_dataset["label"]
    )
    num_non_target_samples = len(multiclass_dataset) - num_target_samples
    non_target_odds = (1.0 - target_prior) / target_prior
    num_non_target_keep = non_target_odds * num_target_samples
    non_target_keep_prob = num_non_target_keep / num_non_target_samples

    keep = []
    for i, label in enumerate(multiclass_dataset["label"]):
        if label == target_class_id:
            keep.append(i)
        else:
            if np.random.random() < non_target_keep_prob:
                keep.append(i)
    binary_dataset = multiclass_dataset.filter(keep, axis=0)
    binary_dataset["label"] = binary_dataset["label"].apply(
        lambda label: int(label == target_class_id)
    )
    binary_dataset.reset_index(drop=True, inplace=True)
    return binary_dataset


def plot_score_distributions(score_dataset: datasets.Dataset, title: str):
    """Plot histograms of target and non-target scores."""
    target_scores = score_dataset["score"][score_dataset["label"] == 1]
    non_target_scores = score_dataset["score"][score_dataset["label"] == 0]

    sns.histplot(
        target_scores,
        stat="density",
        alpha=0.6,
        color="#2ecc71",
        label="Target",
        bins=30,
    )
    sns.histplot(
        non_target_scores,
        stat="density",
        alpha=0.6,
        color="#e74c3c",
        label="Non-Target",
        bins=30,
    )

    plt.axvline(
        target_scores.mean(),
        color="#27ae60",
        linestyle="--",
        alpha=0.8,
        label="Target Mean",
    )
    plt.axvline(
        non_target_scores.mean(),
        color="#c0392b",
        linestyle="--",
        alpha=0.8,
        label="Non-Target Mean",
    )

    plt.xlabel("Model Score")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)

## Create Zero-Shot Classifier and Extract Image Features

In [None]:
! mkdir -p /content/results
! rm -rf /content/results/*

In [None]:
print("Creating zero-shot classifier")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name, provider = "ViT-B-32-quickgelu", "openai"
#model_name, provider = "ViT-B-16-SigLIP", "webli"
#model_name, provider = "ViT-B-32", "datacomp_m_s128m_b4k"
#model_name, provider = "RN50-quickgelu", "yfcc15m"
zsc = ZeroShotClassifier(model_name, provider, device)

print("Extracting CLIP features for AID dataset")
aid_dataset = datasets.load_from_disk("/content/aid")
aid_dataset = build_features_dataset(zsc, aid_dataset)
aid_dataset.set_format(type="numpy")

print("Extracting CLIP features for Places365 dataset (takes several minutes)")
places_dataset = datasets.load_from_disk("/content/places365")
places_dataset = build_features_dataset(zsc, places_dataset)
places_dataset.set_format(type="numpy")

# class names from AID and Places365 already mapped to match during data prep
class_names = places_dataset.features["label"].names
class_ids = {c: places_dataset.features["label"].str2int(c) for c in class_names}

# use pandas dataframes instead of HF datasets since images no longer needed
aid_dataset = pd.DataFrame(aid_dataset)
places_dataset = pd.DataFrame(places_dataset)
model_slug = f"{model_name}--{provider}"
results_dir = pathlib.Path("/content/results") / model_slug
results_dir.mkdir(exist_ok=True, parents=True)
plots_dir = results_dir / "plots"
plots_dir.mkdir(exist_ok=True, parents=True)

## Run Zero-Shot Classification and In-Class Calibration and Collect Metrics

In [None]:
target_prior = 1.0 / 3.0
calibration_methods = ["sigmoid", "isotonic", "SBA"]
if "siglip" in zsc.model_name.lower():
    calibration_methods.append("siglip")
    siglip_calibrator = SigLIPCalibrator(zsc, target_prior=target_prior)
in_domain_results = []
ood_results = []
calibrators = {}
in_domain_calibration_test_sets = {}
show_plots = False

In [None]:
for target_class_name in tqdm.tqdm(class_names):
    target_class_id = class_ids[target_class_name]
    target_class_name_clean = places365_clean_label(target_class_name)
    eval_slug = f"eval--{target_class_id:>02}--{target_class_name}--inclass"
    print(f"Processing {eval_slug}")
    zsc.set_text(target_class_name_clean)

    # make one-vs-rest version of dataset
    # note, we filter out some non-targets to make the dataset more balanced
    calibration_dataset = create_binary_dataset(places_dataset, target_class_id, target_prior)
    calibration_dataset = build_scores_dataset(zsc, calibration_dataset)
    cal_train, cal_test = sklearn.model_selection.train_test_split(
        calibration_dataset, test_size=0.5, random_state=42
    )
    plot_score_distributions(
        cal_test,
        f"In-domain Score Distribution for {target_class_name_clean.title()}",
    )
    plt.savefig(plots_dir / f"{eval_slug}--indomain--scores.png", bbox_inches="tight")
    plt.show() if show_plots else plt.close()

    # use AID as the out-of-domain test set
    calibration_dataset_ood = create_binary_dataset(aid_dataset, target_class_id, target_prior)
    cal_test_ood = build_scores_dataset(zsc, calibration_dataset_ood)
    plot_score_distributions(
        cal_test_ood,
        f"Out-of-domain Score Distribution for {target_class_name_clean.title()}",
    )
    plt.savefig(plots_dir / f"{eval_slug}--ood-scores.png", bbox_inches="tight")
    plt.show() if show_plots else plt.close()

    # cache calibration test partition to test in-domain calibration transfer later
    in_domain_calibration_test_sets[target_class_name] = cal_test.drop("features", axis=1)

    for calibration_method in calibration_methods:
        calibrator_args = [cal_train["score"], cal_train["label"]]
        predict_proba_args = [cal_test["score"]]
        ood_predict_proba_args = [cal_test_ood["score"]]

        if calibration_method == "sigmoid":
            calibrator = SigmoidCalibrator()
        elif calibration_method == "isotonic":
            calibrator = IsotonicCalibrator()
        elif calibration_method == "siglip":
            calibrator = siglip_calibrator
        elif calibration_method == "SBA":
            calibrator = SimilarityBinningAveragingCalibrator()
            calibrator_args.append(np.stack(cal_train["features"].values))
            predict_proba_args.append(np.stack(cal_test["features"].values))
            ood_predict_proba_args.append(np.stack(cal_test_ood["features"].values))
        else:
            raise ValueError(f"Unknown calibration method {calibration_method}")

        calibrator.fit(*calibrator_args)
        labels = cal_test["label"]
        probs = calibrator.predict_proba(*predict_proba_args)
        preds = probs >= 0.5
        ood_labels = cal_test_ood["label"]
        ood_probs = calibrator.predict_proba(*ood_predict_proba_args)
        ood_preds = ood_probs >= 0.5

        cal_slug = f"places365--{target_class_id:>02}--{target_class_name}--{calibration_method}"
        if calibration_method != "SBA":
            calibrators[cal_slug] = calibrator

        results_metadata = {
            "model_name": model_name,
            "provider": provider,
            "calibration_method": calibration_method,
            "train_class_id": target_class_id,
            "train_class_name": target_class_name,
            "train_class_name_clean": target_class_name_clean,
            "test_class_id": target_class_id,
            "test_class_name": target_class_name,
            "test_class_name_clean": target_class_name_clean,
        }

        results_slug = f"{eval_slug}--indomain--{calibration_method}"
        print(f"Computing plots and metrics for {results_slug}")
        results = {
            # classifier/calibrator metadata
            "test_class_domain": "indomain",
            "slug": results_slug,
            # in-domain classifier metrics
            "accuracy": sklearn.metrics.accuracy_score(labels, preds),
            "precision": sklearn.metrics.precision_score(labels, preds),
            "recall": sklearn.metrics.recall_score(labels, preds),
            "f1": sklearn.metrics.f1_score(labels, preds),
            # in-domain calibration metrics
            "brier_score": sklearn.metrics.brier_score_loss(labels, probs),
            "sm_ece": relplot.smECE(probs, labels),
            "binned_ece": relplot.metrics.binnedECE(probs, labels, nbins=15),
        }
        in_domain_results.append({**results_metadata, **results})

        sklearn.metrics.ConfusionMatrixDisplay.from_predictions(labels, preds)
        plt.savefig(plots_dir / f"{results_slug}-cm.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()
        relplot.rel_diagram(probs, labels)
        plt.savefig(plots_dir / f"{results_slug}-relplot.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()
        relplot.rel_diagram_binned(probs, labels)
        plt.savefig(plots_dir / f"{results_slug}-relplot-binned.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()

        results_slug = f"{eval_slug}--ood--{calibration_method}"
        print(f"Computing plots and metrics for {results_slug}")
        results = {
            "test_class_domain": "ood",
            "slug": results_slug,
            # out-of-domain classifier metrics
            "accuracy": sklearn.metrics.accuracy_score(ood_labels, ood_preds),
            "precision": sklearn.metrics.precision_score(ood_labels, ood_preds),
            "recall": sklearn.metrics.recall_score(ood_labels, ood_preds),
            "f1": sklearn.metrics.f1_score(ood_labels, ood_preds),
            # out-of-domain calibration metrics
            "brier_score": sklearn.metrics.brier_score_loss(ood_labels, ood_probs),
            "sm_ece": relplot.smECE(ood_probs, ood_labels),
            "binned_ece": relplot.metrics.binnedECE(ood_probs, ood_labels, nbins=15),
        }
        ood_results.append({**results_metadata, **results})

        sklearn.metrics.ConfusionMatrixDisplay.from_predictions(ood_labels, ood_preds)
        plt.savefig(plots_dir / f"{results_slug}-cm.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()
        relplot.rel_diagram(ood_probs, ood_labels)
        plt.savefig(plots_dir / f"{results_slug}-relplot.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()
        relplot.rel_diagram_binned(ood_probs, ood_labels)
        plt.savefig(plots_dir / f"{results_slug}-relplot-binned.png", bbox_inches="tight")
        plt.show() if show_plots else plt.close()

in_domain_df = pd.DataFrame(in_domain_results)
save_path = results_dir / f"metrics--indomain.csv"
in_domain_df.to_csv(save_path)
print(f"wrote in-domain metrics to {save_path}")

ood_df = pd.DataFrame(ood_results)
save_path = results_dir / f"metrics--ood.csv"
ood_df.to_csv(save_path)
print(f"wrote out-of-domain metrics to {save_path}")

## Run In-Domain Calibration Transfer Evaluations

In [None]:
in_domain_transfer_results = []
for cal_test_class_name in tqdm.tqdm(class_names):
    print(f"Evaluating transfer of in-domain calibrators (i.e. trained on other classes in Places365) to {cal_test_class_name} calibration test set")
    cal_test_class_id = class_ids[cal_test_class_name]
    cal_test_class_name_clean = places365_clean_label(cal_test_class_name)
    cal_test = in_domain_calibration_test_sets[cal_test_class_name]
    eval_slug = f"eval--{cal_test_class_id:>02}--{cal_test_class_name}--xfer"
    labels = cal_test["label"]
    for cal_train_class_name in tqdm.tqdm(class_names):
        # could skip when train == test, but keep as a sanity test for now
        # if cal_train_class_name == cal_test_class_name:
        #     continue
        cal_train_class_id = class_ids[cal_train_class_name]
        cal_train_class_name_clean = places365_clean_label(cal_train_class_name)
        for calibration_method in calibration_methods:
            if calibration_method == "SBA":
                # skip SBA since it cannot transfer in-domain (non-parameteric)
                continue
            cal_slug = f"places365--{cal_train_class_id:>02}--{cal_train_class_name}--{calibration_method}"
            calibrator = calibrators[cal_slug]
            probs = calibrator.predict_proba(cal_test["score"])
            preds = probs >= 0.5
            results_slug = f"{eval_slug}--{cal_train_class_id:>02}--{cal_train_class_name}--indomain"
            results = {
                # classifier/calibrator metadata
                "model_name": model_name,
                "provider": provider,
                "calibration_method": calibration_method,
                "train_class_id": cal_train_class_id,
                "train_class_name": cal_train_class_name,
                "train_class_name_clean": cal_train_class_name_clean,
                "test_class_id": cal_test_class_id,
                "test_class_name": cal_test_class_name,
                "test_class_name_clean": cal_test_class_name_clean,
                "test_class_domain": "xfer",
                "slug": results_slug,
                # in-domain class-transfer classifier metrics
                "accuracy": sklearn.metrics.accuracy_score(labels, preds),
                "precision": sklearn.metrics.precision_score(labels, preds),
                "recall": sklearn.metrics.recall_score(labels, preds),
                "f1": sklearn.metrics.f1_score(labels, preds),
                # in-domain class-transfer calibration metrics
                "brier_score": sklearn.metrics.brier_score_loss(labels, probs),
                "sm_ece": relplot.smECE(probs, labels),
                "binned_ece": relplot.metrics.binnedECE(probs, labels, nbins=15),
            }
            in_domain_transfer_results.append(results)
            show = False
            make_plots = False
            if make_plots:
                sklearn.metrics.ConfusionMatrixDisplay.from_predictions(labels, preds)
                plt.savefig(plots_dir / f"{results_slug}-cm.png", bbox_inches="tight")
                plt.show() if show else plt.close()
                relplot.rel_diagram(probs, labels)
                plt.savefig(plots_dir / f"{results_slug}-relplot.png", bbox_inches="tight")
                plt.show() if show else plt.close()
                relplot.rel_diagram_binned(probs, labels)
                plt.savefig(plots_dir / f"{results_slug}-relplot-binned.png", bbox_inches="tight")
                plt.show() if show else plt.close()
xfer_df = pd.DataFrame(in_domain_transfer_results)
save_path = results_dir / f"metrics--xfer.csv"
xfer_df.to_csv(save_path)
print(f"wrote in-domain transfer metrics to {save_path}")

In [None]:
! zip -r results results

## Analyze Results

In [None]:
#model_name, provider = "ViT-B-32-quickgelu", "openai"
model_name, provider = "ViT-B-16-SigLIP", "webli"
model_slug = f"{model_name}--{provider}"
results_dir = pathlib.Path("/content/results") / model_slug

in_domain_df = pd.read_csv(results_dir / "metrics--indomain.csv")
ood_df = pd.read_csv(results_dir / "metrics--ood.csv")
xfer_df = pd.read_csv(results_dir / "metrics--xfer.csv")

combined_df = pd.concat([in_domain_df, ood_df, xfer_df])
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.2)

In [None]:
summary_stats = combined_df.groupby(["calibration_method", "test_class_domain"])[["sm_ece", "accuracy"]].mean()
summary_stats["accuracy"] = summary_stats["accuracy"].apply(lambda x: f"{x:.1%}")
summary_stats = summary_stats.reindex([
    ("sigmoid", "indomain"),
    ("sigmoid", "ood"),
    ("sigmoid", "xfer"),
    ("isotonic", "indomain"),
    ("isotonic", "ood"),
    ("isotonic", "xfer"),
    ("siglip", "indomain"),
    ("siglip", "ood"),
    ("siglip", "xfer"),
    ("SBA", "indomain"),
    ("SBA", "ood"),
])
summary_stats = summary_stats.rename(index={
    "indomain": "In Domain",
    "ood": "Out of Domain",
    "sigmoid": "Sigmoid",
    "isotonic": "Iso. Reg.",
    "SBA": "SBA",
    "siglip": "SigLIP",
    "xfer": "Cross Class",
})
summary_stats.index = summary_stats.index.set_names(["Calibration Method", "Condition"])
summary_stats = summary_stats.rename(columns={
    "sm_ece": "SmoothECE",
    "accuracy": "Accuracy"
})
summary_stats = summary_stats.round(3)
summary_stats

In [None]:
order = ["sigmoid", "isotonic", "siglip", "SBA"] if "siglip" in model_name.lower() else ["sigmoid", "isotonic", "SBA",]
display_names = ["Sigmoid", "Iso. Reg.", "SigLIP", "SBA"] if "siglip" in model_name.lower() else ["Sigmoid", "Iso. Reg.", "SBA"]

plt.figure(figsize=(10, 6))
sns.barplot(
    data=combined_df,
    x="calibration_method",
    order=order,
    y="sm_ece",
    hue="test_class_domain",
    hue_order=["indomain", "ood", "xfer"],
    palette="pastel"
)
#plt.bar_label(plt.gca().containers[0], fmt="%.3f", label_type="edge")
#plt.bar_label(plt.gca().containers[1], fmt="%.3f", label_type="edge")
plt.title(f"Calibration Performance Across Domains for {model_name}", pad=20)
plt.ylabel("SmoothECE", labelpad=10)
plt.xticks(range(4 if "siglip" in model_name.lower() else 3), display_names, ha="center")
plt.xlabel("Calibration Method", labelpad=10)
boxes, _ = plt.gca().get_legend_handles_labels()
plt.legend(boxes, ["In Domain", "Out of Domain", "Cross Class"], bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.savefig(results_dir / "plot-calibrations.png")
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.barplot(
    data=combined_df,
    x="calibration_method",
    order=order,
    y="accuracy",
    hue="test_class_domain",
    hue_order=["indomain", "ood", "xfer"],
    palette="pastel"
)
plt.title(f"Accuracy Across Domains for {model_name}", pad=20)
plt.ylabel("Accuracy", labelpad=10)
plt.xticks(range(4 if "siglip" in model_name.lower() else 3), display_names, ha="center")
plt.xlabel("Calibration Method", labelpad=10)
boxes, _ = plt.gca().get_legend_handles_labels()
plt.legend(boxes, ["In Domain", "Out of Domain", "Cross Class"], bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.savefig(results_dir / "plot-accuracies.png")
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
colors = ["#1f77b4", "#2ca02c", "#ff7f0e", "#d44674"] if "siglip" in model_name.lower() else ["#1f77b4", "#2ca02c", "#d44674"]
markers = ["o", "s", "^", "x"] if "siglip" in model_name.lower() else ["o", "s", "x"]

for method, color, marker in zip(order, colors, markers):
    method_data = in_domain_df[in_domain_df["calibration_method"] == method]
    plt.scatter(
        method_data["sm_ece"],
        method_data["accuracy"],
        label=method,
        alpha=0.7,
        c=color,
        marker=marker,
        s=100
    )

    z = np.polyfit(method_data["sm_ece"], method_data["accuracy"], 1)
    p = np.poly1d(z)
    plt.plot(
        method_data["sm_ece"],
        p(method_data["sm_ece"]),
        c=color,
        linestyle="--",
        alpha=0.5
    )

plt.xlabel("smECE")
plt.ylabel("Accuracy")
plt.title(f"Accuracy vs smECE by Calibration Method for {model_name}")
plt.legend()
boxes, _ = plt.gca().get_legend_handles_labels()
plt.legend(boxes, ["Sigmoid", "Iso. Reg.", "SigLIP", "SBA"], title="Calibration Method")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(results_dir / "plot-calibrations-vs-accuracies.png")
plt.show()

In [None]:
! rm results.zip
! zip -r results results