In [2]:
import torch
import config
import utils
import glob
import numpy as np
import cv2
from pathlib import Path

In [3]:
def draw_bboxes(img, preds, thre, class_colors, save_fname):
    preds = [{k: v.to('cpu') for k,v in t.items()} for t in preds]

    if len(preds[0]['boxes']) != 0:
        boxes = preds[0]['boxes'].data.numpy()
        scores = preds[0]['scores'].data.numpy()
        print(f"boxes={boxes}, scores = {scores}")
        
        boxes = boxes[scores >= thre].astype(np.int32)
        pred_classes = [i for i in preds[0]['labels'].cpu().numpy() ]

        for j, box in enumerate(boxes):
            color = class_colors[pred_classes[j]]
            cv2.rectangle(img,
                        (int(box[0]), int(box[1])),
                        (int(box[2]), int(box[3])),
                        color, 2)

        # save the image
        cv2.imwrite(save_fname, img)

        
def inference_1img(model, img_name, device, thre, class_colors):
    in_img = cv2.imread(img_name)

    # convert to tensor
    img = cv2.cvtColor(in_img, cv2.COLOR_BGR2RGB).astype(np.float32)
    img /= 255.0
    img = np.transpose(img, (2,0,1)) # HWC -> CHW
    img = torch.tensor(img, dtype=torch.float).to(device)
    img = torch.unsqueeze(img,0) # add batch dim

    # run inference
    with torch.no_grad():
        preds = model(img)
    print(f"inference on {img_name} done.")

    save_fname = str(Path(config.result_img_dir) / Path(img_name).name)
    draw_bboxes(in_img, preds, thre, class_colors, save_fname)    


In [4]:
Path(config.result_img_dir).mkdir(parents=True, exist_ok=True)
    
device = torch.device("cpu")
saved_name = './result/small.pth'
checkpoint = torch.load(saved_name, map_location=device)
model = utils.get_model_object_detector(config.num_classes)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device).eval()

test_dir = config.test_data_dir
img_format = config.test_img_format
test_imgs = glob.glob(f"{test_dir}/*.{img_format}")

class_colors = np.random.uniform(0, 255, size=(config.num_classes, 3))

for i in range(len(test_imgs)):
    img_name = test_imgs[i]
    inference_1img(model, img_name, device, config.detection_threshold, class_colors)



inference on data/test/img_2.jpg done.
boxes=[[224.67178  248.14891  245.86023  265.09232 ]
 [ 30.520754 115.15315   46.269424 135.59442 ]
 [ 27.924904 115.05443   48.9541   158.35545 ]
 [177.67117  233.8971   193.63019  270.55157 ]
 [231.41232  251.14005  249.11319  267.15506 ]
 [217.1475   276.6272   336.6364   389.6994  ]
 [283.85684  282.6336   320.26477  332.4831  ]], scores = [0.81870145 0.44065732 0.34783265 0.29893565 0.13558812 0.09093992
 0.05539606]
inference on data/test/image-1271-2023-01-14T16-51-24-752193_jpg.rf.9c92e8ba39bd63b02fe0df04299586cd.jpg done.
boxes=[[227.64348 283.97336 359.67453 416.     ]
 [238.28873 288.76794 288.3917  409.38745]], scores = [0.9710005 0.1672477]
inference on data/test/img_0.jpg done.
boxes=[[232.31015 231.91377 291.35388 311.5991 ]
 [152.15683 217.93265 301.97818 319.1821 ]
 [449.31195 295.09418 450.      302.8758 ]], scores = [0.05954088 0.05657011 0.05539208]
inference on data/test/img_5.jpg done.
boxes=[[318.15524  99.04579 475.19348 34