In [32]:
import os
import sys
import numpy as np
from xml.etree import ElementTree
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from PIL import Image, ImageStat
from random import shuffle
import json
from torchvision.transforms import functional
from torchvision.utils import draw_bounding_boxes
import torch
from skimage.transform import resize
import cv2

sys.path.append('../detection')
from fathomnethelper.json_loader import Taxonomicon 

In [33]:
N_per_class = 40  # Number of images per class
save_dest = '../data/sample_imgs2'


In [34]:
with open("../classes", "r") as f:
    classes = json.load(f)
    
root = '../data'
imgs = list(sorted(os.listdir(os.path.join(root, 'images'))))
anns = list(sorted(os.listdir(os.path.join(root, 'annotations'))))
imgs_and_anns = list(zip(imgs,anns))
label_mapping = {cls: i+1 for (i, cls) in enumerate(sorted(classes))}
tax = Taxonomicon()

shuffle(imgs_and_anns)

class_mapping = {}
for cls in classes:
    if type(classes) == list:
        nodes = set(tax.get_subtree_nodes(cls))
    elif type(classes) == dict:
        nodes = set.union(*[set(tax.get_subtree_nodes(cls2)) for cls2 in classes[cls]])
    else:
        raise TypeError('Class definition needs to be of type list or dict.')
    for node in nodes:
        class_mapping[node] = cls

imgs_and_boxes = {cls: {} for cls in classes}
        
for img, ann in imgs_and_anns:
    if all([len(imgs_and_boxes[cls]) >= N_per_class for cls in classes]):
        break
    ann_tree = ElementTree.parse(os.path.join(root, 'annotations', ann))
    tree_root = ann_tree.getroot()
    for box in tree_root.iter('object'):
        name = box.find('name').text
        if name in class_mapping:
            cls = class_mapping[name]
            xmin = int(box.find('bndbox/xmin').text)
            ymin = int(box.find('bndbox/ymin').text)
            xmax = int(box.find('bndbox/xmax').text)
            ymax = int(box.find('bndbox/ymax').text)
            if img in imgs_and_boxes[cls]:
                imgs_and_boxes[cls][img].append((cls, name, (xmin, ymin, xmax, ymax)))
            else:
                imgs_and_boxes[cls][img] = [(cls, name, (xmin, ymin, xmax, ymax))]
    

In [52]:
min_size = 256

if not os.path.exists(save_dest):
    os.mkdir(save_dest)
for cls in classes:
    os.mkdir(os.path.join(save_dest, cls))
    for img, boxes in list(imgs_and_boxes[cls].items())[:N_per_class]:
        im = cv2.imread(os.path.join(root, 'images', img))
        # im = Image.open(os.path.join(root, 'images', img)).convert('RGB')
        h, w = im.shape[:2]
        scalar = 1
        if h < min_size or w < min_size:
            scalar = min_size / min(h,w)
            new_h, new_w = int(h * scalar), int(w * scalar)

            im = cv2.resize(im, (new_h, new_w), interpolation = cv2.INTER_AREA)
        parent_classes, specific_classes, boxes = zip(*boxes)
        colours = ["white" if pc == cls else "grey" for pc in parent_classes]
        labels = [f"Class={pc}" for pc, sc in zip(parent_classes, specific_classes)]
        # boxes = torch.tensor(boxes)
        # res = draw_bounding_boxes(im_tensor, boxes, labels, colours, width=2, font='/home/gardar/.fonts/tamzen-font/ttf/Tamzen10x20b.ttf', font_size=14)
        #im = functional.to_pil_image(res)
        for (xmin, ymin, xmax, ymax) in boxes:
            im = cv2.rectangle(im, (int(xmin*scalar), int(scalar*ymin)), (int(scalar*xmax), int(scalar*ymax)), color=(255, 255, 255), thickness=2)
        cv2.imwrite(os.path.join(save_dest, cls, img), im) 
        
        # plt.savefig(os.path.join(save_dest, cls, img), dpi=300)

In [44]:
im.shape * 2

(486, 720, 3, 486, 720, 3)