In [1]:
# 모듈 import

from mmcv import Config 
from mmdet.datasets import build_dataset 
from mmdet.models import build_detector 
from mmdet.apis import train_detector 
from mmdet.datasets import (build_dataloader, build_dataset, 
                            replace_ImageToTensor) 
from mmdet.utils import get_device 
import wandb 



In [69]:
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass", 
           "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")
    
# config file 들고오기
filename = './configs/cascade_rcnn/cascade_rcnn_swin_l_fpn_1x_coco.py'
cfg = Config.fromfile(filename)


root='../../dataset/'

# dataset config 수정
if('mstrain' in filename or 'scp' in filename):
    #cfg.data.train.classes = classes
    #cfg.data.test.classes = classes
    cfg.data.test.img_prefix = root
    cfg.data.test.ann_file = root + 'test.json' # test json 정보
    # cfg.data.test.pipeline[1]['img_scale'] = (512,512) # Resize
    
    cfg.data.train.dataset.img_prefix = root
    cfg.data.train.dataset.ann_file = root + 'train.json' # train json 정보
    cfg.data.val.img_prefix = root
    cfg.data.val.ann_file = root + 'test.json' # test json 정보
elif('aug1' in filename):
    #cfg.data.train.classes = classes
    cfg.data.train.dataset.img_prefix = root
    cfg.data.train.dataset.ann_file = root + 'train.json' # train json 정보
    
    cfg.data.test.classes = classes
    cfg.data.test.img_prefix = root
    cfg.data.test.ann_file = root + 'test.json' # test json 정보
else: 
    cfg.data.train.classes = classes
    cfg.data.train.img_prefix = root
    cfg.data.train.ann_file = root + 'train.json' # train json 정보
    cfg.data.train.pipeline[2]['img_scale'] = (1024,1024) # Resize
    
    cfg.data.test.classes = classes
    cfg.data.test.img_prefix = root
    cfg.data.test.ann_file = root + 'test.json' # test json 정보
    cfg.data.test.pipeline[1]['img_scale'] = (1024,1024) # Resize


# wandb logger hook 추가
cfg.log_config.hooks = [
    dict(type='TextLoggerHook'),
    dict(type='MMDetWandbHook',
         init_kwargs={'project': "object_detection",
                     'entity' : "cv-2",
                     'name' : filename.split('/')[-1]},
         interval=10,
         log_checkpoint=True,
         log_checkpoint_metadata=True,
         num_eval_images=100)]

cfg.data.samples_per_gpu = 4

cfg.seed = 2022
cfg.gpu_ids = [0]
cfg.work_dir = './work_dirs/cascade_rcnn_swin_l_fpn_1x_coco'



# change number of classes
if('cascade' in filename):
    cfg.model.roi_head.bbox_head[0].num_classes=10
    cfg.model.roi_head.bbox_head[1].num_classes=10
    cfg.model.roi_head.bbox_head[2].num_classes=10
elif('retina' in filename
    or 'atss' in filename
    or 'detr' in filename):
    cfg.model.bbox_head.num_classes = 10
elif(filename in ['./configs/yolox/yolox_tiny_8x8_300e_coco.py']):
    cfg.model.bbox_head.feat_channels=10
elif(filename in ['./configs/yolo/yolov3_d53_320_273e_coco.py']):
    pass
else:
    cfg.model.roi_head.bbox_head.num_classes = 10


#epoch, batch size 수 변경
cfg.runner = dict(type='EpochBasedRunner', max_epochs=50)
#swin small, large사용할 경우 배치 사이즈 4로 설정
auto_scale_lr = dict(enable=False, base_batch_size=4)

cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.checkpoint_config = dict(max_keep_ckpts=2, interval=1)

cfg.device = get_device()


In [70]:
# build_dataset
datasets = [build_dataset(cfg.data.train)]

loading annotations into memory...
Done (t=0.08s)
creating index...
index created!


In [71]:
# dataset 확인
datasets[0]


