In [1]:
from typing import Union
import glob
import math
import random
import os
import tempfile
import logging
import logging.config

import contextlib
import hashlib
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from PIL import ExifTags, Image, ImageOps
import torch
import torchvision.transforms as T


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
LOGGING_NAME = "ultralytics"
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true"  # global verbose mode


def set_logging(name=LOGGING_NAME, verbose=True):
    # sets up logging for the given name
    rank = int(os.getenv("RANK", -1))  # rank in world for Multi-GPU trainings
    level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
    logging.config.dictConfig(
        {
            "version": 1,
            "disable_existing_loggers": False,
            "formatters": {name: {"format": "%(message)s"}},
            "handlers": {
                name: {
                    "class": "logging.StreamHandler",
                    "formatter": name,
                    "level": level,
                }
            },
            "loggers": {name: {"level": level, "handlers": [name], "propagate": False}},
        }
    )


# Set logger
set_logging(LOGGING_NAME, verbose=VERBOSE)  # run before defining LOGGER
LOGGER = logging.getLogger(
    LOGGING_NAME
)  # define globally (used in train.py, val.py, detect.py, etc.)

NUM_THREADS = min(
    8, max(1, os.cpu_count() - 1)
)  # number of YOLOv5 multiprocessing threads
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}"  # tqdm bar format
IMG_FORMATS = (
    "bmp",
    "dng",
    "jpeg",
    "jpg",
    "mpo",
    "png",
    "tif",
    "tiff",
    "webp",
    "pfm",
)  # include image suffixes
LOCAL_RANK = int(
    os.getenv("LOCAL_RANK", -1)
)  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders


