In [None]:
import shutil
from validate_detector import *


params = dict(
    data='LogoDet-3K_CIL.yaml',
    yolo_model_path='weights/yolov5m6-CIL-512px-1000cls.pt',
    cil_model_path='weights/CIL_1000_250_2993-WA-mem50-resnet34-pretrained-drop0.5-augmented-adam.pt',
    student_model_path='weights/kd_resnet50-drop0.3-mem50_STATE_DICT.pt',
    detection_out='tmp',
    detection_input='dataset/LogoDet-3K_det4cil/inf',
    conf_thres=0.4,
)

In [None]:
yolo_res_path = Path('yolov5/runs/detect/') / params['detection_out']

inf_out = Path('yolov5/runs/detect/') / (params['detection_input'] + '-res')

# Metadata df
metadata = pd.read_pickle(Path(DATASET_PATH) / LOGODET_3K_NORMAL_PATH / METADATA_CROPPED_IMAGE_PATH)


In [None]:
# Configure
cil_model, cil_idx2class, cil_class2idx, cil_class_remap = load_cil_model(
    ROOT / params['cil_model_path'],
    params['student_model_path']
)
cil_model.eval()


In [None]:
import subprocess
import os

shutil.rmtree(yolo_res_path, ignore_errors=True)
shutil.rmtree(str(Path(params['detection_input']+'-res')), ignore_errors=True)
os.makedirs(str(Path(params['detection_input']+'-res')))

cmd = f'python yolov5/detect.py ' \
      f'--data {params["data"]} ' \
      f'--weights {params["yolo_model_path"]} ' \
      f'--source {params["detection_input"]} ' \
      f'--conf-thres {params["conf_thres"]} ' \
      f'--name {params["detection_out"]} ' \
      f'--augment --save-txt --agnostic-nms --name tmp --exist-ok'

process = subprocess.Popen(cmd, shell=True)
(output, err) = process.communicate()
p_status = process.wait()


In [None]:
inf_images = [x for x in os.listdir(params['detection_input']) if Path(x).suffix != '']
img2label = {}

print(inf_images)

for file in inf_images:
    label_path = (yolo_res_path / 'labels' / file).with_suffix('.txt')
    if label_path.is_file():
        with open(label_path) as f:
            labels = [[float(y) for y in x.strip().split()][1:] for x in f.readlines()]
    else:
        labels = []
    image_path = Path(params['detection_input']) / file
    img2label.update([(image_path, labels)])


In [None]:
def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np

bbox_color = 'green'
font_color = (255,255,255)

font_size = 18
font = ImageFont.truetype("arial.ttf", font_size)

box_width = 3
top = 10

for i, el in enumerate(img2label.items()):
    img_path, label = el
    im = Image.open(img_path)
    #im.show()

    width, height = im.size
    if label:
        dim_tensor = torch.tensor([width, height, width, height]).repeat(len(label), 1)

        native_pred = (xywh2xyxy(torch.tensor(label)) * dim_tensor).round()
        im1 = im.copy()
        im2 = im.crop()

        im_bbox = ImageDraw.Draw(im1)
        im_bbox_annotated = ImageDraw.Draw(im2)
        for j, pred in enumerate(native_pred):
            # Crop image
            cropped = im.crop(np.array(pred)).convert('RGB')
            # transform cropped image
            common_trsf = iLogoDet3K_trsf['common']
            test_trsf = iLogoDet3K_trsf['test']
            all_trsf = transforms.Compose([*test_trsf, *common_trsf])
            cropped = all_trsf(cropped)
            # CIL model prediction
            cil_prediction = cil_model(cropped.expand(1, *cropped.shape))
            cil_class = cil_prediction.argmax().int().item()
            # Predictions
            resolved_label = cil_idx2class[cil_class_remap[cil_class]]
            print(f'Image {i} - bbox {j}: {resolved_label}')
            # Generate images
            im_bbox.rectangle(np.array(pred), fill=None, outline=bbox_color, width=box_width)

            im_bbox_annotated.rectangle(np.array(pred), fill=None, outline=bbox_color, width=box_width)
            pos = np.array(pred)
            pos = [pos[0], pos[1]-font_size, pos[2], pos[1]]
            im_bbox_annotated.rectangle(pos, fill=bbox_color, outline=bbox_color, width=box_width)

            text_pos = [pos[0]+box_width+1, pos[1]]
            im_bbox_annotated.text(text_pos, resolved_label, font_color, font=font)

        prefix = Path(params['detection_input']+'-res')
        im1.save(prefix  / (img_path.stem + '_bbox.png'))
        im2.save(prefix / (img_path.stem + '_bbox-clf.png'))
    else:
        im.save(prefix  / (img_path.stem + '_bbox.png'))
        im.save(prefix / (img_path.stem + '_bbox-clf.png'))
