In [23]:
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

In [3]:
DATASETS_PATH = "datasets"

In [120]:
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
        ),  # normalize the cifar10 images
    ]
)

test_dataset = datasets.CIFAR10(
    root=DATASETS_PATH,
    train=False,
    download=True,
    transform=test_transform,
)

# Get the indices of the test_dataset
indices = list(range(len(test_dataset)))

# Split the indices into test_indices and validation_indices
test_indices, validation_indices = train_test_split(
    indices, test_size=0.5, random_state=42
)

# Create validation and test datasets
validation_dataset = Subset(test_dataset, validation_indices)
test_dataset = Subset(test_dataset, test_indices)

calibration_dl = DataLoader(
    validation_dataset, batch_size=256, shuffle=False, pin_memory=True
)
test_dl = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)

Files already downloaded and verified


In [121]:
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet18_cifar(num_classes: int, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

In [133]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [134]:
resnet18 = resnet18_cifar(10).to(device)
resnet18.load_state_dict(torch.load("resnet18_cifar10.pth"))

<All keys matched successfully>

In [135]:
# test the accuracy
resnet18.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_dl:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the model on the test images: {100 * correct / total}%")

Accuracy of the model on the test images: 94.48%


In [136]:
# calibrate the model
conformity_scores = []

for images, labels in calibration_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)
    scores = 1 - predicted_prob[range(len(labels)), labels]

    conformity_scores.extend(scores.detach().cpu().numpy())

In [137]:
print(conformity_scores[:5])

[0.00055629015, 0.0005043745, 0.4528237, 0.00020855665, 0.3005407]


In [138]:
COVERAGE = 0.96

quantile = COVERAGE * ((len(calibration_dl.dataset) + 1) / len(calibration_dl.dataset))
print(f"Quantile: {quantile}")
threshold = np.quantile(conformity_scores, quantile)
print(f"Threshold: {threshold}")

Quantile: 0.9601919999999999
Threshold: 0.8349157995147701


In [None]:
def calculate_coverage(scores, threshold):
    return np.mean(scores > threshold)

In [140]:
# test the conformal prediction sets

classes_dict = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

resnet18.eval()
for images, labels in test_dl:
    images, labels = images.to(device), labels.to(device)
    predicted_prob = resnet18(images).softmax(dim=1)
    scores = 1 - predicted_prob

    sets = scores <= threshold
    correct = 0
    total = 0
    for i in range(len(labels)):
        total += 1
        if sets[i, labels[i]]:
            correct += 1

print(f"Coverage of the model on the test images: {100 * correct / total:.2f}%")

Coverage of the model on the test images: 96.32%
