In [1]:
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torch.utils.data as td
import random, time
import matplotlib.pyplot as plt
import torchvision
import PIL.Image as Image
from tqdm import tqdm
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.utils as vutils
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
from vgg import vgg11_bn
vgg11 = vgg11_bn()
vgg11.load_state_dict(torch.load('./state_dicts/vgg11_bn.pt'))
vgg11.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

In [3]:
## CREATING THE DATALOADER

# os.chdir('/content/drive/MyDrive/Final Project/Adversarial Examples')

import pickle

perturbed_list = []
pert_label_list = []

for i in range(7):
  file = open(f"perturbed_imgs_combo_{i}.pkl", 'rb')
  perturbed_list.append(pickle.load(file))
  file.close()


  file = open(f"labels_combo_{i}.pkl", 'rb')
  pert_label_list.append(pickle.load(file))
  file.close()

In [4]:
perturbed_list[0][0].size()

torch.Size([50, 3, 32, 32])

In [5]:
len(perturbed_list[0])

1000

In [6]:
flat_pert_imgs = []

for i in range(7):
  if type(perturbed_list[i]) is list:
    for x, (data) in enumerate(perturbed_list[i]):
      flat_pert_imgs.append(data)


    

In [7]:
len(flat_pert_imgs[0])

50

In [8]:
flat_pert_imgs[0].size()

torch.Size([50, 3, 32, 32])

In [9]:
len(pert_label_list)

7

In [10]:
flat_pert_labels = []

for i in range(7):
  if type(pert_label_list[i]) is list:
    for x, (label) in enumerate(pert_label_list[i]):
      flat_pert_labels.append(label)

In [11]:
len(flat_pert_labels[0])

50

In [12]:
# To create the dataset

from torch.utils.data import DataLoader, Dataset

class CustomImageDataset(Dataset):
    def __init__(self, labels, images, transform=None, target_transform=None):
        self.labels = labels
        self.images = images
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        image = self.images[idx]
        return image, label

In [13]:
full_ds = CustomImageDataset(flat_pert_labels, flat_pert_imgs)
full_loader = torch.utils.data.DataLoader(full_ds, batch_size = 1, shuffle = True)

In [14]:
print(len(full_loader.dataset))

7000


In [15]:
print(len(full_loader))

7000


In [16]:
batch_size_cifar = 50

def cifar_loaders(batch_size, shuffle_test=False): 
    data_dir = './data'
    train = datasets.CIFAR10(data_dir, train=True, download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(),
        ]))
    # Once you have downloaded the data by setting download=True, you can
    # change download=True to download=False
    test = datasets.CIFAR10(data_dir, train=False, 
        transform=transforms.Compose([transforms.ToTensor()]))
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size,
        shuffle=shuffle_test, pin_memory=True)
    return train_loader, test_loader

train_cifar_loader, test_cifar_loader = cifar_loaders(batch_size_cifar)

Files already downloaded and verified


In [17]:
def adversarial_update(images,labels, model,loss, epsilon, step_size, n_iter):
  delta = torch.zeros(images.size()).to(device)
  image_mod = (images+delta)
  for i in range(n_iter):
    image_mod.requires_grad = True
    preds = model(image_mod)
    loss_val = loss(preds,labels)
    if i == 0:
      standard_loss = loss_val.item()
    loss_val.backward()
    gradient = torch.sign(image_mod.grad.data)
    delta = delta + step_size*gradient
    delta = delta.clamp(min=-epsilon, max=epsilon)
    image_mod = (images+delta).clamp(min=0,max=1)
  return (image_mod, standard_loss)

In [19]:
# os.chdir('/content/drive/MyDrive/Final Project/models/state_dicts')
from tqdm import tqdm
import time

vgg11_adv_ens = vgg11_bn()
vgg11_adv_ens.load_state_dict(torch.load('./state_dicts/vgg11_bn.pt'))
vgg11_adv_ens.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg11_adv_ens.parameters(),lr=1e-5)
num_epochs = 32

