[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lorenzobasile/DeepLearningMHPC/blob/main/3_adv_attention.ipynb)

# Lab 3: Adversarial attacks and Attention mechanism

### Recap from previous Lab

* We saw the main techniques to mitigate overfitting in neural networks;
* We built and trained a convolutional network for image classification;
* We saw how to leverage pre-trained parameters to transfer the network's knowledge to new tasks.

### Today

We will show an example of one of the most intriguing phenomena in neural networks, **adversarial vulnerability**, in the context of CIFAR-10 classification. Then, we will move away from CNNs and towards transformers, by implementing the **attention** layer and a simple transformer model from scratch.

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision

# Adversarial attacks

Adversarial attacks are small and maliciously crafted perturbations to the input data of DL systems, usually imperceptible to human eyes but highly disruptive for the neural network. By exploiting the model’s sensitivity to slight changes in input, these attacks can cause incorrect and often unpredictable behavior, even in well-trained systems. 

<img src="./images/adv.png" width="800"/>

This phenomenon reveals fundamental weaknesses in how neural networks process information, highlighting a gap between human perception and machine learning models. Understanding adversarial attacks is crucial for assessing the reliability of deep learning, especially in applications where robustness and security are critical, such as autonomous driving or medical diagnosis.

## Training a CNN on CIFAR-10

The first step we take is training a convolutional classifier on CIFAR-10 images. CIFAR-10 is another popular benchmark dataset, much like MNIST. However, it represents a significant step up in terms of complexity from MNIST: images are now coloured, and slightly larger (each image is represented as 32x32 pixels over 3 (red, green, blue) channels). They belong to 10 categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.

We are not aiming for state-of-the-art performance, but to improve the classification capability of our CNN, we can use light data augmentation. When using `torchvision` datasets, it is very simple to include data augmentation by exploiting the `transforms` submodule. Each time an image is loaded from the dataloader, it may get horizontally flipped and/or cropped, adding to the variability of the data distribution.

In [None]:
batch_size = 64

transforms = {
    'train': torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(32, 4),
        torchvision.transforms.ToTensor()
        ]),
    'test': torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]),
}

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms['train'])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms['test'])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

The data and network architecture we will be using today start to approach real-world sizes. For this experiment, it is a good idea to speed-up training through GPU acceleration. GPU runtime can be chosen in Colab by simply using the menu option `Runtime>Change runtime type`. Once this has been done, the following PyTorch command will automatically perceive `cuda:0` (the first GPU) as the current device.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Our network has a simple structure, with 2 convolutional layers and ReLU activations. We have to pay extra attention to the device: by default, PyTorch loads the model (and everything else) to CPU; if a GPU has to be used, the model must be loaded to the correct device by using `.to(device)`.

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=0),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.Dropout(p=0.2),
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
        )
        self.pool = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2),
            torch.nn.Flatten()
        )

        self.head = torch.nn.Linear(128*7*7, 10)

    def forward(self, x):
        x = self.conv(x)
        self.last_feature_map = x
        x = self.pool(x)
        return self.head(x)

model = CNN().to(device)

Just the usual optimizer and loss definition...

In [None]:
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
loss=torch.nn.CrossEntropyLoss()

...and the usual code to compute the accuracy.

