In [1]:
# 모듈 import

from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.utils import get_device

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass", 
           "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")

# config file 들고오기
cfg = Config.fromfile('/data/ephemeral/home/UniverseNet/configs/universenet/universenet101_gfl_fp16_4x4_mstrain_480_960_2x_coco.py')

root='/data/ephemeral/home/dataset/'

In [3]:
print(cfg.optimizer)

{'type': 'SGD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0.0001}


In [4]:
# dataset config 수정
cfg.data.train.classes = classes
cfg.data.train.img_prefix = root
cfg.data.train.ann_file = root + 'train_rgt.json' # train json 정보
cfg.data.train.pipeline[2]['img_scale'] = (512,512) # Resize

cfg.data.test.classes = classes
cfg.data.test.img_prefix = root
cfg.data.test.ann_file = root + 'test.json' # test json 정보
cfg.data.test.pipeline[1]['img_scale'] = (512,512) # Resize

cfg.data.samples_per_gpu = 4

cfg.runner.max_epochs = 30

cfg.seed = 2022
cfg.gpu_ids = [0]
cfg.work_dir = './work_dirs/universenet_epoch_50_newdata'

cfg.model.bbox_head.num_classes = 10

cfg.optimizer = dict(type='AdamW', lr=0.01, weight_decay=0.0001)
cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.checkpoint_config = dict(max_keep_ckpts=3, interval=1)
cfg.device = get_device()

In [5]:
# build_dataset
datasets = [build_dataset(cfg.data.train)]

loading annotations into memory...
Done (t=0.08s)
creating index...
index created!


In [6]:
# dataset 확인
datasets[0]


CocoDataset Train dataset with number of images 4883, and instance counts: 
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| category          | count | category      | count | category        | count | category    | count | category     | count |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| 0 [General trash] | 701   | 1 [Paper]     | 6352  | 2 [Paper pack]  | 897   | 3 [Metal]   | 936   | 4 [Glass]    | 982   |
| 5 [Plastic]       | 2943  | 6 [Styrofoam] | 1263  | 7 [Plastic bag] | 5178  | 8 [Battery] | 159   | 9 [Clothing] | 468   |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+

In [7]:
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()

2024-01-18 11:46:48,069 - mmcv - INFO - initialize Res2Net with init_cfg {'type': 'Pretrained', 'checkpoint': 'open-mmlab://res2net101_v1d_26w_4s'}
2024-01-18 11:46:48,070 - mmcv - INFO - load model from: open-mmlab://res2net101_v1d_26w_4s
2024-01-18 11:46:48,071 - mmcv - INFO - load checkpoint from openmmlab path: open-mmlab://res2net101_v1d_26w_4s

unexpected key in source state_dict: fc.weight, fc.bias

missing keys in source state_dict: layer2.0.convs.0.conv_offset.weight, layer2.0.convs.0.conv_offset.bias, layer2.0.convs.1.conv_offset.weight, layer2.0.convs.1.conv_offset.bias, layer2.0.convs.2.conv_offset.weight, layer2.0.convs.2.conv_offset.bias, layer2.1.convs.0.conv_offset.weight, layer2.1.convs.0.conv_offset.bias, layer2.1.convs.1.conv_offset.weight, layer2.1.convs.1.conv_offset.bias, layer2.1.convs.2.conv_offset.weight, layer2.1.convs.2.conv_offset.bias, layer2.2.convs.0.conv_offset.weight, layer2.2.convs.0.conv_offset.bias, layer2.2.convs.1.conv_offset.weight, layer2.2.convs

In [8]:
# 모델 학습
train_detector(model, datasets[0], cfg, distributed=False, validate=False)

2024-01-18 11:46:49,257 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
2024-01-18 11:46:51,221 - mmdet - INFO - Start running, host: root@instance-5874, work_dir: /data/ephemeral/home/UniverseNet/work_dirs/universenet_epoch_50_newdata
2024-01-18 11:46:51,222 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook                      
 -------------------- 
after_train_iter:
(ABOVE_NORMAL) OptimizerHook                      
