### Pretrained models
 + https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

### BUGFIX
 + Executing this notebook requires fixing a bug in PILLOW: https://github.com/python-pillow/Pillow/pull/3771

In [None]:
import sys
import os
import logging
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, models, transforms
from torchsummary import summary
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch import optim

sys.path.append('..')
import utils

logging.basicConfig(level=logging.INFO)

In [None]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if DEVICE.type == 'gpu':
    torch.cuda.set_device(0)

### Get data

In [None]:
utils.get_hymenoptera('../data')

### Get model.
 + download pre-trained weights if not available (in `~/.torch/models`)
 + restrict the parameters to use for training

In [None]:
model = models.resnet18(pretrained=True)
# freeze the model parameters

freeze_model = True
if freeze_model:
    for param in model.parameters():
        param.requires_grad = False
    
# Change the output size of the fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default,
# so if freeze_model == True, only the fully-connected layer would be trained
numb_in_features, numb_out_features = model.fc.in_features, 2
model.fc = torch.nn.Linear(numb_in_features, numb_out_features)

### Display summary with dimension propagation.
 + the dirst dimension (`-1`) stands for the batch size
 + the second dimension is the number of channels
 + the last two dimensions are the image size

In [None]:
if False:
    summary(model, (3, 244, 244))

### Get a dictionary with model weights (can be used to save the model)

In [None]:
model_weights = model.state_dict()

### Transformations to be applied

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), 
        transforms.ToTensor()
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), 
        transforms.ToTensor()
    ])}

### Create a dataloader

In [None]:
data_dir = '../data/hymenoptera_data'

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=4, shuffle=True)
               for x in ['train', 'val']}

class_names = image_datasets['train'].classes

### Customize the dataloader

In [None]:
def preprocess(x, y): return x.to(DEVICE), y.to(DEVICE)
dataloaders['train'] = utils.CustomSizeDataLoader(dataloaders['train'], preprocess)
dataloaders['val'] = utils.CustomSizeDataLoader(dataloaders['val'], preprocess)

### Display several images

In [None]:
if True:
    rows, cols = 5, 5
    width, height = 3*cols, 3*rows

    axes = plt.subplots(rows, cols, figsize=(width, height))[1].flatten()
    for ax, (x, y) in zip(axes, dataloaders['train']):
        ax.imshow(x[0].transpose(2,0).to(torch.device("cpu")), cmap="gray")
        ax.axis('off')

In [None]:
model.to(DEVICE)
opt = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
loss_func = F.cross_entropy

In [None]:
utils.fit(2, model, loss_func, opt, dataloaders['train'], dataloaders['val'])