In [1]:
import sys
sys.path.insert(0, '../')  # Add the parent directory to sys.path
import os
import requests
from pathlib import Path
from PIL import Image
from io import BytesIO
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
import torchvision
import torchvision.transforms as transforms
from train import create_dataloaders, train

  from .autonotebook import tqdm as notebook_tqdm


## Step 0: Download image 

In [2]:
# URL of a random image
url = 'https://lmb.informatik.uni-freiburg.de/people/dosovits/Dosovitskiy_photo.JPG'

# Download the image
response = requests.get(url)
img = Image.open(BytesIO(response.content))

# Resize and convert to tensor
img_transform = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
])
img_tensor = img_transform(img)
img_tensor.shape

torch.Size([3, 224, 224])

# Obtain a 1D sequence from a 2D image for input to the ViT encoder.

## Step 1: Split image tensor into patches using CNN 

In [3]:
patch_size = 16

# Hidden size D (dimensions)
vector_size = 768

patcher = nn.Conv2d(in_channels=3, out_channels=vector_size,
                    kernel_size=patch_size, stride=patch_size)
patched_img = patcher(img_tensor)
patched_img.shape

torch.Size([768, 14, 14])

## Step 2: Flatten 2d patches and permute dimensions

In [4]:
flatter = nn.Flatten(start_dim=1)

# Changing dimensions due to input requirements of MSA (multi-head self attention)
flattened_patches = flatter(patched_img).permute(1,0)
flattened_patches.shape

torch.Size([196, 768])

## Step 3: Prepend extra learnable [class] embedding

In [5]:
class_token = nn.Parameter(torch.randn(1, 768))
flattened_patches = torch.cat((class_token, flattened_patches))
flattened_patches.shape

torch.Size([197, 768])

## Step 4: Add position embeddings

In [6]:
pos_embs = nn.Parameter(torch.randn(flattened_patches.shape))
encoder_input =  pos_embs + flattened_patches

dropout_rate = 0.1
embedding_dropout = nn.Dropout(dropout_rate)
encoder_input = embedding_dropout(encoder_input)

print(f'First 3 dimensions of first patch of flattened_patches, pos_embs and their sum\
\n\n {flattened_patches[0][:3]}\n {pos_embs[0][:3]}\n {encoder_input[0][:3]}')

First 3 dimensions of first patch of flattened_patches, pos_embs and their sum

 tensor([-0.2942, -0.1906, -0.5862], grad_fn=<SliceBackward0>)
 tensor([-0.6299,  0.4337,  0.7066], grad_fn=<SliceBackward0>)
 tensor([-1.0267,  0.2700,  0.1338], grad_fn=<SliceBackward0>)


# Create ViT encoder 

## Step 5: Layer Normalization, Multi-head Self-Attention and Residual connection

In [7]:
LN1 = nn.LayerNorm(vector_size)
normalized_vecs = LN1(encoder_input).unsqueeze(0)
print(f'Normalized embeddings:\n    {normalized_vecs[:,0,:3]}')

MSA = nn.MultiheadAttention(batch_first=True, num_heads=12, embed_dim=vector_size)
contextualized_embs, _ = MSA(query=normalized_vecs, key=normalized_vecs, 
                            value=normalized_vecs, need_weights=False)
print(f'Contextualized embeddings:\n    {contextualized_embs[:,0,:3]}')

intermediate_embeddings = encoder_input + contextualized_embs
print(f'Embeddings after first skip or residual connection:\n   {intermediate_embeddings[:,0,:3]}')

Normalized embeddings:
    tensor([[-0.7083,  0.1522,  0.0618]], grad_fn=<SliceBackward0>)
Contextualized embeddings:
    tensor([[ 0.0124, -0.0056,  0.1641]], grad_fn=<SliceBackward0>)
Embeddings after first skip or residual connection:
   tensor([[-1.0144,  0.2644,  0.2979]], grad_fn=<SliceBackward0>)


## Step 6: Layer Normalization, Multi Layer Perceptron and Residual connection

In [8]:
LN2 = nn.LayerNorm(vector_size)
normalized_vecs = LN2(intermediate_embeddings)

class MLP_block(nn.Module):
    def __init__(self, MLP_size, dropout_rate):
        super().__init__()
        self.fc1 = nn.Linear(in_features=vector_size, out_features=MLP_size)
        self.gelu = nn.GELU('tanh')
        self.fc2 = nn.Linear(in_features=MLP_size, out_features=vector_size)

        # Applied after each dense layer
        self.dropout = nn.Dropout(p=dropout_rate)
    
    def forward(self, x):
        x = self.dropout(self.fc1(x))
        x = self.gelu(x)
        x = self.dropout(self.fc2(x))
        return x

MLP = MLP_block(MLP_size=3072, dropout_rate=dropout_rate)
mlp_embeddings = MLP(normalized_vecs)
encoder_out_embeddings = intermediate_embeddings + mlp_embeddings
encoder_out_embeddings.shape

torch.Size([1, 197, 768])

# Create classification head

## Step 7: Adjustable classifier

In [9]:
classes = ['sausage', 'not a sausage']

LNf = nn.LayerNorm(vector_size)
normalized_embs = LNf(encoder_out_embeddings)

# Single linear layer used during fine-tuning (while pre-training was used MLP with one hidden layer)
classifier = nn.Linear(in_features=vector_size, out_features=len(classes))

# Put [class] token in classifier head
logits = classifier(normalized_embs[:,0])
logits

tensor([[-0.4039,  0.1328]], grad_fn=<AddmmBackward0>)

# Create ViT class (for fine-tuning) using steps

