In [1]:
import timm
import torch

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

x = torch.rand((4,3,384,384))

# ViT small

## timm

In [2]:
vits16 = timm.create_model("vit_small_patch16_384", pretrained=True)
count_parameters(vits16), vits16.forward_features(x).shape

(22196584, torch.Size([4, 384]))

## DINO

In [4]:
vits16dino = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
count_parameters(vits16dino), vits16dino(x).shape

Using cache found in /home/memmelma/.cache/torch/hub/facebookresearch_dino_main
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


(21665664, torch.Size([4, 384]))

# ViT base

## timm

In [3]:
vitb16 = timm.create_model("vit_base_patch16_384", pretrained=True)
count_parameters(vitb16), vitb16.forward_features(x).shape

(86859496, torch.Size([4, 768]))

## DINO

In [5]:
vitb16dino = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
count_parameters(vitb16dino), vitb16dino(x).shape

Using cache found in /home/memmelma/.cache/torch/hub/facebookresearch_dino_main


(85798656, torch.Size([4, 768]))

# ViT base hybrid

## timm

In [13]:
vitl16res50 = timm.create_model("vit_base_r50_s16_384", pretrained=True)
count_parameters(vitl16res50), vitl16res50.forward_features(x).shape

(98950952, torch.Size([4, 768]))

## DPT (Omnidata)

In [15]:
import os
from dpt.dpt_depth import DPTDepthModel

backbone = 'vitb_rn50_384' # vitl16_384
dpt_depth = DPTDepthModel(backbone=backbone)

load_depth_omnidata = 'dpt/pretrained_models'
assert os.path.exists(load_depth_omnidata), f"Path doesn't exist: {load_depth_omnidata}!"
pretrained_weights_path = os.path.join(load_depth_omnidata, f'omnidata_rgb2depth_dpt_hybrid.pth')

checkpoint = torch.load(pretrained_weights_path, map_location='cpu')

if 'state_dict' in checkpoint:
    state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        state_dict[k[6:]] = v
else:
    state_dict = checkpoint

dpt_depth.load_state_dict(state_dict, strict=False)

count_parameters(dpt_depth.pretrained.model), dpt_depth.pretrained.model.forward_features(x).shape

(98950952, torch.Size([4, 768]))

# ViT large

## timm

In [16]:
vitl16res50 = timm.create_model("vit_large_patch16_384", pretrained=True)
count_parameters(vitl16res50), vitl16res50.forward_features(x).shape

(304715752, torch.Size([4, 1024]))

## DPT (Omnidata)

In [18]:
import os
from dpt.dpt_depth import DPTDepthModel

backbone = 'vitl16_384'
dpt_depth = DPTDepthModel(backbone=backbone)

load_depth_omnidata = 'dpt/pretrained_models'
assert os.path.exists(load_depth_omnidata), f"Path doesn't exist: {load_depth_omnidata}!"
pretrained_weights_path = os.path.join(load_depth_omnidata, f'omnidata_rgb2depth_dpt_large.pth')

checkpoint = torch.load(pretrained_weights_path, map_location='cpu')

if 'state_dict' in checkpoint:
    state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        state_dict[k[6:]] = v
else:
    state_dict = checkpoint

dpt_depth.load_state_dict(state_dict, strict=False)

count_parameters(dpt_depth.pretrained.model), dpt_depth.pretrained.model.forward_features(x).shape

(304715752, torch.Size([4, 1024]))