# 続・detectron2 for まちカドまぞく ～カスタム訓練編～

<img src="https://user-images.githubusercontent.com/33882378/79110283-0f737980-7db5-11ea-89b0-3b6b19ae1716.jpg">

detectron2 の訓練をカスタマイズする方法

ここでは、「カスタムトレーナー」を作りますが、今までの工程では一番楽でした。。

---

In [1]:
import os
import numpy as np
import json
import matplotlib.pyplot as plt
import cv2
import random

---
##  VoTT Export からの読み込み

In [2]:
# VoTT のエクスポートファイルや、画像が格納されているディレクトリ
BASE_DIRECTORY = './vott-json-export/'
# VoTT のエクスポートファイル名
EXPORT_FILENAME = 'Machikado-export.json'
# 訓練データに使用する割合
TRAIN_RATIO = 0.8
# 乱数シード
RANDOM_STATE = 0

In [3]:
from detectron2.data import DatasetCatalog, MetadataCatalog
from machikado_util.Machikado_vott import get_cat_names, get_machikado_dicts

# vott エクスポートファイルの読み込み
CAT_NAME2ID, CAT_ID2NAME = get_cat_names(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME))
dataset_dicts = get_machikado_dicts(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME), BASE_DIRECTORY, CAT_NAME2ID)

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)

split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

MetadataCatalog.get('train').set(thing_classes=list(CAT_NAME2ID.keys()))
MetadataCatalog.get('test').set(thing_classes=list(CAT_NAME2ID.keys()))

警告: name: 59.jpg - 画像サイズが不一致であるためスキップ image_size:(268, 201), ./vott-json-export/Machikado-export.json: (600, 600)


Metadata(name='test', thing_classes=['Shamiko', 'Gosenzo', 'Lilith', 'Momo', 'Mikan', 'Mob'])

In [4]:
from detectron2.data import DatasetCatalog

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)
split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

---
## 学習

* 前回より訓練回数を持っています。

In [5]:
from detectron2.config import get_cfg
from machikado_util.custom_config import append_custom_cfg

cfg = get_cfg()
append_custom_cfg(cfg)

cfg.OUTPUT_DIR = './output'
cfg.CUDA = 'cuda:0'

# cfg.merge_from_file("../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# cfg.MODEL.WEIGHTS = './coco_models/model_final_f10217.pkl'
# cfg.SOLVER.IMS_PER_BATCH = 2

# 重いけど、これ精度良いです。
cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')
cfg.MODEL.WEIGHTS = './coco_models/model_final_2d9806.pkl'
cfg.SOLVER.IMS_PER_BATCH = 1 # GTX2070 ではこれが限界

cfg.DATASETS.TRAIN = ('train',)
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 2500
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(CAT_ID2NAME) 

cfg.INPUT.CROP.ENABLED = True
cfg.INPUT.CROP.SIZE = [0.8, 0.8]

#### カスタムトレーナー

* DefaultTrainer は detectron2/engine/defaults.py にあります。
* DefaultTrainer を全部書き換える必要は無く、今回の場合であればデーターマッパーを渡す部分を変更するだけで良いです。（オーバーライドすれば良いです。）
* あとは、同じです。

In [6]:
from machikado_util.MachikadoDatasetMapper import MachikadoDatasetMapper
from detectron2.engine import DefaultTrainer

class MachikadoTrainer(DefaultTrainer):
    def __init__(self, cfg):
        super().__init__(cfg)
    
    # ここをオーバーライドして、作った自分のデータマッパーを返すように変更する。
    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg, MachikadoDatasetMapper(cfg))

In [7]:
from detectron2.data import build_detection_test_loader, build_detection_train_loader

# 出力先のディレクトリを作る
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = MachikadoTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

[32m[04/14 07:37:43 d2.data.detection_utils]: [0mTransformGens used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[04/14 07:37:43 d2.data.build]: [0mRemoved 0 images with no usable annotations. 92 images left.
[32m[04/14 07:37:43 d2.data.build]: [0mDistribution of instances among all 6 categories:
[36m|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|  Shamiko   | 74           |  Gosenzo   | 28           |   Lilith   | 11           |
|    Momo    | 42           |   Mikan    | 15           |    Mob     | 19           |
|            |              |            |              |            |              |
|   total    | 189          |            |              |            |              |[0m
[32m[04/14 07:37:43 d2.data.common]: [0mSerializing 92 elements to byte te

'roi_heads.box_predictor.cls_score.weight' has shape (81, 1024) in the checkpoint but (7, 1024) in the model! Skipped.
'roi_heads.box_predictor.cls_score.bias' has shape (81,) in the checkpoint but (7,) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.weight' has shape (320, 1024) in the checkpoint but (24, 1024) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.bias' has shape (320,) in the checkpoint but (24,) in the model! Skipped.
'roi_heads.mask_head.predictor.weight' has shape (80, 256, 1, 1) in the checkpoint but (6, 256, 1, 1) in the model! Skipped.
'roi_heads.mask_head.predictor.bias' has shape (80,) in the checkpoint but (6,) in the model! Skipped.


[32m[04/14 07:37:45 d2.engine.train_loop]: [0mStarting training from iteration 0
[32m[04/14 07:37:55 d2.utils.events]: [0m eta: 0:19:54  iter: 19  total_loss: 3.546  loss_cls: 2.001  loss_box_reg: 0.814  loss_mask: 0.690  loss_rpn_cls: 0.002  loss_rpn_loc: 0.017  time: 0.4902  data_time: 0.0101  lr: 0.000005  max_mem: 3782M
[32m[04/14 07:38:04 d2.utils.events]: [0m eta: 0:20:09  iter: 39  total_loss: 3.506  loss_cls: 1.854  loss_box_reg: 0.878  loss_mask: 0.688  loss_rpn_cls: 0.002  loss_rpn_loc: 0.018  time: 0.4871  data_time: 0.0018  lr: 0.000010  max_mem: 3782M
[32m[04/14 07:38:15 d2.utils.events]: [0m eta: 0:20:23  iter: 59  total_loss: 2.860  loss_cls: 1.579  loss_box_reg: 0.577  loss_mask: 0.682  loss_rpn_cls: 0.006  loss_rpn_loc: 0.016  time: 0.4916  data_time: 0.0020  lr: 0.000015  max_mem: 3782M
[32m[04/14 07:38:25 d2.utils.events]: [0m eta: 0:20:13  iter: 79  total_loss: 2.711  loss_cls: 1.152  loss_box_reg: 0.809  loss_mask: 0.666  loss_rpn_cls: 0.003  loss_rpn_loc