This tutorial has been modified from imagenet-smallest-sets.ipynb from conformal-prediction (https://github.com/aangelopoulos/conformal-prediction). It is based on Sadinle et al. (2016) (https://arxiv.org/abs/1609.00451). This worksheet in particular was created with the help of Claude 3.7.

# Worksheet: Image classification example using ImageNet

In this tutorial, we will apply conformal prediction to an image classification problem to obtain prediction sets of plausible labels using a pre-trained model. This means we will make use of the model outputs rather than training an image classifier from scratch.

## What is conformal prediction?

Conformal prediction is a framework that allows us to quantify uncertainty in machine learning predictions by producing prediction sets that are guaranteed to contain the true label with a specified probability (e.g., 90%). Unlike traditional machine learning methods that output a single prediction, conformal prediction produces a set of plausible predictions with a statistical guarantee on its coverage.

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.image import imread
!pip install -U --no-cache-dir gdown --pre

In [None]:
# Requires 1.31G space!!
# Download the data. The data include softmax scores from a pre-trained ResNet-152 model on ImageNet
if not os.path.exists('../data'):
    os.system('gdown 1h7S6N_Rx7gdfO3ZunzErZy6H7620EbZK -O ../data.tar.gz')
    os.system('tar -xf ../data.tar.gz -C ../')
    os.system('rm ../data.tar.gz')
if not os.path.exists('../data/imagenet/human_readable_labels.json'):
    !wget -nv -O ../data/imagenet/human_readable_labels.json -L https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json

# Load the data
data = np.load('../data/imagenet/imagenet-resnet152.npz') # softmax scores-label pairs
example_paths = os.listdir('../data/imagenet/examples') # path to actual image files
smx = data['smx'] # softmax scores of images from a pre-trained model
labels = data['labels'].astype(int) # true labels

# Examine the data shape
print(f"Shape of smx: {smx.shape}") # shows the number of images and number of classes
print(f"Number of example images: {len(example_paths)}")


Each row of $\texttt{smx}$ is the *softmax scores* of an image, which we can think of as estimated probabilities that the given image belongs to each of $K = 1000$ possible classes. Symbolically, $$\texttt{smx}[i,] = \hat{f}(\text{Image}_i) \in [0,1]^{K}, \quad \texttt{smx}[i, j] \approx \mathbb{P}\{\text{Image}_i \text{ has Label } j\} \text{ (according to $\hat{f}$)}.$$

For this example, we are going to use the nonconformity score $$s(x, y) = 1-\hat{f}(x)_y,$$ i.e., 1-the softmax score for the *true* class, which has the interpretation of the probability of the image $x$ *not* belonging to the true class *according* to the model $\hat{f}$.

**Think.** Why does this nonconformity score make sense?

In [None]:
# EXERCISE: Set the parameters
# alpha controls the error rate we're willing to accept
# alpha = 0.1 means we want our prediction sets to contain the true label at least 90% of the time
alpha = # TODO: Set alpha to 0.1 for 90% coverage
m = # TODO: Set m to 1000 (number of calibration points)

In [None]:
# EXERCISE: Split the data into calibration and test sets
# Create a Boolean mask for selecting calibration points
idx = # TODO: Create a Boolean array with m True values and the rest False
np.random.shuffle(idx) # shuffle to randomly select calibration points

# Split the data using the mask
smx_cal, smx_te = # TODO: Split smx into calibration and test sets using idx
labels_cal, labels_te = # TODO: Split labels into calibration and test sets using idx

In [None]:
# EXERCISE: Compute the nonconformity scores for the calibration set
# Hint: For each calibration point, the score is 1 minus the softmax score for the true class
S_cal = # TODO:

# Sort the scores
S_cal = np.sort(S_cal)

# EXERCISE: Find the threshold (quantile) that ensures the desired coverage
qhat = # TODO: Find the appropriate quantile of S_cal to ensure 1-alpha coverage

print(f"Threshold (qhat): {qhat:.4f}")

Let's visualize the distribution of nonconformity scores and our chosen threshold:

In [None]:
# Visualize the distribution of nonconformity scores
plt.figure(figsize=(10, 5))
plt.hist(S_cal, bins=30, alpha=0.7)
plt.axvline(qhat, color='red', linestyle='--', label=f'Threshold (qhat = {qhat:.3f})')
plt.xlabel('Nonconformity Score (1 - softmax score of true class)')
plt.ylabel('Frequency')
plt.title('Distribution of Nonconformity Scores in Calibration Set')
plt.legend()
plt.show()

Now, applying the same idea as discussed in the previous class, we have that for a test image $X'$ with *unknown* label $Y'$, $$\mathbb{P}\left\{S' = 1-\hat{f}(X')_{Y'} \leq \hat{q}_{1-\alpha} \right\} \geq 1-\alpha.$$ Therefore, the set $$\hat{C}(x) = \left\{y: \hat{f}(x)_y \geq 1-\hat{q}_{1-\alpha}\right\}$$ must satisfy $$\mathbb{P}\left\{Y' \in \hat{C}(X')\right\} \geq 1-\alpha.$$ This is how we construct prediction sets in this case.

In [None]:
# EXERCISE: Construct the prediction sets for the test data
Chat = # TODO: Create a Boolean matrix where entry (i,j) is True if class j is in the prediction set for test point i

In [None]:
# EXERCISE: Check the empirical coverage
# Hint: The true label is covered if the corresponding entry in Chat is True
empirical_coverage = # TODO:
print(f"The empirical coverage is: {empirical_coverage:.4f}")
print(f"The target coverage was: {1-alpha:.4f}")

In [None]:
# Analyze the sizes of prediction sets
set_sizes = Chat.sum(axis=1)
avg_size = set_sizes.mean()
plt.figure(figsize=(10, 5))
plt.hist(set_sizes, bins=30, alpha=0.7)
plt.axvline(avg_size, color='red', linestyle='--', label=f'Average size: {avg_size:.2f}')
plt.xlabel('Prediction Set Size')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction Set Sizes')
plt.legend()
plt.show()
print(f"Average prediction set size: {avg_size:.2f} out of 1000 possible classes")
print(f"Min set size: {set_sizes.min()}, Max set size: {set_sizes.max()}")

Now, let's take a look at some examples and see what conformal prediction gave us

In [None]:
# Load the human-readable labels
with open('../data/imagenet/human_readable_labels.json') as f:
    label_strings = np.array(json.load(f))

# EXERCISE: Complete the function to display an example image and its prediction set
def display_example(img_index):
    # Load the image
    img_path = f'../data/imagenet/examples/{img_index}.JPEG' 
    img = imread(img_path)
    
    # Get the true label and prediction set
    true_label = # TODO: Get the true label for this image
    prediction_set = # TODO: Determine which classes are in the prediction set
    
    # Display the image
    plt.figure()
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    
    # Print information
    print("The true label is:", str(label_strings[true_label]))
    print(f"Prediction set size: {np.sum(prediction_set)} classes")
    print("The prediction set includes: %s" % ", ".join(map(str, list(label_strings[prediction_set]))))
    
    # Return whether the true label is in the prediction set
    return prediction_set[true_label]

# Display a few random examples
example_paths = os.listdir('../data/imagenet/examples')
for i in range(5):
    rand_path = np.random.choice(example_paths)
    img_index = int(rand_path.split('.')[0])
    covered = display_example(img_index)
    print(f"True label covered: {covered}")
    print("-" * 80)

## EXERCISE: Experiment with different alpha values

Try changing alpha to different values (e.g., 0.01, 0.05, 0.2) and observe how it affects:
1. The threshold (qhat)
2. The empirical coverage
3. The sizes of the prediction sets

Complete the function below to rerun the conformal prediction with a new alpha value:

In [None]:
def run_conformal_prediction(alpha, m=1000):
    # Split the data
    idx = np.array([1] * m + [0] * (smx.shape[0]-m)) > 0
    np.random.shuffle(idx)
    smx_cal, smx_te = smx[idx,:], smx[~idx,:]
    labels_cal, labels_te = labels[idx], labels[~idx]
    
    # Compute nonconformity scores
    S_cal = 1 - smx_cal[np.arange(m), labels_cal]
    S_cal = np.sort(S_cal)
    
    # Find the threshold
    qhat = # TODO: Calculate the threshold
    
    # Construct prediction sets
    Chat = # TODO: Construct prediction sets
    
    # Check empirical coverage
    empirical_coverage = # TODO: Calculate the empirical coverage
    
    # Calculate set sizes
    set_sizes = Chat.sum(axis=1)
    
    return {
        'alpha': alpha,
        'target_coverage': 1-alpha,
        'qhat': qhat,
        'empirical_coverage': empirical_coverage,
        'avg_set_size': set_sizes.mean(),
        'min_set_size': set_sizes.min(),
        'max_set_size': set_sizes.max()
    }

# Try different alpha values
results = []
for alpha in [0.01, 0.05, 0.1, 0.2]:
    result = run_conformal_prediction(alpha)
    results.append(result)
    print(f"Alpha: {alpha}, Target coverage: {1-alpha:.2f}")
    print(f"Empirical coverage: {result['empirical_coverage']:.4f}")
    print(f"Average set size: {result['avg_set_size']:.2f}")
    print("-" * 40)

# Plot the relationship between coverage and set size
plt.figure(figsize=(10, 5))
coverages = [r['target_coverage'] for r in results]
set_sizes = [r['avg_set_size'] for r in results]
plt.plot(coverages, set_sizes, 'o-')
plt.xlabel('Target Coverage (1-alpha)')
plt.ylabel('Average Prediction Set Size')
plt.title('Trade-off between Coverage and Prediction Set Size')
plt.grid(True)
plt.show()

## Summary

In this tutorial, we demonstrated how to apply conformal prediction to image classification:

1. **Problem Setup**: We worked with pre-trained ResNet152 model outputs on ImageNet, which gave us softmax scores for 1000 classes.

2. **Nonconformity Score**: We used a simple nonconformity score: 1 minus the softmax score for the true class, representing the model's estimated probability that the image does not belong to its true class.

3. **Calibration Process**: We:
   - Split the data into calibration and test sets
   - Computed nonconformity scores for the calibration set
   - Found the threshold (quantile) that ensures the desired coverage level

4. **Prediction Sets**: For each test image, we included all classes with softmax scores high enough to meet our threshold criterion.

5. **Coverage Guarantee**: We verified that our prediction sets achieved the desired coverage rate empirically.

6. **LABEL Method**: This approach (Least-Ambiguous with Bounded Error Levels) is simple but effective, producing prediction sets that are guaranteed to contain the true label with high probability.

7. **Limitations**: The method doesn't adapt to the difficulty of individual examples, which more advanced conformal prediction methods can address.

The key benefit of conformal prediction is that it provides statistically valid uncertainty quantification, regardless of the underlying model. This allows us to make reliable predictions with guaranteed error rates, which is essential in high-stakes applications.

## Discussion Questions

1. How does the size of the prediction sets relate to the model's confidence in its predictions? How would you expect the prediction set sizes to change if we used a less accurate model?

2. What happens to the prediction sets as we change the coverage level (1-alpha)? What are the trade-offs involved in choosing different values of alpha?

3. In what real-world applications would having prediction sets (rather than single predictions) be particularly valuable? When might it be less useful?

4. The nonconformity score we used (1 - softmax score for the true class) doesn't adapt to the difficulty of each example. Can you think of alternative nonconformity scores that might better account for example difficulty?

5. How does the calibration set size affect the properties of our prediction sets? What would happen if we used a very small or very large calibration set?

6. Conformal prediction is model-agnostic. How would the process change if we wanted to use a different base model (e.g., ViT instead of ResNet)?

7. How does conformal prediction compare to other uncertainty quantification methods you might be familiar with (e.g., Bayesian methods, ensemble methods)?

8. Can you think of ways to visualize or interpret the prediction sets to gain insights about the model's behavior and limitations?