In [1]:
import os
import json

from glob import glob
from tqdm import tqdm

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

from utils.drawing import plot_one_box, draw_text
from utils.bboxes import box_iou

In [2]:
res_path = "/home/kos/projects/PhD/ZebraFish/tinyROI/metrics/val-fixed-seq-empty-added-test-rotate-IBS/roi-fix/results-ZebraFish-03-imgF.json"
CONF_TH = 0.3
FILT_TH = 0.7
out_dir =  f'/home/kos/projects/PhD/ZebraFish/tinyROI/metrics/val-fixed-seq-empty-added-test-rotate-IBS/roi-fix/before-after-{FILT_TH}-{CONF_TH}'

In [3]:
def draw_dets_and_windows(img, dets, windows, vis_conf_th=0.1):
    for i, window in enumerate(windows):
        img = plot_one_box(list(map(int, window)), img, color=(0,0,180), label=f'WINDOW {i}', line_thickness=4, draw_label=False)
    
    dets = dets[dets[:, -1] >= vis_conf_th]
    for i, det in enumerate(dets.tolist()):
        xmin,ymin,xmax,ymax = det[:4]
        img = plot_one_box(list(map(int, [xmin,ymin,xmax,ymax])), img, color=(180,20,20), 
                             label=f'DET {i}', line_thickness=1, draw_label=False)
        
        
    stats = ["WINDOWS   %02d" % len(windows.tolist()), "DETS   %02d" % len(dets.tolist())]
    img = draw_text(img, "\n".join(stats), 20, 40, color=(255,255,255))
    return img

def filter_dets_in_window(window, dets, other_dets, th=FILT_TH): # th=0.8):
    intersection_maxs = torch.min(other_dets[:, None, 2:], window[:, 2:]) # xmax, ymax
    intersection_mins = torch.max(other_dets[:, None, :2], window[:, :2]) # xmin, ymin
    intersections = torch.flatten(torch.cat((intersection_mins, intersection_maxs), dim=2), start_dim=0, end_dim=1)
    ious = box_iou(intersections, dets)
    to_del = torch.where(ious > th)[1]
    dets_filtered = dets[[x for x in range(dets.shape[0]) if x not in to_del],:]
    return dets_filtered

In [4]:
os.makedirs(out_dir, exist_ok=True)
results = json.load(open(res_path))

unique_frames = sorted(list(set([x['image_path'] for x in results])))

for unique_frame in tqdm(unique_frames):
    res = [x for x in results if x['image_path']==unique_frame]
    bboxes = torch.tensor([x['bbox']+[x['score']] for x in res])
    im = cv2.imread(res[0]['image_path'])
    
    windows = torch.empty((0,4)).to(torch.int32)
    if 'window_bbox' in res[0].keys():
        windows = torch.tensor([x['window_bbox'] for x in res]).to(torch.int32)
        
    indices = torch.where(bboxes[:,-1] > CONF_TH)[0]
    bboxes = bboxes[indices,:]
    bboxes = bboxes[:,:-1].to(torch.int32)
    bboxes[:,2] = bboxes[:, 0]+bboxes[:,2]
    bboxes[:,3] = bboxes[:, 1]+bboxes[:,3]

    if 'window_bbox' not in res[0].keys():
        im_before = draw_dets_and_windows(im, bboxes, windows, vis_conf_th=CONF_TH)
        im = np.hstack((im_before, im_before))
        cv2.imwrite(f"{out_dir}/{os.path.basename(res[0]['image_path'])}", im)
        continue
        
    windows = windows[indices,:]
    # print(bboxes.shape, windows.shape)

    unique_windows = torch.unique(windows, dim=0)

    dets_after = torch.empty((0,4))
    for unique_window in unique_windows:
        ind_win = torch.unique(torch.where(windows==unique_window)[0])
        ind_notwin = torch.unique(torch.where(windows!=unique_window)[0])
        window_dets = bboxes[ind_win,:]
        other_dets = bboxes[ind_notwin,:]

        filtered = filter_dets_in_window(unique_window.unsqueeze(0), window_dets, other_dets) # remove filtered after first iteration + sort by size before ???
        dets_after = torch.cat((dets_after,filtered))
    
    im_before = draw_dets_and_windows(im.copy(), bboxes, windows, vis_conf_th=CONF_TH)
    im_after = draw_dets_and_windows(im, dets_after, windows, vis_conf_th=CONF_TH)
    
    im = np.hstack((im_before, im_after))
    cv2.imwrite(f"{out_dir}/{os.path.basename(res[0]['image_path'])}", im)
    # break

100%|█████████████████████████████████████████████| 1800/1800 [04:26<00:00,  6.75it/s]
