In [None]:
from lib.data.cityscapes import CityscapesDataset

ds = CityscapesDataset(
    root="data/cityscapes",
    split="val",
    train_size=(1024, 1024),
    val_size=(1024, 2048),
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
    ignore_index=255,
    bbox_format="pascal_voc",
    logger=None,
)

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt


def process_image(image, ds):
    image = (image.transpose(1, 2, 0) * np.array(ds.std) + np.array(ds.mean)) * 255
    image = image.astype(np.uint8)
    return image


def process_mask(seg_mask, ds):
    colored_mask = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
    for id, color in enumerate(ds.COLOR_PALETTE):
        colored_mask[seg_mask == id] = ds.COLOR_PALETTE[color]
    return colored_mask


def process_bboxes(image, targets, ds):
    for idx, bbox in enumerate(targets["bboxes"]):
        bbox = bbox.numpy().astype(np.int32)
        label = (
            ds.LOCALIZATION_CLASSES[targets["labels"][idx].item()]
            + " "
            + str(targets["labels"][idx].item())
        )
        color = ds.COLOR_PALETTE[11 + targets["labels"][idx].item()]

        image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 5)

        (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.2, 3)
        image = cv2.rectangle(
            image, (bbox[0], bbox[1] - h - 10), (bbox[0] + w, bbox[1]), color, -1
        )
        image = cv2.putText(
            image,
            label,
            (bbox[0], bbox[1] - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            1.2,
            [255, 255, 255],
            3,
        )
    return image


targets = ds[2]
image = process_image(targets["image"].numpy(), ds)
seg_mask = targets["mask"].numpy().astype(np.uint8)
colored_mask = process_mask(seg_mask, ds)
image = process_bboxes(image.copy(), targets, ds)

plt.figure(figsize=(10, 20))
plt.title(f"Image: {targets['info']['name']}")
plt.imshow(image)
plt.imshow(colored_mask, alpha=0.55)
plt.axis("off")
plt.show()