loss_standard = np.zeros(num_epochs)
loss_adv = np.zeros(num_epochs)

for i in range(num_epochs):
  start = time.time()
  running_loss_adv = 0
  running_loss = 0
  print("Completing epoch: ", i)
  for batch, (data, labels) in enumerate(train_cifar_loader):
    data = data.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    data_mod, standard_loss = adversarial_update(data,labels, vgg11_adv_ens,criterion, epsilon=.0625, step_size=.01, n_iter=7)
    running_loss += standard_loss
    preds = vgg11_adv_ens(data_mod)
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    running_loss_adv += loss.item()
    batches = batch+1
  for batch, (data, labels) in enumerate(tqdm(full_ds)):
    # print("Batch #")
    data = data.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    preds = vgg11_adv_ens(data)
    loss = criterion(preds, labels)
    loss.backward()
    optimizer.step()
    running_loss_adv += loss.item()
    batches = batches+1
  #loss_standard[i] = running_loss/batches
  loss_adv[i] = running_loss_adv/batches
  end = time.time()
  print("Epoch completed in: ", (end - start)/60, " minutes" )
  print("Epoch "+str(i)+": Adversarial Loss - "+str(round(loss_adv[i],3))+", Loss - "+str(round(loss_standard[i],3)))
torch.save(vgg11_adv_ens.state_dict(), "./state_dicts/vgg11_adv_ens_4.pt")
vgg11_adv_ens.load_state_dict(torch.load( "./state_dicts/vgg11_adv_ens_4.pt"))

Completing epoch:  0


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4485867897669475  minutes
Epoch 0: Adversarial Loss - 0.444, Loss - 0.0
Completing epoch:  1


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.448859671751658  minutes
Epoch 1: Adversarial Loss - 0.361, Loss - 0.0
Completing epoch:  2


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.4472189505894977  minutes
Epoch 2: Adversarial Loss - 0.331, Loss - 0.0
Completing epoch:  3


100%|██████████| 7000/7000 [01:57<00:00, 59.74it/s]


Epoch completed in:  3.4470033685366315  minutes
Epoch 3: Adversarial Loss - 0.311, Loss - 0.0
Completing epoch:  4


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.447192211945852  minutes
Epoch 4: Adversarial Loss - 0.295, Loss - 0.0
Completing epoch:  5


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.4473703066507975  minutes
Epoch 5: Adversarial Loss - 0.282, Loss - 0.0
Completing epoch:  6


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.4475618759791056  minutes
Epoch 6: Adversarial Loss - 0.268, Loss - 0.0
Completing epoch:  7


100%|██████████| 7000/7000 [01:57<00:00, 59.73it/s]


Epoch completed in:  3.447585713863373  minutes
Epoch 7: Adversarial Loss - 0.255, Loss - 0.0
Completing epoch:  8


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4488937656084695  minutes
Epoch 8: Adversarial Loss - 0.239, Loss - 0.0
Completing epoch:  9


100%|██████████| 7000/7000 [01:57<00:00, 59.75it/s]


Epoch completed in:  3.4465649644533793  minutes
Epoch 9: Adversarial Loss - 0.225, Loss - 0.0
Completing epoch:  10


100%|██████████| 7000/7000 [01:57<00:00, 59.78it/s]


Epoch completed in:  3.4449847181638082  minutes
Epoch 10: Adversarial Loss - 0.215, Loss - 0.0
Completing epoch:  11


100%|██████████| 7000/7000 [01:57<00:00, 59.78it/s]


Epoch completed in:  3.445174849033356  minutes
Epoch 11: Adversarial Loss - 0.208, Loss - 0.0
Completing epoch:  12


100%|██████████| 7000/7000 [01:57<00:00, 59.77it/s]


Epoch completed in:  3.445517587661743  minutes
Epoch 12: Adversarial Loss - 0.195, Loss - 0.0
Completing epoch:  13


