# Adversarial attacks (ART)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet101, ResNet101_Weights

from art.estimators.classification import PyTorchClassifier
from art.attacks.evasion import (
    FastGradientMethod,
    ProjectedGradientDescent
)

from adv_utils import download_file

## Load model

In [None]:
# set model weights
weights = ResNet101_Weights.DEFAULT

# create preprocessor
preprocessor = weights.transforms()

# load model
model = resnet101(weights=weights)
model = model.eval()

# get class names
class_names = weights.meta['categories']

In [None]:
# create preprocessing function
preprocess = lambda img: preprocessor(img).unsqueeze(0)

# create inverse normalization
mean = torch.as_tensor(preprocessor.mean).view(-1, 1, 1)
std = torch.as_tensor(preprocessor.std).view(-1, 1, 1)

renormalize = transforms.Compose([
    transforms.Lambda(lambda x: x * std + mean), # reverse normalization
    transforms.Lambda(lambda x: x.clamp(0, 1)) # clip to valid range
])

In [None]:
# create ART model wrapper
estimator = PyTorchClassifier(
    model=model,
    loss=nn.CrossEntropyLoss(),
    input_shape=(3, 224, 224),
    nb_classes=len(class_names)
)

## Load image

In [None]:
# load image
image_path = '../test.jpg'

if not Path(image_path).exists():
    _ = download_file(
        url='https://upload.wikimedia.org/wikipedia/commons/4/48/Augustine_volcano_Jan_24_2006_-_Cyrus_Read.jpg',
        save_path=image_path
    )

image = Image.open(image_path)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Run model

In [None]:
# preprocess image
x = preprocess(image) # (1, 3, h, w)

# run model
with torch.no_grad():
    logits = model(x) # (1, 1000)

# get predictions
# label_ids = logits.argmax(dim=1) # (1,)
probs = logits.softmax(dim=1) # (1, 1000)
label_probs, label_ids = probs.max(dim=1) # (1,)
labels = [class_names[lidx.item()] for lidx in label_ids]

for l, p in zip(labels, label_probs):
    print(f'Predicted: {l} ({p:.2f})')

## Untargeted FGSM attack

In [None]:
# perform FGSM attack
fgsm = FastGradientMethod(
    estimator=estimator,
    eps=0.005,
    targeted=False
)

fgsm_x = fgsm.generate(x=x.numpy(), y=None)
fgsm_x = torch.from_numpy(fgsm_x)

In [None]:
# run model
with torch.no_grad():
    fgsm_logits = model(fgsm_x) # (1, 1000)

# get predictions
fgsm_probs = fgsm_logits.softmax(dim=1) # (1, 1000)
fgsm_label_probs, fgsm_label_ids = fgsm_probs.max(dim=1) # (1,)
fgsm_labels = [class_names[lidx.item()] for lidx in fgsm_label_ids]

for l, p in zip(fgsm_labels, fgsm_label_probs):
    print(f'Predicted: {l} ({p:.2f})')

In [None]:
# show images
plot_idx = 0

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))

ax1.imshow(renormalize(x[plot_idx]).permute(1, 2, 0).numpy())
ax1.set_title(f'Original: {labels[plot_idx]} ({label_probs[plot_idx]:.2f})')

ax2.imshow(renormalize(fgsm_x[plot_idx]).permute(1, 2, 0).numpy())
ax2.set_title(f'Attacked: {fgsm_labels[plot_idx]} ({fgsm_label_probs[plot_idx]:.2f})')

for ax in (ax1, ax2):
    ax.set_aspect('equal', adjustable='box')
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

fig.tight_layout()

## Targeted PGD attack

In [None]:
# set target label
target_label = 1

print(f'Target: {class_names[target_label]}')

In [None]:
# perform PGD attack
pgd = ProjectedGradientDescent(
    estimator=estimator,
    norm=np.inf,
    eps=0.02,
    eps_step=0.001,
    max_iter=70,
    targeted=True
)

pgd_x = pgd.generate(x=x.numpy(), y=np.array([target_label]))
pgd_x = torch.from_numpy(pgd_x)

In [None]:
# run model
with torch.no_grad():
    pgd_logits = model(pgd_x) # (1, 1000)

# get predictions
pgd_probs = pgd_logits.softmax(dim=1) # (1, 1000)
pgd_label_probs, pgd_label_ids = pgd_probs.max(dim=1) # (1,)
pgd_labels = [class_names[lidx.item()] for lidx in pgd_label_ids]

for l, p in zip(pgd_labels, pgd_label_probs):
    print(f'Predicted: {l} ({p:.2f})')

In [None]:
# show images
plot_idx = 0

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))

ax1.imshow(renormalize(x[plot_idx]).permute(1, 2, 0).numpy())
ax1.set_title(f'Original: {labels[plot_idx]} ({label_probs[plot_idx]:.2f})')

ax2.imshow(renormalize(pgd_x[plot_idx]).permute(1, 2, 0).numpy())
ax2.set_title(f'Attacked: {pgd_labels[plot_idx]} ({pgd_label_probs[plot_idx]:.2f})')

for ax in (ax1, ax2):
    ax.set_aspect('equal', adjustable='box')
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

fig.tight_layout()