<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Transformer_scale_invariant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Install TIMM library

In [1]:
!pip -q install timm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[?25h

#ViT with fixed image size:224x224

In [4]:
from timm import create_model
import torch
from torch import nn
from PIL import Image
import requests
from torchvision.models import resnet18, resnet34, resnet101
from torchvision import transforms

img_url = 'https://www.animalfunfacts.net/images/stories/pets/dogs/pembroke_welsh_corgi_l.jpg'
img_raw = Image.open(requests.get(img_url, stream=True).raw)
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=mean, std=std)])
img = transform(img_raw)[None]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = create_model("vit_base_patch16_224", pretrained=True).to(device)
patches = model.patch_embed(img)
print('patches:', patches.shape)

print('model.cls_token', model.cls_token.shape)
print('model.pos_embed', model.pos_embed.shape)
transformer_input = torch.cat((model.cls_token, patches), dim=1) + model.pos_embed
print("Input tensor to Transformer (z0): ", transformer_input.shape)
x = transformer_input.clone()
for i, blk in enumerate(model.blocks):
    print("Entering the Transformer Encoder {}, input:{}".format(i, x.shape))
    x = blk(x)
x = model.norm(x)
transformer_output = x[:, 0]
print("Output vector from Transformer (z12-0):", transformer_output.shape)

#then use any classification head
num_labels=10
cls_head = nn.Linear(768, num_labels)
logits = cls_head(transformer_output)
print(logits.shape)

patches: torch.Size([1, 196, 768])
model.cls_token torch.Size([1, 1, 768])
model.pos_embed torch.Size([1, 197, 768])
Input tensor to Transformer (z0):  torch.Size([1, 197, 768])
Entering the Transformer Encoder 0, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 1, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 2, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 3, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 4, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 5, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 6, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 7, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 8, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 9, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 10, input:torch.Size([1, 197, 768])
Entering the Transformer Encoder 11, input:torch.Size([1, 197, 768])
Outp

#ViT with Invariant Image Scale

In [10]:
from timm import create_model
import torch
from torch import nn
from PIL import Image
import requests
from torchvision.models import resnet18, resnet34, resnet101
from torchvision import transforms

img_url = 'https://www.animalfunfacts.net/images/stories/pets/dogs/pembroke_welsh_corgi_l.jpg'
img_raw = Image.open(requests.get(img_url, stream=True).raw)
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((224, 324)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=mean, std=std)])
img = transform(img_raw)[None]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = create_model("vit_base_patch16_224", pretrained=True).to(device)
model.patch_embed.strict_img_size = False
model.patch_embed.dynamic_img_pad = True
#If it is necessary to change the model.embed_dim=768
# model.embed_dim=798
patches = model.patch_embed(img)
print('patches:', patches.shape)

num_patches = patches.shape[1]
embed_len = num_patches + model.num_prefix_tokens

#If it is necessary to change the model.embed_dim=768
embed_dim = model.embed_dim
model.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)

print('model.cls_token', model.cls_token.shape)
print('model.pos_embed', model.pos_embed.shape)
transformer_input = torch.cat((model.cls_token, patches), dim=1) + model.pos_embed
print("Input tensor to Transformer (z0): ", transformer_input.shape)
x = transformer_input.clone()
for i, blk in enumerate(model.blocks):
    print("Entering the Transformer Encoder {}, input:{}".format(i, x.shape))
    x = blk(x)
x = model.norm(x)
transformer_output = x[:, 0]
print("Output vector from Transformer (z12-0):", transformer_output.shape)

#then use any classification head
# num_labels=10
# cls_head = nn.Linear(768, num_labels)
# logits = cls_head(transformer_output)
# print(logits.shape)

patches: torch.Size([1, 294, 768])
model.cls_token torch.Size([1, 1, 768])
model.pos_embed torch.Size([1, 295, 768])
Input tensor to Transformer (z0):  torch.Size([1, 295, 768])
Entering the Transformer Encoder 0, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 1, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 2, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 3, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 4, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 5, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 6, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 7, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 8, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 9, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 10, input:torch.Size([1, 295, 768])
Entering the Transformer Encoder 11, input:torch.Size([1, 295, 768])
Outp