# Fine tune pre-trained networks
We will fine-tune a pre-trained network (CaffeNet) on the Caltech-101 dataset. 

# 1. Load Caltech-101

In [None]:
import os, random

# fix the seed to have same training and testing samples
random.seed(20) 

# 30 training samples per class and at most 20 testing samples per class
training_sample=30

dataset_dir = '{}/Caltech101/101_ObjectCategories'.format(os.getcwd())
labels = os.listdir(dataset_dir)
train_x, train_y, test_x, test_y = [], [], [], []

total = 0
for c, category in enumerate(labels):
    files = os.listdir('{}/{}'.format(dataset_dir, category))
    total += len(files)
    random.shuffle(files)
    for img in files[:training_sample]:
        train_x.append('{}/{}/{}'.format(dataset_dir, category, img))
    train_y += [c for _ in range(len(files[:training_sample]))]
    for img in files[training_sample:training_sample + 20]: 
        test_x.append('{}/{}/{}'.format(dataset_dir, category, img))
    test_y += [c for _ in range(len(files[training_sample: training_sample + 20]))]

print('Total images: {}'.format(total))
print('Train images: {}: {}'.format(len(train_x), len(train_y)))
print('Validation images: {}: {}'.format(len(test_x), len(test_y)))
# random.shuffle(train_list) # Be sure to shuffle training images (otherwise fine-tuning will fail)

# 2. Prepare preprocessing and data-augmentation

In [2]:
import cv2, chainer
import numpy as np
from skimage import io
from scipy.misc import imresize


# aspect ratio is kept after resizing
def resize_image(img, minimum_length=256):
        y, x = img.shape[:2]
        # keep aspect ratio
        if y <= x:
            scale = float(minimum_length) / y
            sizes = (minimum_length, int(scale * x))
        else:
            scale = float(minimum_length) / x
            sizes = (int(scale * y), minimum_length)
        # If grey picture
        if img.ndim == 2:
            img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
        return imresize(img, sizes, interp='bilinear', mode='RGB')

    
def crop_center(img, sizes=(224, 224)):
        y, x, channel = img.shape
        center_y, center_x = int(y / 2), int(x / 2)
        frame_y, frame_x = sizes
        up, down = -int((frame_y + 1) / 2), int(frame_y / 2)
        left, right = -int((frame_x + 1) / 2), int(frame_x / 2)
        start_h, end_h = max(center_y + up, 0), min(center_y + down, y)
        start_w, end_w = max(center_x + left, 0), min(center_x + right, x)
        return img[start_h:end_h, start_w:end_w]


def crop_randomly(img, sizes=(224, 224)):
        y, x, channel = img.shape
        length_y, length_x = sizes
        # pick random number
        keypoint_y = np.random.randint(1, y - length_y + 1)
        keypoint_x = np.random.randint(1, x - length_x + 1)
        start_y = keypoint_y
        end_y = keypoint_y + length_y
        start_x = keypoint_x
        end_x = keypoint_x + length_x
        return img[start_y: end_y, start_x: end_x]


class PreprocessedDataset(chainer.dataset.DatasetMixin):

    def __init__(self, x, y, crop_size=(224, 224), resize=(256, 512), horizontal_flip=True, test=False, gpu=-1):
        self.x, self.y = x, y
        self.crop_size = crop_size
        self.resize = resize
        self.horizontal_flip = horizontal_flip
        self.test = test
        self.gpu = gpu

    def __len__(self):
        return len(self.x)

    def get_example(self, i):
        # Load image
        img = io.imread(self.x[i])
        # Resize image in the range
        img = resize_image(img, minimum_length=int(np.random.randint(*self.resize)))
        if self.test:
            # Crop center for test
            img = crop_center(img, sizes=self.crop_size)
        else:
            # Crop randomly
            img = crop_randomly(img, sizes=self.crop_size)
            if self.horizontal_flip and np.random.rand() >= 0.5:
                # Horizontal filp with 0.5
                img = cv2.flip(img, 1)
        # To BGR
        img = img[:, :, ::-1]
        # Subtract mean
        img = np.array(img, dtype=np.float32) -np.array([103.063,  115.903,  123.152], dtype=np.float32)
        # (channel , height, width)
        img = img.transpose((2, 0, 1))
        # Label
        t = np.array(self.y[i], dtype=np.int32)
        return img, t

# 3. Define model

In [3]:
import chainer
import chainer.functions as F
from chainer import initializers
import chainer.links as L


class FineTuneResNet(chainer.Chain):
    
    def __init__(self, path, layer):
        super(FineTuneResNet, self).__init__()
        with self.init_scope():
            self.resnet = chainer.links.model.vision.resnet.ResNetLayers(path,  layer)
            self.linear = L.Linear(2048, 102)
    
    def __call__(self, x, t):
        feature = self.resnet(x, layers=['pool5'])['pool5']
        h = self.linear(feature)
        loss = F.softmax_cross_entropy(h, t)
        chainer.report({'loss': loss, 'accuracy': F.accuracy(h, t)}, self)
        return loss

# 4. Start fine-tuning

In [None]:
import chainer
from chainer import training
from chainer.training import extensions

train_batchsize=32
test_batchsize=32

# Define ResNet and load weights
model = FineTuneResNet('/root/userspace/readonly/chapter1/resnet_50.caffemodel',  50)

# Send to gpu
model.to_gpu(0)

# Prepare dataset
train = PreprocessedDataset(train_x, train_y, crop_size=(224, 224), resize=(256, 257), horizontal_flip=True, test=False, gpu=0)
val = PreprocessedDataset(test_x, test_y, crop_size=(224, 224), resize=(256, 257), horizontal_flip=False, test=True, gpu=0)
train_iter = chainer.iterators.SerialIterator(train, train_batchsize, repeat=True, shuffle=True)
val_iter = chainer.iterators.SerialIterator(val, test_batchsize, repeat=False, shuffle=False)

# Set up an optimizer: Momentum SGD
optimizer = chainer.optimizers.MomentumSGD(lr=0.001, momentum=0.9)
optimizer.setup(model)
# Weight decay
weight_decay = chainer.optimizer.WeightDecay(5.0e-4)
optimizer.add_hook(weight_decay)


# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=0)
trainer = training.Trainer(updater, (10, 'epoch'))
val_interval = 100, 'iteration'
log_interval = 100, 'iteration'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.Evaluator(val_iter, model, device=0), trigger=val_interval)
trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy',
    ]), trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.run()