# timm (PyTorch image models)

In [None]:
import torch
import timm

In [None]:
# list pretrained models
timm.list_models(pretrained=True)

## Random inputs

In [None]:
# sample inputs
x = torch.randn(1, 3, 224, 224)

## MobileNet

In [None]:
# load pretrained model
mobile_net = timm.create_model('mobilenetv3_large_100', pretrained=True)

print(mobile_net)

In [None]:
# load pretrained model for finetuning
mobile_net = timm.create_model(
    'mobilenetv3_large_100',
    pretrained=True,
    num_classes=23
)

print(mobile_net)

In [None]:
# compute predictions
with torch.no_grad():
    logits = mobile_net(x)

print(f'Logits shape: {logits.shape}')

## ResNet

In [None]:
# load pretrained model
resnet = timm.create_model('resnet50', pretrained=True)
resnet = resnet.eval()

print(resnet)

In [None]:
# compute features and predictions
with torch.no_grad():
    features = resnet.forward_features(x)
    logits = resnet.forward_head(features)

print(f'Features shape: {features.shape}')
print(f'Logits shape: {logits.shape}')

In [None]:
# load pretrained model without last pooling and FC layers
resnet_features = timm.create_model(
    'resnet50',
    pretrained=True,
    num_classes=0,
    global_pool=''
)

print(resnet_features)

In [None]:
# compute features
with torch.no_grad():
    features = resnet_features(x)

print(f'Features shape: {features.shape}')

## Transformations

In [None]:
# create default transform
transform = timm.data.create_transform()

print(transform)

In [None]:
# create transform from model metadata
data_cfg = timm.data.resolve_data_config(resnet.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)

print(transform)