<a href="https://colab.research.google.com/github/erodola/DLAI-s2-2023/blob/main/labs/09/CycleGAN_and_Adversarial_Attacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning & Applied AI

We recommend going through the notebook using Google Colaboratory.

# Tutorial 9: CycleGAN and Adversarial Attacks


In this tutorial, we will cover:

- GAN, cGAN and CycleGAN
- Adversarial Attacks

Based on original material by Dr. Luca Moschella, Dr. Antonio Norelli and Dr. Marco Fumero.

Course:

- Website and notebooks will be available at https://github.com/erodola/DLAI-s2-2025/

In [None]:
!pip install wandb
!pip install pytorch-lightning==1.8.0
!pip install pyyaml==5.4.1

In [None]:
# @title import dependencies

from typing import Sequence, List, Dict, Tuple, Optional, Any, Set, Union, Callable, Mapping
import itertools

import dataclasses
from dataclasses import dataclass
from dataclasses import asdict
from pathlib import Path
from pprint import pprint
from urllib.request import urlopen
import random

from PIL import Image
import PIL

import torchvision.utils
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
import torch.nn.functional as F

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm


In [None]:
# @title reproducibility stuff

import random
np.random.seed(0)
random.seed(0)

torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False

# This prevoius seeds should be redundant.
_ = pl.seed_everything(0)

> 🚜🚜🚜 Since the download we will perform later on is quite slow, it may be convenient to start the download by executing the "Download datasets" cell as soon as possible

# Generative Adversarial Networks

The Generative Adversarial Networks (GANs) are based on a game theoretic scenario.

- The **discriminator** tries its best to discriminate *real* samples drawn from the training data and *fake* samples drawn from the generator.

- The **generator** does its best to trick the discriminator and produce a fake sample, that is wrongly recognized as a real one.

The generator network directly produces samples $\mathbf{x}= g(\mathbf{z}; \mathbf{\theta}^{(g)})$. The discriminator emits a probability value given by $d(\mathbf{x};\mathbf{\theta}^{(d)})$, indicating the probability that $\mathbf{x}$ is a real training example rather than a fake sample drawn from the generator network.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/gan.png)


## Learning formulation

The simplest way to formulate learning in generative adversarial networks is a zero-sum game.

We choose a function $v(\mathbf{\theta}^{(g)}, \mathbf{\theta}^{(d)})$ as the reward for the the discriminator and $-v(\mathbf{\theta}^{(g)}, \mathbf{\theta}^{(d)})$ as the generator reward. They both want to maximise their reward!

This can be phrased mathematically:

$$
g^* = \min_g \max_d v(g, d)
$$

The default choice for $v$ is:

$$
v(\mathbf{\theta}^{(g)}, \mathbf{\theta}^{(d)})
=
\mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} \log d(\mathbf{x})
+
\mathbb{E}_{\mathbf{x} \sim p_{\text{model}}} \log (1 - d(\mathbf{x}))
$$

This drives the discriminator to attempt learning to correctly classify samples as real or fake.
Simultaneously, the generator attempts to fool the classifier into believing its samples are real! Since the discriminator tries to maximise and the generator to minimise.

At convergence, the generator's samples are ingistinguishable from real data and the discriminator outputs $d(\mathbf{x}) = \frac{1}{2}$ everywhere.
This means that the generator is able to produce data that lies in the same distribution of the training data.


In the slides you can find the closed form derivation for the previous formulation.


## The Generator

Let's look with more attention at what the generator is doing.

It takes in input *noise* and produces an image from the desired distribution.
The noise is usually drawn from a gaussian distribution:

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/generator.png)


Intuitively, it is learning how to translate samples drawn from a normal distribution to samples drawn from the desired distribution (e.g. the distribution of faces).
It translates one distribution to another.

Why should we limit the generator to translate only the plain normal distribution to other distributions? Can't we condition the starting distribution somehow? Can't we translate arbitrary distributions from one into another one?
Welcome to the Conditional Adversarial Networks (cGAN).



---

References:

- Deep Learning, Goodfellow et al, 2016. Section 20.10.4

Image credits:

- [Thalles Silva](https://www.freecodecamp.org/news/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394/)

- [Pankaj Kishore](https://towardsdatascience.com/art-of-generative-adversarial-networks-gan-62e96a21bc35)

## Conditional Adversarial Network

The normal GANs learn a mapping from a random noise vector $z$ to output image $y$, $G: z \to y$. The conditional GANs proposed in (Isola et al.) learns a mapping from an observed image $x$ *and* a random noise vector $z$ to $y$, $G: \{x,  z\} \to y$.

> Interestingly, in (Isola et al.) the noise vector $z$ is provided only in the form of dropout (!), applied on several layers of the generator both at training and test time. Despite the dropout at test time, they observe only minor stochasticity in the output of the net.

For example, this is the high level architecture of a cGAN that maps edges $\to$ photos:

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/cGAN.png)

The discriminator $D$ learns to classify between fake and real $\{\text{edge},\text{photo}\}$ tuples. The generator $G$ learns to fool the discriminator. Unlike an uncoditional GAN, **both the generator and discriminator observe the input edge map**!


## Learning formulation

The objective of a conditional GAN can be expressed as:

$$
\mathcal{L}_{cGAN}(G, D)
=
\mathbb{E}_{x, y} \log D(x, y)
+
\mathbb{E}_{x, z} \log (1 - D(x, G(x, z)))
$$

where $G$ tries to minimize this objective against an adversarial $D$ that tries to maximise it:

$$
G^* = \arg \min_G \max_D \mathcal{L}_{cGAN}(G, D)
$$


This technique is very general and is able to tackle many different tasks!

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/pix2pix.png)

Take a look at the paper to see more detailed performance.

## The need of paired data

Let's think a bit about the data required to implement this architecture.
Both the generator and discriminator observe the conditioning image. The discriminator observes tuples of conditioning image and real image, otherwise it has no way to understand the correlation between the two images and make a sensible decision.

What does it mean? **To implement this architecture we need paired data**!
i.e. we need tuples of $(\text{conditioning image}, \text{real image})$.

It may be possible to get this data for some tasks. For example, if we want to translate a photo taken during the day to a overnight photo the needed data is possible to acquire. Maybe it is expensive, but it is enough to take the same photo at different hours of the day.

But what if we want to translate our photo, that lies on the distribution of realistic images, to the distribution where the Van Gogh paintings lie? i.e. we want to transform our photo to a Van Gogh painting.
How can we acquire enough paired data? Is it even possible?

It turns out that it is possible to solve this styling task, the previous tasks you just saw and many others... **without paired data**. In the next sections we are going to go deep and implement this technique that exploits unpaired data: **CycleGAN**.

---

References:

- Isola et al. [“Image-to-Image Translation with Conditional Adversarial Networks.”](http://arxiv.org/abs/1611.07004)

## CycleGAN: Unpaired Image-to-Image Translation

CycleGAN [(Zhu et al)](https://arxiv.org/abs/1703.10593) is able to **automatically translate** an image from one set of images into another and viceversa.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/cycleGAN.png)

Moreover, it does so **without the need of paired data**!

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/unpaired.png)

As already said, paired training data consists of samples
$\{x_i, y_i\}^N$
where the corrispondence between $x_i$ and $y_i$ exists.
The *unpaired* data consist of a source set $\{x_i\}^N (x_i \in X)$ and a target set $\{y_j\}^N (y_j \in Y)$... there is no information as to which $x_i$ matches which $y_i$!


## Learning formulation

The goal is to learn mapping functions between two domains $X$ and $Y$ given training examples $\{x_i\}^N$ where $x_i \in X$ and $\{y_j\}^N$ where $y_j \in Y$. Let's denote the data distribution as $x \sim p_{data}(x)$  and $y \sim p_{data}(y)$.


CycleGAN includes two mappings $G: X \to Y$ and $F: Y \to X$, moreover, both $G$ and $F$ have an adversarial discriminator $D_X$ and $D_Y$.

The role of these adversarial discriminators is to distinguish generated *fake* images from *real* images: $D_X$ aims to distinguish between images $\{x\}$ and translated images $\{F(y)\}$; $D_Y$ aims to discriminate between $\{y\}$ and $\{G(x)\}$

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/cycle.png)

The objective contains two different terms:

- A standard *adversarial loss*, two in total, one for each GAN
- A *cycle consistency loss* to prevent incosistencies between the mappings $G$ and $F$.

### Adversarial losses
The adversarial losses are very similar to the conditional GANs, but:

- The data is unpaired, thus the discriminator doesn't have any image pairs to look at.
- There is no noise vector $z$, not even in the form of dropout.

