# 続・detectron2 for まちカドまぞく ～訓練編～

<img src="https://user-images.githubusercontent.com/33882378/79055210-2ff0e600-7c86-11ea-93c6-8a65112f80f0.jpg">

detectron2 の訓練をカスタマイズする方法

---

In [1]:
import os
import numpy as np
import json
import matplotlib.pyplot as plt
import cv2
import random

In [2]:
# VoTT のエクスポートファイルや、画像が格納されているディレクトリ
BASE_DIRECTORY = './vott-json-export/'
# VoTT のエクスポートファイル名
EXPORT_FILENAME = 'Machikado-export.json'
# 訓練データに使用する割合
TRAIN_RATIO = 0.8
# 乱数シード
RANDOM_STATE = 0

### DatasetCatalogを用意する

In [3]:
from detectron2.data import DatasetCatalog, MetadataCatalog
from Machikado_vott import get_cat_names, get_machikado_dicts

# vott エクスポートファイルの読み込み
CAT_NAME2ID, CAT_ID2NAME = get_cat_names(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME))
dataset_dicts = get_machikado_dicts(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME), BASE_DIRECTORY, CAT_NAME2ID)

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)

split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

MetadataCatalog.get('train').set(thing_classes=list(CAT_NAME2ID.keys()))
MetadataCatalog.get('test').set(thing_classes=list(CAT_NAME2ID.keys()))

警告: name: 59.jpg - 画像サイズの不整合 image_size:(268, 201), ./vott-json-export/Machikado-export.json: (600, 600)


Metadata(name='test', thing_classes=['Shamiko', 'Gosenzo', 'Lilith', 'Momo', 'Mikan', 'Mob'])

In [4]:
from detectron2.data import DatasetCatalog

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)
split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

---
## 学習

In [5]:
from detectron2.config import get_cfg
from MachikadoDatasetMapper import get_custom_cfg

cfg = get_cfg()
cfg.OUTPUT_DIR = './output'
cfg.CUDA = 'cuda:0'

# cfg.merge_from_file("../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# cfg.MODEL.WEIGHTS = './coco_models/model_final_f10217.pkl'
# cfg.SOLVER.IMS_PER_BATCH = 2

# 重いけど、これ精度良いです。
cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')
cfg.MODEL.WEIGHTS = './coco_models/model_final_2d9806.pkl'
cfg.SOLVER.IMS_PER_BATCH = 1 # GTX2070 ではこれが限界

cfg.DATASETS.TRAIN = ('train',)
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1500    # 300 iterations seems good enough, but you can certainly train longer <- とあるが、まあデータセットによるよね・・・
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(CAT_ID2NAME) 

cfg.INPUT.CROP.ENABLED = True
cfg.INPUT.CROP.SIZE = [0.8, 0.8]

custom_cfg = get_custom_cfg()

In [6]:
from MachikadoDatasetMapper import MachikadoDatasetMapper
from detectron2.engine import DefaultTrainer

from detectron2.data import (
    build_detection_test_loader,
    build_detection_train_loader,
)

class MachikadoTrainer(DefaultTrainer):
    def __init__(self, cfg):
        super().__init__(cfg)
        
    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg, MachikadoDatasetMapper(cfg, custom_cfg=custom_cfg))

In [7]:
# 出力先のディレクトリを作る
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = MachikadoTrainer(cfg) 
trainer.resume_or_load(resume=False) # True で途中から学習できるらしい
trainer.train()

[32m[04/13 07:39:29 d2.data.detection_utils]: [0mTransformGens used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[04/13 07:39:29 d2.data.build]: [0mRemoved 0 images with no usable annotations. 57 images left.
[32m[04/13 07:39:29 d2.data.build]: [0mDistribution of instances among all 6 categories:
[36m|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|  Shamiko   | 38           |  Gosenzo   | 20           |   Lilith   | 10           |
|    Momo    | 20           |   Mikan    | 8            |    Mob     | 8            |
|            |              |            |              |            |              |
|   total    | 104          |            |              |            |              |[0m
[32m[04/13 07:39:29 d2.data.common]: [0mSerializing 57 elements to byte te

'roi_heads.box_predictor.cls_score.weight' has shape (81, 1024) in the checkpoint but (7, 1024) in the model! Skipped.
'roi_heads.box_predictor.cls_score.bias' has shape (81,) in the checkpoint but (7,) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.weight' has shape (320, 1024) in the checkpoint but (24, 1024) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.bias' has shape (320,) in the checkpoint but (24,) in the model! Skipped.
'roi_heads.mask_head.predictor.weight' has shape (80, 256, 1, 1) in the checkpoint but (6, 256, 1, 1) in the model! Skipped.
'roi_heads.mask_head.predictor.bias' has shape (80,) in the checkpoint but (6,) in the model! Skipped.


[32m[04/13 07:39:30 d2.engine.train_loop]: [0mStarting training from iteration 0
[32m[04/13 07:39:40 d2.utils.events]: [0m eta: 0:11:32  iter: 19  total_loss: 3.243  loss_cls: 1.975  loss_box_reg: 0.498  loss_mask: 0.692  loss_rpn_cls: 0.002  loss_rpn_loc: 0.007  time: 0.4553  data_time: 0.0099  lr: 0.000005  max_mem: 3668M
[32m[04/13 07:39:49 d2.utils.events]: [0m eta: 0:11:16  iter: 39  total_loss: 3.366  loss_cls: 1.791  loss_box_reg: 0.887  loss_mask: 0.690  loss_rpn_cls: 0.003  loss_rpn_loc: 0.022  time: 0.4647  data_time: 0.0018  lr: 0.000010  max_mem: 3668M
[32m[04/13 07:39:59 d2.utils.events]: [0m eta: 0:11:20  iter: 59  total_loss: 3.113  loss_cls: 1.568  loss_box_reg: 0.885  loss_mask: 0.688  loss_rpn_cls: 0.005  loss_rpn_loc: 0.013  time: 0.4688  data_time: 0.0018  lr: 0.000015  max_mem: 3668M
[32m[04/13 07:40:08 d2.utils.events]: [0m eta: 0:11:22  iter: 79  total_loss: 2.478  loss_cls: 1.138  loss_box_reg: 0.583  loss_mask: 0.677  loss_rpn_cls: 0.002  loss_rpn_loc