In [None]:
def get_accuracy(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        correct=0
        for x, y in iter(dataloader):
            x=x.to(device)
            y=y.to(device)
            out=model(x)
            correct+=(torch.argmax(out, axis=1)==y).sum()
        return (correct/len(dataloader.dataset)).item()

Our model reaches a competitive accuracy relatively quickly. This value is far from the state-of-the-art on CIFAR-10 (which is well above 99%), but it just serves as a baseline.

In [None]:
epochs=20
for epoch in range(epochs):
    print("Test accuracy: ", get_accuracy(model, testloader, device))
    model.train()
    print("Epoch: ", epoch)
    for x, y in iter(trainloader):
        x=x.to(device)
        y=y.to(device)
        out=model(x)
        l=loss(out, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
print("Final accuracy: ", get_accuracy(model, testloader, device))

## Fast Gradient Sign attack

We will implement the simplest form of gradient-based adversarial attack, the Fast Gradient Sign Method (FGSM), introduced by Goodfellow et al. in 2014. Given an input image $x$, the adversarial perturbation is simply computed as the gradient of the loss function (in our case, the cross-entropy) with respect to $x$:
$$
\Delta = \text{sign}(\nabla_x L(y,\hat{y}))
$$
The perturbation is then rescaled by a factor $\epsilon$, so that it does not exceed a given threshold in $\ell_\infty$ norm. Common values for $\epsilon$ include $\{1,2,3,...,8\}/255$.

The adversarial image is then obtained as:
$$
x'=x+\epsilon\Delta
$$
After this computation, we clamp the perturbed image in $[0,1]$ as this was the original range for the pixels of the clean image.

In this lab, given its simplicity, we are implementing FGSM from scratch. However, in standard research practise, attack algorithms can be taken directly from libraries, such as [advertorch](https://advertorch.readthedocs.io/en/latest/index.html) or [foolbox](https://foolbox.readthedocs.io/en/stable/).

In [None]:
def fgsm_attack(model, image, label, epsilon=8/255):
    # Ensure the image requires gradients for the attack
    image.requires_grad = True

    # Forward pass: Get the model's prediction
    output = model(image)
    loss = torch.nn.functional.cross_entropy(output, label)

    # Backward pass: Compute gradients of the loss w.r.t. the image
    model.zero_grad()
    loss.backward()

    # Get the sign of the gradients
    sign_data_grad = image.grad.sign()

    # Create adversarial example
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # Keep pixel values valid


    return perturbed_image, perturbed_image-image

We have to compute gradients to apply FGSM, hence the context manager `torch.no_grad()` cannot be applied here.

In [None]:
def get_adversarial_accuracy(model, dataloader, attack, device):
    model.eval()
    correct=0
    for x, y in iter(dataloader):
        x=x.to(device)
        y=y.to(device)
        adv, _ =attack(model, x, y)
        with torch.no_grad():
            out=model(adv)
            correct+=(torch.argmax(out, axis=1)==y).sum()
    return (correct/len(dataloader.dataset)).item()

A very small perturbation, $\epsilon=\frac{4}{255}$ is already enough to disrupt most of the classifier's performance.

In [None]:
from functools import partial

get_adversarial_accuracy(model, testloader, partial(fgsm_attack, epsilon=4/255), device)

We can now visualize a few examples of clean and adversarially attacked images.

In [None]:
x,y=next(iter(testloader))
x=x.to(device)
y=y.to(device)
adv, pert =fgsm_attack(model, x, y, 4/255)
adversarial_y = model(adv).argmax(-1)

In [None]:
image_idx=42

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(x[image_idx].detach().cpu().permute(1, 2, 0))
axes[0].set_title(f"Clean: {testset.classes[y[image_idx]]}")

axes[1].imshow(adv[image_idx].detach().cpu().permute(1, 2, 0))
axes[1].set_title(f"Adversarial: {testset.classes[adversarial_y[image_idx]]}")

plt.show()


## Adversarial defense

Defending against adversarial attacks is complex, and it is still an open problem in DL research. 

A possible approach is to apply adversarial training. This method foresees a double optimization: the network gets fine-tuned using both clean and adversarial data. Note: adversarial data depends on current network weights; when doing adversarial training, the attack algorithm has to be run again after each optimization step.

Adversarial training can significantly improve robustness, at the cost of a (hopefully small) drop in clean accuracy.

In [None]:
epochs=10
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
loss=torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
    print("Initial test accuracy: ", get_accuracy(model, testloader, device))
    print("Initial adversarial accuracy: ", get_adversarial_accuracy(model, testloader,partial(fgsm_attack, epsilon=4/255), device))
    model.train()
    print("Epoch: ", epoch)
    for x, y in iter(trainloader):
        x=x.to(device)
        y=y.to(device)
        clean_out=model(x)
        l=loss(clean_out, y)
        adv, pert =fgsm_attack(model, x, y, 4/255)
        adv_out=model(adv)
        l+=loss(adv_out, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
print("Final clean accuracy: ", get_accuracy(model, testloader, device))
print("Final adversarial accuracy: ", get_adversarial_accuracy(model, testloader,partial(fgsm_attack, epsilon=4/255), device))

# Attention

In [None]:
import torch
import math

class Attention(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.d = output_dim
        self.q_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.k_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.v_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.out_proj = torch.nn.Linear(output_dim, output_dim, bias=True)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        weights = torch.nn.functional.softmax(torch.einsum('btd,bTd->btT', q, k)/math.sqrt(self.d), dim=-1)
        #weights = torch.nn.functional.softmax(torch.bmm(q, k.permute(0,2,1))/math.sqrt(self.d), dim=-1)
        out = torch.bmm(weights, v)
        return self.out_proj(out)
attn = Attention(128, 128)
x=torch.randn(8, 10, 128)
myattn = attn(x)

y = x.permute(1, 0, 2)
mha2_out = torch.nn.functional.multi_head_attention_forward(
    query=y,
    key=y,
    value=y,
    embed_dim_to_check=128,
    num_heads=1,
    use_separate_proj_weight=True,
    in_proj_weight=None,
    in_proj_bias=None,
    bias_k=None,
    bias_v=None,
    need_weights=False,
    q_proj_weight=attn.q_proj.weight,
    k_proj_weight=attn.k_proj.weight,
    v_proj_weight=attn.v_proj.weight,
    out_proj_bias=attn.out_proj.bias,
    out_proj_weight=attn.out_proj.weight,
    add_zero_attn=False,
    dropout_p=0,
)[0].permute(1, 0, 2)
print(torch.allclose(mha2_out, myattn, atol=1e-6))



In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, input_dim, output_dim, num_heads):
        super().__init__()
        self.d = output_dim
        self.h = num_heads
        assert self.d % self.h == 0, "Output dimension must be divisible by the number of heads"
        self.q_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.k_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.v_proj = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.out_proj = torch.nn.Linear(output_dim, output_dim, bias=True)

    def forward(self, x):
        B, T, input_dim = x.shape
        H = self.h
        D = self.d
        q = self.q_proj(x).reshape(B, T, H, D//H)
        k = self.k_proj(x).reshape(B, T, H, D//H)
        v = self.v_proj(x).reshape(B, T, H, D//H)

        weights = torch.einsum('bthd,bThd->bhtT', q, k)/math.sqrt(D//H)
        weights = torch.nn.functional.softmax(weights, dim=-1)
        out = torch.einsum('bhtT,bThd->bthd', weights, v).reshape(B, T, D)
        return self.out_proj(out)
attn = MultiHeadAttention(128, 128, 2)
x=torch.randn(8, 10, 128)
myattn = attn(x)

y = x.permute(1, 0, 2)
mha2_out = torch.nn.functional.multi_head_attention_forward(
    query=y,
    key=y,
    value=y,
    embed_dim_to_check=128,
    num_heads=2,
    use_separate_proj_weight=True,
    in_proj_weight=None,
    in_proj_bias=None,
    bias_k=None,
    bias_v=None,
    need_weights=False,
    q_proj_weight=attn.q_proj.weight,
    k_proj_weight=attn.k_proj.weight,
    v_proj_weight=attn.v_proj.weight,
    out_proj_bias=attn.out_proj.bias,
    out_proj_weight=attn.out_proj.weight,
    add_zero_attn=False,
    dropout_p=0,
)[0].permute(1, 0, 2)
print(myattn.shape)
print(torch.allclose(mha2_out, myattn, atol=1e-7))

# Homework (optional)

TODO