In [12]:
import timm
import torch

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())

x = torch.rand((4,3,384,384))

# Possible Speedup (didn't really work so far)
https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html

In [2]:
%%time
vits16.eval()
for i in range(25):
    vits16(x)

NameError: name 'vits16' is not defined

In [3]:
%%time
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(vits16.eval()))
for i in range(25):
    frozen_mod.forward(x)

AttributeError: module 'torch.jit' has no attribute 'optimize_for_inference'

# ResNet

In [2]:
resnet18 = timm.create_model("resnet18", pretrained=True)
count_parameters(resnet18), resnet18.forward_features(x).shape, resnet18.forward(x).shape

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/memmel/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


(11689512, torch.Size([4, 512, 12, 12]), torch.Size([4, 1000]))

In [3]:
resnet50 = timm.create_model("resnet50", pretrained=True)
count_parameters(resnet50), resnet50.forward_features(x).shape, resnet50.forward(x).shape

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /home/memmel/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


(25557032, torch.Size([4, 2048, 12, 12]), torch.Size([4, 1000]))

In [18]:
resnet101 = timm.create_model("resnet101", pretrained=True)
count_parameters(resnet101), resnet101.forward_features(x).shape, resnet101.forward(x).shape

(44549160, torch.Size([4, 2048, 12, 12]), torch.Size([4, 1000]))

In [19]:
resnet152 = timm.create_model("resnet152", pretrained=True)
count_parameters(resnet152), resnet152.forward_features(x).shape, resnet152.forward(x).shape

(60192808, torch.Size([4, 2048, 12, 12]), torch.Size([4, 1000]))

In [20]:
resnet200 = timm.create_model("resnet200", pretrained=True)
count_parameters(resnet200), resnet200.forward_features(x).shape, resnet200.forward(x).shape

No pretrained weights exist for this model. Using random initialization.


(64673832, torch.Size([4, 2048, 12, 12]), torch.Size([4, 1000]))

# ConvNext

In [4]:
convnext_tiny = timm.create_model("convnext_tiny", pretrained=True)
count_parameters(convnext_tiny), convnext_tiny.forward_features(x).shape, convnext_tiny.forward(x).shape

Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth" to /home/memmel/.cache/torch/hub/checkpoints/convnext_tiny_1k_224_ema.pth


(28589128, torch.Size([4, 768, 12, 12]), torch.Size([4, 1000]))

In [5]:
convnext_small = timm.create_model("convnext_small", pretrained=True)
count_parameters(convnext_small), convnext_small.forward_features(x).shape, convnext_small.forward(x).shape

Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth" to /home/memmel/.cache/torch/hub/checkpoints/convnext_small_1k_224_ema.pth


(50223688, torch.Size([4, 768, 12, 12]), torch.Size([4, 1000]))

In [6]:
convnext_base = timm.create_model("convnext_base", pretrained=True)
count_parameters(convnext_base), convnext_base.forward_features(x).shape, convnext_base.forward(x).shape

Downloading: "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" to /home/memmel/.cache/torch/hub/checkpoints/convnext_base_1k_224_ema.pth


(88591464, torch.Size([4, 1024, 12, 12]), torch.Size([4, 1000]))

In [16]:
convnext_large = timm.create_model("convnext_large", pretrained=True)
count_parameters(convnext_large), convnext_large.forward_features(x).shape, convnext_large.forward(x).shape

(197767336, torch.Size([4, 1536, 12, 12]), torch.Size([4, 1000]))

# ViT small

## timm

In [4]:
vits16 = timm.create_model("vit_small_patch16_384", pretrained=True)
count_parameters(vits16), vits16.forward_features(x).shape

(22196584, torch.Size([4, 384]))

In [2]:
x = torch.rand((4,3,224,224))
vits16 = timm.create_model("vit_small_patch16_224", pretrained=True)
count_parameters(vits16), vits16.forward_features(x).shape

(22050664, torch.Size([4, 384]))

In [13]:
x = torch.rand((4,3,224,224))
vits16_in21k = timm.create_model("vit_small_patch16_224_in21k", pretrained=True)
count_parameters(vits16_in21k), count_all_parameters(vits16_in21k), vits16_in21k.forward_features(x).shape

(30075219, 30075219, torch.Size([4, 197, 384]))

