### Dependencies

In [19]:
import torch
import torchvision
from torchvision import transforms
print('torchvision version:', torchvision.__version__) # Needs at least >= 0.8.0 to do cropping on tensors

from vision_transformer import vit_small
from vision_transformer4k import vit4k_xs
from main_dino4k import DataAugmentationDINO

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

torchvision version: 0.8.0


### ViT-16 Model

In [20]:
model = vit_small()
print("Num Parameters:", count_parameters(model))

x = torch.randn(4, 3, 256, 256)
print("1. Input Shape:", x.shape)
out = model(x)
print("2. Output Shape:", out.shape)

Num Parameters: 21665664
1. Input Shape: torch.Size([4, 3, 256, 256])
2. Output Shape: torch.Size([4, 384])


### ViT-256 Model

In [25]:
model = vit4k_xs()
print("Num Parameters:", count_parameters(model))

t_tensorcrop = transforms.Compose([
    transforms.RandomCrop(14), # 14 x 14 for "global" crop, 6 x 6 for "local" crop
])

# [14 x 14] crop in a 16 x 16 grid would retain the same relative information as [224 x 224] in a 256 x 256 img
assert 224/256 == 14/16 

# [6 x 6] crop in a 16 x 16 grid would retain the same relative information as [96 x 96] in a 256 x 256 img
assert 96/256 == 6/16 

x_bag = torch.randn(256, 384)
print('1. For a 4K x 4K image, torch.load in 256-len sequence of 384-dim embeddings:', x_bag.shape)
x_bag = x_bag.unsqueeze(dim=0).unfold(1, 16, 16).transpose(1,2)
print('2. Reshape this sequence to be a 2D image grid (B NC W H):', x_bag.shape)
x_bag = t_tensorcrop(x_bag)
print('3. Applying 2D cropping (B NC W H):', x_bag.shape)
_ = model(x_bag)
print('4. Out:', _.shape)

# of Patches: 196
Num Parameters: 2793792
1. For a 4K x 4K image, torch.load in 256-len sequence of 384-dim embeddings: torch.Size([256, 384])
2. Reshape this sequence to be a 2D image grid (B NC W H): torch.Size([1, 384, 16, 16])
3. Applying 2D cropping (B NC W H): torch.Size([1, 384, 14, 14])
4. Out: torch.Size([1, 192])


In [14]:
t_dino = DataAugmentationDINO(8)

x_bag = torch.randn(256, 384)
x_crops = t_dino(x_bag)
for idx, crop in enumerate(x_crops):
    print('Crop %d:' % (idx+1), crop.shape)

Crop 1: torch.Size([384, 14, 14])
Crop 2: torch.Size([384, 14, 14])
Crop 3: torch.Size([384, 6, 6])
Crop 4: torch.Size([384, 6, 6])
Crop 5: torch.Size([384, 6, 6])
Crop 6: torch.Size([384, 6, 6])
Crop 7: torch.Size([384, 6, 6])
Crop 8: torch.Size([384, 6, 6])
Crop 9: torch.Size([384, 6, 6])
Crop 10: torch.Size([384, 6, 6])
