# Setup

In [1]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))


In [36]:
import random
from typing import Any, List, Iterable, Tuple, NamedTuple, TypedDict

from uuid import uuid4
import cv2
import numpy as np
from matplotlib import pyplot as plt
import torch
import polars as pl
import torchvision
from xml.etree import ElementTree

from my_project.di import get_injector
from my_project.directories.interface import InputDir, OutputDir


In [3]:
injector = get_injector()


In [14]:
input_dir = injector.get(InputDir)
output_dir = injector.get(OutputDir)


# Load the pets dataset

In [6]:
def load_filenames(list_file: Path) -> Iterable[str]:
    with list_file.open("r") as f:
        for line in f:
            line = line.strip()
            if line.startswith("#"):
                continue
            if not line:
                continue
            (img_name, _rest) = line.split(" ", 1)
            yield img_name


In [7]:
class BBox(NamedTuple):
    center_x: float
    center_y: float
    width: float
    height: float


In [8]:
def load_bbox(file_path: Path) -> BBox:
    assert file_path.is_file(), file_path
    tree = ElementTree.parse(file_path)
    bbox = tree.find("./object/bndbox")
    xmin = int(bbox.find("./xmin").text)
    xmax = int(bbox.find("./xmax").text)
    ymin = int(bbox.find("./ymin").text)
    ymax = int(bbox.find("./ymax").text)

    center_x = (xmin + xmax) / 2
    center_y = (ymin + ymax) / 2
    width = xmax - xmin
    height = ymax - ymin

    return BBox(center_x, center_y, width, height)


In [31]:
pet_image_dir = input_dir / "pet" / "images"
pet_xml_dir = input_dir / "pet" / "annotations" / "xmls"
pet_list_file = input_dir / "pet" / "annotations" / "list.txt"
landscape_image_dir = input_dir / "landscapes"


In [41]:
output_size = 224
records = []
dataset_output_dir = output_dir / "pets_and_landscapes"
dataset_output_dir.mkdir(exist_ok=True)

for filename in [
    filename
    for filename in load_filenames(pet_list_file)
    if (pet_xml_dir / f"{filename}.xml").is_file()
]:
    xml_file = pet_xml_dir / f"{filename}.xml"
    bbox = load_bbox(xml_file)
    image_path = pet_image_dir / f"{filename}.jpg"
    image = torchvision.io.read_image(
        str(image_path), mode=torchvision.io.ImageReadMode.RGB
    )
    annotation_path = pet_xml_dir / f"{filename}.xml"
    bbox = load_bbox(annotation_path)

    channels, img_height, img_width = image.shape
    assert channels == 3

    bbox_factor = BBox(
        center_x=bbox.center_x / img_width,
        center_y=bbox.center_y / img_height,
        width=bbox.width / img_width,
        height=bbox.height / img_height,
    )
    resized_image = torchvision.transforms.functional.resize(
        image, (output_size, output_size)
    )

    id = uuid4()
    output_filename = f"{id}.jpg"
    torchvision.io.write_jpeg(
        resized_image, str(dataset_output_dir / output_filename), quality=90
    )
    records.append(
        {
            "filename": output_filename,
            "class": 1,
            "bbox_center_x": bbox_factor.center_x,
            "bbox_center_y": bbox_factor.center_y,
            "bbox_width": bbox_factor.width,
            "bbox_height": bbox_factor.height,
        }
    )

for filename in landscape_image_dir.glob("*.jpg"):
    try:
        image = torchvision.io.read_image(
            str(filename), mode=torchvision.io.ImageReadMode.RGB
        )
    except RuntimeError:
        print("Failed to read image", filename)
        continue
    channels, img_height, img_width = image.shape
    assert channels == 3
    resized_image = torchvision.transforms.functional.resize(
        image, (output_size, output_size)
    )
    id = uuid4()
    output_filename = f"{id}.jpg"
    torchvision.io.write_jpeg(
        resized_image, str(dataset_output_dir / output_filename), quality=90
    )
    records.append(
        {
            "filename": output_filename,
            "class": 0,
            "bbox_center_x": 0.5,
            "bbox_center_y": 0.5,
            "bbox_width": 1.0,
            "bbox_height": 1.0,
        }
    )


Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9


Failed to read image /inputs/landscapes/00000051_(3).jpg


In [42]:
pl.DataFrame(records).write_csv(dataset_output_dir / "pets_and_ladscapes.csv")