In [10]:
class ViTEncoder(nn.Module):
    def __init__(self, MLP_size, num_heads, vector_size, dropout_rate):
        super().__init__()
        self.LN1 = nn.LayerNorm(vector_size)
        self.MSA = nn.MultiheadAttention(batch_first=True, num_heads=num_heads, embed_dim=vector_size)
        self.LN2 = nn.LayerNorm(vector_size)
        self.mlp = MLP_block(MLP_size, dropout_rate=dropout_rate)
        

    def forward(self, x):
        normalized_vecs = self.LN1(x)
        contextualized_embs, _ = self.MSA(query=normalized_vecs, key=normalized_vecs,
                                            value=normalized_vecs, need_weights=False)
        intermediate_embeddings = x + contextualized_embs
        normalized_vecs = self.LN2(intermediate_embeddings)
        x = self.mlp(normalized_vecs)
        return x


class ViTbasehybrid(nn.Module):
    def __init__(self, num_labels:int, img_size:tuple[int, int], patch_size:int=16, 
                    dropout_rate:float=0.1, vector_size:int=768, num_heads:int=12, 
                    num_transformer_layers:int=12, MLP_size:int=3072):
        super().__init__()

        # Obtain 1d seq from 2d images
        self.patcher = nn.Conv2d(in_channels=3, out_channels=vector_size,
                    kernel_size=patch_size, stride=patch_size)
        self.flatter = nn.Flatten(start_dim=2) #start_dim=2 because we added batch_size dim
        self.class_token = nn.Parameter(torch.randn(1, 1, vector_size), requires_grad=True)
        assert (img_size[0]*img_size[1]) % patch_size == 0, 'Img size must be divisible by patch_size!'
        number_of_patches = int((img_size[0]*img_size[1]) / patch_size**2)
        self.pos_embs = nn.Parameter(torch.randn(1, number_of_patches+1, vector_size), requires_grad=True)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Create encoder blocks
        self.transformer_encoder = nn.Sequential(*[ViTEncoder(MLP_size,num_heads,
                                                              vector_size,dropout_rate) 
                                                    for layer in range(num_transformer_layers)])

        # Classifier
        self.classifier = nn.Sequential(nn.LayerNorm(vector_size),
                         nn.Linear(in_features=vector_size, out_features=num_labels))


    def forward(self, x):
        batch_size = x.shape[0]

        x = self.patcher(x)
        x = self.flatter(x).permute(0,2,1)

        # Making class token suitable for current batch_size
        cls_token = self.class_token.expand(batch_size, -1, -1)
        
        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_embs + x
        x = self.dropout(x)
        x = self.transformer_encoder(x)
        x = self.classifier(x[:, 0])
        return x 

In [11]:
from torchinfo import summary

# Setting 1000 images as in ImageNet
custom_vit = ViTbasehybrid(num_labels=1000, img_size=tuple(img_tensor.shape[1:]))

summary(model=custom_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ViTbasehybrid (ViTbasehybrid)            [32, 3, 224, 224]    [32, 1000]           152,064              True
├─Conv2d (patcher)                       [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
├─Flatten (flatter)                      [32, 768, 14, 14]    [32, 768, 196]       --                   --
├─Dropout (dropout)                      [32, 197, 768]       [32, 197, 768]       --                   --
├─Sequential (transformer_encoder)       [32, 197, 768]       [32, 197, 768]       --                   True
│    └─ViTEncoder (0)                    [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─LayerNorm (LN1)              [32, 197, 768]       [32, 197, 768]       1,536                True
│    │    └─MultiheadAttention (MSA)     --                   [32, 197, 768]       2,362,368            True
│    │    └─LayerN

## As it can be seen in output the "Total Parameters" amount equal to 86,567,656. 
## Same as in torchvision.models.vit_b_16()
## The model is successfuly replicated!

In [12]:
summary(model=torchvision.models.vit_b_16(),
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 1000]           768                  True
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              True
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       7,087,872            True
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 197, 76

# Check the model

In [13]:
# Prepare dataloaders
data_dir = '../' / Path('CV_test_data/')
train_threshold = 0.8
batch_size = 32
transform = transforms.Compose([
    transforms.Resize(size=(384,384)),
    transforms.ToTensor()
])

train_dataloader, test_dataloader = create_dataloaders(data_dir, train_threshold, batch_size, transform)

# Fine-tune the model (using parameters for fine-tuning mentioned in Research Paper)
## Except for batch_size and used Polyak & Juditsky (1992) averaging
num_labels = len([folder for folder in os.listdir(data_dir) 
                                if (data_dir / folder).is_dir()])
test_vit = ViTbasehybrid(num_labels=num_labels, img_size=tuple(img_tensor.shape[1:]))
loss_fn = nn.CrossEntropyLoss()

initial_lr = 0.001
optimizer = optim.SGD(params=custom_vit.classifier.parameters(), lr=initial_lr,
                        weight_decay=0, momentum=0.9)
def lr_lambda(step):
    if step < 125:
        return 0.001 / initial_lr
    elif step < 250:
        return 0.003 / initial_lr
    elif step < 375:
        return 0.01 / initial_lr
    else:
        return 0.03 / initial_lr

warmup_scheduler = LambdaLR(optimizer, lr_lambda)
decay_scheduler = CosineAnnealingLR(optimizer, T_max=100)
# Combine schedulers using SequentialLR
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[500])


device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
dropout_rate = 0.1
epochs = 3
max_norm = 1.0

train(test_vit,train_dataloader,test_dataloader,
        loss_fn,optimizer,device,batch_size,dropout_rate,epochs,
        max_norm=max_norm,scheduler=scheduler)

  0%|          | 0/3 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (197) must match the size of tensor b (577) at non-singleton dimension 1