In [None]:
# 모듈 import

from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)

In [None]:
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass", 
           "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")

# config file 들고오기
cfg = Config.fromfile('./configs/swin/swin_t.py')

# data root 설정
root='../dataset/'

# dataset config 수정
cfg.data.train.classes = classes
cfg.data.train.img_prefix = root
cfg.data.train.ann_file = root + 'split_train_v2.json' # train json 정보

cfg.data.val.classes = classes
cfg.data.val.img_prefix = root
cfg.data.val.ann_file = root + 'split_valid_v2.json' # train json 정보

cfg.data.test.classes = classes
cfg.data.test.img_prefix = root
cfg.data.test.ann_file = root + 'test.json' # test json 정보
cfg.data.test.pipeline[1]['img_scale'] = (1024,1024) # Resize

cfg.data.samples_per_gpu = 4
cfg.data.workers_per_gpu = 4

# seed 설정
cfg.seed = 2021

# gpu_ids 설정
cfg.gpu_ids = [0]

# work_dir 경로 설정
cfg.work_dir = './work_dirs/swin_t'

# model num_classes 설정
cfg.model.roi_head.bbox_head[0].num_classes = 10
cfg.model.roi_head.bbox_head[1].num_classes = 10
cfg.model.roi_head.bbox_head[2].num_classes = 10

# img_norm 설정
cfg.img_norm_cfg = dict(mean=[127.49413776397705, 127.43779182434082, 127.46098327636719], 
                        std=[73.86627551077616, 73.88234865304638, 73.8944344154546], to_rgb=True)

cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.checkpoint_config = dict(max_keep_ckpts=3, interval=1)

# evaluation 설정
cfg.evaluation.save_best = 'bbox_mAP_50'
cfg.evaluation.interval = 1

# wandb 연결
cfg.log_config.hooks.append(
    dict(
        type='WandbLoggerHook',
        init_kwargs=dict(
            project='mmdetection',
            name='number-model',
            entity = 'passion-ate'
        )
    )
)

In [None]:
# multi-scale training 설정
cfg.data.train.pipeline[3].policies[0][0].img_scale = [(614, 1024), (655, 1024), (696, 1024), 
                                                       (737, 1024), (778, 1024), (819, 1024), 
                                                       (860, 1024), (901, 1024), (942, 1024), 
                                                       (983, 1024), (1024, 1024)]

cfg.data.train.pipeline[3].policies[1][0].img_scale = [(512, 1024), (640, 1024), (768, 1024)]
cfg.data.train.pipeline[3].policies[1][2].img_scale = [(614, 1024), (655, 1024), (696, 1024), 
                                                       (737, 1024), (778, 1024), (819, 1024), 
                                                       (860, 1024), (901, 1024), (942, 1024), 
                                                       (983, 1024), (1024, 1024)]

In [None]:
# pretrained 파일 변경 -> 경로 지정
# https://github.com/microsoft/Swin-Transformer 참고
cfg.pretrained = '/opt/ml/detection/mmdetection/configs/swin/swin_base_patch4_window12_384_22k.pth'

In [None]:
# build_dataset
datasets = [build_dataset(cfg.data.train)]

In [None]:
# dataset 확인
datasets[0]

In [None]:
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()

In [None]:
# 모델 학습
train_detector(model, datasets[0], cfg, distributed=False, validate=True)