In [1]:
import requests
from bs4 import BeautifulSoup
from PIL import Image
from io import BytesIO
import torch
from torch import nn
import torchvision.transforms as transforms

## Step 0: Download image 

In [4]:
import requests
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms

# URL of a random image
url = 'https://lmb.informatik.uni-freiburg.de/people/dosovits/Dosovitskiy_photo.JPG'

# Download the image
response = requests.get(url)
img = Image.open(BytesIO(response.content))

# Resize and convert to tensor
img_transform = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
])
img_tensor = img_transform(img)
img_tensor.shape

torch.Size([3, 224, 224])

# Obtain a 1D sequence from a 2D image for input to the ViT encoder.

## Step 1: Split image tensor into patches using CNN 

In [21]:
patch_size = 16

# Hidden size D (dimensions)
vector_size = 768

patcher = nn.Conv2d(in_channels=3, out_channels=vector_size,
                    kernel_size=patch_size, stride=patch_size)
patched_img = patcher(img_tensor)
patched_img.shape

torch.Size([768, 14, 14])

## Step 2: Flatten 2d patches and permute dimensions

In [124]:
flatter = nn.Flatten(start_dim=1)

# Changing dimensions due to input requirements of MSA (multi-head self attention)
flattened_patches = flatter(patched_img).permute(1,0)
flattened_patches.shape

torch.Size([196, 768])

## Step 3: Prepend extra learnable [class] embedding

In [125]:
class_token = nn.Parameter(torch.randn(1, 768))
flattened_patches = torch.cat((class_token, flattened_patches))
flattened_patches.shape

torch.Size([197, 768])

## Step 4: Add position embeddings

In [152]:
pos_embs = nn.Parameter(torch.randn(flattened_patches.shape))
encoder_input =  pos_embs + flattened_patches

dropout_rate = 0.1
embedding_dropout = nn.Dropout(dropout_rate)
encoder_input = embedding_dropout(encoder_input)

print(f'First 3 dimensions of first patch of flattened_patches, pos_embs and their sum\
\n\n {flattened_patches[0][:3]}\n {pos_embs[0][:3]}\n {encoder_input[0][:3]}')

First 3 dimensions of first patch of flattened_patches, pos_embs and their sum

 tensor([1.1299, 0.9617, 0.6921], grad_fn=<SliceBackward0>)
 tensor([ 1.6853, -0.3063,  0.5664], grad_fn=<SliceBackward0>)
 tensor([3.1281, 0.7282, 1.3983], grad_fn=<SliceBackward0>)


# Create ViT encoder 

## Step 5: Layer Normalization, Multi-head Self-Attention and Residual connection

In [164]:
LN1 = nn.LayerNorm(vector_size)
normalized_vecs = LN1(encoder_input)
print(f'Normalized embeddings:\n    {normalized_vecs[0][:3]}')

MSA = nn.MultiheadAttention(batch_first=True, num_heads=12, embed_dim=vector_size)
contextualized_embs, _ = MSA(query=normalized_vecs, key=normalized_vecs, 
                            value=normalized_vecs, need_weights=False)
print(f'Contextualized embeddings:\n    {contextualized_embs[0][:3]}')

intermediate_embeddings = encoder_input + contextualized_embs
print(f'Embeddings after first skip or residual connection:\n   {intermediate_embeddings[0][:3]}')

Normalized embeddings:
    tensor([2.0638, 0.4959, 0.9337], grad_fn=<SliceBackward0>)
Contextualized embeddings:
    tensor([ 0.0149, -0.0182,  0.0260], grad_fn=<SliceBackward0>)
Embeddings after first skip or residual connection:
   tensor([3.1429, 0.7100, 1.4243], grad_fn=<SliceBackward0>)


## Step 6: Layer Normalization, Multi Layer Perceptron and Residual connection

In [None]:
intermediate_embeddings = 