In [1]:
import os
from urllib import request

from mmcv import Config

from utils import modify_path, modify_num_classes

## Model Name / Config name / Checkpoint URL

In [2]:
model_name = \
"detectors"

model_cfg = \
"detectors_cascade_rcnn_r50_1x_coco"

model_url = \
"http://download.openmmlab.com/mmdetection/v2.0/detectors/detectors_cascade_rcnn_r50_1x_coco/detectors_cascade_rcnn_r50_1x_coco-32a10ba0.pth"

## Set Config Setting Values.

In [3]:
lr_gamma = 0.1

batch_size = 8

multi_scale_factor = [0.5, 0.75, 1.0, 1.25, 1.5]

## Load Config

In [4]:
cfg = Config.fromfile(os.path.join("./configs", model_name, f"{model_cfg}.py"))

## Modify Config

In [5]:
# modify lr_gamma * default lr.
cfg.optimizer["lr"] *= lr_gamma

# modify evaluation image scales.
if multi_scale_factor:
    test_h, test_w = cfg.data.test.pipeline[1]["img_scale"]
    cfg.data.test.pipeline[1]["img_scale"] = [(int(test_h * scale_factor), int(test_w * scale_factor)) for scale_factor in multi_scale_factor]
    cfg.data.test.pipeline[1]["flip"] = True

# add pretrained model from url.
cfg.load_from = os.path.abspath(os.path.join("pretrained", f"{model_cfg}.pth"))

# # add wandb log.
# cfg.log_config = dict(
#     interval=50, hooks=[dict(type="TextLoggerHook"), dict(type="WandbLoggerHook", init_kwargs=dict(project="pstage-3-od", name=model_cfg))]
# )

# pth model file would be saved for most recent 3 epoch.
cfg.checkpoint_config = dict(interval=1, max_keep_ckpts=3)

# validation processed for every epoch, and best bbox mAP 50 model would be saved.
cfg.evaluation = dict(interval=1, metric='bbox', save_best="bbox_mAP_50")

# modify data path, save path.
modify_path(cfg, data_path="../input/data", save_path=os.path.join("./save_dir", model_cfg))

# modify num of class.
modify_num_classes(cfg, class_num=11)

# batch size and num of worker.
cfg.data.samples_per_gpu = batch_size
cfg.data.workers_per_gpu = 4

## Save Config

In [6]:
os.makedirs(os.path.join("custom_configs"), exist_ok=True)
cfg.dump(os.path.join("custom_configs", f"{model_cfg}.py"))

## Pretrained Checkpoint Save.

In [7]:
if not os.path.isdir("pretrained"):
    os.makedirs("pretrained")

ckpt_name = os.path.join("pretrained", f"{model_cfg}.pth")

request.urlretrieve(model_url, ckpt_name)

('pretrained/detectors_cascade_rcnn_r50_1x_coco.pth',
 <http.client.HTTPMessage at 0x7f78572586d0>)