For the mapping $G: X \to Y$ and associated discriminator $D_Y$ we have the adversarial objective:


$$
\mathcal{L}_{GAN}(G, D_Y, X, Y)
=
\mathbb{E}_{y \sim p_{data}(y)} \log D_Y(y)
+
\mathbb{E}_{x \sim p_{data}(x)} \log (1 - D_Y(G(x))
$$

where again the generator tries to minimize the objective against and adversary $D$ that tries to maximize it: $\min_G \max_{D_Y} \mathcal{L}_{GAN}(G, D_Y, X, Y)$

Equivalently for the mapping $F: Y \to X$ and associated discriminator $D_X$ we have the adversarial objective:

$$
\mathcal{L}_{GAN}(F, D_X, Y, X)
=
\mathbb{E}_{x \sim p_{data}(x)} \log D_X(x)
+
\mathbb{E}_{y \sim p_{data}(y)} \log (1 - D_X(F(y))
$$

with: $\min_F \max_{D_X} \mathcal{L}_{GAN}(F, D_X, Y, X)$


### Cycle consistency

Adversarial losses alone cannot guarantee that the learned function maps a given $x_i$ to a desired output $y_i$. To force the semantically "correct" mapping, it is enforced the cycle consistency.

Intuitively we want that $x \approx F(G(x))$ and $y \approx G(F(y))$.


![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/cycle2.png)

In practice this can be enforced using this objective:

$$
\mathcal{L}_{cyc}(G, F) =
\mathbb{E}_{x \sim p_{data}(x)}|| F(G(x)) - x ||_1
+
\mathbb{E}_{y \sim p_{data}(y)} || G(F(y)) - y ||_1
$$


### Identity consistency

In some applications (Zhu et al) found useful to enforce a form of *identity consistency*. Meaning that the mapping $G: X \to Y$ if given $y\in Y$ in input should produce the same $y$. Similarly for the mapping $F$.

This can be enforce with the objective:

$$
\mathcal{L}_{\text{identity}}(G, F) =
\mathbb{E}_{y \sim p_{data}(y)}|| G(y) - y ||_1
+
\mathbb{E}_{x \sim p_{data}(x)} || F(x) - x ||_1
$$


### Full Objective

The full objective is given by:

$$
\mathcal{L}(G, F, D_X, D_Y) =
\mathcal{L}_{GAN}(G, D_Y, X, Y)
+
\mathcal{L}_{GAN}(F, D_X, Y, X)
+
\lambda\mathcal{L}_{cyc}(G, F)
+
\beta\mathcal{L}_{\text{identity}}(G, F)
$$

The goal is to solve:

$$
G^*, F^* =
\arg
\min_{G, F}
\max_{D_X, D_Y}
\mathcal{L}(G, F, D_X, D_Y)
$$

> **QUESTION:** What could happen if we minimize only the $\mathcal{L}_{cyc}$ and $\mathcal{L}_{\text{identity}}$ terms, so removing the $\mathcal{L}_{GAN}$ ones?

### Implementation

This is a brief overview of all the implementation and training tricks in the paper that we are going to implement.

#### 1. Architecture

Regarding the architecture, keep in mind that it is not required to undersand every detail of the generators and discriminators to get the main idea of CycleGAN!

The technique proposed is general and can be applied to a variety of types of generators and discriminators.


#### 2. Loss stability

(Zhu et al.) **replaces all the negative log likelihood objectives with least-squares losses** to improve stability.

In particular for a GAN loss $\mathcal{L}_{GAN}(G, D, X, Y)$ they train:

- $G$ to minimize
  $$\mathbb{E}_{x \sim p_{data}(x)} (D(G(x)) - 1)^2$$

- $D$ to minimize
  $$
  \mathbb{E}_{y \sim p_{data}(y)} (D(y) - 1)^2
  +
  \mathbb{E}_{x \sim p_{data}(x)} D(G(x))^2
  $$


#### 3. Model oscillations

They don't update the discriminator using the images generated by the latest generator. Instead, they keep a history of generated images and randomly swap some images generated by the latest generator with the old ones.
This buffer has length 50.


#### 4. Weights initialization

All the weights are initialized from a Gaussian distribution $\mathcal{N}(0, 0.02)$


#### 5. Convergence speed

They divide the objective by two while optimizing the discriminator, in order to slow down the rate at which $D$ learns relative to the rate of $G$.

#### 6. Learning rate decay

They use a learning rate scheduler. It keeps the same learning rate for the first 100 epochs, then, linearly decay the rate to zero over the next 100 epochs.


---

References:

- Zhu et al. ["Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."](https://arxiv.org/abs/1703.10593)


Implementation inspired by:

- [PyTorch GAN](https://github.com/eriklindernoren/PyTorch-GAN/). Several components of the architecture comes from here.

- Official [paper repository](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)


## PyTorch Lightning


Since the code for this example will be fairly long, we are going to use [PyTorch Lightning](https://www.pytorchlightning.ai/) to keep the code clean and organized.

Lightning is a way to organize your PyTorch code to **decouple the science code from the engineering**. It's more of a PyTorch style-guide than a framework.

In Lightning, you organize your code into 3 distinct categories:

- Research code (goes in the LightningModule).
- Engineering code (you delete, and is handled by the Trainer).
- Non-essential research code (logging, etc... this goes in Callbacks).

Here's an example of how to refactor your research code into a [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html):

![](https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_static/images/general/pl_quick_start_full_compressed.gif)

---

References:

- Description from the [Lightning docs](https://pytorch-lightning.readthedocs.io/en/latest/)

## Weights and Biases

We will use Weights and Biases to log metrics and sample images in this section.

As you will see, it is nicely integrated in PyTorch Lightning and requires minimal effort to set it up. Remember that even if Weights and Biases is free for academic use, there are many other loggers you can use (even integrated in [Lightning](https://pytorch-lightning.readthedocs.io/en/latest/api_references.html#loggers))

In [None]:
!wandb login

## Download datasets

The official repository of (Zhu et al) provides scripts to download the datasets they used in their experiments.

We are going to use those scripts to download the `maps` and `ukiyoe2photo` dataset. We will not import any python code from the repository.


All the dataset are composed of (at least) four folders:

- `trainA`: train images from distribution A
- `trainB`: train images from distribution B
- `testA`: test images from distribution A
- `testB`: test images from distribution B

In [None]:
# Slow download... :)
!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git
!cd pytorch-CycleGAN-and-pix2pix; ./datasets/download_cyclegan_dataset.sh ukiyoe2photo
!cd pytorch-CycleGAN-and-pix2pix; ./datasets/download_cyclegan_dataset.sh maps

> **How to read code**
>
> Some code cells will be quite long, don't be scared.
> There won't be line-by-line comments or extended explanations in markdown, since the code is self-documenting.
>
> The code is documented and typed.
> 1. Read the docstrings and look at the types to get a rough idea of *what* a function does.
> 2. Look at the variable names that are usually meaningful enough to understand *what* the code is doing at a finer level.
> 3. Only after that, read the whole code to understand the details of *how* it does that.
>
> *Curiosity*: probably we will have [statically checked tensor shapes](https://github.com/pytorch/pytorch/issues/26889) (with generic types) in PyTorch! That will be a huge improvement in code readability :]

## Dataset definition

Let's define the `Dataset` to read the data we just downloaded.

Keep in mind that this is an `unpaired` task, thus we do not have $(a_i, b_i)$ couples. The $i$-sample will be made of $x_i$ and a random $y$.

To ease the visualization and see how the image generation evolves with time, we add a parameter `fixed_pairs` that if set to `True` returns always the same couple.

In [None]:
class DatasetUnpaired(Dataset):

    def __init__(self,
                 folderA: Path,
                 folderB: Path,
                 transform: Optional[Callable] = None,
                 fixed_pairs: bool = False,
        ) -> None:
        """
        Dataset to handle unpaired images, i.e. the number of images in folderA
        and in folderB may be different.

        :param folderA: path to the folder that contains the A images
        :param folderB: path to the folder that contains the B images
        :param tranform: tranform to apply to the images
        """
        super().__init__()
        self.folderA: Path = Path(folderA)
        self.folderB: Path = Path(folderB)

        if not (folderA.is_dir() and folderB.is_dir()):
            raise RuntimeError(f"The folders are not valid!\n\t- Folder A: {folderA}\n\t- Folder B: {folderB}")

        self.filesA: List[Path] = list(sorted(folderA.rglob('*.jpg')))
        self.filesB: List[Path] = list(sorted(folderB.rglob('*.jpg')))

        if not self.filesA:
            raise RuntimeError("Empty image lists for folderA!")

        if not self.filesB:
            raise RuntimeError("Empty image lists for folderB!")

        self.filesA_num: int = len(self.filesA)
        self.filesB_num: int = len(self.filesB)

        self.transform: Optional[Callable] = transform
        self.fixed_pairs: bool = fixed_pairs

    def __len__(self) -> int:
        """
        Since it is unpaired, it is not well defined.
        We will use the maximum number of images between folderA and folderB

        :returns: maximum number between #imagesA and #imagesB
        """
        return max(self.filesA_num, self.filesB_num)

    def pil_loader(self, path: Path) -> PIL.Image:
        """ PIL loader implementation from the Pytorch's ImageDataset class
        https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder

        :param path: the path to an image
        :returns: an image converted into RGB format
        """
        # open path as file to avoid ResourceWarning
        # (https://github.com/python-pillow/Pillow/issues/835)
        with path.open('rb') as f:
            img = PIL.Image.open(f)
            return img.convert('RGB')

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        """
        Return a random sample imageA-imageB

        :param index: index of the sample (not relevant)
        :returns: a dictionary containing:
                    - A: the imageA
                    - B: the imageB
                    - pathA: the path to the imageA
                    - pathB: the path to the imageB
        """

        # Enforce a valid index for `filesA`
        fileA = self.filesA[index % self.filesA_num]

        if self.fixed_pairs:
            # When e.g. testing use reproducible samples
            fileB = self.filesB[index % self.filesB_num]

        else:
            # When training, get a random image from filesB
            fileB = self.filesB[random.randint(0, self.filesB_num - 1)]

        imageA = self.pil_loader(fileA)

        imageB = self.pil_loader(fileB)

        if self.transform is not None:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)

        return {
            'A': imageA,
            'A_path': str(fileA),
            'B': imageB,
            'B_path': str(fileB),
        }

## Data selection

The official CycleGAN repository provides many different datasets, to showcase different applications.

In this tutorial you can choose to use the `ukiyo-e` or the `maps` dataset.

*The choice you make now will persist for the rest of the notebook!*

In [None]:
#@title visualization utility functions

def plot_images(images,
                images_per_row: int,
                border: float = 3.,
                pad_value: float = 1,
                title = 'Some images'):
    """
    Visualize many images in a nice grid

    :param images: the images to visualize
    :param images_per_row: number of images per row
    :param border: the border size of the grid
    :param pad_value: border color
    :param title: the title of plot
    """
    # Matplolib plot, much faster for static images
    # First visualise the original images
    plt.figure(figsize = (17, 17))
    plt.imshow(torchvision.utils.make_grid(images,images_per_row,border,pad_value=pad_value).permute(1, 2, 0))
    plt.title(title)
    plt.axis('off')

In [None]:
# @title select the dataset to use  { run: "auto" }
dataset_name = "maps"  # @param ["maps", "ukiyoe2photo"]

trainA = Path(f"pytorch-CycleGAN-and-pix2pix/datasets/{dataset_name}/trainA")
trainB = Path(f"pytorch-CycleGAN-and-pix2pix/datasets/{dataset_name}/trainB")
testA = Path(f"pytorch-CycleGAN-and-pix2pix/datasets/{dataset_name}/testA")
testB = Path(f"pytorch-CycleGAN-and-pix2pix/datasets/{dataset_name}/testB")


visualize_batch_idx = 3  # @param {type:"slider", min:1, max:50, step:1}


import torchvision.utils
import matplotlib.pyplot as plt


loader = DataLoader(
    DatasetUnpaired(
        testA,
        testB,
        transform=transforms.Compose([transforms.ToTensor()]),
        fixed_pairs=True,
    ),
    batch_size=10,
    shuffle=False,
)

load_iter = iter(loader)

# ugly :]
for _ in range(visualize_batch_idx):
    batch = next(load_iter)


plot_images(batch["A"], images_per_row=5, border=30, pad_value=1, title="A images")
plot_images(batch["B"], images_per_row=5, border=30, pad_value=1, title="B images")

## Hyperparameters

Let's define the hyperparameters that we are going to use

In [None]:
# The dataclass are fancy classes to hold data

# Working with dataclasses is particularly comfortable
# since you can specify types and get autocomplete/suggestion
# of the available hyperparameters

@dataclass
class Config:
    #dataset_name: str = "ukiyoe2photo"  # name of the dataset

    # Run we did:
    # map: 200 epochs, 100 decay
    # ukiyo: 40 epochs, 20 decay (due to time contraints)
    # They took ~7 hours each on a 2080ti
    n_epochs: int = 200  # number of epochs of training
    decay_epoch: int = 100  # epoch from which to start lr decay

    img_height: int = 128  # size of image height # default 256x256
    img_width: int = 128  # size of image width

    batch_size: int = 1  # size of the batches
    lr: float = 0.0002  # adam: learning rate
    b1: float = 0.5  # adam: decay of first order momentum of gradient
    b2: float = 0.999  # adam: decay of first order momentum of gradient

    channels: int = 3  # number of image channels
    n_residual_blocks: int = 6  # number of residual blocks in generator # original 9
    lambda_cyc: float = 10.0  # cycle loss weight
    lambda_id: float = 5.0  # identity loss weight

    n_cpu: int = 8  # number of cpu threads to use for the dataloaders

    log_images: int = min(25, 100)  # number of images to log


cfg = Config()
pprint(asdict(cfg))

In [None]:
# Hyperparameters are just attributes of an object
cfg.batch_size

## Model sub-components

### Residual block

This is a generic residual block. We already saw residual networks in the Convolutional Neural Networks lecture.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features: int) -> None:
        """
        A generic residual block.

        The input is transformed by a block,
        then, the transformation is summed up to the original input

        :param in_features: number of input features
        """
        super().__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: tensor with shape [batch, channels, w, h]

        :returns: tensor with shape [batch, channels, w, h]
        """
        return x + self.block(x)


In [None]:
# Example of inner working
res = ResidualBlock(in_features=3)

batch = torch.rand(10, 3, 128, 128)
batch.shape

In [None]:
res(batch).shape

### Generator

The image-conditioned generator we are going to use.
At high level it performs an image downsampling and then an upsampling.

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape: Sequence[int], num_residual_blocks: int) -> None:
        """
        Image-conditioned image generator.

        It takes in input an image and produces another image.

        :param input_shape: shape of expected input image
        :param num_residual_blocks: number of residual blocks to use
        """
        super().__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(out_features, channels, 7),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: tensor with shape [batch, channels, w, h]

        :returns: tensor with shape [batch, channels, w, h]
        """
        return self.model(x)


In [None]:
# Example of inner working
g = GeneratorResNet(input_shape=[3, 50, 50], num_residual_blocks=6)

In [None]:
batch = torch.rand(2, 3, 128, 128)
batch.shape

In [None]:
g(batch).shape

### Discriminator

The discriminator we are going to use.

It tries to predict $1$ for *real* images and $0$ for *fake* images.
In practice it does not try to predict a single $1$ or $0$ but a $3\times 3$ matrix of ones or zeros.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape: Sequence[int]) -> None:
        """
        Discriminator that tries to infer if an image is:
        - fake, i.e. it has been generated by a generator
        - real, i.e. it has not been generated by a generator

        :param input_shape: shape of the expected image
        """
        super().__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(
            in_filters: int, out_filters: int, normalize: bool = True
        ) -> Sequence[nn.Module]:
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img: torch.tensor) -> torch.Tensor:
        """
        :param img: tensor with shape [batch, channels, w, h]

        :returns: tensor with shape [batch, 1, 3, 3]
        """
        return self.model(img)


In [None]:
# Example of inner working
d = Discriminator(input_shape=[3, 50, 50])

In [None]:
batch = torch.rand(2, 3, 50, 50)
batch.shape

In [None]:
# Single output channel!
# The aim of the discriminator is to predict all ones if the image is real
# and all zeros if the the image is fake.
d(batch).shape

### Image Buffer

This is the training trick they use to increase the robustness and reduce the model oscillation.

They don't update the discriminator using the images generated by the latest generator. Instead, they keep a history of generated images and randomly swap some images generated by the latest generator with the old ones.
This buffer has length 50.



In [None]:
class ReplayBuffer:
    def __init__(self, max_size: int = 50) -> None:
        """
        Image buffer to increase the robustness of the generator.

        Once it is full, i.e. it contains max_size images, each image in a given batch
        is swapped with probability p=0.5 with another one contained in the buffer.

        """
        assert (
            max_size > 0
        ), "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data: torch.Tensor) -> torch.Tensor:
        """
        Fill the buffer with each element in data.
        If the buffer is full, with p=0.5 swap each element in data with
        another in the buffer.

        :param data: tensor with shape [batch, ...]

        :returns: tensor with shape [batch, ...]
        """
        to_return = []

        for i in range(data.shape[0]):
            element = data[[i], ...]

            if len(self.data) < self.max_size:
                self.data.append(element)

            elif random.uniform(0, 1) > 0.5:
                i = random.randint(0, self.max_size - 1)
                self.data[i], element = element, self.data[i]

            to_return.append(element)

        return torch.cat(to_return)


In [None]:
# Example of inner working
b = ReplayBuffer(max_size=5)
batch_s = 0
batch_size = 5
batch_e = batch_s + batch_size

In [None]:
# Execute multiple times this cell!
a = torch.arange(batch_s, batch_e)[..., None]
batch_s = batch_e
batch_e = batch_s + batch_size
batch = b.push_and_pop(a)
print(f"Input batch:\n{a}\n\nOutput batch:\n{batch}\n\nHidden buffer state:\n{b.data}")

### LR Scheduler


They use a learning rate scheduler. It keeps the same learning rate for the first 100 epochs, then, linearly decay the rate to zero over the next 100 epochs.

This is a parametric implementation of this idea, where it is possible to specify the total number of epochs and when to start the linear decay.

In [None]:
class LambdaLR:
    def __init__(self, n_epochs: int, decay_start_epoch: int) -> None:
        """
        Linearly decay the leraning rate to 0, starting from `decay_start_epoch`
        to the final epoch.

        In practice

        :param n_epochs: total number of epochs
        :param decay_start_epoch: epoch in which the learning rate starts to decay
        """
        assert (
            n_epochs - decay_start_epoch
        ) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch: int) -> float:
        """
        One step of lr decay:
        - if `epoch < self.decay_start_epoch` it doesn't change the learning rate.
        - Otherwise, it linearly decay the lr to reach zero

        :param epoch: current epoch
        :returns: learning rate multiplicative factor
        """
        return 1.0 - max(0, epoch - self.decay_start_epoch) / (
            self.n_epochs - self.decay_start_epoch
        )

In [None]:
# Example of the inner workings
n_epochs = 10
decay_from = 3
lr = LambdaLR(n_epochs, decay_from)
for i in range(n_epochs + 1):
    if i == decay_from:
        print("\tStarting to decay")
    print(lr.step(i))
    if i == n_epochs:
        print("\tEnd of the decay")


## CycleGAN lightning module

This is the main model. It encapsulates all the logic into a clear and well defined framework, as defined by Lightning.

The main methods of every Lightning model are:

- `train_dataloader` and `val_dataloader`: defines the dataloader for the train and test set

- `configure_optimizers`: configure optimizers and schedulers. For each couple (optimizer, scheduler) there will be one call to `training_step` with the appropriate `optimizer_idx` to identify the optimizer.

- `training_step`: defines what happens in a single training step

- `validation_step`: defines what happens in a single validation step

- `validation_epoch_end`: receive in input an aggregation of all the output of the `validation_step`. It is useful to compute metrics and log examples.

In [None]:
class CycleGAN(pl.LightningModule):
    def __init__(
        self,
        hparams: Union[Dict, Config],
        trainA_folder: Path,
        trainB_folder: Path,
        testA_folder: Path,
        testB_folder: Path,
    ) -> None:
        """
        The CycleGAN model.

        :param hparams: dictionary that contains all the hyperparameters
        :param trainA_folder: Path to the folder that contains the trainA images
        :param trainB_folder: Path to the folder that contains the trainB images
        :param testA_folder: Path to the folder that contains the testA images
        :param testB_folder: Path to the folder that contains the testB images
        """
        super().__init__()
        self.save_hyperparameters(asdict(hparams) if not isinstance(hparams, Mapping) else hparams)

        # Dataset paths
        self.trainA_folder = trainA_folder
        self.trainB_folder = trainB_folder
        self.testA_folder = testA_folder
        self.testB_folder = testB_folder

        # Expected image shape
        self.input_shape = (self.hparams["channels"], self.hparams["img_height"], self.hparams["img_width"])

        # Generators A->B and B->A
        self.G_AB = GeneratorResNet(self.input_shape, self.hparams["n_residual_blocks"])
        self.G_BA = GeneratorResNet(self.input_shape, self.hparams["n_residual_blocks"])

        # Discriminators
        self.D_A = Discriminator(self.input_shape)
        self.D_B = Discriminator(self.input_shape)

        # Initialize weights
        # https://pytorch.org/docs/stable/nn.html?highlight=nn%20module%20apply#torch.nn.Module.apply
        self.G_AB.apply(self.weights_init_normal)
        self.G_BA.apply(self.weights_init_normal)
        self.D_A.apply(self.weights_init_normal)
        self.D_B.apply(self.weights_init_normal)

        # Image Normalizations
        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(int(self.hparams["img_height"] * 1.12), Image.BICUBIC),
                transforms.RandomCrop((self.hparams["img_height"], self.hparams["img_width"])),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Image Normalization for the validation: remove source of randomness
        self.val_image_transforms = transforms.Compose(
            [
                transforms.Resize(int(self.hparams["img_height"] * 1.12), Image.BICUBIC),
                transforms.CenterCrop((self.hparams["img_height"], self.hparams["img_width"])),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Image buffers
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Forward pass cache to avoid re-doing some computation
        self.fake_A = None
        self.fake_B = None

        # Losses
        self.mse = torch.nn.MSELoss()
        self.l1 = torch.nn.L1Loss()

        # Ignore this.
        # It avoids wandb logging when lighting does a sanity check on the validation
        self.is_sanity = True

    def forward(self, x: torch.Tensor, a_to_b: bool) -> torch.Tensor:
        """
        Forward pass for this model.

        This is not used while training!

        :param x: input of the forward pass with shape [batch, channel, w, h]
        :param a_to_b: if True uses the mapping A->B, otherwise uses B->A

        :returns: the translated image with shape [batch, channel, w, h]
        """
        if a_to_b:
            return self.G_AB(x)
        else:
            return self.G_BA(x)

    def weights_init_normal(self, m: nn.Module) -> None:
        """
        Initialize the weights with a gaussian N(0, 0.02) as described in the paper.

        :param m: the module that contains the weights to initialise
        """
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
            if hasattr(m, "bias") and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    def train_dataloader(self) -> DataLoader:
        """ Create the train set DataLoader

        :returns: the train set DataLoader
        """
        train_loader = DataLoader(
            DatasetUnpaired(
                self.trainA_folder, self.trainB_folder, transform=self.image_transforms
            ),
            batch_size=self.hparams["batch_size"],
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        return train_loader

    def val_dataloader(self, custom_batch_size: Optional[int] = None) -> DataLoader:
        """ Create the validation set DataLoader.

        It is deterministic.
        It does not shuffle and does not use random transformation on each image.

        :returns: the validation set DataLoader
        """
        test_loader = DataLoader(
            DatasetUnpaired(
                self.testA_folder,
                self.testB_folder,
                transform=self.val_image_transforms,
                fixed_pairs=True,
            ),
            batch_size=custom_batch_size if custom_batch_size is not None else 32,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
        )
        return test_loader

    def configure_optimizers(
        self,
    ) -> Tuple[Sequence[optim.Optimizer], Sequence[Dict[str, Any]]]:
        """ Instantiate the optimizers and schedulers.

        We have three optimizers (and relative schedulers):

        - Optimizer with index 0: optimizes the parameters of both generators
        - Optimizer with index 1: optimizes the parameters of D_A
        - Optimizer with index 2: optimizes the parameters of D_B

        Each scheduler implements a linear decay to 0 after `cfg.hparams["decay_epoch"]`

        :returns: the optimizers and relative schedulers (look at the return type!)
        """
        # Optimizers
        optimizer_G = torch.optim.Adam(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
            lr=self.hparams["lr"],
            betas=(self.hparams["b1"], self.hparams["b2"]),
        )
        optimizer_D_A = torch.optim.Adam(
            self.D_A.parameters(), lr=self.hparams["lr"], betas=(self.hparams["b1"], self.hparams["b2"])
        )
        optimizer_D_B = torch.optim.Adam(
            self.D_B.parameters(), lr=self.hparams["lr"], betas=(self.hparams["b1"], self.hparams["b2"])
        )

        # Schedulers for each optimizers
        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            optimizer_G,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_A,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_B,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )

        return (
            [optimizer_G, optimizer_D_A, optimizer_D_B],
            [
                {"scheduler": lr_scheduler_G, "interval": "epoch", "frequency": 1},
                {"scheduler": lr_scheduler_D_A, "interval": "epoch", "frequency": 1},
                {"scheduler": lr_scheduler_D_B, "interval": "epoch", "frequency": 1},
            ],
        )

    def criterion_GAN(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for GAN losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the mse between x and y
        """
        return self.mse(x, y)

    def criterion_cycle(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for Cycle losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the l1 between x and y
        """
        return self.l1(x, y)

    def criterion_identity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for Identity losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the l1 between x and y
        """
        return self.l1(x, y)

    def identity_loss(self, image: torch.Tensor, generator: nn.Module) -> torch.Tensor:
        """ Implements the identity loss for the given generator

        :param generator: a generator module that maps X -> Y
        :param image: an image in the Y distribution with shape [batch, channel, w, h]

        :returns: the identity loss for these (generator, image)
        """
        return self.criterion_identity(generator(image), image)

    def gan_loss(
        self,
        generator: nn.Module,
        discriminator: nn.Module,
        image: torch.Tensor,
        expected_label: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the GAN loss for the given generator and discriminator

        :param image: the input image with shape [batch, channle, w, h]
        :param generator: the generator module to use to translate the image from X -> Y
        :param discriminator: the discriminator that tries to distinguish fake and real images
        :expected_label: tensor with shape compatible to the discriminator's output.
                         It is full of ones when training the generator. We feed a fake
                         image to the discriminator and we expect to get ones
                         (for the discriminator this is an error!)

        :returns: the GAN loss for these (image, generator, discriminator)
        """
        fake_image = generator(image)
        predicted_label = discriminator(fake_image)
        loss_GAN = self.criterion_GAN(predicted_label, expected_label)
        return loss_GAN, fake_image

    def cycle_loss(
        self,
        fake_image: torch.Tensor,
        reverse_generator: nn.Module,
        original_image: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the cycle consistency loss

        It takes in input a fake image, to avoid repeated computation,
        thus it only needs the reverse mapping that produced that fake image.

        :param fake_image: a image produced by a mapping X->Y with shape [batch, channel, w, h]
        :param reverse_generator: the generator module that maps Y->X
        :param original_image: the original image in X with shape [batch, channel, w, h]
                               to compare with the reconstructed fake image

        :returns: the cycle consistency loss for this (fake_image, reverse_generator, original_image)
        """
        recovered_image = reverse_generator(fake_image)
        return self.criterion_cycle(recovered_image, original_image)

    def discriminator_loss(
        self,
        discriminator: nn.Module,
        proposed_image: torch.Tensor,
        expected_label: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the loss used to train the discriminator

        :param discriminator: the discriminator model to train
        :param proposed_image: the fake or real image proposed with shape [batch, channel, w, h]
        :param expected_label: tensor with shape compatible to the discriminator's output,
                               full of zeros if the proposed image is fake
                               full of ones if the proposed image is real

        :returns: the discriminator loss for this (discriminator, proposed_image, expected_label)
        """
        predicted_label = discriminator(proposed_image)
        return self.criterion_GAN(predicted_label, expected_label)

    def training_step(
        self, batch: Dict[str, torch.Tensor], batch_nb: int, optimizer_idx: int
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """ Implements a single training step

        The parameter `optimizer_idx` identifies with optimizer "called" this training step,
        this we can change the behaviour of the training depending on which optimizers
        is currently performing the optimization

        :param batch: current training batch
        :param batch_nb: the index of the current batch
        :param optimizer_idx: the index of the optimizer in use, see the function `configure_optimizers`

        :returns: the total loss for the current training step, together with other information for the
                  logging and possibly the progress bar
        """
        # Unpack the batch
        real_A = batch["A"]
        real_B = batch["B"]

        # Adversarial ground truths
        valid = torch.ones(
            (real_A.size(0), *self.D_A.output_shape), device=real_A.device
        )
        fake = torch.zeros(
            (real_A.size(0), *self.D_A.output_shape), device=real_A.device
        )

        # The first optimizer is for the two generators!
        if optimizer_idx == 0:

            # Identity A and B loss
            loss_id_A = self.identity_loss(real_A, self.G_BA)
            loss_id_B = self.identity_loss(real_B, self.G_AB)
            loss_identity = self.hparams["lambda_id"] * ((loss_id_A + loss_id_B) / 2)

            # GAN A loss and GAN B loss
            loss_GAN_AB, self.fake_B = self.gan_loss(self.G_AB, self.D_B, real_A, valid)
            loss_GAN_BA, self.fake_A = self.gan_loss(self.G_BA, self.D_A, real_B, valid)
            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss: A -> B -> A  and  B -> A -> B
            loss_cycle_A = self.cycle_loss(self.fake_B, self.G_BA, real_A)
            loss_cycle_B = self.cycle_loss(self.fake_A, self.G_AB, real_B)
            loss_cycle = self.hparams["lambda_cyc"] * ((loss_cycle_A + loss_cycle_B) / 2)

            # Total loss
            loss_G = loss_GAN + loss_cycle + loss_identity

            self.log_dict({
                    "total_loss_generators": loss_G,
                    "loss_GAN": loss_GAN,
                    "loss_cycle": loss_cycle,
                    "loss_identity": loss_identity,
                }
            )
            return loss_G

        # The second optimizer is to train the D_A discriminator
        elif optimizer_idx == 1:

            # Real loss
            loss_real = self.discriminator_loss(self.D_A, real_A, valid)

            # Fake loss (on batch of previously generated samples)
            loss_fake = self.discriminator_loss(
                self.D_A, self.fake_A_buffer.push_and_pop(self.fake_A).detach(), fake
            )

            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2
            self.log_dict({
                    "total_D_A": loss_D_A,
                    "loss_D_A_real": loss_real,
                    "loss_D_A_fake": loss_fake,
                }
            )
            return loss_D_A


        # The second optimizer is to train the D_B discriminator
        elif optimizer_idx == 2:

            # Real loss
            loss_real = self.discriminator_loss(self.D_B, real_B, valid)

            # Fake loss (on batch of previously generated samples)
            loss_fake = self.discriminator_loss(
                self.D_B, self.fake_B_buffer.push_and_pop(self.fake_B).detach(), fake
            )

            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            self.log_dict({
                    "total_D_B": loss_D_B,
                    "loss_D_B_real": loss_real,
                    "loss_D_B_fake": loss_fake,
                }
            )
            return loss_D_B

        raise RuntimeError("There is an error in the optimizers configuration!")

    def get_image_examples(
        self, real: torch.Tensor, fake: torch.Tensor
    ) -> Sequence[wandb.Image]:
        """
        Given real and "fake" translated images, produce a nice coupled images to log

        :param real: the real images with shape [batch, channel, w, h]
        :param fake: the fake image with shape [batch, channel, w, h]

        :returns: a sequence of wandb.Image to log and visualize the performance
        """
        example_images = []
        for i in range(real.shape[0]):
            couple = torchvision.utils.make_grid(
                [real[i], fake[i]],
                nrow=2,
                normalize=True,
                scale_each=True,
                pad_value=1,
                padding=4,
            )
            example_images.append(
                wandb.Image(couple.permute(1, 2, 0).detach().cpu().numpy(), mode="RGB")
            )
        return example_images

    def validation_step(
        self, batch: Dict[str, torch.Tensor], batch_idx: int
    ) -> Dict[str, Union[torch.Tensor,Sequence[wandb.Image]]]:
        """ Implements a single validation step

        In each validation step some translation examples are produced and a
        validation loss that uses the cycle consistency is computed

        :param batch: the current validation batch
        :param batch_idx: the index of the current validation batch

        :returns: the loss and example images
        """

        real_B = batch["B"]
        fake_A = self.G_BA(real_B)
        images_BA = self.get_image_examples(real_B, fake_A)

        real_A = batch["A"]
        fake_B = self.G_AB(real_A)
        images_AB = self.get_image_examples(real_A, fake_B)

        ####

        real_A = batch["A"]
        real_B = batch["B"]

        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)

        # Cycle loss A -> B -> A
        recov_A = self.G_BA(fake_B)
        loss_cycle_A = self.criterion_cycle(recov_A, real_A)

        # Cycle loss B -> A -> B
        recov_B = self.G_AB(fake_A)
        loss_cycle_B = self.criterion_cycle(recov_B, real_B)

        # Cycle loss aggregation
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
        loss_cycle = self.hparams["lambda_cyc"] * loss_cycle

        # Total loss
        loss_G = loss_cycle

        return {"val_loss": loss_G, "images_BA": images_BA, "images_AB": images_AB}

    def validation_epoch_end(
        self, outputs: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, Union[torch.Tensor, Dict[str, Union[torch.Tensor,Sequence[wandb.Image]]]]]:
        """ Implements the behaviouir at the end of a validation epoch

        Currently it gathers all the produced examples and log them to wandb,
        limiting the logged examples to `hparams["log_images"]`.

        Then computes the mean of the losses and returns it.
        Updates the progress bar label with this loss.

        :param outputs: a sequence that aggregates all the outputs of the validation steps

        :returns: the aggregated validation loss and information to update the progress bar
        """
        images_AB = []
        images_BA = []

        for x in outputs:
            images_AB.extend(x["images_AB"])
            images_BA.extend(x["images_BA"])

        images_AB = images_AB[: self.hparams["log_images"]]
        images_BA = images_BA[: self.hparams["log_images"]]

        if not self.is_sanity:  # ignore if it not a real validation epoch. The first one is not.
            print(f"Logged {len(images_AB)} images for each category.")

            self.logger.experiment.log(
                {f"images_AB": images_AB, f"images_BA": images_BA,},
                step=self.global_step,
            )
        self.is_sanity = False

        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log_dict({"val_loss": avg_loss})
        return {"val_loss": avg_loss}


## Training

Since we followed the `Lightning Module` template, we can exploit the `Trainer` to implement the train-loop.

It is extremely easy to do so. After the training, we log the trained model on wandb.


> ⏱⏱⏱ *long training time* ⏱⏱⏱
>
> Pay attention that, depending on the hyperparameters that you choose, the training may take a *lot* of time.
>
> We did `200` epochs for the `map` dataset, and only `40` for the `ukiyoe2photo`. We reduced the image size to `128x128` (from the original `256x256`) to shorten the training time.
>
> Each run took $\approx7$ hours to complete on a 2080 ti GPU.
> Keep in mind that if you use the GPUs provided by Colab the training times are about $\times 5$ slower.
>
> **In the next section, you can load pre-trained models and look at the statitics of the runs on the W&B dashboard.**

In [None]:
# ⏱⏱⏱ slow executing cell ⏱⏱⏱
# Suggested to use pre-trained models!

# Instantiate the model
gan_model = CycleGAN(hparams=cfg,
                     trainA_folder=trainA,
                     trainB_folder=trainB,
                     testA_folder=testA,
                     testB_folder=testB)

# Define the logger
# https://www.wandb.com/articles/pytorch-lightning-with-weights-biases.
wandb_logger = WandbLogger(project="CycleGAN Tutorial 2021", log_model=True)

## Currently it does not log the model weights, there is a bug in wandb and/or lightning.
wandb_logger.experiment.watch(gan_model, log='all', log_freq=100)

# Define the trainer
trainer = pl.Trainer(logger=wandb_logger,
                     max_epochs=cfg.n_epochs,
                     gpus=1,
                     limit_val_batches=.2,
                     val_check_interval=0.25)

# Start the training
trainer.fit(gan_model)

# Log the trained model
trainer.save_checkpoint('model.pth')
wandb.save('model.pth')

## Pre-trained models

We have pretrained two models for the two datasets in these tutorial.

The project is public on Weights and Biases, thus you can inspect the hyperparameters, statistics and example images in each one of the two run that produced these two models. In the next sections we'll see in detail what this model is able to produce!


- [W&B run](https://app.wandb.ai/lucmos/CycleGAN%20Tutorial/runs/2vce8p04/overview?workspace=user-lucmos) for the model trained on the ukiyoe2photo dataset

- [W&B run](https://app.wandb.ai/lucmos/CycleGAN%20Tutorial/runs/1p5mk2ia/overview?workspace=user-lucmos) for the model trained on the maps dataset



In [None]:
!wget https://api.wandb.ai/files/lucmos/CycleGAN%20Tutorial/1p5mk2ia/model.pth -O map-model.pth
!wget https://api.wandb.ai/files/lucmos/CycleGAN%20Tutorial/2vce8p04/model2.pth -O ukiyoe-model.pth

In [None]:
!ls

You can choose which model to use in the following sections. There are three possibilities:
- The model you **just trained**
- The **pre-trained `ukiyoe2photo` model**
- The **pre-trained `map` model**

In [None]:
#@title Load model weights  { run: "auto" }

load_model = "Load model trained on Maps" #@param ["Use your own model just trained", "Load model trained on Maps", "Load model trained on Ukiyo-e"]

if load_model == "Use your own model just trained":
    loaded_gan_model = gan_model
    print('Continuing to use your own model.')
elif load_model == "Load model trained on Maps":
    loaded_gan_model = CycleGAN.load_from_checkpoint('map-model.pth',
                                                    # ugly workaround to load old lightning checkpoint :[
                                                    hparams=torch.load('map-model.pth')['hparams'],
                                                    trainA_folder=trainA,
                                                    trainB_folder=trainB,
                                                    testA_folder=testA,
                                                    testB_folder=testB)
    print('Maps model loaded.')
elif load_model == "Load model trained on Ukiyo-e":
    loaded_gan_model = CycleGAN.load_from_checkpoint('ukiyoe-model.pth',
                                                    # ugly workaround to load old lightning checkpoint :[
                                                    hparams=torch.load('ukiyoe-model.pth')['hparams'],
                                                    trainA_folder=trainA,
                                                    trainB_folder=trainB,
                                                    testA_folder=testA,
                                                    testB_folder=testB)
    print('Ukiyo-e model loaded.')

loaded_gan_model = loaded_gan_model.cuda()

print(f"\nRemember to select the right dataset at the beginning!\nDataset currently selected: **{dataset_name}**")

## Playground: validation exploration

Let's see how this CycleGAN behave!

In the next cell you can select:

- The input image to feed the generator: $a \in A$ or $b \in B$
- The generator direction: $A \to B$ or $B \to A$

In [None]:
# Precompute batches

loader = loaded_gan_model.val_dataloader(custom_batch_size=10)
batches = [x for x in loader]

In [None]:
#@title Playgrond: explore validation set  { run: "auto" }

visualize_batch_idx = 34  #@param {type:"slider", min:1, max:100, step:1}


batch_extractor = {
    'A': lambda x: x['A'],
    'B': lambda x: x['B'],
}
generator_input =  'A' #@param ["A", "B"]
batch = batches[visualize_batch_idx]
gen_input = batch_extractor[generator_input](batch)


generators = {
    'A to B': loaded_gan_model.G_AB,
    'B to A': loaded_gan_model.G_BA,
}
generator_direction =  'A to B' #@param ["A to B", "B to A"]
generator = generators[generator_direction]


plot_images(gen_input, images_per_row=5, border = 5, pad_value=1, title='Generator input batch')
plot_images(generator(gen_input.cuda()).cpu().detach(), images_per_row=5, border=5, pad_value=1, title='Generator output batch')

## **EXERCISE**
>
> - What happens if you select a generator $X \to Y$ and use $y \in Y$ as the input? Why?
> - Do the two mapping directions $A \to B$ and $B \to A$ have the same complexity? When using the `ukiyoe2photo` dataset which direction is harder? Why?
> - Which hyperparameter would you change to improve the fidelity of the translated image with respect to the original image?

## Performance examples

All the examples below consist of an image pair:
- On the left there is the input image of the generator.
- On the right the translated left image.


You can look up this recap and the one in the next section from the W&B dashboard:

- [Ukiyo-e run](https://app.wandb.ai/lucmos/CycleGAN%20Tutorial/runs/2vce8p04?workspace=user-lucmos)
- [Map run](https://app.wandb.ai/lucmos/CycleGAN%20Tutorial/runs/1p5mk2ia?workspace=user-lucmos)

### Photo $\to$ Ukiyo-e

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/photo2ukiyoe.png)

### Ukiyo-e $\to$ Photo

That tree looks good!

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/ukiyoe2photo.png)

### Satellite view $\to$ Map

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/sat2map.png)

### Map $\to$ Satellite view

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/map2sat.png)





## Performance examples over time

Here there are some grid visualization, from the W&B dashboard, that let us visualise how the translation performance got better in function of the number of training steps.

If we pay attention to some images, we can see the model oscillations: a typical behaviour of GANs.

### Photo $\to$ Ukiyo-e

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/photo2ukiyoe_time.png)

### Ukiyo-e $\to$ Photo

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/ukiyoe2photo_time.png)

### Satellite view $\to$ Map

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/sat2map_time.png)

### Map $\to$ Satellite view

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/map2sat_time.png)



## Playground: online image translation

In the next cell you can translate any RGB image on the web, just insert the url of the image you want to translate in the `custom_image_url`.

In [None]:
#@title Styling online images  { run: "auto" }

urls = {
        'lake': 'https://media-cdn.tripadvisor.com/media/photo-s/01/4b/8c/20/hopfen-am-see.jpg',
        'sea': 'https://media-cdn.tripadvisor.com/media/photo-s/01/53/4d/5d/see-photo-album-for-more.jpg',
        'department': 'https://web.uniroma1.it/i3s/sites/default/files/pictures/salariaok.jpg',
        'department ...artistic?': 'https://www.di.uniroma1.it/sites/default/files/pictures/salaria.jpg',
        'colosseo': 'https://i.pinimg.com/originals/88/db/b4/88dbb44ed72207fcd8a1e9a26cbfc58d.jpg',
        'colosseo2': 'https://d9k3q4j9.stackpathcdn.com/wp-content/uploads/2016/10/Colosseo-laptop_1040_529-815x500.jpeg',
        'map': 'https://agenziauscite.openstreetmap.it/img/colosseo_nokia_maps.png',
        'turing': 'https://www.focus.it/site_stored/imgs/0003/005/h_00528110.630x360.jpg',
        'piramide': 'https://i.guim.co.uk/img/media/e3d9827f235ac40064f15d7df25024aec60500cb/0_134_5616_3370/master/5616.jpg?width=1200&height=1200&quality=85&auto=format&fit=crop&s=56f9da8e992f2558c4709614daf82a69',
        'vertical forest': 'https://i.imgur.com/mYUXjEg.png'
}



example_image = "colosseo" #@param ["department", "department ...artistic?", "lake", "sea", "colosseo", "colosseo2", "map", 'turing', 'piramide', 'vertical forest']
image_url = "" #@param {type:"string"}
url = image_url if image_url else urls[example_image]

generators = {
    'A to B': loaded_gan_model.G_AB,
    'B to A': loaded_gan_model.G_BA,
}

generator_direction =  'A to B' #@param ["A to B", "B to A"]
generator = generators[generator_direction]

force_image_resize = False #@param {type:"boolean"}


def styling_from_url(url, gen, image_transforms):
    img =  Image.open(urlopen(url))

    rimg = image_transforms(img)[None, ...].cuda()
    rimg_u = gen(rimg)

    imgs = torchvision.utils.make_grid(
        [rimg.squeeze(), rimg_u.squeeze()],
        nrow=2,
        normalize=True,
        pad_value=1,
        ).permute(1, 2, 0).detach().cpu()
    return imgs

img_transforms = transforms.Compose(
            [
                transforms.Resize([cfg.img_height, cfg.img_width], Image.BICUBIC),
                loaded_gan_model.val_image_transforms
            ]
        )


img = styling_from_url(url, generator, img_transforms if force_image_resize else loaded_gan_model.val_image_transforms)
plt.imshow(img)
plt.axis('off')
plt.title('Generator input                 Generator output')
plt.show()

## **EXERCISE**
> Implement another application of the CycleGAN besides the `maps` and `ukiyoe2photo` we saw in this section.
>
> To reduce the long training times you may:
> - Try to reduce the float precision of the model weights/use the TPUs. The Trainer class is able to do this automagically [through a parameter](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/mnist-tpu-training.html).
> - Try to re-use one of the two pretrained models provided in this notebook and do **transfer learning** on the new dataset, e.g. freezing some layers and fine-tuning the others (success is not guaranteed!).
>
> You can use the pre-made datasets from the [official repository](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), e.g. to translate horses in zebras
>
> ![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/horse2zebra.gif)
>
> If you feel brave you can use any two dataset of images with different but intuitively correlated distributions you can find online.
>
> *Bonus question*: why does the grass change colour in the previous gif?

# Adversarial examples

We have heard many times the [pompous claim](https://www.forbes.com/sites/michaelthomsen/2015/02/19/microsofts-deep-learning-project-outperforms-humans-in-image-recognition/#3a6934a2740b) of deep neural networks being better than humans in the task of recognizing object in images. If such a statement undoubtely does its job in pointing out to the laymen the huge leap forward made by neural networks in this and a multitude of other tasks, the deep learning practitioner should know its [extent](http://karpathy.github.io/2012/10/22/state-of-computer-vision/).

Today we will see spectacular failures of state-of-the-art deep neural networks right in object recongition; imperceptible but targeted changes to correctly classified examples lead to surprising misclassifications from a human point of view, we call them *adversarial examples*.

<img src="https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/10/pics/bus.jpeg" width="400">

### [Supervised learning: powerful feature learning yet still fragile](https://www.qualcomm.com/news/onq/2020/05/13/far-ai-can-see-what-we-still-need-build-human-level-intelligence?fbclid=IwAR12fZ-hh7LghpK_GxkR2c_XiFiMTsJ1dZSYSxSWwS6jLIl2Ia8wk68Wnrk)

Apart from the obvious security concerns that adversarial examples raise - consider even for a moment the application of autonomous driving - this vulnerability casts many doubts on the degree of generalization reached by DNNs. After seeing the image above, we have also probably updated our judgement about the *intelligence* needed to solve the task of object recognition in images.

Indeed this is genuine progress, we need to understand better *what is* intelligence before making an artificial intelligence.

## The fast gradient sign attack

Today we will experiment with the Fast Gradient Sign Attack (FGSM) on a large ResNet50 pretrained on ImageNet. This attack is described by Goodfellow et al. in the very readable paper [*Explaining and Harnessing Adversarial Examples*](https://arxiv.org/abs/1412.6572), upon which this tutorial is based.



### The linear explanation of adversarial examples

As in the paper, we start explaining the existence of adversarial examples in high dimensional linear models.

Let $x$ be a correctly classified example and $\tilde{x} = x + \eta$ its adversarial, obtained adding a small perturbation $\eta$. How can we build such a $\eta$ and how big should it be to lead to a misclassification?

The output of a linear model for the input $x$ can be written as $w^\top x$, so for the adversarial example:

$$w^\top \tilde{x} = w^\top x + w^ \top \eta$$

Respect to the original input, the adversarial perturbation causes the output to grow by $w^\top \eta$. We can maximize this increase by taking $\eta = \epsilon \, \text{sign}(w)$, where $\epsilon$ controls the magnitude of the perturbation, remember we want it to be small.

>**EXERCISE**: A typical constraint to keep the perturbation small is $\|\eta \|_\infty < k$.
Can you imagine why the infinite norm is a reasonable bound to a perturbation? Remember that $\|\eta \|_\infty = \max_{i}((\eta_1, ... , \eta_n))$. Answer in the paper, section 3.

Let's make some considerations on the magnitude of the perturbation when $\eta = \epsilon \, \text{sign}(w)$.

$w^\top \eta$ will be proportional to $\epsilon m n$, where $m$ is the average absolute value of the weight vector $w$ and $n$ is the dimensionaliy of the input. Notice that $\|\eta \|_\infty$ does not growth with the dimensionality of the problem, so the change in the output caused by the perturbation grows linearly with $n$, i.e. many infinitesimal changes to the input add up to a huge change in the output in a high-dimensional problem.

Notice also how the average magnitude of the weights $m$ is involved, a model with small weights is less vulnerable than one with large weights.

>**EXERCISE:** Do we have studied methods which promote actively small weights? Which was the rationale behind such methods and how it relates to this new vulnerability concept?

### Linear perturbation of non-linear models

I know what are you thinking right now...

>*Wait! All these arguments are simple and convincing but in the end do not apply to DNNs, since DNNs are highly non-linear models! Am I wrong?*

Actually DNNs are definetely non-linear respect to their parameters, but are quite linear respect to their inputs. More precisely, DNNs are piecewise linear, with RELU activations, and these linear intervals are much bigger than we expect.

The piecewise linearity of a DNN using RELU as activation function should not be so surprising. Nevertheless this apply also to more non-linear models such as networks using sigmoids as activation, since these are designed to spend most of their time in the linear regime of the sigmoid, where gradients does not go to zero, for instance using batch normalization. This linearity is a key-ingredient for an efficient gradient-based optimization.

Are DNNs so much linear to be vulnerable to such kind of attack?

We will experiment it in a minute. Let $f_\theta$ be our DNN with parameters $\theta$, if its output $y$ was linear respect to the input $x$ then $y = f_\theta(x) \sim w^\top x$ for some $w(\theta)$. Let $J(\theta, x, y)$ be the cost used to train the neural network, which will be in a form:

$$J(\theta, x, y) = \|f_\theta(x) - y\| \sim \| w^\top x - y\|$$

now if we take the gradient respect to the input $x$:

$$\nabla_x J(\theta, x, y) \sim \nabla_x(\| w^\top x - y\|) = w$$

so we can attack our DNN using the perturbation defined above $\eta = \epsilon \, \text{sign}(w)$ taking:

$$\eta = \epsilon \, \text{sign}(\nabla_x J(\theta, x, y))$$

We refer to this as the **fast gradient sign attack**, note that such a gradient can be computed very efficiently using backpropagation.

This kind of attack will work only if DNNs are sufficiently linear, let's see!

## Experimenting with adversarial examples on Imagenette


### Goliath
Let's start by loading a state-of-the-art model in object classification.

Ladies and gentlemen, directly from the pretrained models on torchvision...

**ResNet50**

Capable of a quite remarkable [76.15%](https://pytorch.org/vision/stable/models.html) top-1 accuracy on the ImageNet classification task (1000 classes).


In [None]:
model = torchvision.models.resnet50(pretrained=True)
model.eval()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### Imagenette, a convenient subset of ImageNet

The full ImageNet dataset weighs approximately 300 GB.

A very small fraction of it would be sufficient for our experiment, so we will download [Imagenette](https://github.com/fastai/imagenette), a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).

Our model takes in input 224x224 images, so the version resized to 320 pixel is enough for us.

In [None]:
!wget 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz'

Let's unpack it and load in a Pytorch Dataset transformed just as the pretrained model wants.

In [None]:
!tar -xzf '/content/imagenette2-320.tgz'

In [None]:
test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),  # this transform perform always the same crop in the center of the image, we do not want to augment the validation dataset
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

testset = torchvision.datasets.ImageFolder(root='/content/imagenette2-320/val', transform=test_transform)

A dictionary containing the human readable labels of ImageNet will turn useful.

In [None]:
!wget 'https://raw.githubusercontent.com/deep-learning-with-pytorch/dlwpt-code/master/data/p1ch2/imagenet_classes.txt'

In [None]:
class_names = {}
names_class = {}
label = 0
with open('/content/imagenet_classes.txt') as f:
    for line in f:
        class_string = line.split(',')[0].replace('\n', '')
        class_names[label] = class_string
        names_class[class_string] = label
        label += 1

In [None]:
sub_class_names = ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']

Actually we do not need all the images in Imagenette for our experiment. Let's take only 50 samples evenly separated in index to grasp every class. The ```Subset``` class of PyTorch will make it very easy.


In [None]:
print('test set size: {}'.format(len(testset)))
indices = list(range(0, len(testset), len(testset) // 50 ))
sub_testset = torch.utils.data.Subset(testset, indices)
print('subtest set size: {}'.format(len(sub_testset)))

test_loader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=True, num_workers=4)

As always, let's visualize some samples to get an idea of the classification task.

Finally we work with something more than a toy-model capable to process only postage-stamp sized images.

We are ready for the fall of the giant.

In [None]:
# @title Visualize samples function

def visualize_samples(inputs, title=None):
    """
    Visualization of transformed samples, a standard call:
        inputs, classes = next(iter(dataloaders['train']))
        visualize_samples(inputs)
    Arguments:
    batch_of_samples -- a batch from the dataloader; a PyTorch tensor of shape (batch_size, 3, 224, 224)

    Return:
    None (A nice plot)
    """

    # Make a grid from batch
    inp = torchvision.utils.make_grid(inputs)

    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)  # plotly accepts the colour information both in the 0-1 range and in the 0-255 range
    fig = px.imshow(inp, title=title)
    fig.show()




In [None]:
# Get a batch of training data
inputs, classes = next(iter(test_loader))

visualize_samples(inputs, title=f'Ground truth: {[sub_class_names[x] for x in classes]}')

### The fall of the giant

Let's craft our slingshot, the function defining the FGSM attack.

$$\tilde{x} = x + \epsilon \, \text{sign}(\nabla_x J(\theta, x, y))$$



In [None]:
def fgsm_attack(image, epsilon, data_grad):

    perturbed_image = image + epsilon * data_grad.sign()

    return perturbed_image

Now it's time to wrap a convenient test function. We will run our attack tracking the accuracy of the model on the adversarial dataset over several $\epsilon$.

For each $\epsilon$ we will save five miscalssified images to look later at their appearence, we expect them to be almost indistinguishable by human eye from the original ones.




In [None]:
def test( model, device, test_loader, epsilon ):

    # Accuracy counter
    correct = 0
    old_target_item = 0
    adv_examples = []

    # Loop over all examples in test set
    for data, target in tqdm(test_loader):

        # The target of the reduced dataset should be mapped to the one of ImageNet
        target.data = torch.tensor([names_class[sub_class_names[target.item()]]])
        data, target = data.to(device), target.to(device)

        # Differently fromt training, we require the gradient respect to the input
        data.requires_grad = True

        # We collect the prediction on the original sample
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # We will use negative log likelihood since this pretrained model comes equipped with a final LogSoftmax layer
        loss = F.nll_loss(output, target)

        model.zero_grad()
        loss.backward()

        # We have to collect the gradient respect to the input
        data_grad = data.grad.data

        # We get the perturbed image from the FGSM function
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # We collect the prediction on the adversarial example
        output = model(perturbed_data)

        # We want to track the accuracy and keep some perturbed images
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
            # Special case for saving 0 epsilon examples
            if (epsilon == 0) and (len(adv_examples) < 5) and final_pred.item() != old_target_item:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
                old_target_item = final_pred.item()
        else:
            # Save some adv examples for visualization later
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )

    # Calculate final accuracy for this epsilon
    final_acc = correct/float(len(test_loader))
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))

    # Return the accuracy and an adversarial example
    return final_acc, adv_examples

Time to attack!

Let's define a bunch of $\epsilon$, we will track the accuracy of the model on our dataset for each one.

> **EXERCISE:** How do you expect the trend of accuracy vs $\epsilon$?

We will include the $\epsilon=0$ case which represents the original test accuracy,




In [None]:
epsilons = [0, 0.0005, 0.0013, 0.002, 0.004, 0.006, 0.008, 0.01, 0.1]

# in this example we will compute the adversarial example of one sample at a time, so we redefine a testloader woth batch_size=1
test_loader = torch.utils.data.DataLoader(sub_testset, batch_size=1,
                                         shuffle=False, num_workers=4)

In [None]:
accuracies = []
examples = []

# Run test for each epsilon
for eps in epsilons:
    acc, ex = test(model, device, test_loader, eps)
    accuracies.append(acc)
    examples.append(ex)

### Results

Let's take a look at the accuracy trend respect to $\epsilon$, as expected we register a degradation of the performance as $\epsilon$ increase.

The trend is almost linear in a finite range of $\epsilon$, then we register a slowdown in the approach to zero accuracy.

>**EXERCISE**: What do you expect to happen for negative values of $\epsilon$? Try it, how do you explain this behaviour?

>**EXERCISE**: Are you able to reach a 0 accuracy? Which is the minimum $\epsilon$ needed for such a result? Have you noticed something unexpected, how do you explain this behaviour?




In [None]:
# @title Accuracy vs epsilon plot

plt.figure(figsize=(20,5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, 0.105, step=0.005))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()

It is now the moment to look at some perturbed images!

As expected the perturbations are almost imperceptible, humans have no problem at all in identifying such images, we have a chance to spot something only for the higher $\epsilon$s, where the accuracy of the model is already below 30%.

The giant has fallen.



In [None]:
# @title A nice plot of some misclassified perturbed images for each epsilon
cnt = 0
plt.figure(figsize=(20,40))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0]),cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
        orig,adv,ex = examples[i][j]

        ex = ex.transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        ex = std * ex + mean
        ex = np.clip(ex, 0, 1)
        ex = ex.transpose((2, 0, 1))

        plt.title("{} -> {}".format(class_names[orig], class_names[adv]))
        plt.imshow(np.moveaxis(ex, 0, -1))
plt.tight_layout()
plt.show()

>**EXERCISE:** Visualize the adversarial perturbation for some images, i.e. $\tilde{x} - x$

>**EXERCISE**: A nice feature of the FGSM attack is its reduced cost, we can perform it very efficiently through backpropagation. Nevertheless in this notebook we have processed each input separately, can you arrange the code to work with ```batch_size``` > 1 and be even more efficient?

*tutorial on adversarial examples adapted from [this](https://pytorch.org/tutorials/beginner/fgsm_tutorial.html) PyTorch tutorial*
