In [None]:
!pip install git+https://github.com/RobustBench/robustbench.git

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision 
import torch.utils.data as torch_data 
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import random_split
from torch.utils.data import TensorDataset, random_split
from robustbench.data import load_cifar10c
import copy
import json
from collections import defaultdict

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


# Test Time Adaptation with Hard Pseudolabeling

First initialize neural network architecture, we start with the `BasicBlock` class, which consists of two convolutional layers and a shortcut connection.

In [2]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

The `ResNet` class stacks these blocks in layers to create a deep neural network, which will be used to train our source classifier on the CIFAR-10 dataset

In [3]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, feature_maps=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out1 = self.layer1(out)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out = self.layer4(out3)
        out = F.avg_pool2d(out, 4)
        out4 = out.view(out.size(0), -1)
        out = self.linear(out4)

        if feature_maps:
            return out, [out1, out2, out3, out4]
        else:
            return out

To improve training stability, we normalize the input images with mean and standard deviation values specific to the CIFAR-10 dataset, the `Normalized_ResNet` class extends `ResNet` through this normalization

In [4]:
class Normalized_ResNet(ResNet):
    def __init__(self, device="cuda", depth=18, num_classes=10):
        if depth == 18:
            super(Normalized_ResNet, self).__init__(BasicBlock, [2,2,2,2], num_classes)
        elif depth == 50:
            super(Normalized_ResNet, self).__init__(Bottleneck, [3,4,6,3], num_classes)
        elif depth == 26:
            super(Normalized_ResNet, self).__init__(BasicBlock, [3,3,3,3], num_classes)
        else:
            pass 

        self.mu = torch.Tensor([0.4914, 0.4822, 0.4465]).float().view(3, 1, 1).to(device)
        self.sigma = torch.Tensor([0.2023, 0.1994, 0.2010]).float().view(3, 1, 1).to(device)

    def forward(self, x, feature_maps=False):
        x = (x - self.mu) / self.sigma
        return super(Normalized_ResNet, self).forward(x, feature_maps)

We begin by initializing our neural network, as well as our data augmentation stack, which involves random cropping and horizontal flipping of the images

In [5]:
save_path = 'saved_models/pretrained/trained_model.pth'
net = Normalized_ResNet(depth=26)
net.to(device)
# net = torch.nn.DataParallel(net)
cudnn.benchmark = True

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

We now load the CIFAR-10 dataset, as well as take a subset of size 5000 of the dataset as a proof of concept, creating training and testing sets of 5000 images each

In [7]:
train_data = torchvision.datasets.CIFAR10("./data", True, transform=train_transform, download=True)
test_data = torchvision.datasets.CIFAR10("./data", False, transform=transforms.Compose([transforms.ToTensor()]), download=True)
subset_size = 5000
train_data, _ = random_split(train_data, [subset_size, len(train_data) - subset_size])
test_data, _ = random_split(test_data, [subset_size, len(test_data) - subset_size])

train_loader = torch_data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=4)
test_loader = torch_data.DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


We now initialize our optimizer, scheduler, and loss function:
- **Optimizer**: We use SGD with momentum, which helps stabilize the training
- **Scheduler**: Cosine Annealing is applied to the learning rate, which gradually reduces it to optimize training speed
- **Loss Function**: Cross-Entropy Loss is used, which is standard for multi-class classification problems

In [7]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
ce_loss = nn.CrossEntropyLoss(reduction='none')

We now train our model on 200 epochs, and for each epoch we:
1. **Train**: Perform forward pass, compute loss, and backpropagate the gradients
2. **Update**: Clip gradients to prevent insatbility and then update model parameters
3. **Evaluate**: Calculate accuracy on the test set at each epoch and save the best performing model

In [8]:
best_acc = 0.0

train_loss_history = []
test_acc_history = []

for epoch in range(200):
    net.train()
    epoch_loss = 0.0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        preds = net(data)
        loss = ce_loss(preds, labels).mean()
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(net.parameters(), 1, norm_type=2.0)
        optimizer.step()

    avg_epoch_loss = epoch_loss / len(train_loader)
    train_loss_history.append(avg_epoch_loss)    
    
    scheduler.step()

    acc = 0.0
    net.eval()
    for batch_idx, (data, labels) in enumerate(test_loader):
        data, labels = data.to(device), labels.to(device)
        preds = net(data)
        acc += (preds.max(1)[1] == labels).float().sum()

    acc = acc / 10000
    test_acc_history.append(acc.item())
    
    print(f"Epoch : {epoch} : Acc : {acc}")

    if acc > best_acc:
        best_acc = acc 
        torch.save({"net": net.state_dict()}, save_path)

