[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lorenzobasile/DeepLearningMHPC/blob/main/3_adv_interp_attention.ipynb)

# Lab 3: Adversarial attacks, Interpretability and Attention mechanism

### Recap from previous Lab

* We saw the main techniques to mitigate overfitting in neural networks;
* We built and trained a convolutional network for image classification;
* We saw how to leverage pre-trained parameters to transfer the network's knowledge to new tasks.

### Today

We will show an example of one of the most intriguing phenomena in neural networks, **adversarial vulnerability**, in the context of CIFAR-10 classification. Then, we will interpret the behaviour of our image classifier using **GradCAM**. Finally, we will move away from CNNs and towards transformers, by implementing an **attention** layer from scratch.

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision

# Adversarial attacks

Adversarial attacks are small and maliciously crafted perturbations to the input data of DL systems, usually imperceptible to human eyes but highly disruptive for the neural network. By exploiting the model’s sensitivity to slight changes in input, these attacks can cause incorrect and often unpredictable behavior, even in well-trained systems.

<img src="https://github.com/lorenzobasile/DeepLearningMHPC/blob/main/images/adv.png?raw=1" width="800"/>

This phenomenon reveals fundamental weaknesses in how neural networks process information, highlighting a gap between human perception and machine learning models. Understanding adversarial attacks is crucial for assessing the reliability of deep learning, especially in applications where robustness and security are critical, such as autonomous driving or medical diagnosis.

## Training a CNN on CIFAR-10

The first step we take is training a convolutional classifier on CIFAR-10 images. CIFAR-10 is another popular benchmark dataset, much like MNIST. However, it represents a significant step up in terms of complexity from MNIST: images are now coloured, and slightly larger (each image is represented as 32x32 pixels over 3 (red, green, blue) channels). They belong to 10 categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.

We are not aiming for state-of-the-art performance, but to improve the classification capability of our CNN, we can use light data augmentation. When using `torchvision` datasets, it is very simple to include data augmentation by exploiting the `transforms` submodule. Each time an image is loaded from the dataloader, it may get horizontally flipped and/or cropped, adding to the variability of the data distribution.

In [None]:
batch_size = 64

transforms = {
    'train': torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(32, 4),
        torchvision.transforms.ToTensor()
        ]),
    'test': torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]),
}

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms['train'])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms['test'])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

The data and network architecture we will be using today start to approach real-world sizes. For this experiment, it is a good idea to speed-up training through GPU acceleration. GPU runtime can be chosen in Colab by simply using the menu option `Runtime>Change runtime type`. Once this has been done, the following PyTorch command will automatically perceive `cuda:0` (the first GPU) as the current device.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
!nvidia-smi

Our network has a simple structure, with 2 convolutional layers and ReLU activations. We have to pay extra attention to the device: by default, PyTorch loads the model (and everything else) to CPU; if a GPU has to be used, the model must be loaded to the correct device by using `.to(device)`.

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=0),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.Dropout(p=0.2),
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
        )
        self.pool = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2),
            torch.nn.Flatten()
        )

        self.head = torch.nn.Linear(128*7*7, 10)

    def forward(self, x):
        x = self.conv(x)
        self.last_feature_map = x
        # this will become useful later
        if self.last_feature_map.requires_grad:
            self.last_feature_map.retain_grad()
        x = self.pool(x)
        return self.head(x)

model = CNN().to(device)

Just the usual optimizer and loss definition...

In [None]:
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
loss=torch.nn.CrossEntropyLoss()

...and the usual code to compute the accuracy.