100%|██████████| 7000/7000 [01:57<00:00, 59.78it/s]


Epoch completed in:  3.445621184508006  minutes
Epoch 13: Adversarial Loss - 0.192, Loss - 0.0
Completing epoch:  14


100%|██████████| 7000/7000 [01:57<00:00, 59.78it/s]


Epoch completed in:  3.4465801239013674  minutes
Epoch 14: Adversarial Loss - 0.185, Loss - 0.0
Completing epoch:  15


100%|██████████| 7000/7000 [01:57<00:00, 59.78it/s]


Epoch completed in:  3.4451008598009745  minutes
Epoch 15: Adversarial Loss - 0.174, Loss - 0.0
Completing epoch:  16


100%|██████████| 7000/7000 [01:57<00:00, 59.79it/s]


Epoch completed in:  3.444683579603831  minutes
Epoch 16: Adversarial Loss - 0.168, Loss - 0.0
Completing epoch:  17


100%|██████████| 7000/7000 [01:57<00:00, 59.80it/s]


Epoch completed in:  3.4444118181864423  minutes
Epoch 17: Adversarial Loss - 0.161, Loss - 0.0
Completing epoch:  18


100%|██████████| 7000/7000 [01:57<00:00, 59.79it/s]


Epoch completed in:  3.4448420643806457  minutes
Epoch 18: Adversarial Loss - 0.156, Loss - 0.0
Completing epoch:  19


100%|██████████| 7000/7000 [01:57<00:00, 59.79it/s]


Epoch completed in:  3.444656201203664  minutes
Epoch 19: Adversarial Loss - 0.152, Loss - 0.0
Completing epoch:  20


100%|██████████| 7000/7000 [01:57<00:00, 59.74it/s]


Epoch completed in:  3.446317942937215  minutes
Epoch 20: Adversarial Loss - 0.148, Loss - 0.0
Completing epoch:  21


100%|██████████| 7000/7000 [01:57<00:00, 59.71it/s]


Epoch completed in:  3.448064080874125  minutes
Epoch 21: Adversarial Loss - 0.143, Loss - 0.0
Completing epoch:  22


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.44717652797699  minutes
Epoch 22: Adversarial Loss - 0.14, Loss - 0.0
Completing epoch:  23


100%|██████████| 7000/7000 [01:57<00:00, 59.71it/s]


Epoch completed in:  3.44740758339564  minutes
Epoch 23: Adversarial Loss - 0.139, Loss - 0.0
Completing epoch:  24


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4465004483858745  minutes
Epoch 24: Adversarial Loss - 0.133, Loss - 0.0
Completing epoch:  25


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4468124866485597  minutes
Epoch 25: Adversarial Loss - 0.127, Loss - 0.0
Completing epoch:  26


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.446265653769175  minutes
Epoch 26: Adversarial Loss - 0.123, Loss - 0.0
Completing epoch:  27


100%|██████████| 7000/7000 [01:57<00:00, 59.71it/s]


Epoch completed in:  3.44640851020813  minutes
Epoch 27: Adversarial Loss - 0.121, Loss - 0.0
Completing epoch:  28


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4472744941711424  minutes
Epoch 28: Adversarial Loss - 0.117, Loss - 0.0
Completing epoch:  29


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.4461735526720685  minutes
Epoch 29: Adversarial Loss - 0.114, Loss - 0.0
Completing epoch:  30


100%|██████████| 7000/7000 [01:57<00:00, 59.72it/s]


Epoch completed in:  3.446364462375641  minutes
Epoch 30: Adversarial Loss - 0.112, Loss - 0.0
Completing epoch:  31


100%|██████████| 7000/7000 [01:57<00:00, 59.71it/s]


Epoch completed in:  3.446285839875539  minutes
Epoch 31: Adversarial Loss - 0.111, Loss - 0.0


<All keys matched successfully>