In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
import os
from copy import deepcopy

import numpy as np
import torch
from ipywidgets import interact
from munch import Munch
from PIL import Image

from qsr_learning import relation
from qsr_learning.entity import Entity

# Test the Relations

In [None]:
def rotate_bbox(entity1, entity2):
    """
    Return a copy of `entity1`, where the bottom side of its bounding box points to entity2.
    """
    v = (entity1.p + entity1.center) - (entity2.p + entity2.center)
    # The angle between v and the y-axis
    theta = (-np.arctan2(v[0], v[1])) % (2 * math.pi)
    entity1_rotated = deepcopy(entity1)
    entity1_rotated.frame_of_reference = "absolute"
    entity1_rotated.translate(-entity1_rotated.p)
    entity1_rotated.rotate(-theta)
    entity1_rotated.translate(entity1_rotated.p)
    entity1_rotated.frame_of_reference = "intrinsic"
    entity1_rotated.translate(-entity1_rotated.p)
    entity1_rotated.rotate(theta)
    entity1_rotated.translate(entity1_rotated.p)
    return entity1_rotated

def left_of(entity1: Entity, entity2: Entity, entity3: Entity):
    entity2_rotated = rotate_bbox(entity2, entity3)
    return relation.left_of(entity1, entity2_rotated)


def right_of(entity1: Entity, entity2: Entity, entity3: Entity):
    entity2_rotated = rotate_bbox(entity2, entity3)
    return relation.right_of(entity1, entity2_rotated)


def in_front_of(entity1: Entity, entity2: Entity, entity3: Entity):
    entity2_rotated = rotate_bbox(entity2, entity3)
    return relation.below(entity1, entity2_rotated)


def behind(entity1: Entity, entity2: Entity, entity3: Entity):
    entity2_rotated = rotate_bbox(entity2, entity3)
    return relation.above(entity1, entity2_rotated)

In [None]:
@interact(
    x1=(0, 190),
    y1=(0, 190),
    theta1=(0, 360),
    x2=(0, 150),
    y2=(0, 150),
    theta2=(0, 360),
    x3=(0, 150),
    y3=(0, 150),
    theta3=(0, 360),
)
def test_spatial_relations(
    x1=32, y1=32, theta1=0, x2=64, y2=64, theta2=150, x3=128, y3=128, theta3=150
):
    canvas = Image.new("RGBA", (224, 224), (127, 127, 127, 127))

    entity1 = Entity(
        name="octopus",
        frame_of_reference="absolute",
        p=(x1, y1),
        theta=theta1 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity1.draw(canvas, add_bbox=False, add_front=False)

    entity2 = Entity(
        name="trophy",
        frame_of_reference="absolute",
        p=(x2, y2),
        theta=theta2 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity2.draw(canvas, add_bbox=False, add_front=False)

    entity3 = Entity(
        name="lion",
        frame_of_reference="absolute",
        p=(x3, y3),
        theta=theta3 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity3.draw(canvas, add_bbox=False, add_front=False)

    display(canvas)
    if left_of(entity1, entity2, entity3):
        print(entity1.name, "left_of", entity2.name, "as_seen_from", entity3.name)
    if right_of(entity1, entity2, entity3):
        print(entity1.name, "right_of", entity2.name, "as_seen_from", entity3.name)
    if in_front_of(entity1, entity2, entity3):
        print(entity1.name, "in_front_of", entity2.name, "as_seen_from", entity3.name)
    if behind(entity1, entity2, entity3):
        print(entity1.name, "behind", entity2.name, "as_seen_from", entity3.name)

# Test the Dataset (Relative Frame of Reference)

In [None]:
import math

import torch
from ipywidgets import interact
from PIL import Image
from torch.utils.data import random_split

from qsr_learning.data_relative import DRLDataset
from qsr_learning.entity import emoji_names

entity_names = ["octopus", "trophy", "lion", "cat face", "cigarette"]
relation_names = ["left_of", "right_of", "in_front_of", "behind"]
dataset = DRLDataset(
    vocab=sorted(entity_names + relation_names),
    entity_names=entity_names,
    excluded_entity_names=[],
    relation_names=relation_names,
    excluded_relation_names=[],
    num_entities=3,
    frame_of_reference="relative",
    w_range=(16, 16),
    h_range=(16, 16),
    theta_range=(0, 2 * math.pi),
    add_bbox=False,
    add_front=False,
    transform=None,
    canvas_size=(128, 128),
    num_samples=1000,
)

In [None]:
@interact(idx=(0, len(dataset) - 1))
def display_sample(idx=0):

    image_t, question_t, answer_t = dataset[idx]
    image = Image.fromarray(
        (255 * (dataset.std.view(-1, 1, 1) * image_t + dataset.mean.view(-1, 1, 1)))
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    entity1, relation, entity2, entity3 = question_t.tolist()
    question = (
        dataset.idx2word[entity1],
        dataset.idx2word[relation],
        dataset.idx2word[entity2],
        "as_seen_from",
        dataset.idx2word[entity3],        
    )
    answer = bool(answer_t)
    display(image)
    print(" ".join(question))
    print("Ground truth: ", answer)

In [None]:
import qsr_learning