In [None]:
def get_accuracy(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        correct=0
        for x, y in iter(dataloader):
            x=x.to(device)
            y=y.to(device)
            out=model(x)
            correct+=(torch.argmax(out, axis=1)==y).sum()
        return (correct/len(dataloader.dataset)).item()

Our model reaches a competitive accuracy relatively quickly. This value is far from the state-of-the-art on CIFAR-10 (which is well above 99%), but it just serves as a baseline.

In [None]:
epochs=10
for epoch in range(epochs):
    print("Test accuracy: ", get_accuracy(model, testloader, device))
    model.train()
    print("Epoch: ", epoch)
    for x, y in iter(trainloader):
        x=x.to(device)
        y=y.to(device)
        out=model(x)
        l=loss(out, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
print("Final accuracy: ", get_accuracy(model, testloader, device))

## Fast Gradient Sign attack

We will implement the simplest form of gradient-based adversarial attack, the Fast Gradient Sign Method (FGSM), introduced by Goodfellow et al. in 2014. Given an input image $x$, the adversarial perturbation is simply computed as the gradient of the loss function (in our case, the cross-entropy) with respect to $x$:
$$
\Delta = \text{sign}(\nabla_x L(y,\hat{y}))
$$
The perturbation is then rescaled by a factor $\epsilon$, so that it does not exceed a given threshold in $\ell_\infty$ norm. Common values for $\epsilon$ include $\{1,2,3,...,8\}/255$.

The adversarial image is then obtained as:
$$
x'=x+\epsilon\Delta
$$
After this computation, we clamp the perturbed image in $[0,1]$ as this was the original range for the pixels of the clean image.

In this lab, given its simplicity, we are implementing FGSM from scratch. However, in standard research practise, attack algorithms can be taken directly from libraries, such as [advertorch](https://advertorch.readthedocs.io/en/latest/index.html) or [foolbox](https://foolbox.readthedocs.io/en/stable/).

In [None]:
def fgsm_attack(model, image, label, epsilon=8/255):
    # ensure the image requires gradients for the attack
    image.requires_grad = True

    output = model(image)
    loss = torch.nn.functional.cross_entropy(output, label)

    # backward pass: compute gradients of the loss w.r.t. the image
    model.zero_grad()
    loss.backward()

    # create adversarial example
    perturbed_image = image + epsilon * image.grad.sign()
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # keep pixel values valid

    return perturbed_image, perturbed_image-image

We have to compute gradients to apply FGSM, hence the context manager `torch.no_grad()` cannot be applied here.

In [None]:
def get_adversarial_accuracy(model, dataloader, attack, device):
    model.eval()
    correct=0
    for x, y in iter(dataloader):
        x=x.to(device)
        y=y.to(device)
        adv, _ =attack(model, x, y)
        with torch.no_grad():
            out=model(adv)
            correct+=(torch.argmax(out, axis=1)==y).sum()
    return (correct/len(dataloader.dataset)).item()

A very small perturbation, $\epsilon=\frac{4}{255}$ is already enough to disrupt most of the classifier's performance.

In [None]:
from functools import partial

get_adversarial_accuracy(model, testloader, partial(fgsm_attack, epsilon=8/255), device)

We can now visualize a few examples of clean and adversarially attacked images.

In [None]:
x,y=next(iter(testloader))
x=x.to(device)
y=y.to(device)
adv, pert =fgsm_attack(model, x, y, 8/255)
adversarial_y = model(adv).argmax(-1)

In [None]:
image_idx=1

fig, axes = plt.subplots(1, 3, figsize=(10, 5))

axes[0].imshow(x[image_idx].detach().cpu().permute(1, 2, 0))
axes[0].set_title(f"Clean: {testset.classes[y[image_idx]]}")

axes[1].imshow(adv[image_idx].detach().cpu().permute(1, 2, 0))
axes[1].set_title(f"Adversarial: {testset.classes[adversarial_y[image_idx]]}")

axes[2].imshow(pert[image_idx].detach().cpu().permute(1, 2, 0))
axes[2].set_title(f"Delta: {testset.classes[adversarial_y[image_idx]]}")

plt.show()


In [None]:
pert[1].max()

## Adversarial defense

Defending against adversarial attacks is complex, and it is still an open problem in DL research.

A possible approach is to apply adversarial training. This method foresees a double optimization: the network gets fine-tuned using both clean and adversarial data. Note: adversarial data depends on current network weights; when doing adversarial training, the attack algorithm has to be run again after each optimization step.

Adversarial training can significantly improve robustness, at the cost of a (hopefully small) drop in clean accuracy.

In [None]:
epochs=10
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
loss=torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
    print("Initial test accuracy: ", get_accuracy(model, testloader, device))
    print("Initial adversarial accuracy: ", get_adversarial_accuracy(model, testloader,partial(fgsm_attack, epsilon=4/255), device))
    model.train()
    print("Epoch: ", epoch)
    for x, y in iter(trainloader):
        x=x.to(device)
        y=y.to(device)
        clean_out=model(x)
        l=loss(clean_out, y)
        adv, pert =fgsm_attack(model, x, y, 4/255)
        adv_out=model(adv)
        l+=loss(adv_out, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
print("Final clean accuracy: ", get_accuracy(model, testloader, device))
print("Final adversarial accuracy: ", get_adversarial_accuracy(model, testloader,partial(fgsm_attack, epsilon=4/255), device))

# GradCAM

Neural networks, especially deep ones, are often criticized as black boxes: they can achieve impressive accuracy, but give us little insight into why they make a particular decision. This lack of interpretability can a serious concern in applications where understanding the model’s reasoning is crucial.

GradCAM (Gradient-weighted Class Activation Mapping, Selvaraju et al., 2017) is a simple but powerful method to visualize what parts of an input image a convolutional neural network focuses on when making a prediction. It works by tracing back the gradients from a specific class score to the last convolutional layer (or, in principle, any other convolutional layer), producing a heatmap that highlights the regions of the image most influential to the decision. This helps us see whether the network is attending to the right features or relying on spurious patterns.

<img src="https://github.com/lorenzobasile/DeepLearningMHPC/blob/main/images/gradcam.png?raw=1" width="800"/>

GradCAM is straightforward to implement, assuming that one has access to the activations and gradients of the convolutional layer of interest.

In a nutshell, starting from the pre-softmax activation of the desired class $y_c$, the GradCAM activation map can be computed as:
$$
M_c(i,j)=\text{ReLU}(\sum_k \alpha_k^c A_k(i,j))
$$
where $\alpha_k^c$ is obtained as:
$$
\alpha_k^c=\frac{1}{Z}\sum_{i,j}\frac{\partial{y_c}}{\partial A_k(i,j)}
$$

The heatmap computed by GradCAM has the same spatial dimension as the convolutional filters, which usually differs from the input dimension. To 'project' the heatmap in the input space by using bilinear interpolation.

We will be using a lot the `squeeze` and `unsqueeze` functions of torch. `squeeze` removes all the dimensions of size 1 from the tensor; `unsqueeze` adds a dimension of size 1 where specified.


In [None]:
x = torch.rand(3,4)

In [None]:
x=x.unsqueeze(dim=0)
print(x.shape)
x=x.unsqueeze(dim=-1)
print(x.shape)
x=x.squeeze()
print(x.shape)

In [None]:
def gradcam(image, label, model):
    model.eval()
    image=image.unsqueeze(0) # we need the batch dimension for the forward pass
    output=model(image).squeeze() # but we can remove it afterwards
    prediction = output.argmax(-1)
    print("Prediction: ", testset.classes[prediction])
    print("True label: ", testset.classes[label])
    model.zero_grad()
    output[label].backward() # we are doing backprop not on the loss function, but on the output
    feature_maps = model.last_feature_map.squeeze() # we remove the batch dimension
    weights = model.last_feature_map.grad.squeeze().mean(dim=(1,2)) # weight is the average importance over a filter
    weighted_feature_maps=(feature_maps*weights.reshape(-1,1,1)) # broadcasting
    grad_cam=torch.nn.functional.relu(weighted_feature_maps.sum(dim=0))
    grad_cam = grad_cam.unsqueeze(0).unsqueeze(0) # interpolate wants [N,C,H,W] data
    print(grad_cam.shape)
    grad_cam = torch.nn.functional.interpolate(grad_cam, size=(image.shape[2:]), mode='bilinear', align_corners=False)
    return grad_cam.squeeze().detach().cpu()

Once the class activation map has been computed, it can be simply overlaid to the original image to see what parts of it matter the most in making the classification.

In [None]:
x,y=next(iter(testloader))
x=x.to(device)
y=y.to(device)
image_idx=0

heatmap = gradcam(x[image_idx], y[image_idx], model)
heatmap = heatmap - heatmap.min()
heatmap = heatmap / heatmap.max()

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(x[image_idx].detach().cpu().permute(1, 2, 0))
axes[0].set_title(f"Image: {testset.classes[y[image_idx]]}")

axes[1].imshow(x[image_idx].detach().cpu().permute(1, 2, 0))
axes[1].set_title(f"GradCAM: {testset.classes[y[image_idx]]}")
axes[1].imshow(heatmap, alpha=heatmap, cmap='jet', interpolation='bilinear')

plt.show()


# The attention mechanism

Attention is the core computational mechanism of transformer models. The aim of attention is to make different *tokens* communicate, either within the same sequence (self-attention) or between different sequences (cross-attention). In this notebook, we will implement self-attention, as it is the most common in modern transformers.

Given an input sequence of tokens $X$, the computation of attention starts by projecting it into three corresponding sequences $Q$ (queries), $K$ (keys) and $V$ (values), using learnable projection matrices:

$$
Q=W_QX, \quad K=W_KX, \quad V=W_VX
$$

Then, an *attention map* is produced. This map, whose shape is $|X|\times|X|$, encodes the relevance of each key entry with respect to each query entry. In simple terms, it represents how important is a position in the sequence to another.

$$
A = \text{Softmax}(\frac{QK^T}{d_K})
$$

Once these weights are computed, the output sequence is obtained by weighting the values and applying a final projection $W_O$:

$$
Y = W_O(A V)
$$

## Self-attention

As a first step, we will implement the self-attention computation in plain PyTorch, by only using linear layers.

In [None]:
import math

class Attention(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.d = output_dim
        self.q_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.k_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.v_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.out_proj = torch.nn.Linear(output_dim, output_dim, bias=True)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        unnormalized_map = torch.einsum('btHd,bTHd->bHtT', q, k)/math.sqrt(self.d)
        weights = torch.nn.functional.softmax(unnormalized_map, dim=-1)
        out = torch.einsum('BtT,BTd->Btd', weights, v)
        #alternative with bmm
        #unnormalized_map = torch.bmm(q, k.permute(0,2,1))/math.sqrt(self.d)
        #out = torch.bmm(weights, v)
        return self.out_proj(out)



Actually, torch implements attention, so there is no need to write this code from scratch when using transformers. We can compare our implementation with the built-in of torch to verify that there are no difference in the results.

In most applications, the simplest way to use attention is by means of the class [MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), but here we use the functional interface because it allows greater flexibility.

In [None]:
attn = Attention(128, 128)
x=torch.randn(8, 10, 128)
myattn = attn(x)

#in torch, the multihead attention is implemented supposing the first dimension is the token dimension
y = x.permute(1, 0, 2)

mha2_out = torch.nn.functional.multi_head_attention_forward(
    query=y,
    key=y,
    value=y,
    embed_dim_to_check=128,
    num_heads=1,
    use_separate_proj_weight=True,
    in_proj_weight=None,
    in_proj_bias=None,
    bias_k=None,
    bias_v=None,
    need_weights=False,
    q_proj_weight=attn.q_proj.weight,
    k_proj_weight=attn.k_proj.weight,
    v_proj_weight=attn.v_proj.weight,
    out_proj_bias=attn.out_proj.bias,
    out_proj_weight=attn.out_proj.weight,
    add_zero_attn=False,
    dropout_p=0,
)[0].permute(1, 0, 2)
print(torch.allclose(mha2_out, myattn, atol=1e-6))


# Homework

- Read the paper [Adversarial Examples in the Physical World](https://arxiv.org/pdf/1607.02533), and implement the Basic Iterative Method explained in section 2.2. In short, BIM is an iterative version of FGSM.
- Test your code on our CNN, classifying CIFAR-10 data, and visually inspect the parts of the image that the attack is using to misguide the network. For this step, you can use GradCAM on some clean and adversarial images, using the correct label in the former case and the adversarial one in the latter.
- Perform adversarial training on BIM, and keep track of clean and adversarial accuracy during training. Extra: you can also measure the robustness of the model trained on BIM on the vanilla FGSM attack.