In [1]:
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from construct_model import construct_model, register_models
from constants import C10_CLASSES_DICT, C100_CLASSES_DICT

In [2]:
register_models()

In [3]:
DATASETS_PATH = "datasets"
DATASET = "CIFAR100"

In [4]:
c10_test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
        ),  # normalize the cifar10 images
    ]
)

c10_test_dataset = datasets.CIFAR10(
    root=DATASETS_PATH,
    train=False,
    download=True,
    transform=c10_test_transform,
)

Files already downloaded and verified


In [5]:
# cifar 100
c100_test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
        ),  # normalize the cifar100 images
    ]
)

c100_test_dataset = datasets.CIFAR100(
    root=DATASETS_PATH,
    train=False,
    download=True,
    transform=c100_test_transform,
)

Files already downloaded and verified


In [6]:
if DATASET == "CIFAR100":
    test_dataset = c100_test_dataset
    classes_dict = C100_CLASSES_DICT
    checkpoint_path = "resnet18_cifar100.pth"
    model_name = "resnet18_cifar"
    num_of_classes = 100
elif DATASET == "CIFAR10":
    test_dataset = c10_test_dataset
    classes_dict = C10_CLASSES_DICT
    checkpoint_path = "resnet18_cifar10.pth"
    model_name = "resnet18_cifar"
    num_of_classes = 10

In [7]:
# Get the indices of the test_dataset
indices = list(range(len(test_dataset)))

# Split the indices into test_indices and validation_indices
test_indices, validation_indices = train_test_split(
    indices, test_size=0.5, random_state=42, stratify=test_dataset.targets
)

# Create validation and test datasets
validation_dataset = Subset(test_dataset, validation_indices)
test_dataset = Subset(test_dataset, test_indices)

print(f"Size of validation dataset: {len(validation_dataset)}")

calibration_dl = DataLoader(
    validation_dataset, batch_size=256, shuffle=False, pin_memory=True
)
test_dl = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)

Size of validation dataset: 5000


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
resnet18 = construct_model(
    model_name, num_classes=num_of_classes, checkpoint_path=checkpoint_path
).to(device)

In [10]:
# Initialize counters
class_correct = [0 for _ in range(num_of_classes)]
class_total = [0 for _ in range(num_of_classes)]

resnet18.eval()
with torch.no_grad():
    for images, labels in test_dl:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18(images)
        _, predicted = torch.max(outputs.data, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

# Print accuracy for each class
for i in range(num_of_classes):
    print(
        f"Accuracy of {classes_dict[i]}: {100 * class_correct[i] / class_total[i]:.2f}%"
    )

Accuracy of apple: 90.00%
Accuracy of aquarium_fish: 86.00%
Accuracy of baby: 72.00%
Accuracy of bear: 52.00%
Accuracy of beaver: 66.00%
Accuracy of bed: 84.00%
Accuracy of bee: 66.00%
Accuracy of beetle: 72.00%
Accuracy of bicycle: 94.00%
Accuracy of bottle: 90.00%
Accuracy of bowl: 60.00%
Accuracy of boy: 38.00%
Accuracy of bridge: 86.00%
Accuracy of bus: 60.00%
Accuracy of butterfly: 72.00%
Accuracy of camel: 84.00%
Accuracy of can: 74.00%
Accuracy of castle: 82.00%
Accuracy of caterpillar: 68.00%
Accuracy of cattle: 62.00%
Accuracy of chair: 96.00%
Accuracy of chimpanzee: 86.00%
Accuracy of clock: 78.00%
Accuracy of cloud: 94.00%
Accuracy of cockroach: 82.00%
Accuracy of couch: 66.00%
Accuracy of crab: 52.00%
Accuracy of crocodile: 62.00%
Accuracy of cup: 88.00%
Accuracy of dinosaur: 78.00%
Accuracy of dolphin: 68.00%
Accuracy of elephant: 78.00%
Accuracy of flatfish: 64.00%
Accuracy of forest: 64.00%
Accuracy of fox: 82.00%
Accuracy of girl: 62.00%
Accuracy of hamster: 90.00%
Accu

In [11]:
def calculate_class_coverage(
    model: torch.nn.Module,
    test_dl: DataLoader,
    class_thresholds: dict[int, float],
    num_of_classes: int,
    classes_dict: dict[int, str],
):
    class_correct = [0 for _ in range(num_of_classes)]
    class_total = [0 for _ in range(num_of_classes)]
    set_sizes = []

    model.eval()
    for images, labels in test_dl:
        images, labels = images.to(device), labels.to(device)
        predicted_prob = model(images).softmax(dim=1)
        scores = 1 - predicted_prob

        sets = scores <= torch.tensor(
            [class_thresholds[i] for i in range(num_of_classes)], device=device
        )
        for i in range(len(labels)):
            label = labels[i].item()
            class_total[label] += 1
            if sets[i, label]:
                class_correct[label] += 1
        set_sizes.extend(sets.sum(dim=1).cpu().numpy())

    # Print coverage for each class
    for i in range(num_of_classes):
        if class_total[i] > 0:
            print(
                f"Coverage of {classes_dict[i]}: {100 * class_correct[i] / class_total[i]:.2f}%"
            )
    print(f"Average coverage: {100 * np.sum(class_correct) / np.sum(class_total):.2f}%")
    print(f"Average set size: {np.mean(set_sizes):.2f}")

In [12]:
# calibrate the model
conformity_scores = []

for images, labels in calibration_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)
    scores = 1 - predicted_prob[range(len(labels)), labels]

    conformity_scores.extend(scores.detach().cpu().numpy())

In [13]:
print(conformity_scores[:5])