In [3]:
class BaseDataset(Dataset):
    """Base Dataset.
    Args:
        img_path (str): image path.
        pipeline (dict): a dict of image transforms.
        label_path (str): label path, this can also be an ann_file or other custom label path.
    """

    def __init__(
        self,
        img_path,
        imgsz=640,
        cache=False,
        augment=True,
        hyp=None,
        prefix="",
        rect=False,
        batch_size=None,
        stride=32,
        pad=0.5,
        single_cls=False,
    ):
        super().__init__()
        self.img_path = img_path
        self.imgsz = imgsz
        self.augment = augment
        self.single_cls = single_cls
        self.prefix = prefix

        self.im_files = self.get_img_files(self.img_path)
        self.labels = self.get_labels()
        if self.single_cls:
            self.update_labels(include_class=[])

        self.ni = len(self.labels)

        # rect stuff
        self.rect = rect
        self.batch_size = batch_size
        self.stride = stride
        self.pad = pad
        if self.rect:
            assert self.batch_size is not None
            self.set_rectangle()

        # cache stuff
        self.ims = [None] * self.ni
        self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
        if cache:
            self.cache_images(cache)

        # transforms
        # self.transforms = self.build_transforms(hyp=hyp)

    def get_img_files(self, img_path):
        """Read image files."""
        try:
            f = []  # image files
            for p in img_path if isinstance(img_path, list) else [img_path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
                    # f = list(p.rglob('*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [
                            x.replace("./", parent) if x.startswith("./") else x
                            for x in t
                        ]  # local to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
                else:
                    raise FileNotFoundError(f"{self.prefix}{p} does not exist")
            im_files = sorted(
                x.replace("/", os.sep)
                for x in f
                if x.split(".")[-1].lower() in IMG_FORMATS
            )
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
            assert im_files, f"{self.prefix}No images found"
        except Exception as e:
            raise FileNotFoundError(
                f"{self.prefix}Error loading data from {img_path}\n"
            ) from e
        return im_files

    def update_labels(self, include_class: Optional[list]):
        """include_class, filter labels to include only these classes (optional)"""
        include_class_array = np.array(include_class).reshape(1, -1)
        for i in range(len(self.labels)):
            if include_class:
                cls = self.labels[i]["cls"]
                bboxes = self.labels[i]["bboxes"]
                segments = self.labels[i]["segments"]
                j = (cls == include_class_array).any(1)
                self.labels[i]["cls"] = cls[j]
                self.labels[i]["bboxes"] = bboxes[j]
                if segments:
                    self.labels[i]["segments"] = segments[j]
            if self.single_cls:
                self.labels[i]["cls"][:, 0] = 0

    def load_image(self, i):
        # Loads 1 image from dataset index 'i', returns (im, resized hw)
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                im = np.load(fn)
            else:  # read image
                im = cv2.imread(f)  # BGR
                if im is None:
                    raise FileNotFoundError(f"Image Not Found {f}")
            h0, w0 = im.shape[:2]  # orig hw
            r = self.imgsz / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
                im = cv2.resize(
                    im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp
                )
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized

    def cache_images(self, cache):
        # cache images to memory or disk
        gb = 0  # Gigabytes of cached images
        self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
        fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
        with ThreadPool(NUM_THREADS) as pool:
            results = pool.imap(fcn, range(self.ni))
            pbar = tqdm(
                enumerate(results),
                total=self.ni,
                bar_format=TQDM_BAR_FORMAT,
                disable=LOCAL_RANK > 0,
            )
            for i, x in pbar:
                if cache == "disk":
                    gb += self.npy_files[i].stat().st_size
                else:  # 'ram'
                    (
                        self.ims[i],
                        self.im_hw0[i],
                        self.im_hw[i],
                    ) = x  # im, hw_orig, hw_resized = load_image(self, i)
                    gb += self.ims[i].nbytes
                pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
            pbar.close()

    def cache_images_to_disk(self, i):
        # Saves an image as an *.npy file for faster loading
        f = self.npy_files[i]
        if not f.exists():
            np.save(f.as_posix(), cv2.imread(self.im_files[i]))

    def set_rectangle(self):
        bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index
        nb = bi[-1] + 1  # number of batches

        s = np.array([x.pop("shape") for x in self.labels])  # hw
        ar = s[:, 0] / s[:, 1]  # aspect ratio
        irect = ar.argsort()
        self.im_files = [self.im_files[i] for i in irect]
        self.labels = [self.labels[i] for i in irect]
        ar = ar[irect]

        # Set training image shapes
        shapes = [[1, 1]] * nb
        for i in range(nb):
            ari = ar[bi == i]
            mini, maxi = ari.min(), ari.max()
            if maxi < 1:
                shapes[i] = [maxi, 1]
            elif mini > 1:
                shapes[i] = [1, 1 / mini]

        self.batch_shapes = (
            np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int)
            * self.stride
        )
        self.batch = bi  # batch index of image

    def __getitem__(self, index):
        return self.get_label_info(index)  # self.transforms(self.get_label_info(index))

    def get_label_info(self, index):
        label = self.labels[index].copy()
        label.pop("shape", None)  # shape is for rect, remove it
        label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(
            index
        )
        label["ratio_pad"] = (
            label["resized_shape"][0] / label["ori_shape"][0],
            label["resized_shape"][1] / label["ori_shape"][1],
        )  # for evaluation
        if self.rect:
            label["rect_shape"] = self.batch_shapes[self.batch[index]]
        label = self.update_labels_info(label)
        return label

    def __len__(self):
        return len(self.labels)

    def update_labels_info(self, label):
        """custom your label format here"""
        return label

    # def build_transforms(self, hyp=None):
    #     """Users can custom augmentations here
    #     like:
    #         if self.augment:
    #             # training transforms
    #             return Compose([])
    #         else:
    #             # val transforms
    #             return Compose([])
    #     """
    #     raise NotImplementedError

    def get_labels(self):
        """Users can custom their own format here.
        Make sure your output is a list with each element like below:
            dict(
                im_file=im_file,
                shape=shape,  # format: (height, width)
                cls=cls,
                bboxes=bboxes, # xywh
                segments=segments,  # xy
                keypoints=keypoints, # xy
                normalized=True, # or False
                bbox_format="xyxy",  # or xywh, ltwh
            )
        """
        raise NotImplementedError

In [4]:
def img2label_paths(img_paths):
    # Define label paths as a function of image paths
    sa, sb = (
        f"{os.sep}images{os.sep}",
        f"{os.sep}labels{os.sep}",
    )  # /images/, /labels/ substrings
    return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]


def get_hash(paths):
    # Returns a single hash value of a list of paths (files or dirs)
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
    h = hashlib.sha256(str(size).encode())  # hash sizes
    h.update("".join(paths).encode())  # hash paths
    return h.hexdigest()  # return hash

def seed_worker(worker_id):  # noqa
    # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [5]:
# Get orientation exif tag


for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == "Orientation":
        break


def exif_size(img):
    # Returns exif-corrected PIL size
    s = img.size  # (width, height)
    with contextlib.suppress(Exception):
        rotation = dict(img._getexif().items())[orientation]
        if rotation in [6, 8]:  # rotation 270 or 90
            s = (s[1], s[0])
    return s


