In [None]:
%reload_ext autoreload
%autoreload 2

# Scratch

## Entity

## Relations

In [None]:
def point_left_of_directed_line(point, dline):
    p = point
    q = dline.point
    d = dline.vector
    d_rot = np.array([-d[1], d[0]])  # d rotated by 90 degree
    return (p - q) @ d_rot > 0


def point_right_of_directed_line(point, dline):
    p = point
    q = dline.point
    d = dline.vector
    d_rot = np.array([-d[1], d[0]])  # d rotated by 90 degree
    return (p - q) @ d_rot < 0


def left_of(entity1, entity2):
    dline = Munch(
        point=entity2.bottom_left, vector=entity2.top_left - entity2.bottom_left
    )
    return all(
        point_left_of_directed_line(p, dline)
        for p in (
            entity1.top_left,
            entity1.top_right,
            entity1.bottom_left,
            entity1.bottom_right,
        )
    )


def right_of(entity1, entity2):
    dline = Munch(
        point=entity2.bottom_right, vector=entity2.top_right - entity2.bottom_right
    )
    return all(
        point_right_of_directed_line(p, dline)
        for p in (
            entity1.top_left,
            entity1.top_right,
            entity1.bottom_left,
            entity1.bottom_right,
        )
    )


def above(entity1, entity2):
    dline = Munch(point=entity2.top_left, vector=entity2.top_right - entity2.top_left)
    return all(
        point_left_of_directed_line(p, dline)
        for p in (
            entity1.top_left,
            entity1.top_right,
            entity1.bottom_left,
            entity1.bottom_right,
        )
    )


def below(entity1, entity2):
    dline = Munch(
        point=entity2.bottom_left, vector=entity2.bottom_right - entity2.bottom_left
    )
    return all(
        point_right_of_directed_line(p, dline)
        for p in (
            entity1.top_left,
            entity1.top_right,
            entity1.bottom_left,
            entity1.bottom_right,
        )
    )


# Tests
def test_left_of():
    entity1 = Entity(2, 3)
    entity2 = Entity(2, 3, p=(3, 3))
    assert left_of(entity1, entity2)


def test_right_of():
    entity1 = Entity(2, 3)
    entity2 = Entity(2, 3, p=(3, 3))
    assert right_of(entity2, entity1)

## Dataset Generator

In [None]:
import math
import random
from copy import deepcopy
from itertools import product
from typing import Callable

from PIL import ImageOps

canvas_size = Munch(w=224, h=224)
num_entities = 5
ranges = Munch(w=(10, 30), h=(10, 30), x=(0, 224), y=(0, 224), theta=(0, math.pi))
relations = (left_of, right_of, above, below)


def inside_canvas(entity, canvas_size):
    xs_inside_canvas = all(
        (0 < entity.bbox[:, 0]) & (entity.bbox[:, 0] < canvas_size.w)
    )
    ys_inside_canvas = all(
        (0 < entity.bbox[:, 1]) & (entity.bbox[:, 1] < canvas_size.h)
    )
    return xs_inside_canvas and ys_inside_canvas


def in_relation(
    entity1: Entity, entity2: Entity, relations: Callable[[Entity, Entity], bool]
) -> bool:
    """Check whether entity1 and entity2 satisfy any of the given relations."""
    return any(relation(entity1, entity2) for relation in relations)


entities_origin = []
for i in range(num_entities):
    w = random.uniform(*ranges.w)
    h = random.uniform(*ranges.h)
    entities_origin.append(Entity(w, h, name=str(i)))

In [None]:
entities_in_canvas = entities_in_relation = False
while not (entities_in_canvas and entities_in_relation):
    entities = deepcopy(entities_origin)
    for entity in entities:
        theta = random.uniform(*ranges.theta)
        entity.rotate(theta)

    for entity in entities:
        p = (random.uniform(*ranges.x), random.uniform(*ranges.y))
        entity.translate(p)

    entities_in_canvas = all(inside_canvas(entity, canvas_size) for entity in entities)
    entities_in_relation = all(
        in_relation(entity1, entity2, relations)
        for entity1, entity2 in product(entities, repeat=2)
        if entity1 != entity2
    )

