In [1]:
%matplotlib inline

In [2]:
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

from matplotlib.offsetbox import AnnotationBbox, OffsetImage
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm

In [3]:
BASE_DIR = 'sprites/no-bg/'

In [4]:
# Nintendo DS has a 4:3 aspect ratio
DS_SCREEN_SIZE = (191, 255)


ALL_IMAGES = list(Path(BASE_DIR).glob('*.png'))


CHARACTER_BOX_COLOR_DICTIONARY = {
    'luigi': '#34eb43',
    'wario': '#ebdb34',
    'yoshi': '#cc1bae',
    'mario': '#cc1b1b',
}

COCO_CHARACTER_CATEGORIES_DICT = {
    'categories': [
        {
            'id': 1,
            'name': 'mario',
        },
        {
            'id': 2,
            'name': 'luigi',
        },
        {
            'id': 3,
            'name': 'wario',
        },
        {
            'id': 4,
            'name': 'yoshi',
        },
    ],
}

In [5]:
def create_wanted_minigame_screen(
    num_characters: int = 15,
    image_filename: Optional[str] = None,
    plot_aspect_ratio: Tuple[int, int] = DS_SCREEN_SIZE,
    display_plot: bool = True,
    display_boxes: bool = False,
    proximity_threshold: float = 5,
    zoom_multiplier: float = 1,
    pad_inches: float = 0.0,
) -> List[Tuple[str, Tuple[float, float, float, float]]]:
    """
    Create a mockup of a "Wanted!" minigame screenshot entirely using Matplotlib.

    Parameters
    ----------
    num_characters: int
        Number of characters to generate in the screenshot, each randomly chosen from sprites of
        Mario, Luigi, Wario, or Yoshi. If ``proximity_threshold`` is too large, there is a chance
        that characters are plotted in a way that it is not possible to satisfy the conditions
        needed to plot additional characters. In this case, fewer than ``num_characters`` characters
        will be plotted in the screenshot. Currently, this happens silently
    image_filename: str
        Filename to name the screenshot when saving to disk. If ``None``, the screenshot will not
        be saved
    plot_aspect_ratio: tuple
        Aspect ratio to use as the base of the plot. Note that because we save the plot with
        ``bbox_inches='tight'``, the saved image will likely NOT have this aspect ratio. Instead,
        it is used as a base for the plot
    display_plot: bool
        Whether or not to display the plot after drawing but before saving
    display_boxes: bool
        Whether or not to draw boxes on the characters in the plot before saving
    proximity_threshold: float
        Threshold for how close a character can be to another character to be plotted. This is used
        to avoid plots where a character is plotted right on top of another, where it would be
        impossible for a model (or human) to detect both characters correctly. Note that this value
        squared will be scaled up with ``zoom_multiplier``
    zoom_multiplier: float
        Determines how much to zoom to apply to the character sprites and character proximity
        threshold ``proximity_threshold``. A larger ``zoom_multiplier`` means larger character
        sprites on the plot
    pad_inches: float
        Number of inches to pad the saved image with. A negative value means characters may be
        cropped during saving

    Returns
    -------
    characters_bbox_list: list of tuples
        List of characters and their bounding boxes in the generated plot. Format is a list of
        tuples with the first value in the tuple being the character name and the second value being
        another tuple containing the bounding box x-position, y-position, width, and height (as
        assigned by Matplotlib). The length of ``characters_bbox_list`` will be equal to the number
        of characters generated in the plot

    """
    fig, ax = plt.subplots()

    characters_chosen = list()
    used_coordinates_list = list()

    ax.imshow(X=np.zeros(shape=plot_aspect_ratio + (3,)))

    for i in range(num_characters):
        while_loop_counter = 0

        while while_loop_counter <= 100:
            while_loop_counter += 1

            x = np.random.randint(low=0, high=plot_aspect_ratio[1])
            y = np.random.randint(low=0, high=plot_aspect_ratio[0])

            # avoid overlap with previously-plotted characters
            for (used_x, used_y) in used_coordinates_list:
                if (
                    abs(used_x - x) < (proximity_threshold * (zoom_multiplier ** 2))
                    and abs(used_y - y) < (proximity_threshold * (zoom_multiplier ** 2))
                ):
                    # break out of sub-for loop
                    break
            else:
                # break out of while loop
                break

        if while_loop_counter > 100:
            # character can't fit without occlusion - stop generating
            break

        used_coordinates_list.append((x, y))

        character_chosen_path = Path(np.random.choice(ALL_IMAGES))
        character_chosen = character_chosen_path.stem
        characters_chosen.append(character_chosen)

        img = plt.imread(fname=character_chosen_path)
        img = OffsetImage(arr=img, zoom=zoom_multiplier)
        img.image.axes = ax

        bbox_kwargs = {
            'frameon': True,
            'bboxprops': {
                'edgecolor': CHARACTER_BOX_COLOR_DICTIONARY.get(character_chosen),
            },
            'pad': 0.1,
        }

        ab = AnnotationBbox(
            offsetbox=img,
            xy=(x, y),
            xycoords='data',
            **(bbox_kwargs if display_boxes else {}),
        )
        ab.patch._facecolor = (0, 0, 0, 0)

        if not display_boxes:
            # set the edges of the ``AnnotationBbox`` transparent
            ab.patch._edgecolor = (0, 0, 0, 0)

        ax.add_artist(a=ab)

    ax.set_axis_off()

    if display_plot:
        plt.show()

    if image_filename:
        Path(image_filename).parent.mkdir(parents=True, exist_ok=True)

        fig.savefig(
            fname=image_filename,
            bbox_inches='tight',
            pad_inches=pad_inches,
            facecolor=(0, 0, 0),  # black color code
        )

    plt.close(fig)

    return list(
        zip(
            characters_chosen,
            [
                (ab.patch._x, ab.patch._y, ab.patch._width, ab.patch._height)
                for ab in ax.artists
            ],
        )
    )