def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
    """
    Check if a directory is writeable.

    Args:
        dir_path (str) or (Path): The path to the directory.

    Returns:
        bool: True if the directory is writeable, False otherwise.
    """
    try:
        with tempfile.TemporaryFile(dir=dir_path):
            pass
        return True
    except OSError:
        return False


def xyxy2xywh(x):
    """
    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.

    Args:
        x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
    Returns:
       y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
    """
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
    y[..., 2] = x[..., 2] - x[..., 0]  # width
    y[..., 3] = x[..., 3] - x[..., 1]  # height
    return y


def segments2boxes(segments):
    """
    It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)

    Args:
      segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates

    Returns:
      (np.ndarray): the xywh coordinates of the bounding boxes.
    """
    boxes = []
    for s in segments:
        x, y = s.T  # segment xy
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
    return xyxy2xywh(np.array(boxes))  # cls, xywh


def verify_image_label(args):
    # Verify one image-label pair
    im_file, lb_file, prefix, keypoint, num_cls = args
    keypoint = False
    # number (missing, found, empty, corrupt), message, segments, keypoints
    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
    try:
        # verify images
        im = Image.open(im_file)
        im.verify()  # PIL verify
        shape = exif_size(im)  # image size
        shape = (shape[1], shape[0])  # hw
        assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
        assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
        if im.format.lower() in ("jpg", "jpeg"):
            with open(im_file, "rb") as f:
                f.seek(-2, 2)
                if f.read() != b"\xff\xd9":  # corrupt JPEG
                    ImageOps.exif_transpose(Image.open(im_file)).save(
                        im_file, "JPEG", subsampling=0, quality=100
                    )
                    msg = (
                        f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
                    )

        # verify labels
        if os.path.isfile(lb_file):
            nf = 1  # label found
            with open(lb_file) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                if any(len(x) > 6 for x in lb) and (not keypoint):  # is segment
                    classes = np.array([x[0] for x in lb], dtype=np.float32)
                    segments = [
                        np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb
                    ]  # (cls, xy1...)
                    lb = np.concatenate(
                        (classes.reshape(-1, 1), segments2boxes(segments)), 1
                    )  # (cls, xywh)
                lb = np.array(lb, dtype=np.float32)
            nl = len(lb)
            if nl:
                if keypoint:
                    assert lb.shape[1] == 56, "labels require 56 columns each"
                    assert (
                        lb[:, 5::3] <= 1
                    ).all(), "non-normalized or out of bounds coordinate labels"
                    assert (
                        lb[:, 6::3] <= 1
                    ).all(), "non-normalized or out of bounds coordinate labels"
                    kpts = np.zeros((lb.shape[0], 39))
                    for i in range(len(lb)):
                        kpt = np.delete(
                            lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)
                        )  # remove occlusion param from GT
                        kpts[i] = np.hstack((lb[i, :5], kpt))
                    lb = kpts
                    assert (
                        lb.shape[1] == 39
                    ), "labels require 39 columns each after removing occlusion parameter"
                else:
                    assert (
                        lb.shape[1] == 5
                    ), f"labels require 5 columns, {lb.shape[1]} columns detected"
                    assert (
                        lb[:, 1:] <= 1
                    ).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
                # All labels
                max_cls = int(lb[:, 0].max())  # max label count
                assert max_cls <= num_cls, (
                    f"Label class {max_cls} exceeds dataset class count {num_cls}. "
                    f"Possible class labels are 0-{num_cls - 1}{lb}"
                )
                assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
                _, i = np.unique(lb, axis=0, return_index=True)
                if len(i) < nl:  # duplicate row check
                    lb = lb[i]  # remove duplicates
                    if segments:
                        segments = [segments[x] for x in i]
                    msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
            else:
                ne = 1  # label empty
                lb = (
                    np.zeros((0, 39), dtype=np.float32)
                    if keypoint
                    else np.zeros((0, 5), dtype=np.float32)
                )
        else:
            nm = 1  # label missing
            lb = (
                np.zeros((0, 39), dtype=np.float32)
                if keypoint
                else np.zeros((0, 5), dtype=np.float32)
            )
        if keypoint:
            keypoints = lb[:, 5:].reshape(-1, 17, 2)
        lb = lb[:, :5]
        return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
    except Exception as e:
        nc = 1
        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
        return [None, None, None, None, None, nm, nf, ne, nc, msg]


