In [1]:
from IPython.display import display
from ipywidgets import widgets, interact
from pathlib import Path
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import cv2

class ImageInteractor:
    def __init__(self, dataset_path: str, target: str, category: str = None):
        self.dataset_path = Path(dataset_path)
        self.target = target
        self.category = category
        self.imgs = self.get_images()
    
    def get_images(self):
        target_path = self.dataset_path / self.target
        imgs = []
        if not self.category:
            imgs = [image for image in target_path.glob("**/*") if image.suffix in [".jpg"]]
        else:
            imgs = self.get_images_by_category(target_path)
        return sorted(imgs)
    
    def get_images_by_category(self, target_path):
        imgs = []
        for file in target_path.glob("**/*"):
            if file.suffix in [".xml"]:
                tree = ET.parse(file)
                root = tree.getroot()
                for obj in root.findall("object"):
                    if obj.find('property/category').text == self.category:
                        img = file.with_suffix(".jpg")
                        if img.exists():
                            imgs.append(img)
        return imgs

    @staticmethod
    def draw_bbox(img, bbox, color=(255, 0, 0), thickness=3):
        draw = ImageDraw.Draw(img)
        draw.rectangle(bbox, outline=color, width=thickness)
        return img
    
    def print_bounding_boxes(self, xml, img):
        tree = ET.parse(xml)
        root = tree.getroot()
        for obj in root.findall("object"):
            print("category:", obj.find('property/category').text)
            print("bbox:", obj.find('bndbox/xmin').text, obj.find('bndbox/ymin').text, obj.find('bndbox/xmax').text, obj.find('bndbox/ymax').text)
            print("width:", int(obj.find('bndbox/xmax').text) - int(obj.find('bndbox/xmin').text), "height:", int(obj.find('bndbox/ymax').text) - int(obj.find('bndbox/ymin').text))
            print("")
            img = self.draw_bbox(img, (int(obj.find('bndbox/xmin').text), int(obj.find('bndbox/ymin').text), int(obj.find('bndbox/xmax').text), int(obj.find('bndbox/ymax').text)))
        return img

    def display_image(self, idx: int):
        img = Image.open(self.imgs[idx])
        display(img)
    
        xml = self.imgs[idx].with_suffix(".xml")
        if xml.exists():
            img = self.print_bounding_boxes(xml, img)
        
        display(img)

In [2]:
### 수정 가능한 부분 ###
dataset_path = "/mnt/disks/data1/aihub/Training"
target = "Modified_Data"
category = "선박" # 확인 하고 싶은 카테고리 입력 (선박, 부표, 어망부표, 기타부유물, 등대)
### 수정 가능한 부분 ###

image_interactor = ImageInteractor(dataset_path, target, category)

# display images using slider
@interact(idx=widgets.IntSlider(min=0, max=len(image_interactor.imgs)-1, step=1, value=0))
def display_image(idx):
    image_interactor.display_image(idx)

interactive(children=(IntSlider(value=0, description='idx', max=21782), Output()), _dom_classes=('widget-inter…

interactive(children=(IntSlider(value=0, description='idx', max=1486), Output()), _dom_classes=('widget-intera…