In [38]:
# 모듈 import

from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.utils import get_device

In [39]:
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass", 
           "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")

# ------------------ 변경할 부분-------------------
model = "cascade_rcnn_x101_64x4d_fpn_1x_coco"
augmentation = False
# ------------------ 변경할 부분-------------------


cfg = Config.fromfile(f'./configs/cascade_rcnn/{model}.py')

root='../../dataset/'

# dataset config 수정
cfg.data.train.classes = classes
cfg.data.train.img_prefix = root
cfg.data.train.ann_file = root + 'train.json' # train json 정보
# cfg.data.train.pipeline[2]['img_scale'] = (512,512) # Resize

cfg.data.val.classes = classes
cfg.data.val.img_prefix = root
cfg.data.val.ann_file = root + 'val.json' # valid json 정보
# cfg.data.val.pipeline[1]['img_scale'] = (1024,1024) # Resize

cfg.data.test.classes = classes
cfg.data.test.img_prefix = root
cfg.data.test.ann_file = root + 'test.json' # test json 정보
# cfg.data.test.pipeline[1]['img_scale'] = (1024,1024) # Resize

cfg.data.samples_per_gpu = 4

cfg.seed = 2022
cfg.gpu_ids = [0]



cfg.work_dir = f'./work_dirs/{model}_trash'

if type(cfg.model.roi_head.bbox_head) == dict:
    cfg.model.bbox_head.num_classes = 10

    #In case of cascade RCNN : List[Dict]
elif type(cfg.model.roi_head.bbox_head) == list:
    for each_head in cfg.model.roi_head.bbox_head:
        if hasattr(each_head, "num_classes"):
            each_head.num_classes = 10 
        else: 
            raise Exception("Num_classes가 없습니다. 제대로 찾으셨나요?")

cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.checkpoint_config = dict(max_keep_ckpts=3, interval=1)
cfg.device = get_device()

In [40]:
# import datetime
# now = (datetime.datetime.now().replace(microsecond=0) + datetime.timedelta(hours=9)).strftime("%m-%d %H:%M")
cfg.log_config.hooks[1].init_kwargs.entity="imsmile2000"#본인wandbID적기
cfg.log_config.hooks[1].init_kwargs.name=f"{model}+aug={augmentation}"

In [41]:
cfg.log_config.hooks[1].init_kwargs

{'project': 'trash_detection V1 (model selection)',
 'entity': 'imsmile2000',
 'name': 'cascade_rcnn_x101_64x4d_fpn_1x_coco+aug=False'}

In [42]:
# build_dataset
datasets = [build_dataset(cfg.data.train)]

loading annotations into memory...
Done (t=0.08s)
creating index...
index created!


In [43]:
# dataset 확인
datasets[0]


CocoDataset Train dataset with number of images 4883, and instance counts: 
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| category          | count | category      | count | category        | count | category    | count | category     | count |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+
| 0 [General trash] | 3965  | 1 [Paper]     | 6352  | 2 [Paper pack]  | 897   | 3 [Metal]   | 936   | 4 [Glass]    | 982   |
| 5 [Plastic]       | 2943  | 6 [Styrofoam] | 1263  | 7 [Plastic bag] | 5178  | 8 [Battery] | 159   | 9 [Clothing] | 468   |
+-------------------+-------+---------------+-------+-----------------+-------+-------------+-------+--------------+-------+

In [44]:
# 모델 build 및 pretrained network 불러오기
model = build_detector(cfg.model)
model.init_weights()

2023-05-08 08:55:51,845 - mmcv - INFO - initialize ResNeXt with init_cfg {'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnext101_64x4d'}
2023-05-08 08:55:51,846 - mmcv - INFO - load model from: open-mmlab://resnext101_64x4d
2023-05-08 08:55:51,847 - mmcv - INFO - load checkpoint from openmmlab path: open-mmlab://resnext101_64x4d
2023-05-08 08:55:53,388 - mmcv - INFO - initialize FPN with init_cfg {'type': 'Xavier', 'layer': 'Conv2d', 'distribution': 'uniform'}
2023-05-08 08:55:53,475 - mmcv - INFO - initialize RPNHead with init_cfg {'type': 'Normal', 'layer': 'Conv2d', 'std': 0.01}
2023-05-08 08:55:53,490 - mmcv - INFO - initialize Shared2FCBBoxHead with init_cfg [{'type': 'Normal', 'std': 0.01, 'override': {'name': 'fc_cls'}}, {'type': 'Normal', 'std': 0.001, 'override': {'name': 'fc_reg'}}, {'type': 'Xavier', 'distribution': 'uniform', 'override': [{'name': 'shared_fcs'}, {'name': 'cls_fcs'}, {'name': 'reg_fcs'}]}]
2023-05-08 08:55:53,626 - mmcv - INFO - initialize Shared2FCBBo

In [45]:
# 모델 학습
train_detector(model, datasets[0], cfg, distributed=False, validate=True)

2023-05-08 08:56:01,973 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.


loading annotations into memory...


2023-05-08 08:56:02,298 - mmdet - INFO - Start running, host: root@c2982ff33823, work_dir: /opt/ml/baseline/baseline_cv11/work_dirs/cascade_rcnn_x101_64x4d_fpn_1x_coco_trash
2023-05-08 08:56:02,299 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
(VERY_LOW    ) MMDetWandbHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook 

Done (t=0.28s)
creating index...
index created!


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666897584994634, max=1.0)…

Problem at: /opt/conda/envs/detection/lib/python3.7/site-packages/mmcv/runner/hooks/logger/wandb.py 97 before_run


CommError: Run initialization has timed out after 60.0 sec. 
Please refer to the documentation for additional information: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-