# Tutorial: Transformers, Vision Transformers and applications

**Filled notebook:**
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/transformers_tutorial.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/metrics-lab/transformer-tutorial/blob/main/tutorial/transformers_tutorial_aml.ipynb)  
**Pre-trained models:**
[![View files on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/github/metrics-lab/transformer-tutorial/blob/main/saved_models/)

**Author:** Simon Dahan

**Contact:** Please contact me at simon.dahan@kcl.ac.uk if you have questions on the tutorial or want to discuss some concepts further.


In this tutorial, we will discuss one of the major deep learning architectures of the past years: the Transformers models - and uncover how it can be used in different applications for Natural Language Processing (NLP), Computer Vision (CV) and medical imaging. This tutorial is greatly inspired by the [*UVA Deep Learning Notebooks*](https://uvadlc-notebooks.readthedocs.io/en/latest/), particularly the [*Tutorial 6*](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.ipynb#scrollTo=1hkNROGHXvaz) and [*Tutorial 15*](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial15/Vision_Transformer.ipynb#scrollTo=DbUKvP9NXy-H), implemented with PyTorch Lightning. Here, we adapted the code in regular Pytorch and with some little twists to adapt to the case of medical imaging.  





# Tutorial overview

<details>
    <summary><b>Part 1 - The Transformer Architecture (NLP)</b></summary>
    <p><a href="#background">1.1. Background</a></p>
    <p><a href="#attention">1.2 What is attention ? </a></p>
    <p><a href="#attention_op">1.3 Learning Keys, Queries and Weights</a></p>
    <p><a href="#attention_op">1.3 Scaled Dot Product Attention</a></p>
    <p><a href="#multi head attention">1.4. Multi-Head Attention</a></p>
    <p><a href="#transformer encoder">1.5. Transformer Encoder</a></p>
    <p><a href="#positional encoding">1.6. Positional Encoding</a></p>
    <p><a href="#transformer model">1.7. Transformer Model</a></p>
</details>
<br>

<details>
    <summary><b>Part 2 - The Vision Transformer Architecture (CV)</b></summary>
    <p><a href="#vision transformer">2.1. Vision Transformers</a></p>
    <p><a href="#experiments_part_2 transformer">2.2. Experiments: Image classification</a></p>
    <p><a href="#visualising attention">2.3. Visualising Attention Maps</a></p>
    
</details>
<br>
<details>
    <summary><b>Part 3 - Application 1: Vision Transformers in Medical Imaging</b></summary>
    <p><a href="#background">1.1. Background</a></p>
</details>
<br>
<details>
    <summary><b>Part 4 - Application 2: Surface Vision Transformers</b></summary>
    <p><a href="#background">1.1. Background</a></p>
</details>


# Import

In [None]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()


## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "./data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/"

os.makedirs(DATASET_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

<a id="transformer"></a>
# 1. The Transformer Architecture (NLP)

The Transformer architecture was introduced in [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al in 2017 in the context of machine translation and other Natural Language Processing (NLP) tasks. Transformer models have revolutionised the NLP field, surpassing the popular RNN and LSTM architectures, though proposing a *self-attention mechanism* which supports the modelling of long-range context. Nowadays, the Transformer architecture is the backbone of many popular Large Language Models such as *GPT* models [A. Radford et al 2019](https://insightcivic.s3.us-east-1.amazonaws.com/language-models.pdf) or *Mixtral* [A.Q. Jiang et al 2024](https://arxiv.org/pdf/2401.04088.pdf). More recently, they have shown potential as a domain agnostic architecure - allowing their adaption to natural image, graph and surface domains, as we will discuss later in this tutorial.

<a id="#background"></a>
## 1.1. Some background and references


In the first part of this notebook, we will implement the Transformer architecture by hand. As the architecture is so popular, there already exists a Pytorch module `nn.Transformer` ([documentation](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)) and a [tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) on how to use it for next token prediction. However, we will implement part of it here ourselves, to get through to the smallest details.

There are of course many more tutorials out there about attention and Transformers. Below, we list a few that are worth exploring if you are interested in the topic and might want yet another perspective on the topic after this one:

* [Transformer: A Novel Neural Network Architecture for Language Understanding (Jakob Uszkoreit, 2017)](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html) - The original Google blog post about the Transformer paper, focusing on the application in machine translation.
* ⭐ [The Illustrated Transformer (Jay Alammar, 2018)](http://jalammar.github.io/illustrated-transformer/) - A very popular and great blog post intuitively explaining the Transformer architecture with many nice visualizations. The focus is on NLP.
* ⭐ [Attention? Attention! (Lilian Weng, 2018)](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html) - A nice blog post summarizing attention mechanisms in many domains including vision.
* [Illustrated: Self-Attention (Raimi Karim, 2019)](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a) - A nice visualization of the steps of self-attention. Recommended going through if the explanation below is too abstract for you.
* [The Transformer family (Lilian Weng, 2020)](https://lilianweng.github.io/lil-log/2020/04/07/the-transformer-family.html) - A very detailed blog post reviewing more variants of Transformers besides the original one.

<a id="#attention"></a>
## 1.2 What is Attention?

The **self-attention mechanism** was introduced prior to the Transformer Architecture in various parallel works, notably in [*A structure self-attentive sentence embedding* Z.Lin et al 2017](https://arxiv.org/pdf/1703.03130.pdf). Therefore, there are a number of different definitions; however, the one we will use here is: '_the attention mechanism describes a weighted average of (sequence) elements, with weights dynamically computed based on an input query and elements' keys_.

 **So what does this exactly mean?** Let's take an example. Imagine that we are interested in understanding the meaning of an English sentence, the self-attention mechanism will  assess how much the meaning of each word depends on all others. For instance, here the self-attention mechanism will allow the model to *attend* to the words *The* and *Animal* in order to understand the meaning of the word *it*.

<center width="100%" style="padding:25px"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/attention_example_1.png?raw=1" height="600px" width="600px"></center>


In practice, the goal is to update the value of each word by taking a weighted average of all other words in the sentence (or wider prose). This process is implemented dynamically, allowing the model to learn many different levels of meaning. This is achieved by breaking the self-attention mechanism down into four component parts:

* **Query**: The query is a feature vector that describes the word/token that we are comparing against, i.e. what would we maybe want to pay attention to.
* **Keys**: For each input element (word/token) we also have a key vector. This is what we compare the query against. The keys should be designed such that we can identify the words/tokens we want to pay attention to based on the query.
* **Values**: For each input element, we also have a value vector. This feature vector is the one we want to average over.
* **Score function**: To rate which elements we want to pay attention to, we need to specify a score function $f_{attn}$. The score function takes the query and a key as input, and outputs an estimate of similarity for each query-key pair. It is usually implemented by a dot product (see section 1.3).

The scores are then normalised and passed through a softmax to output self-attention weights. This will assign a higher contribution to value vectors whose key is most similar to the query. Visually, we can show the attention over a sequence of words as follows:

<center width="100%" style="padding:25px"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/attention_example_2.svg?raw=1" width="750px"></center>

For every word, we have one key and one value vector. The query is compared to all keys with a score function (in this case the dot product) to determine the weights. Finally, the value vectors of all words are averaged using the attention weights.

<a id="#attention_basics"></a>
## 1.3 Learning Keys, Queries and Weights
The keys, queries and weights are not given, rather they are learnt from the data by tuning weights matrices: $W_{1...h}^{Q}\in\mathbb{R}^{D \times d_k}$, $W_{1...h}^{K}\in\mathbb{R}^{D \times d_k}$, $W_{1...h}^{V}\in\mathbb{R}^{D \times d_k}$ s.t for input $X\in \mathbb{R}^{T \times D}$, $Q=XW^Q$, $K=XW^K$ and $V=XW^V$.

Here $T$ is the sequence length (e.g. number of words in sentence, number of patches in an image), $D$ is the length of the input token, and $d_k$ and $d_v$ output length of queries/keys and values respectively. In practice, it is standard to keep token length the same throughout the network s.t $d_k=d_v=D$; going forward we will therefore stick to $D$ to represent all token/feature vector length throught the network.

Note training learnable weights for $Q,K$ and $V$ is vital to allow the network to model different relationships beteen words/or tokens in order to learn a range of hierarchy of 'meanings' for each sequence (see slides).

The final matrices therefore have shape: $Q\in\mathbb{R}^{T\times D}$,$K\in\mathbb{R}^{T\times D}$ and $V\in\mathbb{R}^{T\times D}$.

<a id="#attention_op"></a>
## 1.4 Scaled Dot Product Attention

Our goal is to have an attention mechanism with which any element in a sequence can attend to any other while still being efficient to compute. This can be achieved through calculating a scaled dot product:

$$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{D}}\right)V$$

Here matrix multiplication $QK^T$ performs the dot product for every possible pair of queries and keys, resulting in a matrix of the shape $T\times T$. Each row then represents the attention logits for a specific query $i$ to keys representings all other elements in the sequence. On these, we normalise by the square root length of the key/query vectors ($\sqrt D$) - in order to preserve unit variance  throughout the model - then apply a softmax. This provides self-attention weights which are then multipled with the values (for each element) to obtain a weighted mean. The computation graph for these operations is visualized below (figure credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)).

<center width="100%"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/scaled_dot_product_attn.svg?raw=1" width="210px"></center>

Note, the block `Mask (opt.)` in the diagram above represents the optional masking of specific entries in the attention matrix. We will not used it going forward.

Let's start by writing a function which computes the output features given the triple of queries, keys, and values:

#### **Exercise 1** implementing the scaled dot product

Here, you can find an implementation of the scaled dot product, essential for the attention computation.


- `Task 1.1`: In the scale dot product function, apply the correct normalisation to the attention logits, before the softmax operation.

**Hint**: check the attention operation above.

- `Task 1.2`: In the "Apply self-attention" cell: create three torch random tensors for the Keys, Queries and Values. The sequence has a length of **3** and each token has a dimension of **6**.

- `Task 1.3`: Answer the question and check the size of tensors.

In [None]:
# --------------------------------------------------task 1 ------------------------------------------------------------
# Task 1.1
def scaled_dot_product(q, k, v):
    #embedding size for normalisation
    d_k = q.size()[-1]
    #dot products between all query and keys pairs
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    #normalisation
    attn_logits = attn_logits / None
    #apply softmax to obtain attention score
    attention = F.softmax(attn_logits, dim=-1)
    #use the attention scores to weights the contribution of value tokens
    #to genereate the new sequence
    new_values = torch.matmul(attention, v)

    return new_values, attention

In [None]:
# Task 1.2

seq_len = None
emb_dim = None

q = None
k = None
v = None

values, attention = scaled_dot_product(q, k, v)


**Question: What do you expect the size of the Attention matrix to be? and why?**

Answer:

In [None]:
print("Q has shape {}\n".format(q.shape), q)
print('')
print("K has shape {}\n".format(k.shape), k)
print('')
print("V has shape {}\n".format(v.shape), v)
print('')
print("Values have shape {}\n".format(values.shape), values)
print('')
print("Attention have shape {}\n".format(attention.shape), attention)

<a id="multi head attention"></a>
## 1.5 Multi-Head Attention

The scaled dot product attention allows a network to attend over a sequence. However, often there are multiple different levels of 'meaning' to any sequence, and this must be modelled through a range of self-attention weightings. This is addressed through the multi-head attention mechanism which models multiple different query-key-value triplets in parallel. Specifically, given a query, key, and value matrix, we transform those into $h$ sub-queries, sub-keys, and sub-values, which are then each passed independently through a scaled dot product attention operation. Afterward, the outputs of all heads are concatenated and combined with a final weight matrix $W^O$:

$$
\begin{split}
    \text{Multihead}(Q,K,V) & = \text{Concat}(\text{head}_1,...,\text{head}_h)W^{O}\\
    \text{where } \text{head}_i & = \text{Attention}(Q_i,K_i, V_i)
\end{split}
$$

Here, Keys, Queries and Values are split across $h$ heads s.t. $Q_i\in\mathbb{R}^{T\times D/h}$, $K_i\in\mathbb{R}^{T\times D/h}$, $V_i\in\mathbb{R}^{T\times D/h}$, and $W^{O}\in\mathbb{R}^{D \times D}$ is the final projection matrix. Expressed in a computational graph, we can visualize it as (figure credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)).

<center width="100%"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/multihead_attention.svg?raw=1" width="230px"></center>


#### **Exercise 2** implementing the Multi-Head Self-Attention Layer

In practice, Keys Queries Values are computed together from a single linear layer (see lines 14 and 31 below). This means that we learn one large weights matrix with $3 \times h \times D$ rows - then split it into $h$ heads (row 34) then split these into the Query, Key, and Value for each head (line 36). This is done for computational efficiency.Finally these are pushed through the scaled dot product attention operation (line 40).   

In this exercise we will take you through this process step by step.

*First Complete the init function:*

- `Task 2.1`: Complete the linear layer for estimation of the weights matrix (line 14).

**Hint** For this you need to calculate what the number of output channels should be (number of rows of the output matrix). Remember, above we said it simultanesouly learns all weights for *all queries, keys and values, for all heads*.

- `Task 2.2`: Reset the parameters of the layers (line 18)

*Now complete the forward function:*

- `Task 2.3`: Pass the correct tensors, in correct order, to the self-attention computation (line 40).

**Hint** remember we defined this above. It calculates the self-attention weights and outputs the transformed values (```new_values```) and attention weights

- `Task 2.4`: Implement the final linear layer (to multiply the output with $W^O$) (line 46, referencing line 15)

In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # --------------------------------------------------task 2 ------------------------------------------------------------
        # Task 2.1: implement keys, queries, values projection layer
        self.qkv_proj = nn.Linear(input_dim, None)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Task 2.2: reset the parameters of the layers
        None

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x,  return_attention=False):

        batch_size, seq_length, _ = x.size()

        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # apply the attention layer
        # Task 2.3: implement the scaled dot product attention
        new_values, attention = scaled_dot_product(None, None, None)

        new_values = new_values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        new_values = new_values.reshape(batch_size, seq_length, self.embed_dim)

        # Task 2.4: implement the final linear projection
        out = self.out_proj(None)

        # ---------------------------------------------------------------------------------------------------------------------s

        if return_attention:
            return out, attention
        else:
            return out

One crucial characteristic of the multi-head attention is that it is permutation-equivariant with respect to its inputs. This means that if we switch two input elements in the sequence, e.g. $X_1\leftrightarrow X_2$ (neglecting the batch dimension for now), the output is exactly the same besides the elements 1 and 2 switched. Hence, the multi-head attention is actually looking at the input not as a sequence, but as a set of elements. This property is what makes the multi-head attention block and the Transformer architecture so powerful! But what if the order of the input is actually important for solving the task? This is often a key component of the structure of languages *and images*. The answer is _Positional encodings _ which we will look at in sec 1.6 below.


## 1.6 Transformer Encoder

First lets look at how to apply the multi-head attention block inside the architecture of a Transformer encoder (see LHS of figure)

<center width="100%"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/transformer_architecture.svg?raw=1" width="400px"></center>.

credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)

This applys $N$ identical layers/blocks, sequentially. Each layer implements a Multi-Head Attention block. The output is added to the input through a
residual connection. Finally, the output is summed and passed through Layer Normalization.  

The residual connection is crucial in the Transformer architecture for two reasons:

1. Similar to ResNets, Transformers are designed to be very deep. Some models contain more than 24 blocks in the encoder. Hence, the residual connections are crucial for enabling a smooth gradient flow through the model, as they make it easier for the network to learn the identity operation when no transform is needed.
2. Without the residual connection, the information about the original sequence is lost. Remember that the Multi-Head Attention layer ignores the position of elements in a sequence, and can only learn it from positional encoding. Removing the residual connections would mean that this information is lost.

The Layer Normalization also plays a vital role as ensures all features are a similar magnitude within each toke, this provides regularisation, enabling faster training. We are not using Batch Normalization because batches are often small with Transformers as they require a lot of GPU memory. BatchNorm has anyway been shown to perform poorly for language since the features of word tokens display high variance (there are many, rare words that must be well modelled for a good distribution estimate).

Finally, each layer contains a small fully-connected (feed-forward) network; this implements: Linear$\to$ReLU$\to$Linear MLP, reading each transformed token ($x$) separately. This means it treats each token identically:  

$$
\begin{split}
    \text{FFN}(x) & = \max(0, xW_1+b_1)W_2 + b_2\\
    x & = \text{LayerNorm}(x + \text{FFN}(x))
\end{split}
$$

This MLP adds extra complexity to the model, and can be seen as "post-processing" the output from the previous Multi-Head Attention operation, to prepare it for the next attention block. Usually, the inner dimensionality of the MLP is 2-8$\times$ larger than $D$. The general advantage of a wider layer instead of a narrow, multi-layer MLP is the faster, parallelizable execution.


#### **Exercise 3** implementing the Encoder block

Let's start implementing it with a single encoder block. Note the addition of dropout layers in the MLP and output of the MLP and Multi-Head Attention (for regularization).

*Implement the following steps:*

- `Task 3.1`: apply self attention layer on the input (line 33); remembering we have already defined this function above

- `Task 3.2`: apply a dropout layer on the output of the attention; then add a residual layer (line 35 - one line!)

- `Task 3.3`: apply the first Layer Normalisation layer (line 37)

- `Task 3.4`: apply the feed forward network (MLP) on the output of the layer norm (line 41)

- `Task 3.5`: again apply dropout and a residual layer (line 43)

- `Task 3.6`: finally apply the second Layer Normalisation (line 45)

**Hint** all of these operation are defined in the ```__init__``` function, and you just need to apply them!


In [None]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # --------------------------------------------------task 3 ------------------------------------------------------------
        # Attention part
        # Task 3.1 apply self attention layer on the input x
        attn_out = None
        # Task 3.2 apply dropout layer on the attention output and add a residual layer
        x = None
        # Task 3.3 apply normalisation layer
        x = None

        # MLP part
        # Task 3.4 apply the feed forward network on the attention part output
        linear_out = None
        # Task 3.5 apply dropout layer on the FFN output and add a residual layer
        x = None
        # Task 3.6 apply the normalisation layer
        x = None
        # ---------------------------------------------------------------------------------------------------------------------

        return x

Based on this block, we can implement a module for the full Transformer encoder. Additionally to a forward function that iterates through the sequence of encoder blocks, we also provide a function called `get_attention_maps`. The idea of this function is to return the attention probabilities for all Multi-Head Attention blocks in the encoder. This helps us in understanding, and in a sense, explaining the model. However, the attention probabilities should be interpreted with some sceptism as they does not necessarily reflect the true interpretation of the model (there is a series of papers about this, including [Attention is not Explanation](https://arxiv.org/abs/1902.10186) and [Attention is not not Explanation](https://arxiv.org/abs/1908.04626)).

In [None]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x

    def get_attention_maps(self, x):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

## 1.7 Positional encoding
<a id="positional encoding"></a>

We have discussed before that the Multi-Head Attention block is permutation-equivariant, and cannot distinguish whether an input comes before another one in the sequence or not; while highlighting *the order of tokens in a sequence is almost always important* for understanding context. To embed he concept of order into the network we must therefore add positional encodings. These may be learnt or pre-defined. In the original 'Attention is all you need paper' Vaswani et al. pre-defined embeddings from sine and cosine functions of different frequencies:

$$
PE_{(pos,i)} = \begin{cases}
    \sin\left(\frac{pos}{10000^{i/d_{\text{model}}}}\right) & \text{if}\hspace{3mm} i \text{ mod } 2=0\\
    \cos\left(\frac{pos}{10000^{(i-1)/d_{\text{model}}}}\right) & \text{otherwise}\\
\end{cases}
$$

Here $PE_{(pos,i)}$ represents the position encoding at position $pos$ in the sequence, with hidden (feature) dimensionality $i$. These values, concatenated for all hidden dimensions, are summed to the original input features.

To understand the positional encoding, we can build from the following [PyTorch tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model). This will allow us to visualise the positional encoding, over feature dimensions and position in a sequence:

In [None]:
encod_block = PositionalEncoding(d_model=48, max_len=96)
pe = encod_block.pe.squeeze().T.cpu().numpy()

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8,3))
pos = ax.imshow(pe, cmap="RdGy", extent=(1,pe.shape[1]+1,pe.shape[0]+1,1))
fig.colorbar(pos, ax=ax)
ax.set_xlabel("Position in sequence")
ax.set_ylabel("Hidden dimension")
ax.set_title("Positional encoding over hidden dimensions")
ax.set_xticks([1]+[i*10 for i in range(1,1+pe.shape[1]//10)])
ax.set_yticks([1]+[i*10 for i in range(1,1+pe.shape[0]//10)])
plt.show()

In [None]:
sns.set_theme()
fig, ax = plt.subplots(2, 2, figsize=(12,4))
ax = [a for a_list in ax for a in a_list]
for i in range(len(ax)):
    ax[i].plot(np.arange(1,17), pe[i,:16], color=f'C{i}', marker="o", markersize=6, markeredgecolor="black")
    ax[i].set_title(f"Encoding in hidden dimension {i+1}")
    ax[i].set_xlabel("Position in sequence", fontsize=10)
    ax[i].set_ylabel("Positional encoding", fontsize=10)
    ax[i].set_xticks(np.arange(1,17))
    ax[i].tick_params(axis='both', which='major', labelsize=10)
    ax[i].tick_params(axis='both', which='minor', labelsize=8)
    ax[i].set_ylim(-1.2, 1.2)
fig.subplots_adjust(hspace=0.8)
sns.reset_orig()
plt.show()

# 2. The Vision Transformer Architecture

The motivation behind Transformers was to propose a mechanism that would support the learning of long-range attention to encode more complicated semantic concepts within language. The contribution of "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" [Alexey Dosovitskiy et al.](https://openreview.net/pdf?id=YicbFdNTTy) was to show that image-understanding also benefits from a more holistic view of each scene.

Specifically, such Vision Transformers approach image understanding as a sequence modelling problem, by splitting images of, for example, $48\times 48$ pixels into 9 $16\times 16$ patches. In this sense, each patch is first embedded with a linear layer and then considered to be a "word"/"token". Positional encodings are then summed as before and a separate token is added for classification. Beyond this the sequence is treated like any other sequence model and processed with a series of Multi-Head Attention layers, with the exact same architecture as used for language models. A nice GIF visualization of the architecture is shown below (figure credit - [Phil Wang](https://github.com/lucidrains/vit-pytorch/blob/main/images/vit.gif)):

<center width="100%"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/vit.gif?raw=1" width="600px"></center>

Let's start by importing everything we need:

In [None]:
## Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

In [None]:
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

# Visualize some examples
NUM_IMAGES = 16
CIFAR_images = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8,8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

<a id="#vision transformers"></a>
## 2.1 Vision Transformers


We will walk step by step through the Vision Transformer, and implement all parts by ourselves.

The first step is to implement patching of each $N\times N$ image into $(N/M)^2$ patches of size $M\times M$.

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

#### **Exercise 6** Image patches

Let's take a look at how that works for our $32\times 32$ `CIFAR images` examples above.

- `Task 6.1`: Visualise patches of different sizes (for instance 4, 8 or 16). What is the lenght of the sequence in each case?

In [None]:
# --------------------------------------------------task 6 ------------------------------------------------------------
# Task 6.1 Try out different patch sizes and visualize the patches
img_patches = img_to_patch(CIFAR_images, patch_size=None, flatten_channels=False)

fig, ax = plt.subplots(CIFAR_images.shape[0], 1, figsize=(14,3))
fig.suptitle("Images as input sequences of patches")
for i in range(CIFAR_images.shape[0]):
    img_grid = torchvision.utils.make_grid(img_patches[i], nrow=64, normalize=True, pad_value=0.9)
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis('off')
plt.show()
plt.close()

Compared to the original images, it is much harder to recognize the objects from those patch lists now. The inductive bias in CNNs that an image is a grid of pixels, is lost in this input format. With the help of positional encodings, the model must to learn for itself how to combine the patches to recognize the objects.


<a id="#experiments_part_2"></a>
## 2.2 Experiments: Image Classification

Let's now try out the Vision Transformer (ViT) on image classification.
We will use the github repository [ViT Pytorch](https://github.com/lucidrains/vit-pytorch) which implements many of the latest vision transformer architectures. We will compare the performance of the ViT against a ResNet CNN.

In [None]:
#install library
!pip install vit-pytorch timm

import torch
from vit_pytorch import ViT
import timm
from torch.optim.lr_scheduler import StepLR  # Or any other scheduler you prefer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

#### **Exercise 7** Comparing Vision Transformer and CNN for image classification

Let's first  see if vision transformers can indeed solve image classification tasks.

- `Task 7.1`: Instantiate a ViT model from the vit_pytorch library. `patch_size=4` is set to 4 by default but feel free to try different values to see how it impacts the training models.

- `Task 7.2`: Set the training parameters to `n_epochs=10` and `lr=0.001`. Then, try to improve the accuracy by playing with the parameters. You can also modify the scheduler.

- `Task 7.3`: Train the vision transformer model and report the accuracy after 10 epochs.

- `Task 7.4`: Instantiate a `resnet50` model using the timm library and compare the validation accuracy between CNN and ViT.


In [None]:
# --------------------------------------------------task 7 ------------------------------------------------------------
# Task 7.1: instantiate the model here
transformer_model = ViT(
                          image_size = None, #height/width of the image
                          patch_size = 4, #size of the patch to form the sequence
                          num_classes = None,
                          dim = 192, #embedding dimension D
                          depth = 12, #number of encoder layers L
                          heads = 3, # number of heads in the MHSA
                          mlp_dim = 4*192, #embedding dimension of the FFN
                          dropout = 0.0, #dropout prior to the first transformer layer
                          emb_dropout = 0.0 #dropout applied between the MHSA and FFN
                      )


transformer_model.to(device)

print('Number of parameters ViT model: {:,}'.format(sum(p.numel() for p in transformer_model.parameters() if p.requires_grad)))

# Task 7.2: Play with the hyperparameters
n_epochs = None
optimizer = optim.Adam(transformer_model.parameters(), lr=None)  # Set your own learning rate
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)  # Adjust according to your needs

In [None]:
# Task 7.3: Train the model
for epoch in range(n_epochs):  # Set your own number of epochs
    transformer_model.train()
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = transformer_model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()

    scheduler.step()  # Adjust learning rate

    # Validation step
    transformer_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = transformer_model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        val_accuracy = correct / total
        print(f'Epoch {epoch}, Val Accuracy: {val_accuracy}')


In [None]:
# Task 7.4: Train a ResNet Model
resnet_model = timm.create_model(None, pretrained=False, num_classes=10)
resnet_model.to(device)

from torch.optim.lr_scheduler import StepLR  # Or any other scheduler you prefer

n_epochs = 10
optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)  # Set your own learning rate
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)  # Adjust according to your needs


for epoch in range(n_epochs):  # Set your own number of epochs
    resnet_model.train()
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = resnet_model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()

    scheduler.step()  # Adjust learning rate

    # Validation step
    resnet_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = resnet_model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        val_accuracy = correct / total
        print(f'Epoch {epoch}, Val Accuracy: {val_accuracy}')

**Question: Which model between the CNN and the ViT give the best classification results? Why do you think it is the case?**

Answer:

<a id="#visualising attention"></a>
## 2.3 Visualising attention maps

Over the years, various interactive tools have been developed to visualise the attention weights learn while training transformers on images or text. Here are a few examples that you can explore:

- [AttentionViz: A Global View of Transformer Attention](https://catherinesyeh.github.io/attn-docs/#view-info)
- ⭐ [Visualization of Self-Attention Maps in Vision](https://epfml.github.io/attention-cnn/) (my personal favourite tool)


<center width="100%" style="padding:25px"><img src="attention_maps.png" height="500px" width="800px"></center>


You will notice that some of the attention maps looks a bit like segmentation maps. This idea has inspired various models to achieve high segmentation accuracy without any groudtruth labels but only via self-supervision. One of the most notable model is [DINO: Emerging Properties in Self-Supervised Vision Transformers](https://arxiv.org/abs/2104.14294) by M.Caron et al 2021. See results below:

<center width="100%" style="padding:25px"><img src="attention_maps2.png" height="200px" width="800px"></center>

# 3. Vision Transformers in Medical Imaging

The Vision Transformer models have naturally been adapted and tested on medical dataset. They present a lot of potential as they seem to be able to inherently model long-range spatial dependencies. This is very valuable in medical imaging settings where, for instance tumours can be scattered over the entire scan or organs can have very particuler shape.
 This [survey paper](https://arxiv.org/pdf/2201.09873.pdf) reviews many deep learning papers that adapt vision transformers in the medical field, for task such assegmentation, classification, registration or anomaly detection. This Github repository [Awesome Medical Transformer](https://github.com/fahadshamshad/awesome-transformers-in-medical-imaging) also compiles the open-source codebase available of the latest transformer medical publications.


In [None]:
!pip install monai
import os
import matplotlib.pyplot as plt
import PIL
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from sklearn.metrics import classification_report

from monai.apps import download_and_extract
from monai.data import decollate_batch, DataLoader
from monai.metrics import ROCAUCMetric
from monai.transforms import (
    Activations,
    EnsureChannelFirst,
    AsDiscrete,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
)


directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = './data/'

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

## 3.1 MedNIST dataset
 Here we will use the popular **Monai** library to implement a image classification on the MedNIST dataset

In [None]:
class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))
num_class = len(class_names)
image_files = [
    [os.path.join(data_dir, class_names[i], x) for x in os.listdir(os.path.join(data_dir, class_names[i]))]
    for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
    image_files_list.extend(image_files[i])
    image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")

In [None]:
plt.subplots(3, 3, figsize=(8, 8))
for i, k in enumerate(np.random.randint(num_total, size=9)):
    im = PIL.Image.open(image_files_list[k])
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    plt.xlabel(class_names[image_class[k]])
    plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()

In [None]:
val_frac = 0.1
test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)

test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
train_indices = indices[val_split:]

train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]

print(f"Training count: {len(train_x)}, Validation count: " f"{len(val_x)}, Test count: {len(test_x)}")

In [None]:
train_transforms = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    ]
)

val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_class)])

In [None]:
class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]


train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)

val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)

test_ds = MedNISTDataset(test_x, test_y, val_transforms)
test_loader = DataLoader(test_ds, batch_size=300, num_workers=10)

#### **Exercise 8** Define ViT model and train

**Monai** has its own implementations of deep learning models such as Vision Transformers, to accommodate for medical images format and particularities.

- `Task 8.1`: Define the ViT model from the monai library with the correct parameters `in_channels`, `img_size`, `patch_size`

- `Task 8.2`: Train the ViT model model (it will take a couple of minutes)

- `Task 8.3`: Load the best model and evaluate its performance on the test dataset

In [None]:
# --------------------------------------------------task 8 ------------------------------------------------------------

from monai.networks.nets import ViT

# Task 8.1: instantiate the model here
model = ViT(in_channels=None,
            img_size=None,
            patch_size=None,
            hidden_size=192,
            mlp_dim=192*4,
            num_layers=12,
            num_heads=3,
            pos_embed='conv',
            proj_type='conv',
            pos_embed_type='learnable',
            classification=True,
            num_classes=num_class,
            dropout_rate=0.0,
            spatial_dims=2,
            post_activation=False,
            qkv_bias=False,
            save_attn=False).to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
max_epochs = 4
val_interval = 1
auc_metric = ROCAUCMetric()

## 3.2 Training vision transformer model on medical images



In [None]:
# Task 8.2: train the model

best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter()

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs[0], labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        epoch_len = len(train_ds) // train_loader.batch_size
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            for val_data in val_loader:
                val_images, val_labels = (
                    val_data[0].to(device),
                    val_data[1].to(device),
                )
                y_pred = torch.cat([y_pred, model(val_images)[0]], dim=0)
                y = torch.cat([y, val_labels], dim=0)
            y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)]
            y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
            auc_metric(y_pred_act, y_onehot)
            result = auc_metric.aggregate()
            auc_metric.reset()
            del y_pred_act, y_onehot
            metric_values.append(result)
            acc_value = torch.eq(y_pred.argmax(dim=1), y)
            acc_metric = acc_value.sum().item() / len(acc_value)
            if result > best_metric:
                best_metric = result
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current AUC: {result:.4f}"
                f" current accuracy: {acc_metric:.4f}"
                f" best AUC: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
            writer.add_scalar("val_accuracy", acc_metric, epoch + 1)

