In [None]:
from mmdet.apis import init_detector
from mmdet.apis import inference_detector
from mmengine.config import Config
import tqdm
import os
import time 
from copy import deepcopy
from ensemble_boxes import weighted_boxes_fusion

In [2]:
config_file="configs/_rdd/main.py" 
checkpoint_file="model1.pth" 
imgsz=640
conf_thres=0.5
iou_thres=0.999

cfg = Config.fromfile(config_file)
cfg['test_pipeline'][1]['img_scale'] = (imgsz, imgsz)
cfg.model.test_cfg.rcnn.score_thr = float(conf_thres)
cfg.model.test_cfg.rcnn.nms.iou_threshold = float(iou_thres)

In [3]:
cfg.tta_model = dict(
    type='DetTTAModel',
    tta_cfg=dict(
        nms=dict(type='nms', iou_threshold=iou_thres), max_per_img=100))

test_data_cfg = cfg.test_dataloader.dataset
while 'dataset' in test_data_cfg:
    test_data_cfg = test_data_cfg['dataset']
cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline)
flip_tta = dict(
    type='TestTimeAug',
    transforms=[
        [
            dict(type='RandomFlip', prob=1.),
            dict(type='RandomFlip', prob=0.)
        ],
        [
            dict(
                type='PackDetInputs',
                meta_keys=('img_id', 'img_path', 'ori_shape',
                            'img_shape', 'scale_factor', 'flip',
                            'flip_direction'))
        ],
    ])
cfg.tta_pipeline[-1] = flip_tta

In [None]:
model1 = init_detector(cfg, checkpoint_file, device='cuda:2')

In [5]:
config_file="configs/_rdd/main2.py" 
checkpoint_file="model2.pth" 
imgsz=640
conf_thres=0.5
iou_thres=0.999

cfg = Config.fromfile(config_file)
cfg['test_pipeline'][1]['img_scale'] = (imgsz, imgsz)
cfg.model.test_cfg.rcnn.score_thr = float(conf_thres)
cfg.model.test_cfg.rcnn.nms.iou_threshold = float(iou_thres)

In [None]:
model2 = init_detector(cfg, checkpoint_file, device='cuda:2')

In [7]:
config_file="configs/_rdd/main3.py" 
checkpoint_file="model3.pth" 
imgsz=640
conf_thres=0.5
iou_thres=0.999

cfg = Config.fromfile(config_file)
cfg['test_pipeline'][1]['img_scale'] = (imgsz, imgsz)
cfg.model.test_cfg.rcnn.score_thr = float(conf_thres)
cfg.model.test_cfg.rcnn.nms.iou_threshold = float(iou_thres)

In [None]:
model3 = init_detector(cfg, checkpoint_file, device='cuda:2')

In [9]:
config_file="configs/_rdd/main7.py" 
checkpoint_file="model4.pth" 
imgsz=640
conf_thres=0.5
iou_thres=0.999

cfg = Config.fromfile(config_file)
cfg['test_pipeline'][1]['img_scale'] = (imgsz, imgsz)
cfg.model.test_cfg.rcnn.score_thr = float(conf_thres)
cfg.model.test_cfg.rcnn.nms.iou_threshold = float(iou_thres)

In [None]:
model4 = init_detector(cfg, checkpoint_file, device='cuda:2')

In [11]:
config_file="configs/_rdd/main8.py" 
checkpoint_file="model5.pth" 
imgsz=1280
conf_thres=0.5
iou_thres=0.999

cfg = Config.fromfile(config_file)
cfg['test_pipeline'][1]['img_scale'] = (imgsz, imgsz)
cfg.model.test_cfg.rcnn.score_thr = float(conf_thres)
cfg.model.test_cfg.rcnn.nms.iou_threshold = float(iou_thres)

In [None]:
model5 = init_detector(cfg, checkpoint_file, device='cuda:2')

In [None]:
model_list = {
    'India': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    },
    'United_States': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    },
    'Japan': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    },
    'Norway': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    },
    'Czech': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    },
    'China_MotorBike': {
        'models': [model1, model2, model3, model4, model5],
        'weights': [3, 2, 1, 1, 2]
    }
}

with open('test_all.txt', 'w') as a:
    for b in ['India', 'United_States', 'Japan', 'Norway', 'Czech', 'China_MotorBike']:
        models = model_list[b]['models']
        weights = model_list[b]['weights']
        path=f"/media/oem/storage01/jmjeong/rdd2022/RDD2022/{b}/test/images/" 
        output_path = f'test_{b}.txt'
        with open(output_path, 'w') as w:
            for image_name in tqdm.tqdm(os.listdir(path)):
                bboxes_list = []
                scores_list = []
                labels_list = []
                
                start_time = time.time()
                for model in models:
                    result = inference_detector(model, path+image_name)
                    image_height, image_width = result.ori_shape
                    
                    bboxes = result.pred_instances.bboxes
                    labels = result.pred_instances.labels
                    scores = result.pred_instances.scores
                    
                    if len(bboxes) > 0: 
                        bboxes[:, 0] = bboxes[:, 0] / image_width
                        bboxes[:, 1] = bboxes[:, 1] / image_height
                        bboxes[:, 2] = bboxes[:, 2] / image_width
                        bboxes[:, 3] = bboxes[:, 3] / image_height
                        
                    bboxes_list.append(bboxes)
                    labels_list.append(labels)
                    scores_list.append(scores)
                    
                bboxes, scores, labels = weighted_boxes_fusion(bboxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thres)
                    
                bbox_str = ""
                for bbox, label in zip(bboxes, labels):
                    xmin = int(bbox[0] * image_width)
                    ymin = int(bbox[1] * image_height)
                    xmax = int(bbox[2] * image_width)
                    ymax = int(bbox[3] * image_height)
                    label = int(label) + 1
                    
                    bbox_str += str(label) + ' ' + str(xmin) + ' ' \
                                + str(ymin) + ' ' + str(xmax) + ' ' + str(ymax) + ' '
                                
                bbox_str += '\n'
                
                w.write(image_name + ',' + bbox_str)     
                a.write(image_name + ',' + bbox_str)       
                end_time = time.time()
                
                print(end_time - start_time)
                