# timm (PyTorch image models)

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import timm
# from timm.models.vision_transformer import VisionTransformer
from transformers.image_utils import load_image

In [None]:
# list pretrained models
# timm.list_models(pretrained=True)

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'

image = load_image(url)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## MobileNet

In [None]:
# load pretrained model
mobile_net = timm.create_model(
    'mobilenetv3_large_100',
    pretrained=True,
    # num_classes=23  # set number of outputs for finetuning
)
mobile_net = mobile_net.eval()

In [None]:
# create transform
data_cfg = timm.data.resolve_data_config(model=mobile_net)
transform = timm.data.create_transform(**data_cfg, is_training=False)

print(transform)

In [None]:
# sample random inputs
# x = torch.randn(1, 3, 224, 224)

# preprocess image
x = transform(image).unsqueeze(0)  # (1, 3, 224, 224)

print(f'Inputs shape: {x.shape}')

In [None]:
# compute predictions
with torch.inference_mode():
    logits = mobile_net(x)

print(f'Logits shape: {logits.shape}')

In [None]:
# get top-5 predictions
top5_probas, top5_ids = torch.topk(logits.softmax(dim=1), k=5)

## ResNet

In [None]:
# load pretrained model
resnet = timm.create_model(
    'resnet50',
    pretrained=True
)
resnet = resnet.eval()

In [None]:
# create transform
data_config = timm.data.resolve_data_config(pretrained_cfg=resnet.pretrained_cfg)
transform = timm.data.create_transform(**data_config, is_training=False)

print(transform)

In [None]:
# preprocess image
x = transform(image).unsqueeze(0)  # (1, 3, 224, 224)

print(f'Inputs shape: {x.shape}')

In [None]:
# compute features and predictions
with torch.inference_mode():
    features = resnet.forward_features(x)
    logits = resnet.forward_head(features)

print(f'Features shape: {features.shape}')
print(f'Logits shape: {logits.shape}')

In [None]:
# load pretrained model without last pooling and FC layers
resnet_features = timm.create_model(
    'resnet50',
    pretrained=True,
    num_classes=0,
    global_pool=''
)

In [None]:
# compute features
with torch.inference_mode():
    features = resnet_features(x)

print(f'Features shape: {features.shape}')

# DINOv2

In [None]:
# load pretrained model
dinov2 = timm.create_model(
    'vit_large_patch14_dinov2.lvd142m',
    pretrained=True
    # num_classes=0  # remove final classifier
)
dinov2 = dinov2.eval()

In [None]:
# create transform
data_config = timm.data.resolve_data_config(model=dinov2)
transform = timm.data.create_transform(**data_config, is_training=False)

In [None]:
# preprocess image
x = transform(image).unsqueeze(0)  # (1, 3, 518, 518)

print(f'Inputs shape: {x.shape}')

In [None]:
# compute predictions
with torch.inference_mode():
    logits = dinov2(x)

print(f'Logits shape: {logits.shape}')

In [None]:
# compute features and predictions
with torch.inference_mode():
    features = dinov2.forward_features(x)
    logits = dinov2.forward_head(features, pre_logits=False)

print(f'Features shape: {features.shape}')
print(f'Logits shape: {logits.shape}')