[0.0017024875, 0.00048708916, 0.001824975, 0.2402181, 0.025772631]


In [14]:
COVERAGE = 0.95

quantile = COVERAGE * ((len(calibration_dl.dataset) + 1) / len(calibration_dl.dataset))
print(f"Quantile: {quantile}")
threshold = np.quantile(conformity_scores, quantile)
print(f"Threshold: {threshold}")

single_threshold = {i: threshold for i in range(num_of_classes)}

Quantile: 0.95019
Threshold: 0.9948265487289428


In [15]:
calculate_class_coverage(
    resnet18, test_dl, single_threshold, num_of_classes, classes_dict
)

Coverage of apple: 100.00%
Coverage of aquarium_fish: 98.00%
Coverage of baby: 90.00%
Coverage of bear: 88.00%
Coverage of beaver: 94.00%
Coverage of bed: 94.00%
Coverage of bee: 94.00%
Coverage of beetle: 98.00%
Coverage of bicycle: 100.00%
Coverage of bottle: 96.00%
Coverage of bowl: 90.00%
Coverage of boy: 88.00%
Coverage of bridge: 96.00%
Coverage of bus: 90.00%
Coverage of butterfly: 92.00%
Coverage of camel: 94.00%
Coverage of can: 98.00%
Coverage of castle: 94.00%
Coverage of caterpillar: 90.00%
Coverage of cattle: 86.00%
Coverage of chair: 98.00%
Coverage of chimpanzee: 96.00%
Coverage of clock: 94.00%
Coverage of cloud: 98.00%
Coverage of cockroach: 98.00%
Coverage of couch: 94.00%
Coverage of crab: 98.00%
Coverage of crocodile: 94.00%
Coverage of cup: 96.00%
Coverage of dinosaur: 94.00%
Coverage of dolphin: 90.00%
Coverage of elephant: 98.00%
Coverage of flatfish: 90.00%
Coverage of forest: 90.00%
Coverage of fox: 96.00%
Coverage of girl: 86.00%
Coverage of hamster: 96.00%
Co

## Class balanced conformal prediction

In [16]:
# Initialize conformity scores list for each class
class_conformity_scores = {i: [] for i in range(num_of_classes)}

for images, labels in calibration_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)

    for i in range(len(labels)):
        class_conformity_scores[labels[i].item()].append(
            1 - predicted_prob[i, labels[i]].item()
        )

In [17]:
CLASS_THRESHOLD = 0.95

class_quantiles = {
    i: CLASS_THRESHOLD
    * ((len(class_conformity_scores[i]) + 1) / len(class_conformity_scores[i]))
    for i in range(num_of_classes)
}

class_thresholds = {
    i: np.quantile(class_conformity_scores[i], class_quantiles[i])
    for i in range(num_of_classes)
}

print(class_thresholds)

{0: 0.9659874893398956, 1: 0.8374806193709374, 2: 0.993297046628315, 3: 0.9968778236630605, 4: 0.9989432904410059, 5: 0.9593531790254638, 6: 0.9718274052161724, 7: 0.9993298088825541, 8: 0.9527631902229041, 9: 0.9931393048067112, 10: 0.99926412077251, 11: 0.998794112010859, 12: 0.986573983468581, 13: 0.9972474765734515, 14: 0.997645575151546, 15: 0.9625182339809836, 16: 0.9974929474468809, 17: 0.964166000502184, 18: 0.9967109735913109, 19: 0.994714277726598, 20: 0.9981684394085314, 21: 0.9730360584463924, 22: 0.9925736276605167, 23: 0.9959616230572574, 24: 0.9932541564456188, 25: 0.9968068491402082, 26: 0.9943388196822344, 27: 0.9886328814402222, 28: 0.9924112147563137, 29: 0.9985804146205192, 30: 0.9953321123765781, 31: 0.9950303460296709, 32: 0.989819327079691, 33: 0.9936385888806545, 34: 0.9903942774292082, 35: 0.9985304254308576, 36: 0.9892522467491217, 37: 0.9851930956458673, 38: 0.9958755964385345, 39: 0.9919543799958191, 40: 0.9961062597543933, 41: 0.9663613795004785, 42: 0.9974

In [18]:
calculate_class_coverage(
    resnet18, test_dl, class_thresholds, num_of_classes, classes_dict
)

Coverage of apple: 96.00%
Coverage of aquarium_fish: 90.00%
Coverage of baby: 90.00%
Coverage of bear: 92.00%
Coverage of beaver: 94.00%
Coverage of bed: 92.00%
Coverage of bee: 88.00%
Coverage of beetle: 100.00%
Coverage of bicycle: 96.00%
Coverage of bottle: 94.00%
Coverage of bowl: 92.00%
Coverage of boy: 98.00%
Coverage of bridge: 96.00%
Coverage of bus: 92.00%
Coverage of butterfly: 94.00%
Coverage of camel: 94.00%
Coverage of can: 100.00%
Coverage of castle: 88.00%
Coverage of caterpillar: 94.00%
Coverage of cattle: 86.00%
Coverage of chair: 100.00%
Coverage of chimpanzee: 92.00%
Coverage of clock: 94.00%
Coverage of cloud: 98.00%
Coverage of cockroach: 98.00%
Coverage of couch: 98.00%
Coverage of crab: 98.00%
Coverage of crocodile: 94.00%
Coverage of cup: 96.00%
Coverage of dinosaur: 98.00%
Coverage of dolphin: 90.00%
Coverage of elephant: 98.00%
Coverage of flatfish: 86.00%
Coverage of forest: 90.00%
Coverage of fox: 94.00%
Coverage of girl: 92.00%
Coverage of hamster: 96.00%
C