# Adversarial attacks

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

from adv_utils import FGSMAttack

## Load model

In [None]:
# set model weights
weights = ResNet18_Weights.DEFAULT

# create preprocessor
preprocessor = weights.transforms()

# load model
model = resnet18(weights=weights)
model = model.eval()

# get class names
class_names = weights.meta['categories']

## Load image

In [None]:
# load image
image_path = '../test.jpg'

image = Image.open(image_path)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Run model

In [None]:
# preprocess image
x = preprocessor(image) # (3, h, w)
x = x.unsqueeze(0) # (1, 3, h, w)

# run model
with torch.no_grad():
    logits = model(x) # (1, 1000)

# get predicted label
label_ids = logits.argmax(dim=1) # (1,)
labels = [class_names[lidx.item()] for lidx in label_ids]

print(f'Label: {labels}')

## Perform attack

In [None]:
# perform FGSM attack
fgsm = FGSMAttack(
    model=model,
    criterion=nn.CrossEntropyLoss()
)

fgsm_x = fgsm(
    image=x,
    label=label_ids,
    eps=0.003,
    targeted=False
)

In [None]:
# run model
with torch.no_grad():
    fgsm_logits = model(fgsm_x) # (1, 1000)

# get predicted label
fgsm_label_ids = fgsm_logits.argmax(dim=1) # (1,)
fgsm_labels = [class_names[lidx.item()] for lidx in fgsm_label_ids]

print(f'Label: {fgsm_labels}')

In [None]:
# create inverse normalization
renormalizer = transforms.Compose([
    transforms.Normalize(
        mean=[-m/s for m, s in zip(preprocessor.mean, preprocessor.std)],
        std=[1/s for s in preprocessor.std]
    ),
    transforms.Lambda(lambda x: x.clamp(0, 1))
])

In [None]:
# show images
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))

ax1.imshow(renormalizer(x[0]).permute(1, 2, 0).numpy())
ax1.set_aspect('equal', adjustable='box')
ax1.set_title(f'Predicted: {labels[0]}')

ax2.imshow(renormalizer(fgsm_x[0]).permute(1, 2, 0).numpy())
ax2.set_aspect('equal', adjustable='box')
ax2.set_title(f'Predicted: {fgsm_labels[0]}')

fig.tight_layout()