In [24]:
module_list_vits16_in21k = []
param_list_vits16_in21k = []
for name, module in vits16_in21k.named_modules():
    module_list_vits16_in21k += [name]
    tmp = 0
    for p in module.parameters():
        tmp += p.numel()
    param_list_vits16_in21k += [tmp]

In [14]:
x = torch.rand((4,3,224,224))
vits16dino = timm.create_model("vit_small_patch16_224_dino", pretrained=True)
count_parameters(vits16dino), count_all_parameters(vits16dino), vits16dino.forward_features(x).shape

(21665664, 21665664, torch.Size([4, 197, 384]))

In [25]:
module_list_vits16_dino = []
param_list_vits16_dino = []
for name, module in vits16dino.named_modules():
    module_list_vits16_dino += [name]
    tmp = 0
    for p in module.parameters():
        tmp += p.numel()
    param_list_vits16_dino += [tmp]

In [26]:
len(module_list_vits16_in21k), len(module_list_vits16_dino), len(param_list_vits16_in21k), len(param_list_vits16_dino)

(225, 225, 225, 225)

In [27]:
import numpy as np
np.intersect1d(param_list_vits16_in21k, param_list_vits16_dino)

array([       0,      768,   147840,   295296,   443520,   590208,
         591360,  1181568,  1774464, 21293568])

## DINO

In [3]:
vits16dino = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
count_parameters(vits16dino), vits16dino(x).shape

Using cache found in /home/memmelma/.cache/torch/hub/facebookresearch_dino_main
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


(21665664, torch.Size([4, 384]))

# ViT base

## timm

In [4]:
vitb16 = timm.create_model("vit_base_patch16_384", pretrained=True)
count_parameters(vitb16), vitb16.forward_features(x).shape

(86859496, torch.Size([4, 768]))

## DINO

In [5]:
vitb16dino = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
count_parameters(vitb16dino), vitb16dino(x).shape

Using cache found in /home/memmelma/.cache/torch/hub/facebookresearch_dino_main


(85798656, torch.Size([4, 768]))

# ViT base hybrid

## timm

In [6]:
vitl16res50 = timm.create_model("vit_base_r50_s16_384", pretrained=True)
count_parameters(vitl16res50), vitl16res50.forward_features(x).shape

(98950952, torch.Size([4, 768]))

## DPT (Omnidata)

In [7]:
import os
from dpt.dpt_depth import DPTDepthModel

backbone = 'vitb_rn50_384' # vitl16_384
dpt_depth = DPTDepthModel(backbone=backbone)

load_depth_omnidata = 'dpt/pretrained_models'
assert os.path.exists(load_depth_omnidata), f"Path doesn't exist: {load_depth_omnidata}!"
pretrained_weights_path = os.path.join(load_depth_omnidata, f'omnidata_rgb2depth_dpt_hybrid.pth')

checkpoint = torch.load(pretrained_weights_path, map_location='cpu')

if 'state_dict' in checkpoint:
    state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        state_dict[k[6:]] = v
else:
    state_dict = checkpoint

dpt_depth.load_state_dict(state_dict, strict=False)

count_parameters(dpt_depth.pretrained.model), dpt_depth.pretrained.model.forward_features(x).shape

(98950952, torch.Size([4, 768]))

# ViT large

## timm

In [8]:
vitl16res50 = timm.create_model("vit_large_patch16_384", pretrained=True)
count_parameters(vitl16res50), vitl16res50.forward_features(x).shape

(304715752, torch.Size([4, 1024]))

## DPT (Omnidata)

In [9]:
import os
from dpt.dpt_depth import DPTDepthModel

backbone = 'vitl16_384'
dpt_depth = DPTDepthModel(backbone=backbone)

load_depth_omnidata = 'dpt/pretrained_models'
assert os.path.exists(load_depth_omnidata), f"Path doesn't exist: {load_depth_omnidata}!"
pretrained_weights_path = os.path.join(load_depth_omnidata, f'omnidata_rgb2depth_dpt_large.pth')

checkpoint = torch.load(pretrained_weights_path, map_location='cpu')

if 'state_dict' in checkpoint:
    state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        state_dict[k[6:]] = v
else:
    state_dict = checkpoint

dpt_depth.load_state_dict(state_dict, strict=False)

count_parameters(dpt_depth.pretrained.model), dpt_depth.pretrained.model.forward_features(x).shape

(304715752, torch.Size([4, 1024]))