In [1]:
import os

In [2]:
os.chdir("../")

In [3]:
import torch
import torchvision
import torchvision.models as models
import torch.nn as nn

In [4]:
def get_model_transfer_learning(model_name="resnet18", n_classes=50):

    # Get the requested architecture
    if hasattr(models, model_name):

        model_transfer = getattr(models, model_name)(pretrained=True)

    else:

        torchvision_major_minor = ".".join(torchvision.__version__.split(".")[:2])

        raise ValueError(f"Model {model_name} is not known. List of available models: "
                         f"https://pytorch.org/vision/{torchvision_major_minor}/models.html")

    # Freeze all parameters in the model
    # HINT: loop over all parameters. If "param" is one parameter,
    # "param.requires_grad = False" freezes it
    # YOUR CODE HERE
    for param in model_transfer.parameters():
        param.requires_grad = False

    # Add the linear layer at the end with the appropriate number of classes
    # 1. get numbers of features extracted by the backbone
    # 2. Create a new linear layer with the appropriate number of inputs and
    #    outputs
    
    if model_name == "mobilenet_v3_large":
        num_ftrs = model_transfer.classifier[-1].in_features
        # Replace the last layer
        model_transfer.classifier[-1] = nn.Linear(num_ftrs, n_classes)
    elif model_name == "googlenet":
        num_ftrs = model_transfer.fc.in_features
        # Replace the last layer
        model_transfer.fc = nn.Linear(num_ftrs, n_classes)
    else:
        num_ftrs = model_transfer.fc.in_features
        # Replace the last layer
        model_transfer.fc = nn.Linear(num_ftrs, n_classes)
 
    return model_transfer