General requirements.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install advertorch

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random
from advertorch.attacks import LinfPGDAttack
from advertorch.context import ctx_noparamgrad_and_eval

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18(pretrained=True)
print(device)

cuda:0


Save and load functions for the model.

In [None]:
PATH = './cifar_net.pth'
torch.save(model.state_dict(), PATH)

In [None]:
PATH = './cifar_net.pth'
model.load_state_dict(torch.load(PATH))

Load the dataset required.

In [None]:
# normally 50
batch_size = 50

root_dir = './data'
download_bool = True

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root=root_dir, train=True, download=download_bool, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=root_dir, train=False, download=download_bool, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Train the network.

In [None]:
report_interval = 10000 // batch_size

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

model.train()
rotate_train = True
rotate_label = True
adv_train = False

if (adv_train):
  adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.01)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        if (rotate_train and random.random() > 0.4):
          degree = random.randint(0, 3)
          inputs = torchvision.transforms.functional.rotate(inputs, degree * 90)
          if (rotate_label):
            #original_labels = labels
            #labels[:] = degree + 10
            labels = torch.reshape(labels, (-1, 2))
            labels[:, 1] = degree + 10
            labels = torch.reshape(labels, (-1,))
        
        if (adv_train and random.random() > 0.7):
          with ctx_noparamgrad_and_eval(model):
            inputs = adversary.perturb(inputs, labels)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        #if (rotate_train and rotate_label):
          #loss2 = criterion(outputs, original_labels)
          #loss = (loss + loss2) / 2
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % report_interval == report_interval - 1:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / report_interval))
            running_loss = 0.0

print('Finished Training')

PATH = './cifar_net-temp.pth'
torch.save(model.state_dict(), PATH)

Test set evaluation.

In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
correct = 0
total = 0
attack_suc = 0
attack_total = 0

model.to(device)
model.eval()

attack = True

if (attack):
  adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.01)

for data in testloader:
    images, labels = data[0].to(device), data[1].to(device)
    with torch.no_grad():
      outputs = model(images)
    _, predicted = torch.max(outputs[:, :10], 1)
    c = (predicted == labels).squeeze()
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

    if (attack):
      perturb_x = adversary.perturb(images, labels)
      perturb_y = model(perturb_x)
      _, perturb_labels = torch.max(perturb_y, 1)

    for i in range(batch_size):
        if (attack and predicted[i] == labels[i]):
          attack_total += 1
          if (perturb_labels[i] != labels[i]):
            attack_suc += 1
        label = labels[i]
        class_correct[label] += c[i].item()
        class_total[label] += 1

    if (attack):
      if (attack_total > 1000):
        attack = False
      print('Attack number and success rate: %d %d %%' % (attack_total, 100 * attack_suc / attack_total))

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))


for i in range(10):
  if (class_total[i] == 0):
    print('Accuracy of %5s : 0 %%' % (classes[i]))
  else:
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))