In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
from itertools import product
from typing import List, Tuple

from munch import Munch

# Start with what you want to have and write tests.

The model receives as its input

- An image that is represented by a list of entities, where each entity `e` is an 9-tuple consisting of its name `e.name` and its four bounding box coordinates `e.bbox = (bottom_left, top_left, top_right, bottom_right)`, which comprise the remaining $8 = 4 \times 2$ entries of the 9-tuple.
    - The orientation of each entity is determined by its second and third coordinates.
    - To process the input one could use Transformer or graph neural networks, where the graph is supposed to be complete (i.e., each pair of vertices is joined by an edge).
- A question that is represented by a triple `(entity1, relation, entity2)` or `(entity1, relation, entity2, entity3)`

There may be already an existing model that is similar to this. I would need to check this out.

## Dataset Generator

In [None]:
import numpy as np
from PIL import Image, ImageDraw, ImageOps
from qsr_learning.data.bounding_box_data import (
    above,
    below,
    generate_entities,
    left_of,
    right_of,
)


def draw(entity, base, orientation_marker=True):
    d = ImageDraw.Draw(base)
    d.polygon(
        [tuple(p) for p in entity.bbox],
        fill=entity.color if entity.color else (0, 0, 0, 255),
    )
    # Use the tenth of the bounding box (from the top) for marking the front side of an entity.
    if orientation_marker:
        bottom_left = ((entity.bbox_float[0] - entity.bbox_float[1]) / 10).astype(
            int
        ) + entity.top_left
        bottom_right = ((entity.bbox_float[3] - entity.bbox_float[2]) / 10).astype(
            int
        ) + entity.top_right
        d.polygon(
            [
                tuple(p)
                for p in (bottom_left, entity.top_left, entity.top_right, bottom_right)
            ],
            fill=(255, 255, 255, 255),
        )
    return base


def show(entities, base, orientation_marker=True):
    for entity in entities:
        draw(entity, base)
    base = ImageOps.flip(base)
    d = ImageDraw.Draw(base)
    for entity in entities:
        # Add the name
        center = (
            ((entity.bbox_float[0] - entity.bbox_float[2]) / 2 + entity.bbox_float[2])
            .astype(int)
            .tolist()
        )
        d.text(
            (center[0]-1, base.size[1] - center[1]-2),
            entity.name,
            fill="pink",
            direction="ttb",
        )
    return base

In [None]:
import random

import numpy as np


def generate_examples(entities, relations, size=None) -> List[Tuple[str, str, str]]:
    """
    Generate positive examples from a list of entities.

    :param entities: a list of entities
    :param size: the number of examples to be generated

    :returns: a list of triples (entity1, relation, entity2)
    """
    # Generate positive examples
    positive_examples = []
    negative_examples = []
    for entity1, entity2 in product(entities, entities):
        if entity1 != entity2:
            for rel in relations:
                if rel(entity1, entity2):
                    positive_examples.append(
                        (
                            entity1.name,
                            rel.__name__,
                            entity2.name,
                        )
                    )
                else:
                    negative_examples.append(
                        (
                            entity1.name,
                            rel.__name__,
                            entity2.name,
                        )
                    )
    return positive_examples, negative_examples

In [None]:
entities = generate_entities(5)

In [None]:
canvas = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
display(show(entities, canvas))
relations = [left_of, right_of, above, below]

In [None]:
entities

In [None]:
positive_examples, negative_examples = generate_examples(entities, relations)

## Dataset

In [None]:
import torch

from torch.utils.data import Dataset, DataLoader

class RelationLearningDataset(Dataset):
    def __init__(self):
        super().__init__()

    def __getitem__(self, idx):

        # [entity_id, bottom_left_x, bottom_left_y, top_left_x, top_left_y, top_right_x, top_right_y, bottom_right_x, bottom_right_y]
        image = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 0], dtype=torch.float)
        question = torch.tensor([0, 0, 1], dtype=torch.float)
        return image, question

    def __len__(self):
        return 1

loader = DataLoader(RelationLearningDataset(), batch_size=1)

for batch in loader:
    print(batch)

# Tests

## Entities

In [None]:
from PIL import Image, ImageDraw, ImageFont, ImageOps
from qsr_learning.data.bounding_box_data import Entity
entity1 = Entity(32, 32, p=(32, 32), theta=30)
entity2 = Entity(32, 32, p=(64, 64), theta=30)

# make a blank image for the text, initialized to transparent text color
base = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
entity1.draw(base)
ImageOps.flip(entity2.draw(base))

## Relations

In [None]:
import numpy as np
from ipywidgets import interact

from qsr_learning.data.bounding_box_data import above, below, left_of, right_of


@interact(
    x1=(0, 150),
    y1=(0, 150),
    w1=(10, 150),
    h1=(10, 150),
    theta1=(0.0, 2 * np.pi, 2 * np.pi / 360),
    x2=(0, 150),
    y2=(0, 150),
    w2=(10, 150),
    h2=(10, 150),
    theta2=(0.0, 2 * np.pi, 2 * np.pi / 360),
)
def test_spatial_relations(
    x1=3, y1=30, w1=30, h1=30, theta1=0.0, x2=60, y2=60, w2=30, h2=30, theta2=0.0
):
    base = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
    entity1 = Entity(w1, h1, p=(x1, y1), theta=theta1, name="green", color="green")
    entity2 = Entity(w2, h2, p=(x2, y2), theta=theta2, name="red", color="red")
    entity1.draw(base)
    entity2.draw(base)
    display(ImageOps.flip(base))
    print(entity1)
    print(entity2)
    print(
        "left_of({}, {}): {}".format(
            entity1.name, entity2.name, left_of(entity1, entity2)
        )
    )
    print(
        "right_of({}, {}): {}".format(
            entity1.name, entity2.name, right_of(entity1, entity2)
        )
    )
    print(
        "above({}, {}): {}".format(entity1.name, entity2.name, above(entity1, entity2))
    )
    print(
        "below({}, {}): {}".format(entity1.name, entity2.name, below(entity1, entity2))
    )