# Fine-tune


In [None]:
import sys
sys.path.append('..')

import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import models
from torchvision import transforms as tfs
from torchvision.datasets import ImageFolder

### Visualize images

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
root_path = '/home/yang/dataset/imagenet/fruits/train/'
im_list = [os.path.join(root_path, 'apple', i) for i in os.listdir(root_path + 'apple')[:4]]
im_list += [os.path.join(root_path, 'avocado', i) for i in os.listdir(root_path + 'avocado')[:4]]
im_list += [os.path.join(root_path, 'banana', i) for i in os.listdir(root_path + 'banana')[:4]]
im_list += [os.path.join(root_path, 'kiwi', i) for i in os.listdir(root_path + 'kiwi')[:4]]
im_list += [os.path.join(root_path, 'watermelon', i) for i in os.listdir(root_path + 'watermelon')[:5]]

nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
    for j in range(ncols):
        figs[i][j].imshow(Image.open(im_list[nrows*i+j]))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

### Preprocess data

In [None]:
train_tf = tfs.Compose([
    tfs.RandomResizedCrop(224),
    tfs.RandomHorizontalFlip(),
    tfs.ToTensor(),
    tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])

valid_tf = tfs.Compose([
    #tfs.Resize(256),
    tfs.CenterCrop(224),
    tfs.ToTensor(),
    tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_tf_random = tfs.Compose([
    #tfs.Resize(256),
    tfs.RandomResizedCrop(224),
    tfs.ToTensor(),
    tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_tf_resize = tfs.Compose([
    tfs.Resize(256),
    tfs.RandomResizedCrop(224),
    tfs.ToTensor(),
    tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

### Define dataset

In [None]:
train_set = ImageFolder('/home/yang/dataset/imagenet/fruits/train/', train_tf)
valid_set = ImageFolder('/home/yang/dataset/imagenet/fruits/val100/', valid_tf)
valid_set_random = ImageFolder('/home/yang/dataset/imagenet/fruits/val100/', valid_tf_random)
valid_set_resize = ImageFolder('/home/yang/dataset/imagenet/fruits/val100/', valid_tf_resize)

train_data = DataLoader(train_set, 64, True, num_workers=4)
valid_data = DataLoader(valid_set, 8, False, num_workers=2)
#valid_data_random = DataLoader(valid_set_random, 8, False, num_workers=2)
#valid_data_resize = DataLoader(valid_set_resize, 8, False, num_workers=2)

In [None]:
class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()

        # Normal convolution block followed by Batchnorm (CONV_3x3-->BN-->Relu)
        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        # Depthwise convolution block (CONV_BLK_3x3-->BN-->Relu-->CONV_1x1-->BN-->Relu)
        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            conv_bn(  3,  32, 2), 
            conv_dw( 32,  64, 1),
            conv_dw( 64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        self.fc = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

In [None]:
model = MobileNet()
#print(model)

In [None]:
model = torch.nn.DataParallel(model).cuda()
#print(model)

In [None]:
params = torch.load('moblienet_30e.pth.tar')

### Method 1 - load directly

In [None]:
model.load_state_dict(params)

### Method 2 - remove prefix in paramas

**Save model state_dict into a variable**

In [None]:
model_dict = model.state_dict()

**1. Filter out unnecessary keys**

In [None]:
pretrained_dict = {k: v for k, v in params.items() if k in model_dict}

**2. Overwrite entries in the existing state dict **

In [None]:
model_dict.update(pretrained_dict)

**3. Load the new state dict**

In [None]:
model.load_state_dict(model_dict)

### Finetune model

In [None]:
# the new defined layer have requires_grad=True by default.
model.fc = nn.Linear(2048, 5)
#model.fc.parameters.requires_grad
#model.fc = nn.Linear(2048, 2)

In [None]:
criterion = nn.CrossEntropyLoss()

# Sef different learning rates in different layers
optimizer = torch.optim.SGD([{'params':model.module.model.parameters(),'lr':1e-2},
                             {'params':model.module.fc.parameters(), 'lr':1e-3}], weight_decay=1e-4)

In [None]:
from utils import train
from utils import validate
from utils import validate_random
from utils import validate_resize

epochs = 15
for e in range(epochs):
    train(model, train_data, e, optimizer, criterion)
    validate(model, valid_data, e, optimizer, criterion)
    #validate_random(model, valid_data_random, e, optimizer, criterion)
    #validate_resize(model, valid_data_resize, e, optimizer, criterion)

### Save model

In [None]:
torch.save(model.state_dict(), "mobienet_30e.pth.tar")

### Load save model

When we saved our pretrained model, we could load it without trainning again.

In [None]:
mobilenet_model = MobileNet()

# transform the model to DataParallel
mobilenet_model = torch.nn.DataParallel(mobilenet_model).cuda()

# load params into a variable
params = torch.load('mobienet_30e.pth.tar')['state_dict']

# load params to model
mobilenet_model.load_state_dict(params)

## Validation

In [None]:
model = model.eval()

In [None]:
im1 = Image.open('/home/yang/dataset/imagenet/fruits/val/kiwi/756419172.jpg')
im1

## Result

In [None]:
im = valid_tf(im1)
out = model(Variable(im.unsqueeze(0),volatile=True).cuda())
pred_label = out.max(1)[1].data[0]
print('predict label: {}'.format(train_set.classes[pred_label]))