This notebook is an alternative approach to train a robust classifier against adversarial examples and implement an approach that achieves the high accuracy for clean, as well as adversarial examples created using FGSM and PGD based on Assignment 3 of course on Trustworthy Machine Learning offered during Summer Semester 2024.

In [None]:
import requests
import torch
import torch.nn as nn
import os
from torchvision import models
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import resnet18, resnet34, resnet50

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/drive/MyDrive")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import requests
import torch
import torch.nn as nn
import numpy as np
import json
import io
import sys
import base64
from torch.utils.data import Dataset
from typing import Tuple
import pickle
import os

cwd = os.getcwd()
print('cwd: ', cwd)

class TaskDataset(Dataset):
    def __init__(self, transform=None):

        self.ids = []
        self.imgs = []
        self.labels = []

        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

cwd:  /content/drive/MyDrive


Loading the dataset and applying transformation

In [None]:
data: TaskDataset = torch.load("./Train.pt", map_location="cpu")

In [None]:
transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x.convert("RGB")),
        transforms.ToTensor(),
    ]
)

In [None]:
data.transform = transform

In [None]:
# The length of the provided dataset is 100000, so we split it into train and validation datasets
train_size = 90000
val_size = 10000

# Split the dataset
train_dataset, val_dataset = random_split(data, [train_size, val_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

In the following block, we define the TRADES loss function referred from these experiments - [Tradeoff-inspired Adversarial Defense via Surrogate-loss minimization](https://github.com/yaodongyu/TRADES)

In [None]:
def trades_loss(model, x_natural, y, optimizer, step_size=0.007, epsilon=0.050, perturb_steps=15, beta=6.0):
    model.eval()
    batch_size = len(x_natural)
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()

    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            logits_adv = model(x_adv)
            logits_natural = model(x_natural)
            loss_kl = F.kl_div(F.log_softmax(model(x_adv), dim=1),
                               F.softmax(model(x_natural), dim=1),
                               reduction='batchmean')

        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        grad = torch.clamp(grad, min=-1, max=1)
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()
    x_adv = x_adv.detach()

    optimizer.zero_grad()
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * F.kl_div(F.log_softmax(model(x_adv), dim=1),
                                                F.softmax(model(x_natural), dim=1),
                                                reduction='batchmean')
    loss = loss_natural + beta * loss_robust
    return loss

Defining a pretrained Resnet50 model

In [None]:
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 10)

state_dict_path = './submission.pt'
state_dict = torch.load(state_dict_path, map_location=torch.device('cpu'))

# Step 5: Apply the state dictionary to the model
model.load_state_dict(state_dict)
model = model.to(device)

In [None]:
# Set the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
criterion = nn.CrossEntropyLoss()

learning_rate = 0.1
num_epochs = 150
epsilon = 0.031
alpha = 0.007
perturb_steps = 10
beta = 6.0

In [None]:
for epoch in range(100):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0
  train_loss = 0.0
  for batch_idx, (ids, images, labels) in enumerate(train_loader):
      images, labels = images.to(device), labels.to(device)
      optimizer.zero_grad()
      loss = trades_loss(model, images, labels, optimizer)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      train_loss += loss.item()

      # Calculate accuracy
      _, predicted = model(images).max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

      # Batchwise accuracy and loss
      if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}]\tLoss: {:.6f}\tAccuracy: {:.6f}%'.format(
                epoch + 1, batch_idx, loss.item(), 100.*correct/total))
  scheduler.step()
  # Epoch accuracy and loss
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss/len(train_loader)}, Accuracy: {100.*correct/total:.2f}%')

  # Validation phase
  model.eval()
  val_loss = 0.0
  val_correct = 0
  val_total = 0
  with torch.no_grad():
      for batch_idx, (ids, images, labels) in enumerate(val_loader):
          images, labels = images.to(device), labels.to(device)
          outputs = model(images)
          loss = criterion(outputs, labels)

          val_loss += loss.item()
          _, predicted = outputs.max(1)
          val_total += labels.size(0)
          val_correct += predicted.eq(labels).sum().item()

  val_accuracy = 100. * val_correct / val_total
  print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

torch.save(model.state_dict(), "submission.pt")

In [None]:
torch.save(model.state_dict(), "submission_new2.pt")

In [None]:
def evaluate_model(model, loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (ids, images, labels) in enumerate(loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

clean_accuracy = evaluate_model(model, val_loader)
print(f'Clean accuracy: {clean_accuracy:.2f}%')

  return F.conv2d(input, weight, bias, self.stride,


Clean accuracy: 96.25%


In [None]:
def pgd_attack(model, images, labels, eps=0.031, alpha=0.007, iters=10):
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    original_images = images.clone().detach()

    for i in range(iters):
        images.requires_grad = True
        outputs = model(images)
        model.zero_grad()
        cost = criterion(outputs, labels).to(device)
        cost.backward()

        adv_images = images + alpha * images.grad.sign()
        eta = torch.clamp(adv_images - original_images, min=-eps, max=eps)
        images = torch.clamp(original_images + eta, min=0, max=1).detach_()

    return images

In [None]:
def fgsm_attack(image, epsilon, data_grad):
    # Collect the sign of the gradients of the input
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon * sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
# Calculate FGSM accuracy
def evaluate_fgsm(model, loader, epsilon):
    correct = 0
    total = 0
    for batch_idx, (ids, images, labels) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)
        images.requires_grad = True
        outputs = model(images)
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()
        data_grad = images.grad.data
        perturbed_data = fgsm_attack(images, epsilon, data_grad)

        outputs = model(perturbed_data)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return 100 * correct / total

fgsm_accuracy = evaluate_fgsm(model, val_loader, epsilon=0.031)
print(f'FGSM accuracy: {fgsm_accuracy:.2f}%')

FGSM accuracy: 42.18%


In [None]:
# Calculate PGD accuracy
def evaluate_pgd(model, loader, epsilon, alpha, iters):
    correct = 0
    total = 0
    for batch_idx, (ids, images, labels) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)
        adv_images = pgd_attack(model, images, labels, epsilon, alpha, iters)
        outputs = model(adv_images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return 100 * correct / total

pgd_accuracy = evaluate_pgd(model, val_loader, epsilon=0.031, alpha=0.007, iters=10)
print(f'PGD accuracy: {pgd_accuracy:.2f}%')

PGD accuracy: 13.39%


In [None]:
# torch.save(model.state_dict(), "submission.pt")

#### Tests ####
# (these are being ran on the eval endpoint for every submission)

allowed_models = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
}
with open("./submission.pt", "rb") as f:
    try:
        model: torch.nn.Module = allowed_models["resnet50"](weights=None)
        model.fc = torch.nn.Linear(model.fc.weight.shape[1], 10)
        # replace_relu_with_silu(model)
    except Exception as e:
        raise Exception(
            f"Invalid model class, {e=}, only {allowed_models.keys()} are allowed",
        )
    try:
        state_dict = torch.load(f, map_location=torch.device("cpu"))
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        out = model(torch.randn(1, 3, 32, 32))
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")

    assert out.shape == (1, 10), "Invalid output shape"


# Send the model to the server
response = requests.post("http://34.71.138.79:9090/robustness", files={"file": open("./submission.pt", "rb")}, headers={"token": "92593601", "model-name": "resnet50"})

# Should be 400, the clean accuracy is too low
print(response.json())

{'clean_accuracy': 0.598, 'fgsm_accuracy': 0.298, 'pgd_accuracy': 0.07166666666666667}
