In [1]:
import os
import random
import sys

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, progress
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

In [2]:
### Set seed for reproducibility
np.random.seed(123)
random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x10c3e8370>

# Transformers

"Attention is All You Need"

<img src="imgs/transformer.png" style="height:600px" class="center" alt="transformer"/><br>

Involves no recurrence. Instead, uses multi-head self-attention. This is a significant improvement in the basic attention mechanism we saw with RNNs.

Generally, the attention mechanism focuses on learning 3 components: the "Query Vector" (Q; $d_{k}$), "Key Vector" (K; $d_{k}$), and the "Value Vector" (V; $d_{v}$). These represent the word that is being considered, the other words that will be "attended to," and the information learned about all of the words, respectively. In the context of encoder-decoder attention, the queries come from the decoder, while the keys and values come from the encoder.

Adding multiple attention mechanisms, i.e., multiple learnable query, key, and value matrices, can improve performance in the same way that collaborating with a team across fields can be useful. This is "multi-head attention." While this embeds each element into a smaller dimensional space (by a factor of the number of heads) than single-head attention, the use of multiple spaces allows the model to capture diversity more effectively.

<img src="imgs/multihead_attention.png" style="height:500px" class="center" alt="mh_attention"/><br>

By using the context vector along with the hidden state, we add the global context of the current state with respect to the entire past, thereby allowing information to flow across time or be ignored.

Another great aspect of attention is that it is easily visualized, so you can see where the network is focusing on.

Simply put, self-attention looks at how the *inputs* interact with one another rather than how the hidden states evolve over time. This gives the model much more information about the context of an input vector (e.g., word) itself rather than the hidden state it results in, giving it the ability to dynamically focus on different parts of the input sequence. This reduces the "path length" between words, allowing information to flow across long ranges in the sequence. It also comes with the benefit that this can be done completely in parallel, which provides a huge speedup. ($O(1)$!!!)

$$\rm{Attention(Q, K, V)} = \rm{softmax}\left( \frac{Q K^{T}}{\sqrt{d_{k}}}\right)V$$

However, this all comes with the caveat that now we lose a lot of the positional information that RNNs easily capture. This leads to the idea of "positional encoding" wherein explicit information about the position is added to each embedded word. In the original transformer:

$$PE(pos, 2i) = sin(pos/10000^{2i/d_{model}})$$
$$PE(pos, 2i+1) = cos(pos/10000^{2i/d_{model}})$$

where $pos$ is the position, $i$ is the diemsnsion of the embedding and $d_{model}$ is the output dimension. This can be made more complex and/or be learned.

All of these improvements now make transformers the go-to natural language processing algorithm.

# Vision Transformers (ViT)

Based on the realization that regions of an image ("patches") can replace typical embeddings in a transformer. This makes the image a sequence of patches ("An Image is Worth 16x16 Words"). These patches are then projected into a higher-dimensional space using a learnable network. The network then proceeds like a normal transformer.

<img src="imgs/vit.png" style="height:500px" class="center" alt="vit"/><br>

Similar to attention is traditional sequences, the attention can be visualized across the image.

<img src="imgs/vit_attention.png" style="height:600px" class="center" alt="vit_attention"/><br>

Positional encoding is still necessary, but the actual form of it is relatively irrelevant. Often, the exact positional encoding is learned.

<img src="imgs/vit_pos_enc.png" style="height:600px" class="center" alt="vit_pos_enc"/><br>




This removes convolutions from the process entirely. It is *not a convolutional neural network*. It is an entirely different type of architecture. When trained on a sufficient amount of data, it can outperform CNNs in many contexts. More recently, there have been efforts to more effectively join the two paradigms.



__Homework__: apply ViT to Fashion MNIST. Compare with CNN.