print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")
writer.close()

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val AUC")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

### Evaluate the model on test dataset


In [None]:
# Task 8.3: evaluate the model performances
model.load_state_dict(torch.load(None))
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = (
            test_data[0].to(device),
            test_data[1].to(device),
        )
        pred = model(test_images)[0].argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

In [None]:
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

# 4. Application: Surface Vision Transformers


In this last part of the tutorial, we will investigate extension of the Vision Transformer to non-Euclidean geometries. As you probably already understood, the transformer achitecture can be used in many different data domains. We often describe the transformer architecture as *agnostic* to the domain, as long as the input data can be represented as a sequence of tokens. [Dahan et al 2022](https://arxiv.org/abs/2203.16414) extented the vision transformer architecture to study cortical surfaces represented on regular icosehadron. This is achivied by patching sphericalised meshes using low resolution icospheral grids and then using a regular ViT to process the cortical patches. The model is named Surface Vision Trasnformer (SiT).

<center width="100%"><img src="https://github.com/metrics-lab/transformer-tutorial/blob/main/tutorial/sit_gif.gif?raw=1"  width="800px"></center>

First let's clone the code for the SiT model and install some necessary dependencies

In [None]:
!git clone https://github.com/metrics-lab/surface-vision-transformers.git
!pip install einops
!pip install vit-pytorch timm

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import os
sys.path.append('./surface-vision-transformers')
from models.sit import SiT

## 4.1 dHCP dataset

Here you will be downloading data from the dHCP dataset. The data has already been processed in order to use the SiT out-of-the-box. Four cortical metrics (myelin maps, cortical thickness, sulcal depth, curvature) are used and ico6 sphericalised meshes are patched using a ico2 sphericalised grid. This leads to a sequence of 320 non-overlapping patching of 153 vertices each (see illustration).

We will use data from the scan age experiment as per [A. Fawaz et al 2021](https://www.biorxiv.org/content/10.1101/2021.12.01.470730v1.full.pdf).

In [None]:
!gdown https://drive.google.com/uc?id=1DJhrERb1hk8Ekp_Cq2qwnxYvjq2nM0kc

!unzip -q dhcp_scan_age_template_processed.zip -d ./data/


In [None]:
train_data = np.load('./data/train_data.npy')
validation_data = np.load('./data/validation_data.npy')

train_labels = np.load('./data/train_labels.npy')
validation_labels = np.load('./data/validation_labels.npy')

print(train_data.shape, validation_data.shape, train_labels.shape, validation_labels.shape)

**Question: To what correspond each dimension of the train_data tensor?**

Answer:

- 1st dimension:  ?
- 2nd dimension:  ?
- 3rd dimension:  ?
- 4th dimension:  ?


In [None]:
batch_size = 8

train_data_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data).float(),
                                                    torch.from_numpy(train_labels).float())

