In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
import random
from copy import deepcopy
from itertools import product
from typing import Callable, List, Tuple

import numpy as np
from munch import Munch
from qsr_learning.entity import Entity
from qsr_learning.relation import above, below, left_of, right_of

In [None]:
from qsr_learning.entity import emoji_names

In [None]:
def inside_canvas(entity: Entity, canvas_size: Tuple[int, int]) -> bool:
    """Check whether entity is in the canvas."""
    xs_inside_canvas = all(
        (0 < entity.bbox[:, 0]) & (entity.bbox[:, 0] < canvas_size[0])
    )
    ys_inside_canvas = all(
        (0 < entity.bbox[:, 1]) & (entity.bbox[:, 1] < canvas_size[1])
    )
    return xs_inside_canvas and ys_inside_canvas


def in_relation(
    entity1: Entity, entity2: Entity, relations: List[Callable[[Entity, Entity], bool]]
) -> bool:
    """Check whether entity1 and entity2 satisfy any of the given relations."""
    return any(relation(entity1, entity2) for relation in relations)

In [None]:
def generate_entities(
    emoji_names,
    num_entities: int = 5,
    absolute_direction=False,
    w_range: Tuple[int, int] = (10, 30),
    h_range: Tuple[int, int] = (10, 30),
    canvas_size: Tuple[int, int] = (224, 224),
    relations: List[Callable[[Entity, Entity], bool]] = [
        left_of,
        right_of,
        above,
        below,
    ],
) -> List[Entity]:
    """
    :param canvas_size: (width, height)
    """

    emoji_names_copy = deepcopy(emoji_names)
    np.random.shuffle(emoji_names_copy)

    # Rotate and translate the entities.
    entities_in_canvas = entities_in_relation = False
    while not (entities_in_canvas and entities_in_relation):
        entities = []
        for emoji_name in emoji_names_copy[:num_entities]:
            theta = random.uniform(0.0, 2 * math.pi)
            p = (random.uniform(0, canvas_size[0]), random.uniform(0, canvas_size[1]))
            entity = Entity(
                name=emoji_name,
                absolute_direction=absolute_direction,
                p=p,
                theta=theta,
                size=(random.randint(*w_range), (random.randint(*h_range))),
            )
            entities.append(entity)
        # Avoid boundary cases
        entities_in_canvas = all(
            inside_canvas(entity, canvas_size) for entity in entities
        )
        entities_in_relation = all(
            in_relation(entity1, entity2, relations)
            for entity1, entity2 in product(entities, repeat=2)
            if entity1 != entity2
        )
    return entities

In [None]:
from PIL import Image

In [None]:
def generate_positive_examples(
    entities, relations, size=1
) -> List[Tuple[str, str, str]]:
    """
    Generate positive examples from a list of entities.

    :param entities: a list of entities
    :param size: the number of positive examples to be generated

    :returns: a list of triples (entity1, relation, entity2)
    """
    assert size <= math.factorial(len(entities))
    all_positive_examples = []
    for entity1, entity2 in product(entities, entities):
        if entity1 != entity2:
            for rel in relations:
                if rel(entity1, entity2):
                    all_positive_examples.append(
                        (
                            entity1.name,
                            rel.__name__,
                            entity2.name,
                        )
                    )
    np.random.shuffle(all_positive_examples)
    return all_positive_examples[:size]


def generate_one_negative_example(
    entity_names, relation_names, positive_examples, negative_sample_type="relation"
):
    """
    Generate negative examples from a list of objects.

    :param objects: a list of objects
    :param size: the number of positive examples to be generated

    :returns: a list of triples (object1, relation, object2)
    """
    head, relation, tail = random.choice(list(positive_examples))
    if negative_sample_type == "head":
        entity_names = entity_names - {head, tail}
        head = random.choice(list(entity_names))
    elif negative_sample_type == "relation":
        relation_names = relation_names - {relation}
        relation = random.choice(list(relation_names))
    elif negative_sample_type == "tail":
        entity_names = entity_names - {head, tail}
        tail = random.choice(list(entity_names))
    else:
        raise ValueError
    negative_example = (head, relation, tail)
    return negative_example


def generate_negative_examples(
    entities,
    relations,
    positive_examples,
    size=None,
    mixture=Munch(head=1, relation=1, tail=1),
):
    if not size:
        size = len(positive_examples)
    entity_names = {obj.name for obj in entities}
    relation_names = {rel.__name__ for rel in relations}
    negative_examples = set()
    negative_sample_types = list(mixture.keys())
    p = np.array(list(mixture.values()))
    p = p / p.sum()
    while len(negative_examples) < size:
        negative_sample_type = np.random.choice(negative_sample_types, p=p)
        negative_example = generate_one_negative_example(
            entity_names, relation_names, positive_examples, negative_sample_type
        )
        negative_examples.add(negative_example)
    return negative_examples

In [None]:
def generate_examples(entities, relations, size=None) -> List[Tuple[str, str, str]]:
    """
    Generate positive examples from a list of entities.

    :param entities: a list of entities
    :param size: the number of examples to be generated

    :returns: a list of triples (entity1, relation, entity2)
    """
    # Generate positive examples
    positive_examples = []
    negative_examples = []
    for entity1, entity2 in product(entities, entities):
        if entity1 != entity2:
            for rel in relations:
                if rel(entity1, entity2):
                    positive_examples.append(
                        (
                            entity1.name,
                            rel.__name__,
                            entity2.name,
                        )
                    )
                else:
                    negative_examples.append(
                        (
                            entity1.name,
                            rel.__name__,
                            entity2.name,
                        )
                    )
    return positive_examples, negative_examples

In [None]:
from PIL import ImageDraw

In [None]:
def draw_entities(
    entities, canvas_size=(224, 224), show_bbox=True, orientation_marker=False
):
    canvas = Image.new("RGBA", canvas_size, (0, 0, 0, 255))
    d = ImageDraw.Draw(canvas)
    d.polygon(
        [
            (0, 0),
            (0, canvas_size[1] - 1),
            (canvas_size[0] - 1, canvas_size[1] - 1),
            (canvas_size[0] - 1, 0),
        ],
        fill=None,
        outline="white",
    )
    for entity in entities:
        entity.draw(canvas, show_bbox=show_bbox, orientation_marker=orientation_marker)
    return canvas

In [None]:
face_emojis = [
    emoji_name for emoji_name in emoji_names if emoji_name.endswith("face")
]

In [None]:
entities = generate_entities(
    emoji_names=emoji_names,  # ["octopus", "trophy"],
    num_entities=2,
    absolute_direction=True,
    w_range=(16, 64),
    h_range=(16, 64),
)
relations = [right_of, left_of, above, below]
image = draw_entities(entities, show_bbox=True, orientation_marker=False)
display(image)
image.save("tmp.png")
positive_examples, negative_examples = generate_examples(entities, relations)
display(positive_examples)
display(negative_examples)