In [16]:
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

In [17]:
DATASETS_PATH = "datasets"

CLASSES_DICT = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

In [18]:
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
        ),  # normalize the cifar10 images
    ]
)

test_dataset = datasets.CIFAR10(
    root=DATASETS_PATH,
    train=False,
    download=True,
    transform=test_transform,
)

# Get the indices of the test_dataset
indices = list(range(len(test_dataset)))

# Split the indices into test_indices and validation_indices
test_indices, validation_indices = train_test_split(
    indices, test_size=0.5, random_state=42, stratify=test_dataset.targets
)

# Create validation and test datasets
validation_dataset = Subset(test_dataset, validation_indices)
test_dataset = Subset(test_dataset, test_indices)

calibration_dl = DataLoader(
    validation_dataset, batch_size=256, shuffle=False, pin_memory=True
)
test_dl = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)

Files already downloaded and verified


In [19]:
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet18_cifar(num_classes: int, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
resnet18 = resnet18_cifar(10).to(device)
resnet18.load_state_dict(torch.load("resnet18_cifar10.pth"))

<All keys matched successfully>

In [22]:
# test the accuracy
resnet18.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_dl:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the model on the test images: {100 * correct / total}%")

Accuracy of the model on the test images: 94.48%


In [23]:
def calculate_class_coverage(model, test_dl, class_thresholds):
    class_correct = [0 for _ in range(10)]
    class_total = [0 for _ in range(10)]
    set_sizes = []

    model.eval()
    for images, labels in test_dl:
        images, labels = images.to(device), labels.to(device)
        predicted_prob = model(images).softmax(dim=1)
        scores = 1 - predicted_prob

        sets = scores <= torch.tensor(
            [class_thresholds[i] for i in range(10)], device=device
        )
        for i in range(len(labels)):
            label = labels[i].item()
            class_total[label] += 1
            if sets[i, label]:
                class_correct[label] += 1
        set_sizes.extend(sets.sum(dim=1).cpu().numpy())

    # Print coverage for each class
    for i in range(10):
        if class_total[i] > 0:
            print(
                f"Coverage of {CLASSES_DICT[i]}: {100 * class_correct[i] / class_total[i]:.2f}%"
            )
    print(f"Average coverage: {100 * np.sum(class_correct) / np.sum(class_total):.2f}%")
    print(f"Average set size: {np.mean(set_sizes):.2f}")

In [24]:
# calibrate the model
conformity_scores = []

for images, labels in calibration_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)
    scores = 1 - predicted_prob[range(len(labels)), labels]

    conformity_scores.extend(scores.detach().cpu().numpy())

In [25]:
print(conformity_scores[:5])

[0.9940668, 0.00043052435, 0.99718106, 0.00038987398, 0.00040358305]


In [26]:
COVERAGE = 0.95

quantile = COVERAGE * ((len(calibration_dl.dataset) + 1) / len(calibration_dl.dataset))
print(f"Quantile: {quantile}")
threshold = np.quantile(conformity_scores, quantile)
print(f"Threshold: {threshold}")

single_threshold = {i: threshold for i in range(10)}

Quantile: 0.95019
Threshold: 0.6108211415642502


In [27]:
calculate_class_coverage(resnet18, test_dl, single_threshold)

Coverage of airplane: 95.80%
Coverage of automobile: 97.80%
Coverage of bird: 92.60%
Coverage of cat: 89.40%
Coverage of deer: 96.20%
Coverage of dog: 89.80%
Coverage of frog: 96.40%
Coverage of horse: 96.60%
Coverage of ship: 97.20%
Coverage of truck: 95.60%
Average coverage: 94.74%
Average set size: 1.00


## Class balanced conformal prediction

In [28]:
# Initialize conformity scores list for each class
class_conformity_scores = {i: [] for i in range(10)}

for images, labels in calibration_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)

    for i in range(len(labels)):
        class_conformity_scores[labels[i].item()].append(
            1 - predicted_prob[i, labels[i]].item()
        )

In [29]:
CLASS_THRESHOLD = 0.95

class_quantiles = {
    i: CLASS_THRESHOLD
    * ((len(class_conformity_scores[i]) + 1) / len(class_conformity_scores[i]))
    for i in range(10)
}

class_thresholds = {
    i: np.quantile(class_conformity_scores[i], class_quantiles[i]) for i in range(10)
}

print(class_thresholds)

{0: 0.2820496103167527, 1: 0.09507762341499224, 2: 0.9458017223857336, 3: 0.9905364011596888, 4: 0.5570856909364452, 5: 0.871548156116902, 6: 0.7361661282956528, 7: 0.3156715995252131, 8: 0.49610651172994913, 9: 0.2510034379839891}


In [30]:
calculate_class_coverage(resnet18, test_dl, class_thresholds)

Coverage of airplane: 94.40%
Coverage of automobile: 95.60%
Coverage of bird: 95.20%
Coverage of cat: 97.00%
Coverage of deer: 95.80%
Coverage of dog: 92.20%
Coverage of frog: 97.20%
Coverage of horse: 95.60%
Coverage of ship: 96.80%
Coverage of truck: 94.00%
Average coverage: 95.38%
Average set size: 1.05
