# Transfer Learning Pytorch

#### References
* https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
* https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
* https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch
* https://discuss.pytorch.org/t/insert-new-layer-in-the-middle-of-a-pre-trained-model/12414/4
* https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/2
* https://discuss.pytorch.org/t/extracting-and-using-features-from-a-pretrained-model/20723/6

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
from torch.hub import load_state_dict_from_url
import time
import os
import copy
import utils_resnet_TL as utils_resnet
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
input_size = 224
num_classes = 2

PyTorch Version:  1.1.0
Torchvision Version:  0.3.0


#### Get definition of some module
What we're interested here is to inspect how resnet18 is implemented.
``` python
def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x
```

In [3]:
import inspect
resnet_source = inspect.getfile(models.resnet18)
print('Resnet Source:', resnet_source)
#with open(resnet_source, 'r') as f:
#    print(f.read())

Resnet Source: /mnt/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py


In [4]:
def set_parameter_requires_grad(model, feature_extracting=True):
    if feature_extracting:
        # Mark parameters to be freezed
        for param in model.parameters():
            param.requires_grad = False

#### Load Pretrained Resnet18

In [5]:
model_ft = models.resnet18(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extracting=True)
# Print model structure
#print(model_ft)

#### Make changes on the final layer
The vanilla way of doing transfer learning consist on substituting the final layer of a model and retrain.

In [6]:
# Get original number of features
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)

### More advanced way of changing base model
Imagine that you need change something on the middle of the model, or if you want to make the model output more things (change the forward definition). The way we described before can't be used for that.

The following example will prepare the Resnet model to be used as feature extraction for RetinaNet, or ChangeNet

In [7]:
class ResnetFeatures(models.ResNet):
    def __init__(self, block, layers, num_classes=1000):
        super(ResnetFeatures, self).__init__(block, layers, num_classes)        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        layer1 = self.layer1(x)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        #x = self.avgpool(x)
        #x = x.reshape(x.size(0), -1)
        #x = self.fc(x)
        return layer1, layer2, layer3, layer4



#### Redefinition of some helper function
Instantiate custom Resnet with pre-trained weights.

In [8]:
model_pretrained = utils_resnet.resnet50(ResnetFeatures, pretrained=True)
x = torch.randn(1, 3, 224, 224)
output = model_pretrained(x)
print('Layer 1:', output[0].shape)
print('Layer 2:', output[1].shape)
print('Layer 3:', output[2].shape)
print('Layer 4:', output[3].shape)

Parameter containing:
tensor([[[[-7.9648e-03, -3.7614e-02,  6.2047e-03,  ..., -1.2469e-02,
            3.7047e-02,  6.4989e-04],
          [ 3.5415e-02,  3.5940e-02, -3.2289e-02,  ...,  3.4159e-03,
            1.3469e-02, -1.0689e-02],
          [ 9.8453e-05,  2.0491e-02, -5.7877e-03,  ...,  1.6177e-02,
           -2.8275e-03, -7.0718e-03],
          ...,
          [ 1.3436e-02, -5.3551e-02,  3.7207e-03,  ..., -1.1829e-02,
           -2.5833e-02,  3.9328e-02],
          [-3.8391e-03,  2.4145e-02, -1.8424e-02,  ...,  5.2373e-02,
           -4.0217e-02,  2.4476e-02],
          [-7.8571e-03,  1.9448e-02, -3.0415e-02,  ..., -3.6407e-02,
            1.2486e-02,  7.0944e-04]],

         [[ 6.4988e-03,  1.0799e-02,  1.0879e-02,  ..., -1.1941e-02,
            2.8047e-02,  1.3347e-02],
          [-3.4459e-02, -1.8909e-02, -3.3859e-03,  ...,  5.7913e-05,
           -5.4104e-04, -1.8716e-02],
          [ 1.5679e-02, -1.0635e-02,  5.9927e-02,  ..., -1.9105e-02,
            1.5371e-02, -1.5712e-02]