In [67]:
from PIL import Image
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
from torchvision import models
from tqdm.notebook import tqdm
import numpy as np


In [68]:
class AdversarialDataset(Dataset):
    def __init__(self, annotation_file, categories_file, img_dir, x_transform=None, y_transform=None):
        self.img_dir = img_dir
        annotations = pd.read_csv(annotation_file)
        self.categories = pd.read_csv(categories_file)
        self.images = annotations["ImageId"]
        self.labels = annotations["TrueLabel"]
        self.x_transform = x_transform
        self.y_transform = y_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx] + ".png")
        image = Image.open(img_path)

        if self.x_transform:
            image = self.x_transform(image)

        return image, self.labels[idx] - 1, self.images[idx]

In [69]:
def show_image(datarow):
    permuted = torch.permute(datarow[0], (1, 2, 0))
    plt.title(datarow[1])
    plt.imshow(permuted)

In [70]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [71]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [72]:
dataset = AdversarialDataset("images.csv", "categories.csv", "images", x_transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=False)

model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model = model.to(device)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

<h1>Training</h1>

In [73]:
n_epochs = 0
running_loss = 0
n_correct = 0

for epoch in range(n_epochs):
    for x_batch, y_batch in tqdm(dataloader):
        x_batch = x_batch.float()
        x_batch = x_batch.to(device)
        y_batch = y_batch.type(torch.LongTensor)
        y_batch = y_batch.to(device)
    
        output, _ = model(x_batch)
    
        loss = loss_function(output, y_batch)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item() + x_batch.size(0)
    
        _, preds = torch.max(output, dim=1)
        n_correct += torch.sum(preds == y_batch.data)

<h1>Testing</h1>

In [74]:
# n_correct = 0
# 
# for x_batch, y_batch in tqdm(dataloader):
#     x_batch = x_batch.to(device)
#     y_batch = y_batch.to(device)
# 
#     output, _ = model(x_batch)
# 
#     _, preds = torch.max(output, dim=1)
#     n_correct += torch.sum(preds == y_batch)
#     
# print(n_correct/len(dataset))

<h3> Generate C&W images </h3>

In [75]:
import cw_impl.cw as cw

def generate_cw_samples(model, dataloader, save_to_disk=True, save_dir="adversarial_images"):
    inputs_box = (min((0 - m) / s for m, s in zip(mean, std)), max((1 - m) / s for m, s in zip(mean, std)))
    
    adversary = cw.L2Adversary(targeted=False, 
                               confidence=0.0,
                               search_steps=10,
                               abort_early=True,
                               box=inputs_box,
                               optimizer_lr=5e-4)
    
    model.eval()
    for inputs, targets, input_ids in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        adversarial_examples = adversary(model, inputs, targets, to_numpy=False)
        
        if save_to_disk:
            for i in range(adversarial_examples.shape[0]):
                # Save as numpy array instead of .png to work around loss of data when scaling
                with open(os.path.join(save_dir, input_ids[i] + ".npy"), 'wb') as f:
                    np.save(f, adversarial_examples[i])
        
generate_cw_samples(model, dataloader)
        

Using scale consts: [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]
batch [0] loss: 0.060366254299879074
batch [10] loss: 0.9654200077056885
batch [20] loss: 0.23800158500671387
batch [30] loss: 0.06144149973988533
batch [40] loss: 0.07452215254306793
batch [50] loss: 0.06957646459341049
batch [60] loss: 0.06344868242740631
batch [70] loss: 0.06080745905637741
batch [80] loss: 0.060240693390369415
batch [90] loss: 0.06039150804281235
batch [100] loss: 0.06029155105352402
batch [110] loss: 0.06021219864487648
batch [120] loss: 0.06022156774997711
batch [130] loss: 0.06021089851856232
batch [140] loss: 0.06022205203771591
batch [150] loss: 0.06020442396402359
batch [160] loss: 0.060220785439014435
batch [170] loss: 0.060203634202480316
batch [180] loss: 0.060196541249752045
batch [190] loss: 0.060198940336704254
batch [200] loss: 0.06019965186715126
batch [210] loss: 0.06021055579185486
batch [220] loss: 0.06022053584456444

In [76]:
def test_batch(inputs, outputs, device):
    inputs = inputs.to(device)
    outputs = outputs.to(device)
    
    output = model(inputs)

    _, preds = torch.max(output, dim=1)
    print(preds)
    return torch.sum(preds == outputs)/inputs.shape[0]