# Using Pre-Trained Weights

https://pytorch.org/vision/stable/models.html  
https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights  
https://pytorch.org/vision/stable/_modules/torchvision/models/_api.html  
https://github.com/pytorch/vision/blob/main/torchvision/transforms/_presets.py


## Imports

In [1]:
import torch
from torchvision import models

## Available Models

https://pytorch.org/vision/stable/generated/torchvision.models.list_models.html

The `list_models()` function can be used to list all available models:

In [2]:
for model in models.list_models():
    print(model)

alexnet
convnext_base
convnext_large
convnext_small
convnext_tiny
deeplabv3_mobilenet_v3_large
deeplabv3_resnet101
deeplabv3_resnet50
densenet121
densenet161
densenet169
densenet201
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_v2_l
efficientnet_v2_m
efficientnet_v2_s
fasterrcnn_mobilenet_v3_large_320_fpn
fasterrcnn_mobilenet_v3_large_fpn
fasterrcnn_resnet50_fpn
fasterrcnn_resnet50_fpn_v2
fcn_resnet101
fcn_resnet50
fcos_resnet50_fpn
googlenet
inception_v3
keypointrcnn_resnet50_fpn
lraspp_mobilenet_v3_large
maskrcnn_resnet50_fpn
maskrcnn_resnet50_fpn_v2
maxvit_t
mc3_18
mnasnet0_5
mnasnet0_75
mnasnet1_0
mnasnet1_3
mobilenet_v2
mobilenet_v3_large
mobilenet_v3_small
mvit_v1_b
mvit_v2_s
quantized_googlenet
quantized_inception_v3
quantized_mobilenet_v2
quantized_mobilenet_v3_large
quantized_resnet18
quantized_resnet50
quantized_resnext101_32x8d
quantized_resnext101_64x4d
quantized_shufflenet_v2_x0_

**Filters** can be used to narrow down the list of available models:

In [3]:
# VGG models
vgg_models = models.list_models(include="vgg*")
print("ALL VGG MODELS:")
for model in vgg_models:
    print(model)

# VGG models that do not use batch normalization
vgg_models = models.list_models(include="vgg*", exclude="*bn")
print("\nVGG MODELS WITHOUT BATCH NORMALIZATION:")
for model in vgg_models:
    print(model)

ALL VGG MODELS:
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn

VGG MODELS WITHOUT BATCH NORMALIZATION:
vgg11
vgg13
vgg16
vgg19


## Pre-Trained Weights

* Available classification weights: https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights
* `get_model_weights()` function: https://pytorch.org/vision/stable/generated/torchvision.models.get_model_weights.html
* `get_weight()` function: https://pytorch.org/vision/stable/generated/torchvision.models.get_weight.html



### Technical Details