train_loader = torch.utils.data.DataLoader(train_data_dataset,
                                                batch_size = batch_size,
                                                shuffle=True,
                                                num_workers=16)

val_data_dataset = torch.utils.data.TensorDataset(torch.from_numpy(validation_data).float(),
                                                torch.from_numpy(validation_labels).float())


val_loader = torch.utils.data.DataLoader(val_data_dataset,
                                        batch_size = batch_size,
                                        shuffle=False,
                                        num_workers=16)

## 4.2 SiT training

We will train the model for task of scan age (PMA) prediction. This is a regressiont task. Therefore, we will use the MSELoss to evaluate the model and the Adam optimiser.

#### **Exercise 9** Define the SiT model and train on PMA prediction

- `Task 9.1`: Using the dimension values from the previous questions, defined the SiT model correctly

- `Task 9.2`: Train the model and comment on its performance

In [None]:
# --------------------------------------------------task 9 ------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Task 9.1: instantiate the model here
sit_model = SiT(dim=192,
            depth=12,
            heads=3,
            mlp_dim=4*192,
            pool='cls',
            num_patches=None,
            num_classes=1,
            num_channels=None,
            num_vertices=None,
            dim_head=64,
            dropout=0.0,
            emb_dropout=0.0)

sit_model.to(device)

