In [2]:
import torch
import torchvision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##### Refer <a href="https://github.com/mgupta70/code-blocks/blob/main/Pytorch%20primer.ipynb">Pytorch Primer</a>  for full pipeline -  The only modification we need to make is shown below

##### In case we need to define custom Augmentations, refer - <a href="https://github.com/mgupta70/code-blocks/blob/main/DL_Albumentations_Image_Augmentations.ipynb">DL_Albumentations_Image_Augmentations</a>

##### 1. fine-tuning the whole network - ideal when dataset is large

In [4]:
num_classes = 10

In [6]:
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# count trainable params
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('total params for finetuning entire network: ', num_trainable_params)

total params for finetuning entire network:  11181642


##### 2 Transfer learning - only tuning the top layers, freezing other layers - suitable when dataset is small.

In [7]:
model = torchvision.models.resnet18(pretrained=True)

# Freeze the parameters of the earlier layers
for param in model.parameters():
    param.requires_grad = False

# Replace the final fully connected layer
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# count trainable params
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('total params in Transfer Learning only the head: ', num_trainable_params)

total params in Transfer Learning only the head:  5130
