In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from clean_dfine import get_data_dir
from clean_dfine.dataset import BatchImageCollateFunction, DataLoader, HFImageDataset
from torchvision.transforms import ToPILImage

from PIL import Image, ImageDraw, ImageFont
import torch
from typing import List, Tuple, Union

def draw_bounding_boxes_on_image(
    image: Image.Image,
    boxes: Union[torch.Tensor, List[Tuple[float, float, float, float]]],
    colors: Union[str, List[str], List[Tuple[int, int, int]]] = 'red',
    width: int = 2,
    labels: Union[List[str], None] = None,
    font: Union[ImageFont.ImageFont, None] = None
) -> Image.Image:
    """
    Draws bounding boxes on a PIL image.

    Args:
        image (PIL.Image.Image): The input image.
        boxes (torch.Tensor or List of Tuples): Bounding boxes in [x_min, y_min, x_max, y_max] format.
            If torch.Tensor, shape should be (N, 4).
        colors (str or list of str or list of tuples, optional): Colors for the boxes.
            Can be a single color for all boxes or a list of colors for each box.
            Default is 'red'.
        width (int, optional): Line width for the bounding boxes. Default is 2.
        labels (List[str], optional): Optional labels for each bounding box. Default is None.
        font (PIL.ImageFont.ImageFont, optional): Font for the labels. If None, default font is used.

    Returns:
        PIL.Image.Image: Image with bounding boxes drawn.
    """
    # Make a copy of the image to draw on
    image_with_boxes = image.copy()
    draw = ImageDraw.Draw(image_with_boxes)

    # If boxes are torch.Tensor, convert to list
    if isinstance(boxes, torch.Tensor):
        boxes = boxes.tolist()
    elif not isinstance(boxes, list):
        raise TypeError("boxes should be a torch.Tensor or a list of tuples/lists.")

    # Handle colors
    if isinstance(colors, str):
        colors = [colors] * len(boxes)
    elif isinstance(colors, (list, tuple)):
        if len(colors) == 1:
            colors = colors * len(boxes)
        elif len(colors) != len(boxes):
            raise ValueError("Number of colors must match number of boxes or be 1.")

    # Handle labels
    if labels is not None:
        if len(labels) != len(boxes):
            raise ValueError("Number of labels must match number of boxes.")
    else:
        labels = [None] * len(boxes)

    # Draw each box
    for idx, box in enumerate(boxes):
        if len(box) != 4:
            raise ValueError(f"Box at index {idx} does not have 4 elements.")

        x_min, y_min, x_max, y_max = box
        color = colors[idx] if isinstance(colors, list) else colors

        # Draw rectangle
        draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=color, width=width)

        # Draw label if provided
        if labels[idx]:
            if font is None:
                font = ImageFont.load_default()
            text_size = draw.textbbox((0,0), labels[idx], font=font)
            text_background = [x_min, y_min - text_size[1], x_min + text_size[0], y_min]
            draw.rectangle(text_background, fill=color)
            draw.text((x_min, y_min - text_size[1]), labels[idx], fill='black', font=font)

    return image_with_boxes


dataset_train = HFImageDataset.from_path(get_data_dir() / "dataset-test", 640)
dataloader_train = DataLoader(
	dataset_train,
	batch_size=1,
	shuffle=True,
	num_workers=2,
	persistent_workers=True,
	collate_fn=BatchImageCollateFunction(),
)

In [None]:
img, target = next(iter(dataloader_train))
img = img.squeeze(0)

In [None]:
t = ToPILImage()
img = t(img)

In [None]:
boxes = (target[0]["boxes"] * 640).tolist()
boxes = [[box[0], box[1], box[0]+box[2]//2, box[1]+box[3]//2] for box in boxes]
boxes

In [None]:
draw_bounding_boxes_on_image(
    img,
    boxes,
    labels=[str(i) for i in target[0]["labels"]],
)

In [None]:
for (img, target) in dataloader_train:
	
    print(img.shape)
    print(target)
    break