In [1]:
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torch.utils.data as td
import random, time
import matplotlib.pyplot as plt
import torchvision
import PIL.Image as Image
from tqdm import tqdm
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.utils as vutils
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
batch_size_cifar = 50

def cifar_loaders(batch_size, shuffle_test=False): 
    data_dir = './data'
    train = datasets.CIFAR10(data_dir, train=True, download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(),
        ]))
    # Once you have downloaded the data by setting download=True, you can
    # change download=True to download=False
    test = datasets.CIFAR10(data_dir, train=False, 
        transform=transforms.Compose([transforms.ToTensor()]))
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size,
        shuffle=shuffle_test, pin_memory=True)
    return train_loader, test_loader

train_cifar_loader, test_cifar_loader = cifar_loaders(batch_size_cifar)

Files already downloaded and verified


In [3]:
from vgg import vgg11_bn
vgg11 = vgg11_bn(pretrained=True).to(device)


In [4]:
def adversarial_update(images,labels, model,loss, epsilon, step_size, n_iter):
  delta = torch.zeros(images.size()).to(device)
  image_mod = (images+delta)
  for i in range(n_iter):
    image_mod.requires_grad = True
    preds = model(image_mod)
    loss_val = loss(preds,labels)
    if i == 0:
      standard_loss = loss_val.item()
    loss_val.backward()
    gradient = torch.sign(image_mod.grad.data)
    delta = delta + step_size*gradient
    delta = delta.clamp(min=-epsilon, max=epsilon)
    image_mod = (images+delta).clamp(min=0,max=1)
  return (image_mod, standard_loss)

In [5]:
vgg11_adv = vgg11_bn(pretrained=True).to(device)

optimizer = optim.Adam(vgg11_adv.parameters(),lr=1e-5)
num_epochs = 256

criterion = nn.CrossEntropyLoss()

loss_standard = np.zeros(num_epochs)
loss_adv = np.zeros(num_epochs)

In [6]:
start_total = time.perf_counter()

for i in range(num_epochs):
  running_loss_adv = 0
  running_loss = 0
  for batch, (data, labels) in enumerate(train_cifar_loader):
    data = data.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    data_mod, standard_loss = adversarial_update(data,labels, vgg11_adv,criterion, epsilon=.0625, step_size=.01, n_iter=7)
    running_loss += standard_loss
    preds = vgg11_adv(data_mod)
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    running_loss_adv += loss.item()
    batches = batch+1
  loss_standard[i] = running_loss/batches
  loss_adv[i] = running_loss_adv/batches
  print("Epoch "+str(i)+": Adversarial Loss - "+str(round(loss_adv[i],3))+", Loss - "+str(round(loss_standard[i],3)))


finish_total = time.perf_counter()

print(f'Total finished in {round((finish_total-start_total)/60, 2)} minutes')


Epoch 0: Adversarial Loss - 2.587, Loss - 1.498
Epoch 1: Adversarial Loss - 2.265, Loss - 1.245
Epoch 2: Adversarial Loss - 2.184, Loss - 1.095
Epoch 3: Adversarial Loss - 2.122, Loss - 0.984
Epoch 4: Adversarial Loss - 2.054, Loss - 0.871
Epoch 5: Adversarial Loss - 1.981, Loss - 0.767
Epoch 6: Adversarial Loss - 1.894, Loss - 0.661
Epoch 7: Adversarial Loss - 1.807, Loss - 0.566
Epoch 8: Adversarial Loss - 1.709, Loss - 0.482
Epoch 9: Adversarial Loss - 1.609, Loss - 0.409
Epoch 10: Adversarial Loss - 1.531, Loss - 0.351
Epoch 11: Adversarial Loss - 1.446, Loss - 0.301
Epoch 12: Adversarial Loss - 1.377, Loss - 0.265
Epoch 13: Adversarial Loss - 1.306, Loss - 0.23
Epoch 14: Adversarial Loss - 1.242, Loss - 0.202
Epoch 15: Adversarial Loss - 1.201, Loss - 0.179
Epoch 16: Adversarial Loss - 1.147, Loss - 0.159
Epoch 17: Adversarial Loss - 1.115, Loss - 0.143
Epoch 18: Adversarial Loss - 1.064, Loss - 0.13
Epoch 19: Adversarial Loss - 1.036, Loss - 0.117
Epoch 20: Adversarial Loss - 1.0

In [8]:
torch.save(vgg11_adv.state_dict(), "/nas/longleaf/home/judychao/STOR566/vgg11_adv_256.pt")
vgg11_adv.load_state_dict(torch.load("/nas/longleaf/home/judychao/STOR566/vgg11_adv_256.pt"))

<All keys matched successfully>