num_training_epochs=100
num_val_epochs=10
criterion = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(sit_model.parameters(), lr=0.0003, weight_decay=0.0)

In [None]:
# Task 9.2: Train the model here
best_mae = 100000000
mae_val_epoch = 100000000
running_val_loss = 100000000

for epoch in range(num_training_epochs):

    running_loss = 0

    sit_model.train()

    targets_ =  []
    preds_ = []

    for i, data in enumerate(train_loader):

        inputs, targets = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = sit_model(inputs)

        loss = criterion(outputs.squeeze(), targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        targets_.append(targets.cpu().numpy())
        preds_.append(outputs.reshape(-1).cpu().detach().numpy())

    mae_epoch = np.mean(np.abs(np.concatenate(targets_) - np.concatenate(preds_)))

    if (epoch+1)%5==0:
        print('| Epoch - {} | Loss - {:.4f} | MAE - {:.4f} | LR - {}'.format(epoch+1, running_loss/(i+1), round(mae_epoch,4), optimizer.param_groups[0]['lr']))

    ##############################
    ######    VALIDATION    ######
    ##############################

    if (epoch+1)%num_val_epochs==0:

        running_val_loss = 0

        sit_model.eval()

        with torch.no_grad():

            targets_ = []
            preds_ = []

            for i, data in enumerate(val_loader):

                inputs, targets = data[0].to(device), data[1].to(device)

                outputs = sit_model(inputs)

                loss = criterion(outputs.squeeze(), targets)

                running_val_loss += loss.item()

                targets_.append(targets.cpu().numpy())
                preds_.append(outputs.reshape(-1).cpu().numpy())


        mae_val_epoch = np.mean(np.abs(np.concatenate(targets_)- np.concatenate(preds_)))

        print('| Validation | Epoch - {} | Loss - {:.4f} | MAE - {:.4f} |'.format(epoch+1, running_val_loss, mae_val_epoch ))

        if mae_val_epoch < best_mae:
            best_mae = mae_val_epoch
            best_epoch = epoch+1

            df = pd.DataFrame()
            df['preds'] = np.concatenate(preds_).reshape(-1)
            df['targets'] = np.concatenate(targets_).reshape(-1)
            df.to_csv(os.path.join(CHECKPOINT_PATH, 'preds_test.csv'))

            torch.save(sit_model.state_dict(), os.path.join(CHECKPOINT_PATH,'checkpoint.pth'))


print('Final results: best model obtained at epoch {} - mean absolute error in weeks {}'.format(best_epoch,best_mae))


# END OF THE TUTORIAL