In [None]:
# Show the image vertically flipped
canvas = Image.new("RGBA", (canvas_size.h, canvas_size.w), (255, 255, 255, 20))
for entity in entities:
    entity.draw(canvas)
display(ImageOps.flip(canvas))

# Start with what you want to have and write tests.

The model receives as its input

- An image that is represented by a list of entities, where each entity `e` is an 9-tuple consisting of its name `e.name` and its four bounding box coordinates `e.bbox = (bottom_left, top_left, top_right, bottom_right)`, which comprise the remaining $8 = 4 \times 2$ entries of the 9-tuple.
    - The orientation of each entity is determined by its second and third coordinates.
    - To process the input one could use Transformer or graph neural networks, where the graph is supposed to be complete (i.e., each pair of vertices is joined by an edge).
- A question that is represented by a triple `(entity1, relation, entity2)` or `(entity1, relation, entity2, entity3)`

There may be already an existing model that is similar to this. I would need to check this out.

## Dataset

In [None]:
import torch

from torch.utils.data import Dataset, DataLoader

class RelationLearningDataset(Dataset):
    def __init__(self):
        super().__init__()

    def __getitem__(self, idx):

        # [entity_id, bottom_left_x, bottom_left_y, top_left_x, top_left_y, top_right_x, top_right_y, bottom_right_x, bottom_right_y]
        image = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 0], dtype=torch.float)
        question = torch.tensor([0, 0, 1], dtype=torch.float)
        return image, question

    def __len__(self):
        return 1

loader = DataLoader(RelationLearningDataset(), batch_size=1)

for batch in loader:
    print(batch)

# Tests

## Entities

In [None]:
from PIL import Image, ImageDraw, ImageFont, ImageOps

entity1 = Entity(32, 32, p=(32, 32), theta=30)
entity2 = Entity(32, 32, p=(64, 64), theta=30)

# make a blank image for the text, initialized to transparent text color
base = Image.new("RGBA", (224, 224), (255, 255, 255, 50))
entity1.draw(base)
ImageOps.flip(entity2.draw(base))

## Relations

In [None]:
from ipywidgets import interact

@interact(
    x1=(0, 150),
    y1=(0, 150),
    w1=(10, 150),
    h1=(10, 150),
    theta1=(0.0, 2 * np.pi, 2 * np.pi / 360),
    x2=(0, 150),
    y2=(0, 150),
    w2=(10, 150),
    h2=(10, 150),
    theta2=(0.0, 2 * np.pi, 2 * np.pi / 360),
)
def test_spatial_relations(
    x1=3, y1=30, w1=30, h1=30, theta1=0.0, x2=60, y2=60, w2=30, h2=30, theta2=0.0
):
    base = Image.new("RGBA", (224, 224), (255, 255, 255, 20))
    entity1 = Entity(w1, h1, p=(x1, y1), theta=theta1, name="green", color="green")
    entity2 = Entity(w2, h2, p=(x2, y2), theta=theta2, name="red", color="red")
    entity1.draw(base)
    entity2.draw(base)
    display(ImageOps.flip(base))
    print(entity1)
    print(entity2)
    print(
        "left_of({}, {}): {}".format(
            entity1.name, entity2.name, left_of(entity1, entity2)
        )
    )
    print(
        "right_of({}, {}): {}".format(
            entity1.name, entity2.name, right_of(entity1, entity2)
        )
    )
    print(
        "above({}, {}): {}".format(entity1.name, entity2.name, above(entity1, entity2))
    )
    print(
        "below({}, {}): {}".format(entity1.name, entity2.name, below(entity1, entity2))
    )