In [1]:
# 모듈 import

from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.utils import get_device
import wandb

In [14]:
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass", 
           "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")
    
# config file 들고오기
filename = './configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco_balanced.py'
cfg = Config.fromfile(filename)

# mean, std도 수정해줄 필요가 있을 것
# img_norm_cfg = dict(
#     mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)



root='../../dataset/'

# dataset config 수정
# cfg.data.train.classes = classes
# cfg.data.train.img_prefix = root
# cfg.data.train.ann_file = root + 'train.json' # train json 정보
# cfg.data.train.pipeline[2]['img_scale'] = (512,512) # Resize
cfg.data.train.oversample_thr = 0.1
cfg.data.train.dataset.classes = classes
cfg.data.train.dataset.img_prefix = root
cfg.data.train.dataset.ann_file = root + 'train.json' # train json 정보
cfg.data.train.dataset.pipeline[2]['img_scale'] = (512,512) # Resize

cfg.data.test.classes = classes
cfg.data.test.img_prefix = root
cfg.data.test.ann_file = root + 'test.json' # test json 정보
cfg.data.test.pipeline[1]['img_scale'] = (512,512) # Resize

# wandb logger hook 추가
cfg.log_config.hooks = [
    dict(type='TextLoggerHook'),
    dict(type='MMDetWandbHook',
         init_kwargs={'project': "object_detection",
                     'entity' : "cv-2",
                     'name' : filename.split('/')[-1]},
         interval=10,
         log_checkpoint=True,
         log_checkpoint_metadata=True,
         num_eval_images=100)]



cfg.data.samples_per_gpu = 4

cfg.seed = 2022
cfg.gpu_ids = [0]
cfg.work_dir = './work_dirs/cascade_rcnn_r50_fpn_1x_coco_balanced'

# change number of classes
if(filename in ['./configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py',
                './configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco_balanced.py', 
                './configs/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py']):
    cfg.model.roi_head.bbox_head[0].num_classes=10
    cfg.model.roi_head.bbox_head[1].num_classes=10
    cfg.model.roi_head.bbox_head[2].num_classes=10
elif(filename in ['./configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py',
                     './configs/detr/detr_r50_8x2_150e_coco.py']):
    cfg.model.bbox_head.num_classes = 10
elif(filename in ['./configs/yolox/yolox_tiny_8x8_300e_coco.py']):
    cfg.model.bbox_head.feat_channels=10
elif(filename in ['./configs/yolo/yolov3_d53_320_273e_coco.py']):
    pass
else:
    cfg.model.roi_head.bbox_head.num_classes = 10


#epoch 수 변경
cfg.runner = dict(type='EpochBasedRunner', max_epochs=10)


cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.checkpoint_config = dict(max_keep_ckpts=3, interval=1)
cfg.device = get_device()

In [15]:
# build_dataset
datasets = [build_dataset(cfg.data.train.dataset)]

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!


In [16]:
# dataset 확인
datasets[0]


CocoDataset Train dataset with number of images 4882, and instance counts: 
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| category          | count | category      | count | category        | count | category    | count | category     | count |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| 0 [General trash] | 4205  | 1 [Paper]     | 6349  | 2 [Paper pack]  | 909   | 3 [Metal]   | 936   | 4 [Glass]    | 976   |
| 5 [Plastic]       | 2966  | 6 [Styrofoam] | 1267  | 7 [Plastic bag] | 5182  | 8 [Battery] | 159   | 9 [Clothing] | 461   |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+

In [17]:
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()

2024-01-11 03:59:04,809 - mmcv - INFO - initialize ResNet with init_cfg {'type': 'Pretrained', 'checkpoint': 'torchvision://resnet50'}
2024-01-11 03:59:04,810 - mmcv - INFO - load model from: torchvision://resnet50
2024-01-11 03:59:04,811 - mmcv - INFO - load checkpoint from torchvision path: torchvision://resnet50

unexpected key in source state_dict: fc.weight, fc.bias

2024-01-11 03:59:04,975 - mmcv - INFO - initialize FPN with init_cfg {'type': 'Xavier', 'layer': 'Conv2d', 'distribution': 'uniform'}
2024-01-11 03:59:05,062 - mmcv - INFO - initialize RPNHead with init_cfg {'type': 'Normal', 'layer': 'Conv2d', 'std': 0.01}
2024-01-11 03:59:05,075 - mmcv - INFO - initialize Shared2FCBBoxHead with init_cfg [{'type': 'Normal', 'std': 0.01, 'override': {'name': 'fc_cls'}}, {'type': 'Normal', 'std': 0.001, 'override': {'name': 'fc_reg'}}, {'type': 'Xavier', 'distribution': 'uniform', 'override': [{'name': 'shared_fcs'}, {'name': 'cls_fcs'}, {'name': 'reg_fcs'}]}]
2024-01-11 03:59:05,180 -