In [6]:
class YOLODataset(BaseDataset):
    cache_version = "1.0.1"  # dataset labels *.cache version, >= 1.0.0 for YOLOv8
    rand_interp_methods = [
        cv2.INTER_NEAREST,
        cv2.INTER_LINEAR,
        cv2.INTER_CUBIC,
        cv2.INTER_AREA,
        cv2.INTER_LANCZOS4,
    ]

    def __init__(
        self,
        img_path,
        imgsz=640,
        cache=False,
        augment=True,
        hyp=None,
        prefix="",
        rect=False,
        batch_size=None,
        stride=32,
        pad=0.0,
        single_cls=False,
        names=None,
    ):
        self.names = names
        self.use_segments = None
        self.use_keypoints = None

        super().__init__(
            img_path,
            imgsz,
            cache,
            augment,
            hyp,
            prefix,
            rect,
            batch_size,
            stride,
            pad,
            single_cls,
        )

    def cache_labels(self, path=Path("./labels.cache")):
        """Cache dataset labels, check images and read shapes.
        Args:
            path (Path): path where to save the cache file (default: Path('./labels.cache')).
        Returns:
            (dict): labels.
        """
        x = {"labels": []}
        nm, nf, ne, nc, msgs = (
            0,
            0,
            0,
            0,
            [],
        )  # number missing, found, empty, corrupt, messages
        desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
        total = len(self.im_files)
        with ThreadPool(NUM_THREADS) as pool:
            results = pool.imap(
                func=verify_image_label,
                iterable=zip(
                    self.im_files,
                    self.label_files,
                    repeat(self.prefix),
                    repeat(self.use_keypoints),
                    repeat(len(self.names)),
                ),
            )
            pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
            for (
                im_file,
                lb,
                shape,
                segments,
                keypoint,
                nm_f,
                nf_f,
                ne_f,
                nc_f,
                msg,
            ) in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x["labels"].append(
                        dict(
                            im_file=im_file,
                            shape=shape,
                            cls=lb[:, 0:1],  # n, 1
                            bboxes=lb[:, 1:],  # n, 4
                            segments=segments,
                            keypoints=keypoint,
                            normalized=True,
                            bbox_format="xywh",
                        )
                    )
                if msg:
                    msgs.append(msg)
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            pbar.close()

        if msgs:
            LOGGER.info("\n".join(msgs))
        if nf == 0:
            LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}.")
        x["hash"] = get_hash(self.label_files + self.im_files)
        x["results"] = nf, nm, ne, nc, len(self.im_files)
        x["msgs"] = msgs  # warnings
        x["version"] = self.cache_version  # cache version
        if is_dir_writeable(path.parent):
            if path.exists():
                path.unlink()  # remove *.cache file if exists
            np.save(str(path), x)  # save cache for next time
            path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix
            LOGGER.info(f"{self.prefix}New cache created: {path}")
        else:
            LOGGER.warning(
                f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved."
            )
        return x

    def get_labels(self):
        self.label_files = img2label_paths(self.im_files)
        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
        try:
            cache, exists = (
                np.load(str(cache_path), allow_pickle=True).item(),
                True,
            )  # load dict
            assert cache["version"] == self.cache_version  # matches current version
            assert cache["hash"] == get_hash(
                self.label_files + self.im_files
            )  # identical hash
        except (FileNotFoundError, AssertionError, AttributeError):
            cache, exists = self.cache_labels(cache_path), False  # run cache ops

        # Display cache
        nf, nm, ne, nc, n = cache.pop(
            "results"
        )  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            tqdm(
                None,
                desc=self.prefix + d,
                total=n,
                initial=n,
                bar_format=TQDM_BAR_FORMAT,
            )  # display cache results
            if cache["msgs"]:
                LOGGER.info("\n".join(cache["msgs"]))  # display warnings
        if nf == 0:  # number of labels found
            raise FileNotFoundError(
                f"{self.prefix}No labels found in {cache_path}, can not start training."
            )

        # Read cache
        [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
        labels = cache["labels"]
        self.im_files = [lb["im_file"] for lb in labels]  # update im_files

        # Check if the dataset is all boxes or all segments
        lengths = (
            (len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels
        )
        len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
        if len_segments and len_boxes != len_segments:
            LOGGER.warning(
                f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
                f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
                "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
            )
            for lb in labels:
                lb["segments"] = []
        if len_cls == 0:
            raise ValueError(
                f"All labels empty in {cache_path}, can not start training without labels."
            )
        return labels

    # TODO: use hyp config to set all these augmentations
    # def build_transforms(self, hyp=None):
    #     if self.augment:
    #         hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
    #         hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
    #         transforms = v8_transforms(self, self.imgsz, hyp)
    #     else:
    #         transforms = Compose(
    #             [LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]
    #         )
    #     transforms.append(
    #         Format(
    #             bbox_format="xywh",
    #             normalize=True,
    #             return_mask=self.use_segments,
    #             return_keypoint=self.use_keypoints,
    #             batch_idx=True,
    #             mask_ratio=hyp.mask_ratio,
    #             mask_overlap=hyp.overlap_mask,
    #         )
    #     )
    #     return transforms

In [7]:
def colorstr(*input):
    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
    *args, string = (
        input if len(input) > 1 else ("blue", "bold", input[0])
    )  # color arguments, string
    colors = {
        "black": "\033[30m",  # basic colors
        "red": "\033[31m",
        "green": "\033[32m",
        "yellow": "\033[33m",
        "blue": "\033[34m",
        "magenta": "\033[35m",
        "cyan": "\033[36m",
        "white": "\033[37m",
        "bright_black": "\033[90m",  # bright colors
        "bright_red": "\033[91m",
        "bright_green": "\033[92m",
        "bright_yellow": "\033[93m",
        "bright_blue": "\033[94m",
        "bright_magenta": "\033[95m",
        "bright_cyan": "\033[96m",
        "bright_white": "\033[97m",
        "end": "\033[0m",  # misc
        "bold": "\033[1m",
        "underline": "\033[4m",
    }
    return "".join(colors[x] for x in args) + f"{string}" + colors["end"]


mode = "train"
img_path = "coco128/"
batch = 2
imgsz = 224
names = (
    {
        0: "person",
        1: "bicycle",
        2: "car",
        3: "motorcycle",
        4: "airplane",
        5: "bus",
        6: "train",
        7: "truck",
        8: "boat",
        9: "traffic light",
        10: "fire hydrant",
        11: "stop sign",
        12: "parking meter",
        13: "bench",
        14: "bird",
        15: "cat",
        16: "dog",
        17: "horse",
        18: "sheep",
        19: "cow",
        20: "elephant",
        21: "bear",
        22: "zebra",
        23: "giraffe",
        24: "backpack",
        25: "umbrella",
        26: "handbag",
        27: "tie",
        28: "suitcase",
        29: "frisbee",
        30: "skis",
        31: "snowboard",
        32: "sports ball",
        33: "kite",
        34: "baseball bat",
        35: "baseball glove",
        36: "skateboard",
        37: "surfboard",
        38: "tennis racket",
        39: "bottle",
        40: "wine glass",
        41: "cup",
        42: "fork",
        43: "knife",
        44: "spoon",
        45: "bowl",
        46: "banana",
        47: "apple",
        48: "sandwich",
        49: "orange",
        50: "broccoli",
        51: "carrot",
        52: "hot dog",
        53: "pizza",
        54: "donut",
        55: "cake",
        56: "chair",
        57: "couch",
        58: "potted plant",
        59: "bed",
        60: "dining table",
        61: "toilet",
        62: "tv",
        63: "laptop",
        64: "mouse",
        65: "remote",
        66: "keyboard",
        67: "cell phone",
        68: "microwave",
        69: "oven",
        70: "toaster",
        71: "sink",
        72: "refrigerator",
        73: "book",
        74: "clock",
        75: "vase",
        76: "scissors",
        77: "teddy bear",
        78: "hair drier",
        79: "toothbrush",
    },
)


dataset = YOLODataset(
    img_path=img_path,
    imgsz=imgsz,
    batch_size=batch,
    augment=mode == "train",  # augmentation
    hyp=None,  # TODO: probably add a get_hyps_from_cfg function
    rect=False,  # rectangular batches
    cache=False,
    single_cls=False,
    stride=1,
    pad=0.0 if mode == "train" else 0.5,
    prefix=colorstr(f"{mode}: "),
    names=names,
)


[34m[1mtrain: [0mScanning coco128/labels/train2017.cache... 128 images, 0 backgrounds, 0 corrupt: 100%|██████████| 128/128 [00:00<?, ?it/s]


In [8]:
batch = min(batch, len(dataset))
loader = DataLoader
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
train_loader = loader(
    dataset=dataset,
    batch_size=batch,
    shuffle=True,
    pin_memory=PIN_MEMORY,
    # collate_fn=getattr(dataset, "collate_fn", None),
    worker_init_fn=seed_worker,
    generator=generator,
)


In [9]:
next(iter(dataset))

{'im_file': 'coco128/images/train2017/000000000009.jpg',
 'cls': array([[45.],
        [45.],
        [50.],
        [45.],
        [49.],
        [49.],
        [49.],
        [49.]], dtype=float32),
 'bboxes': array([[0.479492 , 0.688771 , 0.955609 , 0.5955   ],
        [0.736516 , 0.247188 , 0.498875 , 0.476417 ],
        [0.637063 , 0.732938 , 0.494125 , 0.510583 ],
        [0.339438 , 0.418896 , 0.678875 , 0.7815   ],
        [0.646836 , 0.132552 , 0.118047 , 0.0969375],
        [0.773148 , 0.129802 , 0.0907344, 0.0972292],
        [0.668297 , 0.226906 , 0.131281 , 0.146896 ],
        [0.642859 , 0.0792187, 0.148063 , 0.148062 ]], dtype=float32),
 'segments': [],
 'keypoints': None,
 'normalized': True,
 'bbox_format': 'xywh',
 'img': array([[[115,  21,   2],
         [115,  22,   1],
         [115,  22,   1],
         ...,
         [205, 177, 146],
         [198, 174, 140],
         [191, 170, 134]],
 
        [[115,  22,   1],
         [116,  23,   2],
         [117,  24,   1],


In [10]:
dataset[0]

{'im_file': 'coco128/images/train2017/000000000009.jpg',
 'cls': array([[45.],
        [45.],
        [50.],
        [45.],
        [49.],
        [49.],
        [49.],
        [49.]], dtype=float32),
 'bboxes': array([[0.479492 , 0.688771 , 0.955609 , 0.5955   ],
        [0.736516 , 0.247188 , 0.498875 , 0.476417 ],
        [0.637063 , 0.732938 , 0.494125 , 0.510583 ],
        [0.339438 , 0.418896 , 0.678875 , 0.7815   ],
        [0.646836 , 0.132552 , 0.118047 , 0.0969375],
        [0.773148 , 0.129802 , 0.0907344, 0.0972292],
        [0.668297 , 0.226906 , 0.131281 , 0.146896 ],
        [0.642859 , 0.0792187, 0.148063 , 0.148062 ]], dtype=float32),
 'segments': [],
 'keypoints': None,
 'normalized': True,
 'bbox_format': 'xywh',
 'img': array([[[115,  21,   2],
         [115,  22,   1],
         [115,  22,   1],
         ...,
         [205, 177, 146],
         [198, 174, 140],
         [191, 170, 134]],
 
        [[115,  22,   1],
         [116,  23,   2],
         [117,  24,   1],


In [11]:
dataset[0]["img"].shape

(168, 224, 3)

In [12]:
batch

2

In [15]:
dataset[1]

{'im_file': 'coco128/images/train2017/000000000025.jpg',
 'cls': array([[23.],
        [23.]], dtype=float32),
 'bboxes': array([[0.770336, 0.489695, 0.335891, 0.697559],
        [0.185977, 0.901608, 0.206297, 0.129554]], dtype=float32),
 'segments': [],
 'keypoints': None,
 'normalized': True,
 'bbox_format': 'xywh',
 'img': array([[[  3,   8,   7],
         [  4,   9,   8],
         [  5,  10,   9],
         ...,
         [191, 192, 192],
         [158, 160, 161],
         [137, 141, 142]],
 
        [[  4,   9,   8],
         [  4,   9,   8],
         [  4,   9,   8],
         ...,
         [225, 221, 217],
         [214, 211, 209],
         [209, 207, 207]],
 
        [[  5,  10,   9],
         [  4,   9,   8],
         [  3,   8,   7],
         ...,
         [228, 220, 211],
         [229, 222, 215],
         [228, 221, 216]],
 
        ...,
 
        [[ 80, 142, 165],
         [ 66, 139, 160],
         [ 48, 132, 152],
         ...,
         [ 67,  92, 133],
         [107, 133, 1