[View in Colaboratory](https://colab.research.google.com/github/mogamin/chainer-ssd/blob/master/chainer_SSD_pascal_original_dataset_train.ipynb)

In [0]:
# chainer,chainercv,cudaののセットアップ
!apt-get install -y -qq libcusparse8.0 libnvrtc8.0 libnvtoolsext1
!ln -snf /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so.8.0 /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so
!pip install cupy-cuda80==4.3.0 
!pip install chainer==4.3.0
!pip install chainercv

!rm -rf result/

In [0]:
# GPU環境の確認
import chainer
print('GPU availability:', chainer.cuda.available)
print('cuDNN availablility:', chainer.cuda.cudnn_enabled)

GPU availability: True
cuDNN availablility: True


In [0]:
# 乱数表の固定化
import random
import numpy as np

RANDOM_SEED = 0
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if chainer.cuda.available:
    chainer.cuda.cupy.random.seed(RANDOM_SEED)

# 実行時間の生成
import datetime
now = (datetime.datetime.now() + datetime.timedelta(hours=+9)).strftime('%Y%m%d-%H%M%S')

In [0]:
import argparse
import copy
import numpy as np

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer.optimizer_hooks import WeightDecay
from chainer import serializers
from chainer import training
from chainer.training import extensions
from chainer.training import triggers

from chainercv.datasets import voc_bbox_label_names
from chainercv.datasets import VOCBboxDataset
from chainercv.extensions import DetectionVOCEvaluator
from chainercv.links.model.ssd import GradientScaling
from chainercv.links.model.ssd import multibox_loss
from chainercv.links import SSD300
from chainercv.links import SSD512
from chainercv import transforms

from chainercv.links.model.ssd import random_crop_with_bbox_constraints
from chainercv.links.model.ssd import random_distort
from chainercv.links.model.ssd import resize_with_random_interpolation

In [0]:
class MultiboxTrainChain(chainer.Chain):

    def __init__(self, model, alpha=1, k=3):
        super(MultiboxTrainChain, self).__init__()
        with self.init_scope():
            self.model = model
        self.alpha = alpha
        self.k = k

    def __call__(self, imgs, gt_mb_locs, gt_mb_labels):
        mb_locs, mb_confs = self.model(imgs)
        loc_loss, conf_loss = multibox_loss(
            mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
        loss = loc_loss * self.alpha + conf_loss

        chainer.reporter.report(
            {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
            self)

        return loss

In [0]:
class Transform(object):

    def __init__(self, coder, size, mean):
        # to send cpu, make a copy
        self.coder = copy.copy(coder)
        self.coder.to_cpu()

        self.size = size
        self.mean = mean

    def __call__(self, in_data):
        # There are five data augmentation steps
        # 1. Color augmentation
        # 2. Random expansion
        # 3. Random cropping
        # 4. Resizing with random interpolation
        # 5. Random horizontal flipping

        img, bbox, label = in_data

        # 1. Color augmentation
        img = random_distort(img)

        # 2. Random expansion
        if np.random.randint(2):
            img, param = transforms.random_expand(img, fill=self.mean, return_param=True)
            bbox = transforms.translate_bbox(
                bbox, y_offset=param['y_offset'], x_offset=param['x_offset'])

        # 3. Random cropping
        img, param = random_crop_with_bbox_constraints(img, bbox, return_param=True)
        bbox, param = transforms.crop_bbox(
            bbox, y_slice=param['y_slice'], x_slice=param['x_slice'],
            allow_outside_center=False, return_param=True)
        label = label[param['index']]

        # 4. Resizing with random interpolatation
        _, H, W = img.shape
        img = resize_with_random_interpolation(img, (self.size, self.size))
        bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))

        # 5. Random horizontal flipping
        img, params = transforms.random_flip(img, x_random=True, return_param=True)
        bbox = transforms.flip_bbox(bbox, (self.size, self.size), x_flip=params['x_flip'])

        # Preparation for SSD network
        img -= self.mean
        mb_loc, mb_label = self.coder.encode(bbox, label)

        return img, mb_loc, mb_label

In [0]:
model = SSD300(n_fg_class=len(voc_bbox_label_names), pretrained_model='imagenet')
model.use_preset('evaluate')
train_chain = MultiboxTrainChain(model)
chainer.cuda.get_device_from_id(0).use()
model.to_gpu()

In [0]:
BATCH_SIZE=32
train = TransformDataset(
    ConcatenatedDataset(
        VOCBboxDataset(year='2007', split='trainval'),
        VOCBboxDataset(year='2012', split='trainval')
    ),
    Transform(model.coder, model.insize, model.mean))
train_iter = chainer.iterators.MultiprocessIterator(train, BATCH_SIZE)

test = VOCBboxDataset(year='2007', split='test', use_difficult=True, return_difficult=True)
test_iter = chainer.iterators.SerialIterator(test, BATCH_SIZE, repeat=False, shuffle=False)

In [0]:
# initial lr is set to 1e-3 by ExponentialShift
optimizer = chainer.optimizers.MomentumSGD()
optimizer.setup(train_chain)
for param in train_chain.params():
    if param.name == 'b':
        param.update_rule.add_hook(GradientScaling(2))
    else:
        param.update_rule.add_hook(WeightDecay(0.0005))

In [0]:
updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0)
trainer = training.Trainer(updater, (120000, 'iteration'), 'result/ssd')
trainer.extend(
    extensions.ExponentialShift('lr', 0.1, init=1e-3),
    trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration'))    

In [0]:
trainer.extend(
    DetectionVOCEvaluator(test_iter, model, use_07_metric=True, label_names=voc_bbox_label_names), trigger=(10000, 'iteration'))

log_interval = 100, 'iteration'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'elapsed_time', 'lr','main/loss', 'main/loss/loc', 'main/loss/conf', 'validation/main/map']), trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=1000))
trainer.extend(extensions.snapshot(), trigger=(10000, 'iteration'))
trainer.extend(extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'), trigger=(1000, 'iteration'))
trainer.run()

epoch       iteration   elapsed_time  lr          main/loss   main/loss/loc  main/loss/conf  validation/main/map
[J0           100         450.031       0.001       12.2222     2.77544        9.44673                              
[J0           200         895.884       0.001       7.65592     2.51781        5.13811                              
[J0           300         1341.8        0.001       7.22516     2.30498        4.92018                              
[J0           400         1785.43       0.001       6.87099     2.08318        4.78781                              
[J0           500         2226.99       0.001       6.6508      1.99207        4.65872                              
[J1           600         2668.55       0.001       6.40884     1.8889         4.51994                              
[J1           700         3110.16       0.001       6.26858     1.8286         4.43999                              
[J1           800         3554.03       0.001       6.12099 