CocoDataset Train dataset with number of images 4883, and instance counts: 
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| category          | count | category      | count | category        | count | category    | count | category     | count |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| 0 [General trash] | 3965  | 1 [Paper]     | 6352  | 2 [Paper pack]  | 897   | 3 [Metal]   | 936   | 4 [Glass]    | 982   |
| 5 [Plastic]       | 2943  | 6 [Styrofoam] | 1263  | 7 [Plastic bag] | 5178  | 8 [Battery] | 159   | 9 [Clothing] | 468   |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+

In [72]:
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()

2024-01-15 01:14:23,159 - mmdet - INFO - load checkpoint from http path: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
2024-01-15 01:14:27,804 - mmdet - INFO - initialize FPN with init_cfg {'type': 'Xavier', 'layer': 'Conv2d', 'distribution': 'uniform'}
2024-01-15 01:14:27,833 - mmdet - INFO - initialize RPNHead with init_cfg {'type': 'Normal', 'layer': 'Conv2d', 'std': 0.01}
2024-01-15 01:14:27,839 - mmdet - INFO - initialize Shared2FCBBoxHead with init_cfg [{'type': 'Normal', 'std': 0.01, 'override': {'name': 'fc_cls'}}, {'type': 'Normal', 'std': 0.001, 'override': {'name': 'fc_reg'}}, {'type': 'Xavier', 'distribution': 'uniform', 'override': [{'name': 'shared_fcs'}, {'name': 'cls_fcs'}, {'name': 'reg_fcs'}]}]
2024-01-15 01:14:27,940 - mmdet - INFO - initialize Shared2FCBBoxHead with init_cfg [{'type': 'Normal', 'std': 0.01, 'override': {'name': 'fc_cls'}}, {'type': 'Normal', 'std': 0.001, 'override': {'name': 'fc_reg'}}, {'

In [None]:
# 모델 학습
# meta=dict()
# meta['fp16']=cfg.fp16
train_detector(model, datasets[0], cfg, distributed=False, validate=False)

2024-01-15 01:14:35,872 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
2024-01-15 01:14:35,877 - mmdet - INFO - Start running, host: root@instance-5087, work_dir: /data/ephemeral/home/level2-objectdetection-cv-02/mmdetection/work_dirs/cascade_rcnn_swin_l_fpn_1x_coco
2024-01-15 01:14:35,877 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW     

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112509735135568, max=1.0…

2024-01-15 01:17:23,416 - mmdet - INFO - Epoch [1][50/1221]	lr: 1.978e-03, eta: 1 day, 23:04:09, time: 2.778, data_time: 0.062, memory: 26038, loss_rpn_cls: 0.3292, loss_rpn_bbox: 0.0496, s0.loss_cls: 0.4227, s0.acc: 94.4551, s0.loss_bbox: 0.0772, s1.loss_cls: 0.2318, s1.acc: 89.1191, s1.loss_bbox: 0.0224, s2.loss_cls: 0.1408, s2.acc: 85.8809, s2.loss_bbox: 0.0029, loss: 1.2767, grad_norm: 9.1908
2024-01-15 01:19:40,470 - mmdet - INFO - Epoch [1][100/1221]	lr: 3.976e-03, eta: 1 day, 22:43:09, time: 2.741, data_time: 0.012, memory: 26038, loss_rpn_cls: 0.1315, loss_rpn_bbox: 0.0358, s0.loss_cls: 0.2807, s0.acc: 94.4180, s0.loss_bbox: 0.1286, s1.loss_cls: 0.0871, s1.acc: 97.3828, s1.loss_bbox: 0.0402, s2.loss_cls: 0.0306, s2.acc: 98.6113, s2.loss_bbox: 0.0058, loss: 0.7403, grad_norm: 2.2039
2024-01-15 01:21:57,657 - mmdet - INFO - Epoch [1][150/1221]	lr: 5.974e-03, eta: 1 day, 22:35:31, time: 2.744, data_time: 0.013, memory: 26038, loss_rpn_cls: 0.0912, loss_rpn_bbox: 0.0354, s0.loss_cl

No regex pattern specified. Nothing done.
