# Using Pre-Trained Weights

https://pytorch.org/vision/stable/models.html  
https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights  
https://pytorch.org/vision/stable/_modules/torchvision/models/_api.html  
https://github.com/pytorch/vision/blob/main/torchvision/transforms/_presets.py


## Imports

In [None]:
import torch
from torchvision import models

## Available Models

https://pytorch.org/vision/stable/generated/torchvision.models.list_models.html

The `list_models()` function can be used to list all available models:

In [None]:
for model in models.list_models():
    print(model)

**Filters** can be used to narrow down the list of available models:

In [None]:
# VGG models
vgg_models = models.list_models(include="vgg*")
print("ALL VGG MODELS:")
for model in vgg_models:
    print(model)

# VGG models that do not use batch normalization
vgg_models = models.list_models(include="vgg*", exclude="*bn")
print("\nVGG MODELS WITHOUT BATCH NORMALIZATION:")
for model in vgg_models:
    print(model)

## Pre-Trained Weights

* Available classification weights: https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights
* `get_model_weights()` function: https://pytorch.org/vision/stable/generated/torchvision.models.get_model_weights.html
* `get_weight()` function: https://pytorch.org/vision/stable/generated/torchvision.models.get_weight.html



### Technical Details

The `get_model_weights()` function can be used to obtain all available weights. The function has return type  
`Type[WeightsEnum]`, i.e., it returns the weights enum **class** of the associated model (not an instance of that class). The  
`WeightsEnum` class defined [here](https://pytorch.org/vision/main/_modules/torchvision/models/_api.html) inherits from Python's built-in `Enum` base class for creating enumerated constants, see [here](https://docs.python.org/3/library/enum.html)  
and [here](https://docs.python.org/3/howto/enum.html#enum-basic-tutorial) for details. This class represents the different pre-trained weights that are available for the given model. Each member  
of this enumeration is a unique instance of this class, representing a specific set of pre-trained weights.

In [None]:
weights_enum = models.get_model_weights("resnet50")
print(weights_enum)

We can list all available weights for a given model as follows:

In [None]:
print("AVAILABLE RESNET-50 WEIGHTS:")
for weights in weights_enum:
    print(f"{weights.name}: {weights}")

The `weights_enum` is an enumeration with as many members as there are available pre-trained weights for a given model. Each  
**member** is technically an **attribute** of `weights_enum`, allowing us to access them as follows:

In [None]:
# ImageNet weights (old version)
print(weights_enum.IMAGENET1K_V1)

# ImageNet weights (new version)
print(weights_enum.IMAGENET1K_V2)

We can also access the individual members as follows:

In [None]:
resnet50_weights_v2 = weights_enum["IMAGENET1K_V2"]

Each member has a `name` and a `value` associated with it. The `name` of a member is what we just used to access the member in  
the previous line.

In [None]:
print(resnet50_weights_v2.name == "IMAGENET1K_V2")

As stated earlier, `weights_enum` is the `WeightsEnum` **class** associated with a given model, and its members/attributes are  
**instances** of that class.

In [None]:
print(type(resnet50_weights_v2))
isinstance(resnet50_weights_v2, weights_enum)  # members are instances of the WeightsEnum class

Finally, the **value** of each member is an instance of the `Weights` class defined in `torchvision.models._api`:

In [None]:
type(resnet50_weights_v2.value)

**To sum up**:

* The `get_model_weights()` function returns the weights enum class associated with the given model. This `WeightsEnum`  
  class inherits from `enum.Enum`, and is an **enumeration** of all the pre-trained weights available for a model.
* The **attributes** of that class are **enumeration members**, and are functionally constants.
* Each member has a **name** and **value** associated with it.
* The **value** of each member inherits from the `Weights` class defined in `torchvision.models._api`.

In [None]:
weights_enum = models.get_model_weights("resnet50")  # weights enum class associated with ResNet-50 (enumeration)
resnet50_weights_v2 = weights_enum.IMAGENET1K_V2     # member of the enumeration
print(type(resnet50_weights_v2.value))               # the actual weights (instance of class `Weights`)

Finally, it is also possible to directly access a **particular instance** of the weights enum class of a given model using the  
`get_weight()` function:

In [None]:
models.get_model_weights("resnet50")["IMAGENET1K_V2"] == models.get_weight("ResNet50_Weights.IMAGENET1K_V2")

All available pre-trained weights are listed [here](https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights).


### Working with Pre-Trained Weights

Each set of pre-trained weights is an instance of the `Weights` class introduced in `torchvision.models._api` (see [here](https://pytorch.org/vision/main/_modules/torchvision/models/_api.html)).  
As such, the following useful properties and methods are available:

* `.url`: Returns the **url** from which the pre-trained weights can be downloaded.
* `.meta`: Returns a `Dict[str, Any]` containing useful metadata about the pre-trained weights, such as **categories** (of the  
  classification task), the **number of parameters**, and the **training recipe**.
* `.transforms`: Returns the **preprocessing transforms** to be used when working with the pre-trained weights.

In [None]:
# Get weights enum for VGG11
vgg11_weights = models.get_weight("VGG11_Weights.IMAGENET1K_V1")

# URL to download weights
print(f"URL:\n{vgg11_weights.url}\n")

# Keys of the dictionary returned by `.dict`
print("KEYS IN META DICT:")
for k in vgg11_weights.meta:
    print(k)

# File size of model weights
print(f"\nFILE SIZE:\n{vgg11_weights.meta['_file_size']} MB")

# Link to training recipe extracted from dict returned by `.meta`
print(f"\nTRAINING RECIPE:\n{vgg11_weights.meta['recipe']}")

# Preprocessing transforms
print(f"\nPREPROCESSING TRANSFORMS:\n{vgg11_weights.transforms()}")

**NOTE**: The **_V2** weights improve upon the results of the original paper by using TorchVision’s [new training recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/).

To obtain the checkpoint storing the pre-trained weights, we can use the `get_state_dict()` method. This **downloads** the  
checkpoint and **loads** the state dictionary.

In [None]:
vgg11_weights_dict = vgg11_weights.get_state_dict()

Let's take a look at the keys of the ordered dict `vgg11_weights_dict`:

In [None]:
for k in vgg11_weights_dict:
    print(k)

To illustrate how to use these weights, let's implement VGG11.

In [None]:
vgg11 = models.vgg11()
print(vgg11)

As we can see, each entry in the `vgg11_weights_dict` corresponds to a layer of VGG11 with trainable parameters.  Also, we can  
retreive **only** those **layers with trainable parameters** as follows:

In [None]:
for name, param in vgg11.named_parameters():
    print(name)

Since we didn't specify any weights when initializing the network, the weights were initialized randomly.

In [None]:
print(
    "Shape identical: "
    f"{vgg11.state_dict()['features.0.weight'].size() == vgg11_weights_dict['features.0.weight'].size()}"
)
print(
    "Weights identical: "
    f"{torch.all(vgg11.state_dict()['features.0.weight'] == vgg11_weights_dict['features.0.weight'])}"
)

Assigning pre-trained weights to **individual layers** is easy:

In [None]:
# Assign weights to first convolutional layer
vgg11.features[0].weight.data = vgg11_weights_dict["features.0.weight"]

# Check whether weights have successfully been assigned
print(
    "Weights identical: "
    f"{torch.all(vgg11.state_dict()['features.0.weight'] == vgg11_weights_dict['features.0.weight'])}"
)