def generate_coco_annotations_and_images_dict(
    characters_bbox_list: List[Tuple[str, Tuple[float, float, float, float]]],
    starting_character_id: int,
    image_id: int,
    image_filename: str,
) -> List[Dict[str, Union[int, float, List[float]]]]:
    """
    Convert the output of ``create_wanted_minigame_screen`` into dictionaries in a valid COCO
    annotation format.

    Parameters
    ----------
    characters_bbox_list: list
        Output of ``create_wanted_minigame_screen``. Format is a list of tuples with the first value
        in the tuple being the character name and the second value being another tuple containing
        the bounding box x-position, y-position, width, and height (as assigned by Matplotlib).
    starting_character_id: int
        Starting character ID to be used for the character annotation. COCO expects that each
        character annotated has a unique character ID across all images and annotations. This
        function assigns the first character's ID to be ``starting_character_id`` and increments
        this by ``1`` for each additional character. If this function is called multiple times,
        the ``starting_character_id`` provided should be externally incremented by
        ``len(characters_bbox_list)``
    image_id: int
        Unique image ID to represent the image
    image_filename: str
        Filename where the image is saved to on disk

    Returns
    -------
    coco_annotations_list: list of dicts
        List of dictionaries containing annotations in COCO format. Each sub-dictionary has the
        following format:

            * id: int
                Character ID

            * image_id: int
                Image ID

            * category_id: int
                Category ID to represent the character plotted (Mario, Luigi, Wario, or Yoshi)

            * segmentation: int
                Ignored for this project, will always just be an empty list

            * area: int
                Area of the character bounding box

            * bbox: int
                Bounding box coordinates in COCO format: ``(x, y, w, h)``

            * iscrowd: int
                Ignored for this project, will always just be ``0``

        There will be a dictionary for each character provided in ``characters_bbox_list`` if the
        character is on the screen in any capacity

    coco_images_list: list of dict
        List of dictionaries containing the images in COCO format. Each sub-dictionary has the
        following format:

            * id: int
                Image ID

            * width: int

            * height: int

            * file_name: str

    """
    coco_annotations_list = list()
    coco_images_list = list()

    image_height, image_width, _ = plt.imread(fname=image_filename).shape

    for idx, (character_name, (bbox_x, bbox_y, bbox_w, bbox_h)) in enumerate(characters_bbox_list):
        bbox_y = image_height - bbox_y - bbox_h

        # if entire character box is outside range, don't label it
        if (
            (bbox_x + bbox_w) <= 0  # too far left
            or bbox_x >= image_width  # too far right
            or (bbox_y + bbox_h) <= 0  # too far up
            or bbox_y >= image_height  # too far down
        ):
            continue

        adjusted_bbox_x = max(bbox_x, 0)
        adjusted_bbox_y = max(bbox_y, 0)
        adjusted_bbox_w = bbox_w
        adjusted_bbox_h = bbox_h

        if bbox_x <= 0:
            adjusted_bbox_w = bbox_x + adjusted_bbox_w
        elif (bbox_x + adjusted_bbox_w) >= image_width:
            adjusted_bbox_w = image_width - bbox_x

        if bbox_y <= 0:
            adjusted_bbox_h = bbox_y + adjusted_bbox_h
        elif (bbox_y + adjusted_bbox_h) >= image_height:
            adjusted_bbox_h = image_height - bbox_y

        coco_annotations_list.append(
            {
                'id': starting_character_id + idx,
                'image_id': image_id,
                'category_id': next(
                    d['id']
                    for d in COCO_CHARACTER_CATEGORIES_DICT['categories']
                    if d['name'] == character_name
                ),
                'segmentation': [],  # N/A
                'area': (adjusted_bbox_w * adjusted_bbox_h),
                # adjusting all four bbox coordinates to not go off screen
                'bbox': [
                    adjusted_bbox_x,
                    adjusted_bbox_y,
                    adjusted_bbox_w,
                    adjusted_bbox_h,
                ],
                'iscrowd': 0,  # N/A
            }
        )

    coco_images_list.append(
        {
            'id': image_id,
            'width': image_width,
            'height': image_height,
            'file_name': str(image_filename),
        },
    )

    return coco_annotations_list, coco_images_list


