#The first three cells of this notebook are the beginning of a solution to get a better version of ResNet18 (and VGG19) for the CIFAR-10 dataset. I (Jack) didn't finish this, so please continue on. You should be able to run these cells and generate a state_dict for any image classifier you'd like.

#I have already trained ResNet18 up to above 85% accuracy on CIFAR-10. I included the checkpoint in the repo. We just need someone to train VGG19 (the model we're going to transfer our attack to) for about 75 epochs

#Before running our experiments, we should make sure we can attack this version of ResNet. You may need to modify the cells below the first three cells in order to get the attack to work with this new version of ResNet (for example, the CIFAR-10 datasets used to train and validate the ResNet in the below 3 cells is different from the CIFAR-10 datasets I use in the rest of the notebook. You should modify any later cells to work with this better version of ResNet18. Let me know if you have any questions about the code!

In [None]:
# Clone the git repository for PyTorch models optimized for performance on CIFAR-10
# Reference: https://github.com/kuangliu/pytorch-cifar
!rm -rf /content/pytorch-cifar
!git clone https://github.com/kuangliu/pytorch-cifar.git
%cd /content/pytorch-cifar/

Cloning into 'pytorch-cifar'...
remote: Enumerating objects: 382, done.[K
remote: Total 382 (delta 0), reused 0 (delta 0), pack-reused 382[K
Receiving objects: 100% (382/382), 81.31 KiB | 2.62 MiB/s, done.
Resolving deltas: 100% (198/198), done.
/content/pytorch-cifar


In [None]:
# Run the main function from the above repo to train a model to >80% accuracy on CIFAR-10

#####################################################################################################

# IMPORTANT!!! Modify main.py so that the training runs for only 75 epochs. Otherwise, it will train for 200 epochs!
# IMPORTANT!!! For Experiment #1 and Experiment #2, modify main.py so that you train ResNet18
# IMPORTANT!!! For Experiment #3, modify main.py so that you ALSO train VGG19 (will need to run this cell twice, once for ResNet 18 and once for ResNet50)

#####################################################################################################

!python main.py

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100% 170498071/170498071 [00:02<00:00, 70791676.78it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
==> Building model..

Epoch: 0
  self.pid = os.fork()
Saving..

Epoch: 1
Saving..

Epoch: 2
Saving..

Epoch: 3
Traceback (most recent call last):
  File "/content/pytorch-cifar/main.py", line 152, in <module>
    train(epoch)
  File "/content/pytorch-cifar/main.py", line 104, in train
    loss.backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt


In [None]:
# Instantiate our model(s)
from models import *

state_dict = torch.load("/content/pytorch-cifar/checkpoint/VGG19.pth")["net"]

# Reference: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
  name = k[7:] # remove `module.`
  new_state_dict[name] = v

model = VGG("VGG19")
model.load_state_dict(new_state_dict)

<All keys matched successfully>

# The cells below this point stand on their own. Run these cells to do a targeted/untargeted patch attack on a weaker version of ResNet18 (~65% accuracy before the attack). Again, the code in the rest of the notebook may need to be modified in order to work with the stronger version of ResNet18 (definitely the dataset used below should be changed to the dataset used above - they're both CIFAR10, but the dataset above uses different image preprocessing).

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader
import torchvision.transforms.functional as TF
import random
import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
# Load dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

cifar_10 = datasets.CIFAR10(root="./data",
                            train=True,
                            download=True,
                            transform=transform)

# Split into training, val, and test sets
train, test_set = random_split(cifar_10, [40000, 10000])
train_set, val_set = random_split(train, [35000, 5000])

# Define dataloaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

# Load pre-trained model
resnet18 = models.resnet18(weights="DEFAULT")

# ResNet is trained on ImageNet, which has 1000 classes
# So we need to modify the output layer for CIFAR-10, which has 10 classes
resnet18.fc = nn.Linear(resnet18.fc.in_features, 10)

# We also need to modify the input layer to accept CIFAR-10 images
first_layer = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
resnet18.conv1 = first_layer

Files already downloaded and verified


In [None]:
# This function evaluates a model's accuracy on the validation set
# Optionally, one can pass an adversarial patch as an argument to evaluate the model's performance against a patch attack
def eval(model, patch=None, target_class=None):
  # Stats to use to calculate accuracy after the eval loop
  total_correct = 0
  total = 0
  total_target = 0
  # Put model on GPU and switch to eval mode
  model = model.to(device)
  model.eval()
  # Evaluation loop
  with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(val_loader):
      # Put data on GPU
      images = images.to(device)
      if patch is not None:
        images = apply(patch, images)
      labels = labels.to(device)
      # Make predictions
      predictions = model(images)
      predictions = torch.argmax(predictions, dim=1)
      # Update validation accuracy information
      total += len(images)
      num_correct = (predictions == labels).float().sum().item()
      total_correct += num_correct
      if target_class is not None:
        target = torch.zeros(len(images), dtype=torch.long).fill_(target_class).to(device)
        num_target = (predictions == target).float().sum().item()
        total_target += num_target
  # If evaluating the effects of a targeted patch attach, it is nice to see whether or not the model is classifying lots of examples to the target class
  if target_class is not None:
    target_percentage = total_target / total
    print(f"Percentage of samples predicted as target class {target_class}: {100 * target_percentage}")
  # Calculate accuracy
  accuracy = total_correct / total
  return accuracy

In [None]:
# This function is designed to take in a pretrained ResNet model and fine-tune its weights for the CIFAR-10 dataset
# The idea is to fine-tune ResNet for the CIFAR-10 dataset (accuracy should be around 65%) and then degrade that performance via an adversarial patch attack
def fine_tune_for_cifar10(model, num_epochs=10):
  # Put model on GPU and put model in training mode
  model = model.to(device)
  model.train()
  # Define loss function and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
  # Training loop
  for i in range(num_epochs):
    # Stats to use for calculating accuracy
    total_correct = 0
    total = 0
    # Iterate through each batch of data
    for batch_idx, (images, labels) in enumerate(train_loader):
      # Put data on GPU
      images = images.to(device)
      labels = labels.to(device)
      # Make predictions
      predictions = model(images)
      # Calculate loss for the batch
      loss = criterion(predictions, labels)
      # Gradient descent
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      # Update training accuracy information
      total += len(images)
      predictions = torch.argmax(predictions, dim=1)
      num_correct = (predictions == labels).float().sum().item()
      total_correct += num_correct
    # Print training accuracy
    print(f"Epoch {str(i + 1)}: Training accuracy = {str(total_correct / total)}")
    # Print validation accuracy
    print(f"Validation accuracy: {str(eval(model))}")

# Applying the function above to ResNet18
fine_tune_for_cifar10(resnet18)

Epoch 1: Training accuracy = 0.28454285714285715
Validation accuracy: 0.3652
Epoch 2: Training accuracy = 0.4175142857142857
Validation accuracy: 0.4798
Epoch 3: Training accuracy = 0.5023428571428571
Validation accuracy: 0.534


KeyboardInterrupt: 

In [None]:
# Apply patch to a batch of images
def apply(patch, batch_of_images):
  num_images = batch_of_images.shape[0]
  patch_size = patch.shape[1]
  # Iterate through each image in the batch
  for i in range(num_images):
    # Rotate the patch by a random number of degrees
    degree = random.uniform(0, 360)
    patch_rotated = TF.rotate(patch, angle=degree)
    # Randomly choose an (x, y) coordinate on the 32x32 CIFAR-10 image
    # This coordinate will be where the top left corner of the rotated patch goes
    top_left_x = random.randint(0, 31 - patch_size)
    top_left_y = random.randint(0, 31 - patch_size)
    # Apply the randomly rotated patch at the random location
    batch_of_images[i, :, top_left_x:top_left_x+patch_size, top_left_y:top_left_y+patch_size] = patch
  return batch_of_images

In [None]:
# This function fine-tunes an adversarial patch against a provided whitebox model
# Model accuracy against the patch attack is reported at each step
def generate_adversarial_patch(model=resnet18, patch_size=8, target_class=None, num_epochs=10, lr=1e-1, momentum=0.8):
  # Initialize patch to all zeros
  patch = nn.Parameter(torch.zeros(3, patch_size, patch_size), requires_grad=True)
  optimizer = optim.SGD([patch], lr, momentum)
  criterion = nn.CrossEntropyLoss()
  # Optimize the patch
  for i in range(num_epochs):
    print(f"Epoch {str(i + 1)}")
    for batch_idx, (images, labels) in enumerate(train_loader):
      # Put data on the GPU
      images = images.to(device)
      labels = labels.to(device)
      # Apply the patch at a random location and with a random rotation for each image in the batch
      images = apply(patch, images)
      # Make predictions on the patched images
      predictions = resnet18(images)
      # For an untargeted attack, create false labels by incrementing the true labels by 1
      if target_class is None:
        false_labels = (labels + 1) % 10
      # For a targeted attack, set all the false labels to the target class
      else:
        false_labels = torch.zeros(len(images), dtype=torch.long).fill_(target_class).to(device)
      # Tune the patch
      loss = criterion(predictions, false_labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    # See how the patch performs
    print(f"Target class: {target_class}")
    accuracy = eval(model, patch=patch, target_class=target_class)
    print(f"Accuracy: {str(accuracy)}\n")

generate_adversarial_patch(target_class=5)