In [1]:
%matplotlib inline
import numpy as np
import cv2
from matplotlib import pyplot as plt

In [3]:
from tqdm import tqdm

import shutil
import os

In [3]:
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import mmcv

In [4]:
shutil.rmtree('predictions', ignore_errors=True)
os.mkdir('predictions')

In [4]:
BOX_COLOR = (0, 0, 255)
TEXT_COLOR = (255, 255, 255)

def visualize_bbox(img, bbox, 
#                    class_id, class_idx_to_name, 
                   score=None,
                   color=BOX_COLOR, thickness=2):
#     x_min, y_min, w, h = bbox
    x_min, y_min, x_max, y_max = np.array(bbox).astype(np.int)
#     x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
#     class_name = class_idx_to_name[class_id]
    if score:
        # pass bbox with small score
        if score < 0.5:
            return img
        class_name = '%.3f' % score
    else:
        class_name = 'gt'
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color, -1)
    cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
    return img



In [5]:
config_file = 'mmdetection/configs/icartoonface/fr50_lite_dcn_gn_icf_ms49_1549_2x.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'work_dirs/fr50_lite_dcn_gn_icf_ms49_1549_2x/epoch_21.pth'

In [6]:
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')

In [7]:
# load val set to test
val_pkl  = '../data/icartoonface/dval.pkl'



In [8]:
# visualize bbox gt and predictions on valdation data set
data_val = mmcv.load(val_pkl)

for data in tqdm(data_val):
    file_name = data['filename']
    img = cv2.imread('../data/' + file_name)
    
    
    # draw gt bbox from annotations
    for bbox in data['ann']['bboxes']:
        img  = visualize_bbox(img, bbox, color=(0, 255, 0))
        
    # draw bbox predictions
    result = inference_detector(model, img)
    for bbox in result[0]:
        img = visualize_bbox(img, bbox[:4], 
                            score=bbox[-1])
        
    
        
    # plot inline or save to disk
#     plt.figure(figsize=(12, 12))
#     plt.imshow(img[:,:,[2,1,0]])
    cv2.imwrite('./predictions/' + file_name.split('/')[-1], img)

100%|██████████| 2500/2500 [07:17<00:00,  5.71it/s]