The `get_model_weights()` function can be used to obtain all available weights. The function has return type `Type[WeightsEnum]`, i.e., it returns the weights enum **class** of the associated model (not an instance of that class). The `WeightsEnum` class defined [here](https://pytorch.org/vision/main/_modules/torchvision/models/_api.html) inherits from Pyton's built-in `Enum` base class for creating enumerated constants, see [here](https://docs.python.org/3/library/enum.html) and [here](https://docs.python.org/3/howto/enum.html#enum-basic-tutorial) for details. This class represents the different pre-trained weights that are available for the given model. Each member of this enumeration is a unique instance of this class, representing a specific set of pre-trained weights.

In [4]:
weights_enum = models.get_model_weights("resnet50")
print(weights_enum)

<enum 'ResNet50_Weights'>


We can list all available weights for a given model as follows:

In [5]:
print("AVAILABLE RESNET-50 WEIGHTS:")
for weights in weights_enum:
    print(f"{weights.name}: {weights}")

AVAILABLE RESNET-50 WEIGHTS:
IMAGENET1K_V1: ResNet50_Weights.IMAGENET1K_V1
IMAGENET1K_V2: ResNet50_Weights.IMAGENET1K_V2


The `weights_enum` is an enumeration with as many members as there are available pre-trained weights for a given model. Each **member** is technically an **attribute** of `weights_enum`, allowing us to access them as follows:

In [6]:
# ImageNet weights (old version)
print(weights_enum.IMAGENET1K_V1)

# ImageNet weights (new version)
print(weights_enum.IMAGENET1K_V2)

ResNet50_Weights.IMAGENET1K_V1
ResNet50_Weights.IMAGENET1K_V2


We can also access the individual members as follows:

In [7]:
resnet50_weights_v2 = weights_enum["IMAGENET1K_V2"]

Each member has a `name` and a `value` associated with it. The `name` of a member is what we just used to access the member in the previous line.

In [8]:
print(resnet50_weights_v2.name == "IMAGENET1K_V2")

True


As stated earlier, `weights_enum` is the `WeightsEnum` **class** associated with a given model, and its members/attributes are **instances** of that class.

In [9]:
print(type(resnet50_weights_v2))
isinstance(resnet50_weights_v2, weights_enum)  # members are instances of the WeightsEnum class

<enum 'ResNet50_Weights'>


True

Finally, the **value** of each member is an instance of the `Weights` class defined in `torchvision.models._api`:

In [10]:
type(resnet50_weights_v2.value)

torchvision.models._api.Weights

**To sum up**:

* The `get_model_weights()` function returns the weights enum class associated with the given model. This `WeightsEnum` class inherits from `enum.Enum`, and is an **enumeration** of all the pre-trained weights available for a model.
* The **attributes** of that class are **enumeration members**, and are functionally constants.
* Each member has a **name** and **value** associated with it.
* The **value** of each member inherits from the `Weights` class defined in `torchvision.models._api`.

In [11]:
weights_enum = models.get_model_weights("resnet50")  # weights enum class associated with ResNet-50 (enumeration)
resnet50_weights_v2 = weights_enum.IMAGENET1K_V2     # member of the enumeration
print(type(resnet50_weights_v2.value))               # the actual weights (instance of class `Weights`)

<class 'torchvision.models._api.Weights'>


Finally, it is also possible to directly access a **particular instance** of the weights enum class of a given model using the `get_weight()` function:

In [12]:
models.get_model_weights("resnet50")["IMAGENET1K_V2"] == models.get_weight("ResNet50_Weights.IMAGENET1K_V2")

True

All available pre-trained weights are listed [here](https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights).


### Working with Pre-Trained Weights

Each set of pre-trained weights is an instance of the `Weights` class introduced in `torchvision.models._api` (see [here](https://pytorch.org/vision/main/_modules/torchvision/models/_api.html)). As such, the following useful properties and methods are available:

* `.url`: Returns the **url** from which the pre-trained weights can be downloaded.
* `.meta`: Returns a `Dict[str, Any]` containing useful metadata about the pre-trained weights, such as **categories** (of the classification task), the **number of parameters**, and the **training recipe**.
* `.transforms`: Returns the **preprocessing transforms** to be used when working with the pre-trained weights.

In [13]:
# Get weights enum for VGG11
vgg11_weights = models.get_weight("VGG11_Weights.IMAGENET1K_V1")

# URL to download weights
print(f"URL:\n{vgg11_weights.url}\n")

# Keys of the dictionary returned by `.dict`
print("KEYS IN META DICT:")
for k in vgg11_weights.meta:
    print(k)

# File size of model weights
print(f"\nFILE SIZE:\n{vgg11_weights.meta['_file_size']} MB")

# Link to training recipe extracted from dict returned by `.meta`
print(f"\nTRAINING RECIPE:\n{vgg11_weights.meta['recipe']}")

# Preprocessing transforms
print(f"\nPREPROCESSING TRANSFORMS:\n{vgg11_weights.transforms()}")

URL:
https://download.pytorch.org/models/vgg11-8a719046.pth

KEYS IN META DICT:
min_size
categories
recipe
_docs
num_params
_metrics
_ops
_file_size

FILE SIZE:
506.84 MB

TRAINING RECIPE:
https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg

PREPROCESSING TRANSFORMS:
ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


**NOTE**: The **_V2** weights improve upon the results of the original paper by using TorchVision’s [new training recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/).

To obtain the checkpoint storing the pre-trained weights, we can use the `get_state_dict()` method. This **downloads** the checkpoint and **loads** the state dictionary.

In [14]:
vgg11_weights_dict = vgg11_weights.get_state_dict()

Let's take a look at the keys of the ordered dict `vgg11_weights_dict`:

In [15]:
for k in vgg11_weights_dict:
    print(k)

features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.11.weight
features.11.bias
features.13.weight
features.13.bias
features.16.weight
features.16.bias
features.18.weight
features.18.bias
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias
classifier.6.weight
classifier.6.bias


To illustrate how to use these weights, let's implement VGG11.

In [16]:
vgg11 = models.vgg11()
print(vgg11)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

As we can see, each entry in the `vgg11_weights_dict` corresponds to a layer of VGG11 with trainable parameters.  Also, we can retreive **only** those **layers with trainable parameters** as follows:

In [17]:
for name, param in vgg11.named_parameters():
    print(name)

features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.11.weight
features.11.bias
features.13.weight
features.13.bias
features.16.weight
features.16.bias
features.18.weight
features.18.bias
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias
classifier.6.weight
classifier.6.bias


Since we didn't specify any weights when initializing the network, the weights were initialized randomly.

In [18]:
print(
    "Shape identical: "
    f"{vgg11.state_dict()['features.0.weight'].size() == vgg11_weights_dict['features.0.weight'].size()}"
)
print(
    "Weights identical: "
    f"{torch.all(vgg11.state_dict()['features.0.weight'] == vgg11_weights_dict['features.0.weight'])}"
)

Shape identical: True
Weights identical: False


Assigning pre-trained weights to **individual layers** is easy:

In [19]:
# Assign weights to first convolutional layer
vgg11.features[0].weight.data = vgg11_weights_dict["features.0.weight"]

# Check whether weights have successfully been assigned
print(
    "Weights identical: "
    f"{torch.all(vgg11.state_dict()['features.0.weight'] == vgg11_weights_dict['features.0.weight'])}"
)

Weights identical: True
