In [None]:
# Tutorial: https://pytorch.org/tutorials/beginner/fgsm_tutorial.html
# FGSM: perturbed_image=image+epsilon∗sign(data_grad)=x+ϵ∗sign(∇ J(θ,x,y))

In [39]:
# import
import torch
import wandb
from fastai.callback.wandb import WandbCallback
from fastai.data.all import (
    CategoryBlock,
    DataBlock,
    DataLoaders,
    RandomSplitter,
    RegressionBlock,
)

from fastai.vision.all import *
from fastai.losses import CrossEntropyLossFlat
from fastai.vision.data import ImageBlock, ImageDataLoaders
# from fastai.vision.learner import Learner, accuracy, vision_learner
# from fastai.vision.models import resnet18, resnet34
from fastai.vision.utils import get_image_files
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from math import radians

In [24]:
fastai.__version__

'2.7.11'

In [25]:
device = torch.device(0)
device

device(type='cuda', index=0)

In [51]:
# get data
data_folder = "/data/clark/scr2023/data/WanderingStaticTextures/"
# get model
models_folder = "/data/clark/scr2023/data/WanderingStaticTextures/models/"
model_name = "wandering-static_rep00.pkl"

In [52]:
learner = load_learner(models_folder + model_name, cpu=False)

In [30]:
epsilons = [0, .05, .1, .15, .2, .25, .3]

In [72]:
def y_from_filename(rotation_threshold, filename) -> str:
    """Extracts the direction label from the filename of an image.

    Example: "path/to/file/001_000011_-1p50.png" --> "right"
    """
    filename_stem = Path(filename).stem
    angle = float(filename_stem.split("_")[2].replace("p", "."))

    if angle > rotation_threshold:
        return "left"
    elif angle < -rotation_threshold:
        return "right"
    else:
        return "forward"
        
def get_dls(data_path: str):
    # NOTE: not allowed to add a type annotation to the input

    image_filenames: list = get_image_files(data_path)  # type:ignore

    # Using a partial function to set the rotation_threshold from args
    label_func = partial(y_from_filename, radians(5))

    return ImageDataLoaders.from_name_func(
            data_path,
            image_filenames,
            label_func,
            valid_pct=0,
            shuffle=True,
            bs=1)

In [73]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

In [74]:
def test(model, test_loader, epsilon):
    
    # Accuracy counter
    correct = 0
    adv_examples = []

    # Loop over all examples in test set
    for data, target in test_loader:

        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model.predict(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        print(output + " init_pred:" + init_pred)

        # If the initial prediction is wrong, don't bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = CrossEntropyLossFlat(output, target)
        print(loss)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect ``datagrad``
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = model(perturbed_data)

        # Check for success
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
            # Special case for saving 0 epsilon examples
            if epsilon == 0 and len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
        else:
            print("adversarial example generated")
            # Save some adv examples for visualization later
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )

    # Calculate final accuracy for this epsilon
    final_acc = correct/float(len(test_loader))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")

    # Return the accuracy and an adversarial example
    return final_acc, adv_examples

In [None]:
accuracies = []
examples = []

test_loader = get_dls(data_folder).train

# Run test for each epsilon
for eps in epsilons:
    acc, ex = test(learner, test_loader, eps)
    accuracies.append(acc)
    examples.append(ex)

In [None]:
# Plot several examples of adversarial samples at each epsilon
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0]),cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
        orig,adv,ex = examples[i][j]
        plt.title(f"{orig} -> {adv}")
        plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
run = wandb.init(
    name=args.name,
    project="DataAugmentation",
    entity="arcslaboratory",
    notes="Training models with adversarial training",
    job_type="train",
)

if run is None:
    raise Exception("wandb.init() failed")