## Transfer Learning with Pytorch

This tutorial introduces how to perform transfer learning with Pytorch. This follows the tutorial on

- [1] http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time, copy
from pathlib import Path

# PLT in interactive mode
plt.ion()

### Load the Data

We want to classify Bees and Ants. We will be training on 120 images.

In [None]:
# Data augmentation and normalization for training. For validation just normalize.
# This uses the transforms from torchvision module
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Use torchvision module to load datasets from file
data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(str(Path(data_dir) / x), data_transforms[x]) for x in ['train', 'val']}

# Data loaders (here is where you define batch sizes and shuffling)
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                             batch_size=4,
                                             shuffle=True,
                                             num_workers=4) for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# Need to explicitly denote whether to use GPU with pytorch
use_gpu = torch.cuda.is_available()

### Visualize Images

Lets visualize some of the images

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor"""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.0001)
    
# Get a batch of the training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

### Train the model

Train model using learning rate scheduling and saving best model.

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    """Train a model.
    :param scheduler: a LR scheduler object from torch.optim.lr_scheduler
    """
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    
    for epoch in range(num_epochs):
        print('Epoch %s/%s' % (epoch, num_epochs - 1))
        print('-' * 10) # Nice little trick there
        
        # Every epoch has a training and a validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True) # Train mode
            if phase == 'val':
                model.train(False) # Validation mode
                
            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over the data
            for data in dataloaders[phase]:
                # Unpack inputs from data
                inputs, labels = data
                
                # Wrap them in a variable
                if use_gpu:
                    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                
                # Zero the optimizer gradients
                optimizer.zero_grad()
                
                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1) # Maximum prediction
                loss = criterion(outputs, labels)
                
                # Backwards pass
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    
                # statistics
                running_loss += loss.data[0] * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            
            print('%s Loss: %.4f Acc: %.4f' % (phase, epoch_loss, epoch_acc))
            
            # Deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
        print()
        
    time_elapsed = time.time() - since
    print('Trainig complete in %.0fm %.0fs' % (time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: %.4f' % best_acc)
    
    # Load the best model weights and return the trained model
    model.load_state_dict(best_model_wts)
    return model            

### Visualizing model prediction

In [None]:
def visualize_model(model, num_images=6):
    images_so_far = 0
    fig = plt.figure()
    
    for i, data in enumerate(dataloaders['val']):
        # Unpack inputs from data
        inputs, labels = data
        # Wrap them in a variable
        if use_gpu:
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
        else:
            inputs, labels = Variable(inputs), Variable(labels)
        # Feed forward, get max prediction
        outputs = model(inputs)
        _, preds = torch.max(outputs.data, 1)
        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images // 2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('predicted: %s' % class_names[preds[j]])
            imshow(inputs.cpu().data[j])
            
            if images_so_far == num_images:
                return

### Finetune the convnet

Load a pre-trained model and reset the final layer

In [None]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2) # 2 classes, same features as previous fc layer

if use_gpu:
    model_ft = model_ft.cuda()

# Loss function comes from torch.nn module
criterion = nn.CrossEntropyLoss()

# Observe all parameters being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Scheduler to decay learning rate by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

### Train the model

Note that we are letting the gradients propagate through the entire resnet model

In [None]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

# Visualize output
visualize_model(model_ft)

### Conv-net as a fixed feature extractor

In the previous training, we let the gradients go back through the entire feature extractor. This is kind of unecessary, so lets instead freeze the conv base. This will increase training speed.

In [None]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False # Freeze the layer

# Parameters in newly constructed layers have requires_grad = True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

if use_gpu:
    model_conv = model_conv.cuda()

criterion = nn.CrossEntropyLoss()

# We only optimize parameters in the final layer
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Learning rate decay
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

In [None]:
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

visualize_model(model_conv)

plt.ioff()
plt.show()