In [13]:
# 모델 학습
train_detector(model, datasets[0], cfg, distributed=False, validate=False)

2024-01-11 03:01:01,222 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
2024-01-11 03:01:01,225 - mmdet - INFO - Start running, host: root@instance-5032, work_dir: /data/ephemeral/home/level2-objectdetection-cv-02/mmdetection/work_dirs/cascade_rcnn_r50_fpn_1x_coco_balanced
2024-01-11 03:01:01,226 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LO

2024-01-11 03:01:21,523 - mmdet - INFO - Epoch [1][50/1221]	lr: 1.978e-03, eta: 1:11:14, time: 0.351, data_time: 0.050, memory: 2577, loss_rpn_cls: 0.4160, loss_rpn_bbox: 0.0433, s0.loss_cls: 0.7079, s0.acc: 87.0537, s0.loss_bbox: 0.1248, s1.loss_cls: 0.2950, s1.acc: 87.2021, s1.loss_bbox: 0.0366, s2.loss_cls: 0.1319, s2.acc: 93.7324, s2.loss_bbox: 0.0051, loss: 1.7606, grad_norm: 9.1915
2024-01-11 03:01:35,628 - mmdet - INFO - Epoch [1][100/1221]	lr: 3.976e-03, eta: 1:03:56, time: 0.282, data_time: 0.007, memory: 2577, loss_rpn_cls: 0.1364, loss_rpn_bbox: 0.0347, s0.loss_cls: 0.3199, s0.acc: 93.5049, s0.loss_bbox: 0.1593, s1.loss_cls: 0.0835, s1.acc: 97.2617, s1.loss_bbox: 0.0454, s2.loss_cls: 0.0236, s2.acc: 98.7051, s2.loss_bbox: 0.0060, loss: 0.8087, grad_norm: 2.9059
2024-01-11 03:01:49,552 - mmdet - INFO - Epoch [1][150/1221]	lr: 5.974e-03, eta: 1:01:06, time: 0.278, data_time: 0.006, memory: 2577, loss_rpn_cls: 0.1152, loss_rpn_bbox: 0.0369, s0.loss_cls: 0.3386, s0.acc: 92.8672,

VBox(children=(Label(value='4762.293 MB of 5273.069 MB uploaded\r'), FloatProgress(value=0.9031350714381969, m…

0,1
learning_rate,▃▅██████████████████████████████▁▁▁▁▁▁▁▁
momentum,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,▅█▂▅▄▂▂▄▃▁▂▃▃▃▂▂▃▂▅▂▄▃▅▄▄▅▄▃▂▅▂▄▂▃▃▁▂▁▆▂
train/loss,▅█▇▆▇▅▄▇▅▅▅▆▆▅▄▄▃▄▆▄▆▄▅▅▄▄▄▄▂▅▂▅▂▃▂▁▂▁▄▁
train/loss_rpn_bbox,▄▇▅▄█▄▂▅▃▄▃▃▅▅▄▄▂▆▅▃▅▃▄▄▄▃▄▃▄▄▂▄▂▃▂▁▃▁▄▁
train/loss_rpn_cls,▅█▅▄▄▄▂▄▃▃▂▃▃▃▃▃▂▃▄▂▃▂▃▃▃▃▂▂▂▂▂▂▁▂▁▁▁▁▂▁
train/s0.acc,▄▂▃▃▂▃▅▁▃▃▃▂▂▃▄▅▅▄▃▄▃▄▃▃▄▄▄▃▇▃▇▄▇▆▆█▇█▄▇
train/s0.loss_bbox,▆█▅▄▅▄▃▆▄▄▃▄▅▄▃▃▂▃▄▃▄▃▄▄▃▃▃▄▁▃▂▃▁▂▁▁▂▁▃▁
train/s0.loss_cls,▅█▇▅▆▅▄▇▄▅▅▆▆▅▄▃▃▄▅▄▅▃▅▅▄▄▄▄▂▅▂▄▂▂▂▁▂▁▄▂
train/s1.acc,█▅▄▃▂▄▅▁▂▃▃▃▂▃▄▄▄▄▃▄▃▄▄▃▄▃▄▄▇▄▆▄▇▅▇█▆█▄▇

0,1
learning_rate,0.002
momentum,0.9
train/grad_norm,2.74297
train/loss,0.54596
train/loss_rpn_bbox,0.02431
train/loss_rpn_cls,0.02502
train/s0.acc,93.75977
train/s0.loss_bbox,0.06985
train/s0.loss_cls,0.18638
train/s1.acc,93.64341
