This notebook trains a robust classifier against adversarial examples and implements an approach that achieves the highest accuracy for clean, as well as adversarial examples created using FGSM and PGD based on [Assignment 3](https://github.com/sprintml/tml_2024/blob/main/Assignment3.pdf) of course on Trustworthy Machine Learning offered during Summer Semester 2024.

In [None]:
import requests
import torch
import torchvision
import os
import json
import io
import sys
import base64
import torch.nn as nn
import numpy as np
import pickle
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from torch.autograd import Variable
from typing import Tuple
from torchvision.models import resnet50

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/drive/MyDrive")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
cwd = os.getcwd()
print('cwd: ', cwd)

class TaskDataset(Dataset):
    def __init__(self, transform=None):

        self.ids = []
        self.imgs = []
        self.labels = []

        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

cwd:  /content/drive/MyDrive


Loading the dataset and applying transformation

In [None]:
data: TaskDataset = torch.load("./Train.pt", map_location="cpu")

In [None]:
transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x.convert("RGB")),
        transforms.ToTensor(),
    ]
)

In [None]:
data.transform = transform

In [None]:
# The length of the provided dataset is 100000, so we split it into train and validation datasets
train_size = 90000
val_size = 10000

train_dataset, val_dataset = random_split(data, [train_size, val_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

We make use of ensemble method and create four instances of Resnet50 model in order to create a more robust model. The intuition behind this is that, an adversarial example that fools one model might not always fool the other one. Thus, combining the predictions of all three trained models can help in better generalization.

In [None]:
model1 = torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model1.fc = nn.Linear(model1.fc.in_features, 10)
model1 = model1.to(device)

model2 = torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model2.fc = nn.Linear(model2.fc.in_features, 10)
model2 = model2.to(device)

model3 = torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model3.fc = nn.Linear(model3.fc.in_features, 10)
model3 = model3.to(device)

model4 = torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model4.fc = nn.Linear(model4.fc.in_features, 10)
model4 = model4.to(device)

ensemble_model_set = [model1, model2, model3, model4]

In [None]:
optimizer_model1 = optim.SGD(model1.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
optimizer_model2 = optim.SGD(model2.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
optimizer_model3 = optim.SGD(model3.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
optimizer_model4 = optim.SGD(model4.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

In the following block, we define a student model. our idea is to transfer(distill) the knowledge of the ensemble models into the student model. By doing this, we also make the student model more robust.

In [None]:
student_model = torchvision.models.resnet50(weights=None)
student_model.fc = nn.Linear(student_model.fc.in_features, 10)
student_model = student_model.to(device)

optimizer_student_model = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

In the following block, FGSM and PGD attacks are defined

In [None]:
def fgsm_attack(model, images, labels, epsilon):
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    images.requires_grad = True
    outputs = model(images)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    model.zero_grad()
    loss.backward()
    data_grad = images.grad.data
    perturbed_image = images + epsilon * data_grad.sign()
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
def pgd_attack(model, images, labels, epsilon=0.031, alpha=0.007, iters=5):
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    original_images = images.clone().detach()
    for i in range(iters):
        images.requires_grad = True
        outputs = model(images)
        model.zero_grad()
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        adv_images = images + alpha * images.grad.sign()
        eta = torch.clamp(adv_images - original_images, min=-epsilon, max=epsilon)
        images = torch.clamp(original_images + eta, min=0, max=1).detach_()
    return images

The following distillation loss function combines the cross entropy loss with KL divergence based distillation loss.

In [None]:
def distillation_loss(output, labels, teacher_outputs, T, alpha):
    loss = nn.KLDivLoss()(F.log_softmax(output / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) * (alpha * T * T) + \
           F.cross_entropy(output, labels) * (1. - alpha)
    return loss

In [None]:
num_epochs = 10
epsilon = 0.050
alpha = 0.007
pgd_iters = 5
T = 2.0
alpha_distillation = 0.7

The following block does the ensemble training, obtains the ensemble outputs for clean as well as adversarial images generated by applying FGSM and PGD attacks, then calculates the distillation loss for all three sets of images (clean, FGSM and PGD), combines the obtained losses and finally saves the trained student model.

In [None]:
for epoch in range(num_epochs):
    for model, optimizer in zip(ensemble_model_set, [optimizer_model1, optimizer_model2, optimizer_model3, optimizer_model4]):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (ids, images, labels) in enumerate(train_loader):
          images, labels = images.to(device), labels.to(device)

          # Generating adversarial examples by FGSM and PGD attacks
          fgsm_images = fgsm_attack(student_model, images, labels, epsilon)
          pgd_images = pgd_attack(student_model, images, labels, epsilon, alpha, pgd_iters)

          optimizer_student_model.zero_grad()

          # Get ensemble outputs
          with torch.no_grad():
              ensemble_outputs = sum([model(images) for model in ensemble_model_set]) / len(ensemble_model_set)
              ensemble_outputs_fgsm = sum([model(fgsm_images) for model in ensemble_model_set]) / len(ensemble_model_set)
              ensemble_outputs_pgd = sum([model(pgd_images) for model in ensemble_model_set]) / len(ensemble_model_set)

          # Calculate distillation loss
          outputs = student_model(images)
          loss = distillation_loss(outputs, labels, ensemble_outputs, T, alpha_distillation)

          outputs_fgsm = student_model(fgsm_images)
          loss_fgsm = distillation_loss(outputs_fgsm, labels, ensemble_outputs_fgsm, T, alpha_distillation)

          outputs_pgd = student_model(pgd_images)
          loss_pgd = distillation_loss(outputs_pgd, labels, ensemble_outputs_pgd, T, alpha_distillation)

          # Combine all three losses
          total_loss = loss + loss_fgsm + loss_pgd
          total_loss.backward()
          optimizer_student_model.step()

          running_loss += total_loss.item()
          _, predicted = outputs.max(1)
          total += labels.size(0)
          correct += predicted.eq(labels).sum().item()

          if batch_idx % 100 == 0:
           print('Train Epoch: {} [{}]\tLoss: {:.6f}\tAccuracy: {:.6f}%'.format(
               epoch + 1, batch_idx, loss.item(), 100.*correct/total))

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}, Accuracy: {100.*correct/total:.2f}%')

    # Validation phase
    student_model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch_idx, (ids, images, labels) in enumerate(val_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = student_model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_accuracy = 100. * val_correct / val_total
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

torch.save(student_model.state_dict(), 'robust_student_model10.pt')

The following block helps to obtain the clean as well as adversarial accuracy against FGSM and PGD on the validation dataset.

In [None]:
def evaluate(model, loader, attack=None, epsilon=None):
    model.eval()
    correct = 0
    total = 0
    for batch_idx, (ids, images, labels) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)

        if attack:
            images = attack(model, images, labels, epsilon)

        outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

clean_accuracy = evaluate(student_model, val_loader)
fgsm_accuracy = evaluate(student_model, val_loader, fgsm_attack, epsilon=0.050)
pgd_accuracy = evaluate(student_model, val_loader, pgd_attack, epsilon=0.050)

print(f'Clean accuracy: {clean_accuracy:.2f}%')
print(f'FGSM accuracy: {fgsm_accuracy:.2f}%')
print(f'PGD accuracy: {pgd_accuracy:.2f}%')

Submission of the model and obtaining clean and adversarial accuracy on evaluation dataset

In [None]:
#### Tests ####
# (these are being ran on the eval endpoint for every submission)

allowed_models = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
}
with open("./robust_student_model10.pt", "rb") as f:
    try:
        model: torch.nn.Module = allowed_models["resnet50"](weights=None)
        model.fc = torch.nn.Linear(model.fc.weight.shape[1], 10)
    except Exception as e:
        raise Exception(
            f"Invalid model class, {e=}, only {allowed_models.keys()} are allowed",
        )
    try:
        state_dict = torch.load(f, map_location=torch.device("cpu"))
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        out = model(torch.randn(1, 3, 32, 32))
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")

    assert out.shape == (1, 10), "Invalid output shape"


# Send the model to the server
response = requests.post("http://34.71.138.79:9090/robustness", files={"file": open("./robust_student_model10.pt", "rb")}, headers={"token": "92593601", "model-name": "resnet50"})

# Should be 400, the clean accuracy is too low
print(response.json())