Epoch : 0 : Acc : 0.06319999694824219
Epoch : 1 : Acc : 0.12629999220371246
Epoch : 2 : Acc : 0.17409999668598175
Epoch : 3 : Acc : 0.14509999752044678
Epoch : 4 : Acc : 0.17569999396800995
Epoch : 5 : Acc : 0.13989999890327454
Epoch : 6 : Acc : 0.1711999922990799
Epoch : 7 : Acc : 0.19619999825954437
Epoch : 8 : Acc : 0.17109999060630798
Epoch : 9 : Acc : 0.19429999589920044
Epoch : 10 : Acc : 0.23709999024868011
Epoch : 11 : Acc : 0.2556999921798706
Epoch : 12 : Acc : 0.2361999899148941
Epoch : 13 : Acc : 0.24549999833106995
Epoch : 14 : Acc : 0.25849997997283936
Epoch : 15 : Acc : 0.25999999046325684
Epoch : 16 : Acc : 0.24809999763965607
Epoch : 17 : Acc : 0.2662000060081482
Epoch : 18 : Acc : 0.2759000062942505
Epoch : 19 : Acc : 0.23809999227523804
Epoch : 20 : Acc : 0.27799999713897705
Epoch : 21 : Acc : 0.24959999322891235
Epoch : 22 : Acc : 0.2525999844074249
Epoch : 23 : Acc : 0.28360000252723694
Epoch : 24 : Acc : 0.2712000012397766
Epoch : 25 : Acc : 0.27219998836517334
Epo

We then save our test accuracies and previous training loss values to files to track our progress

In [9]:
with open('progress/test_acc_history.txt', 'w') as f:
    for number in test_acc_history:
        f.write(f'{number}\n')
with open('progress/train_loss_history.txt', 'w') as f:
    for number in train_loss_history:
        f.write(f'{number}\n')

Now that we have trained our source classifier, we can benchmark it at test time using the CIFAR-10C dataset, which contains corrupted versions of the CIFAR-10 dataset 

We first begin by initializing our neural net and loading the previously saved weights

In [11]:
net = Normalized_ResNet(depth=26)
checkpoint = torch.load(save_path, map_location=device)
net.load_state_dict(checkpoint['net'])
net.to(device)
net.eval()

# # Remove 'module.' prefix if checkpoint was saved with DataParallel
# checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}

# net.to(device)
# net.load_state_dict(checkpoint)

Normalized_ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0

We then define the subset size of CIFAR-10C that we would like to use, as well as lists of severity levels as well as corruption types that we would like to test on.

In [21]:
subset_size = 5000
severity_list = [5, 4, 3, 2, 1]
corruption_type_list = [
    "gaussian_noise",
    "shot_noise",
    "impulse_noise",
    "defocus_blur",
    "glass_blur",
    "motion_blur",
    "zoom_blur",
    "snow",
    "frost",
    "fog",
    "brightness",
    "contrast",
    "elastic_transform",
    "pixelate",
    "jpeg_compression"
]

We now test our classifier, evaluating its ability to generate pseudolabels for various severities of corruptions in CIFAR-10C.

In [25]:
error_dict = defaultdict(dict)

# Define a subset size
subset_size = 5000

for i, severity in enumerate(severity_list):  # First severity level only for quick test
    for j, corruption_type in enumerate(corruption_type_list):  # First corruption type only
        # Load a limited subset of the CIFAR-10-C data directly
        x_test, y_test = load_cifar10c(subset_size, severity, './data', True, [corruption_type])
        num_classes = 10

        # Wrap data in TensorDataset
        test_dataset = TensorDataset(x_test, y_test)

        # Create DataLoader directly with the subset
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False)

        print("Meta test begin!")
        net_test = copy.deepcopy(net)

        # Use the test DataLoader with the subset
        acc = 0.0
        net_test.eval()
        for x_curr, y_curr in test_loader:
            x_curr, y_curr = x_curr.to(device), y_curr.to(device)
            outputs = net_test(x_curr)
            acc += (outputs.max(1)[1] == y_curr).float().sum()

        acc /= subset_size
        err = 1. - acc
        print(f"error % [{corruption_type}{severity}]: {err:.2%}")
        
        error_dict[corruption_type][severity] = err.item()

Meta test begin!
error % [gaussian_noise5]: 88.86%
Meta test begin!
error % [shot_noise5]: 89.00%
Meta test begin!
error % [impulse_noise5]: 89.48%
Meta test begin!
error % [defocus_blur5]: 88.98%
Meta test begin!
error % [glass_blur5]: 89.04%
Meta test begin!
error % [motion_blur5]: 88.90%
Meta test begin!
error % [zoom_blur5]: 89.12%
Meta test begin!
error % [snow5]: 89.70%
Meta test begin!
error % [frost5]: 89.76%
Meta test begin!
error % [fog5]: 89.76%
Meta test begin!
error % [brightness5]: 89.38%
Meta test begin!
error % [contrast5]: 88.52%
Meta test begin!
error % [elastic_transform5]: 89.26%
Meta test begin!
error % [pixelate5]: 88.66%
Meta test begin!
error % [jpeg_compression5]: 89.12%
Meta test begin!
error % [gaussian_noise4]: 89.04%
Meta test begin!
error % [shot_noise4]: 88.88%
Meta test begin!
error % [impulse_noise4]: 89.04%
Meta test begin!
error % [defocus_blur4]: 88.82%
Meta test begin!
error % [glass_blur4]: 88.92%
Meta test begin!
error % [motion_blur4]: 88.98%
Met

We then write our error rates to an external json, which can be used for other data visualizations.

In [26]:
with open("progress/error_rates.json", "w") as f:
    json.dump(error_dict, f)