### Assignment Goals


1. Familiarity with the Vision Transformer architecture
2. Familiarity with the self-attention algorithm
3. Practice with PyTorch matrix operations



### Tasks
1. Implement multi-head self-attention
2. Incorporate that into a ViT

### Runtime Acceleration
Colab limits GPU usage, so set `device` below as `'cpu'` and change your runtime to CPU as well (Runtime > Change runtime type) when you're developing, and only change it to `'cuda'` (and your runtime to GPU) when you're ready to train.

In [1]:
#device = 'cpu'
device = 'cuda'

### Multi-head self-attention
Begin by implementing multiheaded self-attention. Do **not** use any `for` loops, and instead put all of the calculations into [batch matrix multiplications](https://pytorch.org/docs/stable/generated/torch.bmm.html) or [Linear layers](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).

Useful references include the lecture slides on transformers and ViTs, and the [illustrated transformer](https://jalammar.github.io/illustrated-transformer/) blog post.



In [2]:
import torch.nn.functional as F
from torch import nn
import torch

class MSA(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads):
    '''
    input_dim: Dimension of input token embeddings
    embed_dim: Dimension of internal key, query, and value embeddings
    num_heads: Number of self-attention heads
    '''

    super().__init__()

    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    self.K_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.Q_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.V_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.out_embed = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self, x):
    '''
    x: input of shape (batch_size, max_length, input_dim)
    return: output of shape (batch_size, max_length, embed_dim)
    '''

    batch_size, max_length, given_input_dim = x.shape
    assert given_input_dim == self.input_dim
    assert max_length % self.num_heads == 0

    # You shouldn't need to initialize any new modules. Everything you need is
    # already in __init__

    # HINT: If you're stuck on how to handle multiple heads without for loops, try to
    # reshape matrix such that the batch_size is num_heads * batch_size
    # e.g. if you have two heads, you'd be doing self-attention twice per instance
    # in the batch, so you essentially have batch_size * 2

    # HINT 2: Feel free to reference: https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
    # although make sure you understand what each command does

    # this implementation projects KQV before splitting into multiple heads
    # but you can also split into multiple heads first

    # compute KQV as a whole, embedding and
    x = x.reshape(batch_size * max_length, -1)
    K = self.K_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)
    # TODO: Compute Q
    Q = self.Q_embed(x).reshape(batch_size, max_length, self.embed_dim)
    # TODO: Compute V
    V = self.V_embed(x).reshape(batch_size, max_length, self.embed_dim)

    # TODO: split each KQV into heads, by reshaping each into (batch_size, max_length, self.num_heads, indiv_dim)
    indiv_dim = self.embed_dim // self.num_heads
    K = K.view(batch_size, self.num_heads, max_length, indiv_dim)
    Q = Q.view(batch_size, self.num_heads, max_length, indiv_dim)
    V = V.view(batch_size, self.num_heads, max_length, indiv_dim)

    #K = K.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)
    #Q = Q.permute(0, 2, 1, 3)
    #V = V.permute(0, 2, 1, 3)

    K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

    # transpose and batch matrix multiply
    # This is our K transposed so we can do a simple batched matrix multiplication (see torch.bmm for more details and the quick solution)
    K_T = K.permute(0, 2, 1)
    # TODO: Compute the weights before dividing by square root of d (batch_size * num_heads, max_length, max_length)
    QK = torch.bmm(Q, K_T)

    # calculate weights by dividing everything by the square root of d (self.embed_dim)
    weights = QK / (self.embed_dim ** 0.5)
    weights = F.softmax(weights, dim=-1)
    # TODO Take the softmax over the last dimension (see torch.functional.Softmax) (batch_size * num_heads, max_length, max_length)

    # TODO get weighted average... see torch.bmm for a one line solution
    # weights is (batch_size * num_heads, max_length, max_length) and V is (batch_size * self.num_heads, max_length, indiv_dim)
    # so we want the matrix multiplication of weights and V
    w_V = torch.bmm(weights, V)

    # rejoin heads
    w_V = w_V.reshape(batch_size, self.num_heads, max_length, indiv_dim)
    w_V = w_V.permute(0, 2, 1, 3) # (batch_size, max_length, num_heads, embed_dim / num_heads)
    w_V = w_V.reshape(batch_size, max_length, self.embed_dim)

    out = self.out_embed(w_V)

    return out

### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [3]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    '''
    x: input embeddings (batch_size, max_length, input_dim)
    return: output embeddings (batch_size, max_length, embed_dim)
    '''

    # TODO: Fill in the code for the forward pass below
    # You shouldn't need to initialize any more modules, everything you need is already
    # in __init__
    # A forward function consists of:
    # 1) LayerNorm of x
    # 2) Self-Attention on output of 1)
    # 3) Dropout
    # 4) Residual w/ original x
    # 5) LayerNorm
    # 6) MLP
    # 7) Residual
    norm_x = self.layernorm1(x)
    msa_out = self.msa(norm_x)
    msa_out = self.w_o_dropout(msa_out)
    msa_out = x + msa_out
    norm_x = self.layernorm2(msa_out)
    mlp_out = self.mlp(norm_x)
    out = msa_out + mlp_out

    return out


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [4]:
class ViT(nn.Module):
    def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
        '''
        patch_dim: patch length and width to split image by
        image_dim: image length and width
        num_layers: number of layers in network
        num_heads: number of heads for multi-head attention
        embed_dim: dimension to project images patches to and dimension to use for position embeddings
        mlp_hidden_dim: hidden dimension of linear layer
        num_classes: number of classes to classify in data
        dropout: dropout rate
        '''

        super().__init__()
        self.num_layers = num_layers
        self.patch_dim = patch_dim
        self.image_dim = image_dim
        self.input_dim = self.patch_dim * self.patch_dim * 3
        self.num_heads = num_heads

        self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.embedding_dropout = nn.Dropout(dropout)

        self.encoder_layers = nn.ModuleList([])
        for i in range(num_layers):
            self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

        self.mlp_head = nn.Linear(embed_dim, num_classes)
        self.layernorm = nn.LayerNorm(embed_dim)

    def forward(self, images):
        '''
        images: raw image data (batch_size, channels, rows, cols)
        '''

        # Don't hardcode dimensions (except for maybe channels = 3), use the variables in __init__.
        # You shouldn't need to add anything else to __init__, all of the embeddings,
        # dropout etc. are already initialized for you.

        # Put the preprocessed patches in variable "out" with shape (batch_size, length, embed_dim).

        # HINT: You can make image patches with .reshape
        # e.g.
        # x = torch.ones((100, 100))
        # x_patches = x.reshape(4, 25, 4, 25)
        # where you have 4 * 4 patches with each patch being 25 by 25

        h = w = self.image_dim // self.patch_dim
        N = images.size(0)
        images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
        images = torch.einsum("nchpwq -> nhwpqc", images)
        patches = images.reshape(N, h * w, self.input_dim) # (batch, num_patches_per_image, patch_size_unrolled)

        patch_embeddings = self.patch_embedding(patches)
        # TODO: Pass through our patch embedding layer
        patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)), patch_embeddings], dim=1)
        out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1)) # We add positional embeddings to our tokens (not concatenated)
        out = self.embedding_dropout(out)
        # TODO: Pass through our embedding dropout layer

        # add padding s.t. input length is multiple of num_heads
        add_len = (self.num_heads - out.shape[1]) % self.num_heads
        out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=device)], dim=1)

        # TODO: Pass through each one of our encoder layers
        for layer in self.encoder_layers:
            out = layer(out)
        # Pop off and read our classification token we added, see what the value is
        cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
        logits = self.mlp_head(cls_head)
        return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
               embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

def get_vit_base(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=12,
               embed_dim=768, mlp_hidden_dim=3072, num_classes=num_classes, dropout=0.1)

Now let's train the model! You don't need to write any code for this - just run the cell.

Remember to change the device variable (in the cell at the beginning of the notebook) to 'cuda' and change your runtime to GPU (Runtime > Change runtime type) as well. For reference, each epoch in the staff solution takes ~3 minutes (so training for 30 epochs will take ~1.5 hours on the Colab GPU; we know this is a long training session)

Try to get 65%+ accuracy after 30 epochs.

In [6]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision
import math
import torch.optim as optim
from tqdm.notebook import tqdm

data_root = './data/cifar10'
train_size = 400
val_size = 100

batch_size = 32

transform_train = T.Compose([
    T.Resize(40),
    T.RandomCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.95, 1.05)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

transform_val = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_train,
)

val_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_val,
)

from torch.utils.data import sampler

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          sampler=sampler.SubsetRandomSampler(range(train_size)))

val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        sampler=sampler.SubsetRandomSampler(range(train_size, 50000)))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

vit = get_vit_small().to(device)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
weight_decay = 0.1

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()
    for inputs, labels in tqdm(train_loader):
        """TODO:
        1. Set inputs and labels to be on device
        2. zero out our gradients
        3. pass our inputs through the ViT
        4. pass our outputs / labels into our loss / criterion
        5. backpropagate
        6. step our optimizeer
        """
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = vit(inputs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
        train_total += inputs.shape[0]
    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    val_loss = 0.0
    val_acc = 0.0
    val_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = vit(inputs)
            loss = criterion(outputs, labels.long())

            val_loss += loss.item() * inputs.shape[0]
            val_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            val_total += inputs.shape[0]
    val_loss = val_loss / val_total
    val_acc = val_acc / val_total
    val_losses.append(val_loss)

    val_accuracies.append(val_acc)
    if val_acc >= max(val_accuracies):
        torch.save(vit.state_dict(), 'best_model.pth')

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | val loss: {val_loss:.3f} | val accuracy: {val_acc:.3f}')

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/13 [00:00<?, ?it/s]

[ 1] train loss: 2.362 | train accuracy: 0.160 | val loss: 2.287 | val accuracy: 0.226


  0%|          | 0/13 [00:00<?, ?it/s]

[ 2] train loss: 2.137 | train accuracy: 0.215 | val loss: 2.183 | val accuracy: 0.227


  0%|          | 0/13 [00:00<?, ?it/s]

[ 3] train loss: 2.035 | train accuracy: 0.268 | val loss: 2.144 | val accuracy: 0.247


  0%|          | 0/13 [00:00<?, ?it/s]

[ 4] train loss: 1.965 | train accuracy: 0.282 | val loss: 2.111 | val accuracy: 0.269


  0%|          | 0/13 [00:00<?, ?it/s]

[ 5] train loss: 1.929 | train accuracy: 0.282 | val loss: 2.107 | val accuracy: 0.244


  0%|          | 0/13 [00:00<?, ?it/s]

[ 6] train loss: 1.870 | train accuracy: 0.302 | val loss: 2.089 | val accuracy: 0.249


  0%|          | 0/13 [00:00<?, ?it/s]

[ 7] train loss: 1.917 | train accuracy: 0.300 | val loss: 2.096 | val accuracy: 0.275


  0%|          | 0/13 [00:00<?, ?it/s]

[ 8] train loss: 1.895 | train accuracy: 0.275 | val loss: 2.106 | val accuracy: 0.272


  0%|          | 0/13 [00:00<?, ?it/s]

[ 9] train loss: 1.779 | train accuracy: 0.372 | val loss: 2.126 | val accuracy: 0.264


  0%|          | 0/13 [00:00<?, ?it/s]

[10] train loss: 1.759 | train accuracy: 0.355 | val loss: 2.088 | val accuracy: 0.268


  0%|          | 0/13 [00:00<?, ?it/s]

[11] train loss: 1.769 | train accuracy: 0.338 | val loss: 2.134 | val accuracy: 0.265


  0%|          | 0/13 [00:00<?, ?it/s]

[12] train loss: 1.766 | train accuracy: 0.357 | val loss: 2.132 | val accuracy: 0.267


  0%|          | 0/13 [00:00<?, ?it/s]

[13] train loss: 1.753 | train accuracy: 0.343 | val loss: 2.164 | val accuracy: 0.263


  0%|          | 0/13 [00:00<?, ?it/s]

[14] train loss: 1.713 | train accuracy: 0.378 | val loss: 2.104 | val accuracy: 0.271


  0%|          | 0/13 [00:00<?, ?it/s]

[15] train loss: 1.695 | train accuracy: 0.385 | val loss: 2.102 | val accuracy: 0.287


  0%|          | 0/13 [00:00<?, ?it/s]

[16] train loss: 1.712 | train accuracy: 0.370 | val loss: 2.164 | val accuracy: 0.269


  0%|          | 0/13 [00:00<?, ?it/s]

[17] train loss: 1.695 | train accuracy: 0.375 | val loss: 2.107 | val accuracy: 0.275


  0%|          | 0/13 [00:00<?, ?it/s]

[18] train loss: 1.679 | train accuracy: 0.360 | val loss: 2.117 | val accuracy: 0.285


  0%|          | 0/13 [00:00<?, ?it/s]

[19] train loss: 1.645 | train accuracy: 0.393 | val loss: 2.132 | val accuracy: 0.280


  0%|          | 0/13 [00:00<?, ?it/s]

[20] train loss: 1.609 | train accuracy: 0.422 | val loss: 2.137 | val accuracy: 0.281


  0%|          | 0/13 [00:00<?, ?it/s]

[21] train loss: 1.584 | train accuracy: 0.427 | val loss: 2.133 | val accuracy: 0.281


  0%|          | 0/13 [00:00<?, ?it/s]

[22] train loss: 1.578 | train accuracy: 0.432 | val loss: 2.177 | val accuracy: 0.275


  0%|          | 0/13 [00:00<?, ?it/s]

[23] train loss: 1.544 | train accuracy: 0.412 | val loss: 2.171 | val accuracy: 0.281


  0%|          | 0/13 [00:00<?, ?it/s]

[24] train loss: 1.547 | train accuracy: 0.435 | val loss: 2.114 | val accuracy: 0.297


  0%|          | 0/13 [00:00<?, ?it/s]

[25] train loss: 1.539 | train accuracy: 0.460 | val loss: 2.157 | val accuracy: 0.297


  0%|          | 0/13 [00:00<?, ?it/s]

[26] train loss: 1.580 | train accuracy: 0.448 | val loss: 2.139 | val accuracy: 0.294


  0%|          | 0/13 [00:00<?, ?it/s]

[27] train loss: 1.510 | train accuracy: 0.410 | val loss: 2.115 | val accuracy: 0.288


  0%|          | 0/13 [00:00<?, ?it/s]

[28] train loss: 1.509 | train accuracy: 0.443 | val loss: 2.222 | val accuracy: 0.289


  0%|          | 0/13 [00:00<?, ?it/s]

[29] train loss: 1.468 | train accuracy: 0.487 | val loss: 2.173 | val accuracy: 0.293


  0%|          | 0/13 [00:00<?, ?it/s]

[30] train loss: 1.460 | train accuracy: 0.487 | val loss: 2.235 | val accuracy: 0.293
Finished Training
