### Unit tests for ReaSCAN modules

In [1]:
from collections import namedtuple, OrderedDict
import itertools
import os
import numpy as np
from typing import Tuple
from typing import List
from typing import Dict
import random
from itertools import product
import copy
import re
import random

from utils import one_hot
from utils import generate_possible_object_names
from utils import numpy_array_to_image

from vocabulary import *
from object_vocabulary import *
from world import *
from grammer import *
from simulator import *

Vocabulary

In [2]:
# test out the vocabulary
intransitive_verbs = ["walk"]
transitive_verbs = ["push", "pull"]
adverbs = ["quickly", "slowly", "while zigzagging", "while spinning", "cautiously", "hesitantly"]
nouns = ["circle", "cylinder", "square", "box"]
color_adjectives = ["red", "blue", "green", "yellow"]
size_adjectives = ["big", "small"]
relative_pronouns = ["that is"]
relation_clauses = ["in the same row as", 
                    "in the same column as", 
                    "in the same color as", 
                    "in the same shape as", 
                    "in the same size as",
                    "inside of"]
vocabulary = Vocabulary.initialize(intransitive_verbs=intransitive_verbs,
                                   transitive_verbs=transitive_verbs, adverbs=adverbs, nouns=nouns,
                                   color_adjectives=color_adjectives,
                                   size_adjectives=size_adjectives, 
                                   relative_pronouns=relative_pronouns, 
                                   relation_clauses=relation_clauses)

In [3]:
# color
assert "grey" not in vocabulary.get_color_adjectives()
assert "red" in vocabulary.get_color_adjectives()
assert "blue" in vocabulary.get_color_adjectives()
assert "green" in vocabulary.get_color_adjectives()
assert "yellow" in vocabulary.get_color_adjectives()

# vvs
assert "push" not in vocabulary.get_intransitive_verbs()
assert "walk" in vocabulary.get_intransitive_verbs()

print("== Vocabuary Tests Passed ==")

== Vocabuary Tests Passed ==


Object Vocabulary

In [4]:
min_object_size = 1
max_object_size = 4
object_vocabulary = ObjectVocabulary(shapes=vocabulary.get_semantic_shapes(),
                                     colors=vocabulary.get_semantic_colors(),
                                     min_size=min_object_size, max_size=max_object_size)

In [5]:
object_set = set([])
for obj in object_vocabulary.all_objects:
    object_set.add(obj)
assert (1, 'blue', 'box') in object_set
assert (5, 'blue', 'box') not in object_set

assert object_vocabulary.smallest_size == 1
assert object_vocabulary.largest_size == 4

assert len(object_vocabulary.object_colors) == 4
print("== Object Vocabuary Tests Passed ==")

== Object Vocabuary Tests Passed ==


Grammer

In [6]:
grammer = Grammer(vocabulary)

In [7]:
# root can contains abstract shape!
root_shape = grammer._sample_object_pattern(root=True)
assert '$ABS_SHAPE' in root_shape

other_shape = grammer._sample_object_pattern(root=False)
assert '$ABS_SHAPE' not in other_shape

dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 & $OBJ_2')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_2' in dependency['$OBJ_0']

dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_0' not in dependency['$OBJ_0']

dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 ^ $OBJ_2')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_2' in dependency['$OBJ_1']

# some harder tests!
relations = grammer.sample_object_relation_grammer(
    '$OBJ_0', 
    grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 & $OBJ_2'))

for relation in relations:
    obj_map = relation[0]
    rel_map = relation[1]
    for obj, obj_pattern in obj_map.items():
        for k, v in rel_map.items():
            if k[0] == obj:
                if "$SHAPE" in obj_pattern:
                    assert v != "$SAME_SHAPE"
                if "$COLOR" in obj_pattern:
                    assert v != "$SAME_COLOR"
                if "$SIZE" in obj_pattern:
                    assert v != "$SAME_SIZE"

obj_pattern_map = {'$OBJ_0': '$SHAPE', '$OBJ_1': '$SIZE $SHAPE', '$OBJ_2': '$COLOR $SHAPE'}
rel_map = {('$OBJ_0', '$OBJ_1'): '$SAME_ROW', ('$OBJ_0', '$OBJ_2'): '$SAME_COLUMN'}
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2'
command_dictionary = grammer.grounding_grammer_with_vocabulary(grammer_pattern, obj_pattern_map, rel_map)

found = False
for command in command_dictionary:
    assert "red" not in command["$OBJ_0"]
    if "red" in command["$OBJ_2"]:
        found = True
assert found

# Command generation
obj_pattern_map = {'$OBJ_0': '$ABS_SHAPE', '$OBJ_1': '$SHAPE', '$OBJ_2': '$SHAPE'}
rel_map = {('$OBJ_0', '$OBJ_1'): '$SAME_COLUMN', ('$OBJ_0', '$OBJ_2'): '$SAME_ROW'}
obj_map = {'$OBJ_0': 'object', '$OBJ_1': 'circle', '$OBJ_2': 'cylinder'}
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2'
verb = "walk"
adverb = "cautiously"
obj_determiner_map = {'$OBJ_0': 'the', '$OBJ_1': 'a', '$OBJ_2': 'a'}
command_str = grammer.repre_str_command(
    grammer_pattern, rel_map, obj_map, 
    obj_determiner_map, 
    verb,
    adverb,
)
assert "object" in command_str
assert "square" not in command_str
assert "the" in command_str
assert " a " in command_str
print("== Grammer Tests Passed ==")

== Grammer Tests Passed ==


Shape World

In [12]:
world = World(grid_size=6, colors=vocabulary.get_semantic_colors(),
              object_vocabulary=object_vocabulary,
              shapes=vocabulary.get_semantic_shapes(),
              save_directory="./tmp/")

# try to place an object on to the map
world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=2))

# try to place an object on to the map
world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=2))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))

world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=3))
world.place_agent_at(Position(row=2, column=2))

# try to place an object on to the map
world.clear_situation()
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=3, column=2))
world.place_agent_at(Position(row=2, column=2))
assert len(world.get_current_situation().placed_objects) == 2

print("== Shape World Tests Passed ==")

== Shape World Tests Passed ==


Simulator

In [8]:
simulator = Simulator(
    object_vocabulary, vocabulary, 
    grid_size=6, 
    n_object_max=10,
)

In [18]:
assert simulator.grid_size == 6
assert simulator.n_object_max == 10

sampled_world = simulator.sample_situations_from_grounded_grammer(
    copy.deepcopy(grammer_pattern), 
    copy.deepcopy(obj_pattern_map), 
    copy.deepcopy(rel_map), 
    copy.deepcopy(obj_map),
    is_plot=False,
    include_relation_distractor=False, 
    include_attribute_distractor=False, 
    include_isomorphism_distractor=False, 
    include_random_distractor=False,
    full_relation_probability=0.5,
    debug=False
)

assert sampled_world
assert "obj_map" in sampled_world.keys()
assert len(sampled_world["obj_map"]) == 3
assert len(sampled_world["obj_pattern_map"]) == 3
assert sampled_world["n_random_distractor"] == -1
print("== Simulator Tests Passed ==")

== Simulator Tests Passed ==


Relational Graph