def generate_object_detection_dataset(
    num_examples_to_generate: int = 20,
    num_characters_range: Tuple[int, int] = (4, 75),
    save_dir: str = 'data/',
    zoom_multiplier_possible_values: Iterable[float] = [1],
    pad_inches_possible_values: Iterable[float] = [0.0],
    **kwargs,
) -> Dict[str, List[Dict[str, Union[int, float, List[float]]]]]:
    """
    Generate and save a collection of both mockup "Wanted!" minigame screenshot entirely using
    Matplotlib and the COCO annotations of all characters generated.

    Parameters
    ----------
    num_examples_to_generate: int
        Number of "Wanted!" minigame screenshots to generate
    num_characters_range: tuple of two ints
        Range of the number of characters to generate in a single screenshot. For each example
        generated out of ``num_examples_to_generate``, a random integer in the range ``(low, high)``
        will be selected and passed to ``create_wanted_minigame_screen`` as the ``num_characters``
        argument
    save_dir: str
        Directory to save the images and annotations to. Images will be saved in the
        ``Path(save_dir) / 'images'`` directory with format ``image_{image_ID}.png``. Annotations
        will be saved at ``Path(save_dir) / 'annotations.json'``. Note that anything in ``save_dir``
        with the same name as another object saved in this function will be overwritten
    zoom_multiplier_possible_values: list of float
        Possible values to randomly sample from and pass into ``create_wanted_minigame_screen`` as
        the ``zoom_multiplier`` argument
    pad_inches_possible_values: list of float
        Possible values to randomly sample from and pass into ``create_wanted_minigame_screen`` as
        the ``pad_inches`` argument
    kwargs: keyword arguments
        Keyword arguments passed into ``create_wanted_minigame_screen``

    Returns
    -------
    coco_annotations_dict: dict of lists
        COCO annotations including the character annotations and image data. This dictionary will
        exactly match the one saved at ``Path(save_dir) / 'annotations.json'``

    """
    assert len(num_characters_range) == 2

    character_id_counter = 1
    coco_annotations_list = list()
    coco_images_list = list()

    image_directory = Path(save_dir) / 'images'
    annotations_filename = Path(save_dir) / 'annotations.json'

    for image_id in tqdm(range(1, num_examples_to_generate + 1)):
        num_characters = np.random.randint(
            low=num_characters_range[0],
            high=num_characters_range[1],
        )

        image_filename = Path(image_directory) / f'image_{image_id:05d}.png'

        characters_bbox_list = create_wanted_minigame_screen(
            num_characters=num_characters,
            image_filename=image_filename,
            plot_aspect_ratio=DS_SCREEN_SIZE,
            zoom_multiplier=np.random.choice(
                a=zoom_multiplier_possible_values,
            ),
            pad_inches=np.random.choice(
                a=pad_inches_possible_values,
            ),
            **kwargs,
        )

        (
            coco_annotations_list_updates,
            coco_images_list_updates,
        ) = generate_coco_annotations_and_images_dict(
            characters_bbox_list=characters_bbox_list,
            starting_character_id=character_id_counter,
            image_id=image_id,
            image_filename=image_filename,
        )

        coco_annotations_list += coco_annotations_list_updates
        coco_images_list += coco_images_list_updates

        character_id_counter += num_characters

    # format and save off annotations dictionary
    coco_annotations_dict = {
        **COCO_CHARACTER_CATEGORIES_DICT,
        'images': coco_images_list,
        'annotations': coco_annotations_list,
    }

    annotations_filename.parent.mkdir(parents=True, exist_ok=True)
    with open(str(annotations_filename), 'w') as fp:
        json.dump(coco_annotations_dict, fp)

    return coco_annotations_dict

In [6]:
coco_annotations_dict = generate_object_detection_dataset(
    num_examples_to_generate=15_000,
    num_characters_range=(4, 150),
    save_dir=Path('../data/').resolve(),
    display_plot=False,
    display_boxes=False,
    proximity_threshold=5,
    zoom_multiplier_possible_values=[1] * 7 + [2] * 2 + [3],
    pad_inches_possible_values=[0] * 5 + [-0.25] * 2 + [-0.5] * 2 + [-0.75],
)


len(coco_annotations_dict['images']), len(coco_annotations_dict['annotations'])

  0%|          | 0/15000 [00:00<?, ?it/s]

(15000, 946677)

In [7]:
# TODO: if occlusion if an issue in the future: go in reverse order, establish coordinates, ensure boxes are not overlapping

----- 