In [1]:
import timm
import torch
import types
import itertools

In [2]:
vits16 = timm.create_model("vit_small_patch16_384", pretrained=True)

In [3]:
vits16dino = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

Using cache found in /home/memmelma/.cache/torch/hub/facebookresearch_dino_main


Override functions of timm ViT. Remove asserts in ~/anaconda3/envs/iprl/lib/python3.7/site-packages/timm/models/layers/patch_embed.py which keeps us from using different image resolutions.

In [4]:
import math
import torch.nn as nn
def interpolate_pos_encoding(self, x, w, h):    
    npatch = x.shape[1] - 1
    N = self.pos_embed.shape[1] - 1
    if npatch == N and w == h:
        return self.pos_embed
    class_pos_embed = self.pos_embed[:, 0]
    patch_pos_embed = self.pos_embed[:, 1:]
    dim = x.shape[-1]
    
    w0 = w // self.patch_embed.patch_size[0]
    h0 = h // self.patch_embed.patch_size[0]
    # we add a small number to avoid floating point error in the interpolation
    # see discussion at https://github.com/facebookresearch/dino/issues/8
    w0, h0 = w0 + 0.1, h0 + 0.1
    patch_pos_embed = nn.functional.interpolate(
        patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
        scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
        mode='bicubic',
    )
    assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def prepare_tokens(self, x):
    B, nc, w, h = x.shape
    x = self.patch_embed(x)  # patch linear embedding

    # add the [CLS] token to the embed patch tokens
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # add positional encoding to each token
    x = x + self.interpolate_pos_encoding(x, w, h)

    return self.pos_drop(x)

def forward(self, x):
    x = self.prepare_tokens(x)
    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)
    return x[:, 0]

vits16.interpolate_pos_encoding = types.MethodType(interpolate_pos_encoding, vits16)
vits16.prepare_tokens = types.MethodType(prepare_tokens, vits16)
vits16.forward = types.MethodType(forward, vits16)

In [5]:
for s in list(itertools.product([192,384,768],[192,384,768])):
    x = torch.rand((4,3,s[0],s[1]))
    
    try:
        print(f'success vit {s[0]},{s[1]} {vits16.forward(x).shape}')
    except:
        print(f'failed vit {s[0]},{s[1]}')

  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


success vit 192,192 torch.Size([4, 384])
success vit 192,384 torch.Size([4, 384])
success vit 192,768 torch.Size([4, 384])
success vit 384,192 torch.Size([4, 384])
success vit 384,384 torch.Size([4, 384])
success vit 384,768 torch.Size([4, 384])
success vit 768,192 torch.Size([4, 384])
success vit 768,384 torch.Size([4, 384])
success vit 768,768 torch.Size([4, 384])


In [6]:
for s in list(itertools.product([192,384,768],[192,384,768])):
    x = torch.rand((4,3,s[0],s[1]))
    try:
        print(f'success dino {s[0]},{s[1]} {vits16dino.forward(x).shape}')
    except:
        print(f'failed dino {s[0]},{s[1]}')

success dino 192,192 torch.Size([4, 384])
success dino 192,384 torch.Size([4, 384])
success dino 192,768 torch.Size([4, 384])
success dino 384,192 torch.Size([4, 384])
success dino 384,384 torch.Size([4, 384])
success dino 384,768 torch.Size([4, 384])
success dino 768,192 torch.Size([4, 384])
success dino 768,384 torch.Size([4, 384])
success dino 768,768 torch.Size([4, 384])
