# 続・detectron2 for まちカドまぞく ～訓練編～

<img src="https://user-images.githubusercontent.com/33882378/79055210-2ff0e600-7c86-11ea-93c6-8a65112f80f0.jpg">

detectron2 の訓練をカスタマイズする方法

---

In [1]:
import os
import numpy as np
import json
import matplotlib.pyplot as plt
import cv2
import random

In [2]:
# VoTT のエクスポートファイルや、画像が格納されているディレクトリ
BASE_DIRECTORY = './vott-json-export/'
# VoTT のエクスポートファイル名
EXPORT_FILENAME = 'Machikado-export.json'
# 訓練データに使用する割合
TRAIN_RATIO = 0.8
# 乱数シード
RANDOM_STATE = 0

### DatasetCatalogを用意する

In [3]:
from detectron2.data import DatasetCatalog, MetadataCatalog
from Machikado_vott import get_cat_names, get_machikado_dicts

# vott エクスポートファイルの読み込み
CAT_NAME2ID, CAT_ID2NAME = get_cat_names(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME))
dataset_dicts = get_machikado_dicts(os.path.join(BASE_DIRECTORY, EXPORT_FILENAME), BASE_DIRECTORY, CAT_NAME2ID)

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)

split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

MetadataCatalog.get('train').set(thing_classes=list(CAT_NAME2ID.keys()))
MetadataCatalog.get('test').set(thing_classes=list(CAT_NAME2ID.keys()))

警告: name: 59.jpg - 画像サイズの不整合 image_size:(268, 201), ./vott-json-export/Machikado-export.json: (600, 600)


Metadata(name='test', thing_classes=['Shamiko', 'Gosenzo', 'Lilith', 'Momo', 'Mikan', 'Mob'])

In [4]:
from detectron2.data import DatasetCatalog

# 訓練用、テスト用に分ける
random.seed(RANDOM_STATE)
random.shuffle(dataset_dicts)
split_idx = int(len(dataset_dicts) * TRAIN_RATIO) + 1

# 登録
DatasetCatalog.clear()
DatasetCatalog.register('train', lambda : dataset_dicts[:split_idx])
DatasetCatalog.register('test', lambda : dataset_dicts[split_idx:])

---
## 学習

### カスタムデータマッパー

オリジナルの DatasetMapper の位置は detectron2/data/dataset_mapper.py

In [5]:
import copy
import logging
import numpy as np
import torch
from fvcore.common.file_io import PathManager
from PIL import Image

from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils

class MachikadoDatasetMapper:
    def __init__(self, cfg, is_train=True):
        if cfg.INPUT.CROP.ENABLED and is_train:
            self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
            logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
        else:
            self.crop_gen = None

        self.tfm_gens = utils.build_transform_gen(cfg, is_train)

        # fmt: off
        self.img_format     = cfg.INPUT.FORMAT
        self.mask_on        = cfg.MODEL.MASK_ON
        self.mask_format    = cfg.INPUT.MASK_FORMAT
        self.keypoint_on    = cfg.MODEL.KEYPOINT_ON
        self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
        # fmt: on
        if self.keypoint_on and is_train:
            # Flip only makes sense in training
            self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
        else:
            self.keypoint_hflip_indices = None

        if self.load_proposals:
            self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
            self.proposal_topk = (
                cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
                if is_train
                else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
            )
        self.is_train = is_train

    def __call__(self, dataset_dict):
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        # USER: Write your own image loading if it's not from a file
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        utils.check_image_size(dataset_dict, image)

        if "annotations" not in dataset_dict:
            image, transforms = T.apply_transform_gens(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
            )
        else:
            # Crop around an instance if there are instances in the image.
            # USER: Remove if you don't use cropping
            if self.crop_gen:
                crop_tfm = utils.gen_crop_transform_with_instance(
                    self.crop_gen.get_crop_size(image.shape[:2]),
                    image.shape[:2],
                    np.random.choice(dataset_dict["annotations"]),
                )
                image = crop_tfm.apply_image(image)
            image, transforms = T.apply_transform_gens(self.tfm_gens, image)
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))

        # USER: Remove if you don't use pre-computed proposals.
        if self.load_proposals:
            utils.transform_proposals(
                dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
            )

        if not self.is_train:
            # USER: Modify this if you want to keep them for some reason.
            dataset_dict.pop("annotations", None)
            dataset_dict.pop("sem_seg_file_name", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                if not self.mask_on:
                    anno.pop("segmentation", None)
                if not self.keypoint_on:
                    anno.pop("keypoints", None)

            # USER: Implement additional transformations if you have other types of data
            annos = [
                utils.transform_instance_annotations(
                    obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
                )
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]
            instances = utils.annotations_to_instances(
                annos, image_shape, mask_format=self.mask_format
            )
            # Create a tight bounding box from masks, useful when image is cropped
            if self.crop_gen and instances.has("gt_masks"):
                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            dataset_dict["instances"] = utils.filter_empty_instances(instances)

        # USER: Remove if you don't do semantic/panoptic segmentation.
        if "sem_seg_file_name" in dataset_dict:
            with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
                sem_seg_gt = Image.open(f)
                sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
            sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
            dataset_dict["sem_seg"] = sem_seg_gt
        return dataset_dict

### カスタムトレイナー

DefaultTrainer の位置は detectron2/engine/defaults.py

In [6]:
from detectron2.engine import DefaultTrainer

from detectron2.data import (
    build_detection_test_loader,
    build_detection_train_loader,
)

class MachikadoTrainer(DefaultTrainer):
    def __init__(self, cfg):
        super().__init__(cfg)
    
    # これをオーバライドする
    @classmethod
    def build_train_loader(cls, cfg):
#         return build_detection_train_loader(cfg, mapper=mapper)
        print('OK!!')
        return build_detection_train_loader(cfg, MachikadoDatasetMapper(cfg, True))

### 訓練

In [7]:
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.OUTPUT_DIR = './output'
cfg.CUDA = 'cuda:0'

# cfg.merge_from_file("../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# cfg.MODEL.WEIGHTS = './coco_models/model_final_f10217.pkl'
# cfg.SOLVER.IMS_PER_BATCH = 2

# 重いけど、これ精度良いです。
cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')
cfg.MODEL.WEIGHTS = './coco_models/model_final_2d9806.pkl'
cfg.SOLVER.IMS_PER_BATCH = 1 # GTX2070 ではこれが限界

cfg.DATASETS.TRAIN = ('train',)
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 10    # 300 iterations seems good enough, but you can certainly train longer <- とあるが、まあデータセットによるよね・・・
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(CAT_ID2NAME) 

In [8]:
# 出力先のディレクトリを作る
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = MachikadoTrainer(cfg) 
trainer.resume_or_load(resume=False) # True で途中から学習できるらしい
trainer.train()

OK!!
[32m[04/12 11:48:19 d2.data.detection_utils]: [0mTransformGens used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[04/12 11:48:19 d2.data.build]: [0mRemoved 0 images with no usable annotations. 57 images left.
[32m[04/12 11:48:19 d2.data.build]: [0mDistribution of instances among all 6 categories:
[36m|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|  Shamiko   | 38           |  Gosenzo   | 20           |   Lilith   | 10           |
|    Momo    | 20           |   Mikan    | 8            |    Mob     | 8            |
|            |              |            |              |            |              |
|   total    | 104          |            |              |            |              |[0m
[32m[04/12 11:48:19 d2.data.common]: [0mSerializing 57 elements to by

'roi_heads.box_predictor.cls_score.weight' has shape (81, 1024) in the checkpoint but (7, 1024) in the model! Skipped.
'roi_heads.box_predictor.cls_score.bias' has shape (81,) in the checkpoint but (7,) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.weight' has shape (320, 1024) in the checkpoint but (24, 1024) in the model! Skipped.
'roi_heads.box_predictor.bbox_pred.bias' has shape (320,) in the checkpoint but (24,) in the model! Skipped.
'roi_heads.mask_head.predictor.weight' has shape (80, 256, 1, 1) in the checkpoint but (6, 256, 1, 1) in the model! Skipped.
'roi_heads.mask_head.predictor.bias' has shape (80,) in the checkpoint but (6,) in the model! Skipped.


[32m[04/12 11:48:21 d2.engine.train_loop]: [0mStarting training from iteration 0
[32m[04/12 11:48:34 d2.utils.events]: [0m eta: 0:00:00  iter: 9  total_loss: 3.374  loss_cls: 1.948  loss_box_reg: 0.739  loss_mask: 0.691  loss_rpn_cls: 0.010  loss_rpn_loc: 0.022  time: 0.4708  data_time: 0.0096  lr: 0.000002  max_mem: 3690M
[32m[04/12 11:48:34 d2.engine.hooks]: [0mOverall training speed: 7 iterations in 0:00:03 (0.5381 s / it)
[32m[04/12 11:48:34 d2.engine.hooks]: [0mTotal training time: 0:00:12 (0:00:08 on hooks)
