### Unit tests for ReaSCAN modules

In [1]:
from collections import namedtuple, OrderedDict
import itertools
import os
import numpy as np
from typing import Tuple
from typing import List
from typing import Dict
import random
from itertools import product
import copy
import re
import random

from utils import one_hot
from utils import generate_possible_object_names
from utils import numpy_array_to_image

from vocabulary import *
from object_vocabulary import *
from world import *
from grammer import *
from simulator import *

Vocabulary

In [2]:
# test out the vocabulary
intransitive_verbs = ["walk"]
transitive_verbs = ["push", "pull"]
adverbs = ["quickly", "slowly", "while zigzagging", "while spinning", "cautiously", "hesitantly"]
nouns = ["circle", "cylinder", "square", "box"]
color_adjectives = ["red", "blue", "green", "yellow"]
size_adjectives = ["big", "small"]
relative_pronouns = ["that is"]
relation_clauses = ["in the same row as", 
                    "in the same column as", 
                    "in the same color as", 
                    "in the same shape as", 
                    "in the same size as",
                    "inside of"]
vocabulary = Vocabulary.initialize(intransitive_verbs=intransitive_verbs,
                                   transitive_verbs=transitive_verbs, adverbs=adverbs, nouns=nouns,
                                   color_adjectives=color_adjectives,
                                   size_adjectives=size_adjectives, 
                                   relative_pronouns=relative_pronouns, 
                                   relation_clauses=relation_clauses)

In [3]:
# color
assert "grey" not in vocabulary.get_color_adjectives()
assert "red" in vocabulary.get_color_adjectives()
assert "blue" in vocabulary.get_color_adjectives()
assert "green" in vocabulary.get_color_adjectives()
assert "yellow" in vocabulary.get_color_adjectives()

# vvs
assert "push" not in vocabulary.get_intransitive_verbs()
assert "walk" in vocabulary.get_intransitive_verbs()

print("== Vocabuary Tests Passed ==")

== Vocabuary Tests Passed ==


Object Vocabulary

In [4]:
min_object_size = 1
max_object_size = 4
object_vocabulary = ObjectVocabulary(shapes=vocabulary.get_semantic_shapes(),
                                     colors=vocabulary.get_semantic_colors(),
                                     min_size=min_object_size, max_size=max_object_size)

In [5]:
object_set = set([])
for obj in object_vocabulary.all_objects:
    object_set.add(obj)
assert (1, 'blue', 'box') in object_set
assert (5, 'blue', 'box') not in object_set

assert object_vocabulary.smallest_size == 1
assert object_vocabulary.largest_size == 4

assert len(object_vocabulary.object_colors) == 4
print("== Object Vocabuary Tests Passed ==")

== Object Vocabuary Tests Passed ==


Grammer

In [6]:
grammer = Grammer(vocabulary)

In [8]:
# root can contains abstract shape!
dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 & $OBJ_2')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_2' in dependency['$OBJ_0']

dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_0' not in dependency['$OBJ_0']

dependency = grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 ^ $OBJ_2')
assert '$OBJ_1' in dependency['$OBJ_0']
assert '$OBJ_2' in dependency['$OBJ_1']

# some harder tests!
relations = grammer.sample_object_relation_grammer(
    '$OBJ_0', 
    grammer.build_dependency_graph('$OBJ_0 ^ $OBJ_1 & $OBJ_2'))

for relation in relations:
    obj_map = relation[0]
    rel_map = relation[1]
    for obj, obj_pattern in obj_map.items():
        for k, v in rel_map.items():
            if k[0] == obj:
                if "$SHAPE" in obj_pattern:
                    assert v != "$SAME_SHAPE"
                if "$COLOR" in obj_pattern:
                    assert v != "$SAME_COLOR"
                if "$SIZE" in obj_pattern:
                    assert v != "$SAME_SIZE"

for relation in relations:
    object_pattern = relation[0]
    rel_map = relation[1]
    for edge, rel in rel_map.items():
        if rel == "$SAME_SHAPE":
            assert "$ABS_SHAPE" in object_pattern[edge[0]]
            assert "$ABS_SHAPE" in object_pattern[edge[1]]
                    
        if rel == "$SAME_COLOR":
            assert "$COLOR" not in object_pattern[edge[0]]
            assert "$COLOR" not in object_pattern[edge[1]]
                
obj_pattern_map = {'$OBJ_0': '$SHAPE', '$OBJ_1': '$SIZE $SHAPE', '$OBJ_2': '$COLOR $SHAPE'}
rel_map = {('$OBJ_0', '$OBJ_1'): '$SAME_ROW', ('$OBJ_0', '$OBJ_2'): '$SAME_COLUMN'}
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2'
command_dictionary = grammer.grounding_grammer_with_vocabulary(grammer_pattern, obj_pattern_map, rel_map)

found = False
for command in command_dictionary:
    assert "red" not in command["$OBJ_0"]
    if "red" in command["$OBJ_2"]:
        found = True
assert found

# Command generation
obj_pattern_map = {'$OBJ_0': '$ABS_SHAPE', '$OBJ_1': '$SHAPE', '$OBJ_2': '$SHAPE'}
rel_map = {('$OBJ_0', '$OBJ_1'): '$SAME_COLUMN', ('$OBJ_0', '$OBJ_2'): '$SAME_ROW'}
obj_map = {'$OBJ_0': 'object', '$OBJ_1': 'circle', '$OBJ_2': 'cylinder'}
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2'
verb = "walk"
adverb = "cautiously"
obj_determiner_map = {'$OBJ_0': 'the', '$OBJ_1': 'a', '$OBJ_2': 'a'}
command_str = grammer.repre_str_command(
    grammer_pattern, rel_map, obj_map, 
    obj_determiner_map, 
    verb,
    adverb,
)
assert "object" in command_str
assert "square" not in command_str
assert "the" in command_str
assert " a " in command_str

# I need to ST to get stable results!
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2'
relations = grammer.sample_object_relation_grammer(
    '$OBJ_0', 
    grammer.build_dependency_graph(grammer_pattern))
command_structs = []
for relation in relations:
    obj_pattern_map = relation[0]
    rel_map = relation[1]
    grammer_bindings = grammer.grounding_grammer_with_vocabulary(grammer_pattern, obj_pattern_map, rel_map)
    for obj_map in grammer_bindings:
        # here, we also sample the verb and adverb bindings!
        
        command_struct = {
            "obj_pattern_map" : obj_pattern_map,
            "rel_map" : rel_map,
            "obj_map" : obj_map,
            "grammer_pattern" : grammer_pattern,
            "adverb" : random.choice(vocabulary.get_adverbs()),
            "verb" : random.choice(vocabulary.get_transitive_verbs() + vocabulary.get_intransitive_verbs()),
        }
        command_structs += [command_struct]

for command_struct in command_structs:
    for edge, rel in command_struct["rel_map"].items():
        if rel == "$SAME_ROW" or rel == "$SAME_COLUMN":
            assert "box" not in command_struct["obj_map"][edge[0]]
            assert "box" not in command_struct["obj_map"][edge[1]]

for command_struct in command_structs:
    if "object" not in command_struct["obj_map"]["$OBJ_0"] and \
        "object" not in command_struct["obj_map"]["$OBJ_1"] and \
        "object" not in command_struct["obj_map"]["$OBJ_2"]:
        assert command_struct["obj_map"]["$OBJ_0"].split(" ")[-1] != command_struct["obj_map"]["$OBJ_1"].split(" ")[-1]
        assert command_struct["obj_map"]["$OBJ_0"].split(" ")[-1] != command_struct["obj_map"]["$OBJ_2"].split(" ")[-1]

# We now don't allow plain form, this is less of a interest!
for command_struct in command_structs:
    for obj_name, obj in command_struct["obj_map"].items():
        if obj_name != "$OBJ_0":
            assert len(obj.split(" ")) >= 2

# Let us test for longer relational clauses!
grammer_pattern = '$OBJ_0 ^ $OBJ_1 & $OBJ_2 & $OBJ_3'
# Sampling relations
relations = grammer.sample_object_relation_grammer(
    '$OBJ_0', 
    grammer.build_dependency_graph(grammer_pattern))

for relation in relations:
    type_set = set([])
    for k, v in relation[1].items():
        type_set.add(v)
    assert len(type_set) == 3
            
print("== Grammer Tests Passed ==")

== Grammer Tests Passed ==


Shape World

In [9]:
world = World(grid_size=6, colors=vocabulary.get_semantic_colors(),
              object_vocabulary=object_vocabulary,
              shapes=vocabulary.get_semantic_shapes(),
              save_directory="./tmp/")

# try to place an object on to the map
world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=2))

# try to place an object on to the map
world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=2))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))

world.clear_situation()
# world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=2, column=3))
world.place_agent_at(Position(row=2, column=2))

# try to place an object on to the map
world.clear_situation()
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=3, column=2))
world.place_agent_at(Position(row=2, column=2))
assert len(world.get_current_situation().placed_objects) == 2

# OOB cases
world = World(grid_size=3, colors=vocabulary.get_semantic_colors(),
              object_vocabulary=object_vocabulary,
              shapes=vocabulary.get_semantic_shapes(),
              save_directory="./tmp/")
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))
assert True # No error is thrown!

world.clear_situation()
is_thrown = False
try:
    world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
except:
    is_thrown = True
assert is_thrown

world.clear_situation()
world = World(grid_size=3, colors=vocabulary.get_semantic_colors(),
              object_vocabulary=object_vocabulary,
              shapes=vocabulary.get_semantic_shapes(),
              save_directory="./tmp/")
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))
all_positions = world.sample_position_complex(condition="normal", box_size=2, sample_one=False)
assert len(all_positions) == 9
all_positions = world.sample_position_complex(condition="box", box_size=3, sample_one=False)
assert len(all_positions) == 1
all_positions = world.sample_position_complex(condition="box", box_size=4, sample_one=False)
assert len(all_positions) == 0

world.clear_situation()
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))
world.place_object(Object(size=2, color="green", shape="square"), position=Position(row=1, column=1))

world.clear_situation()
world.place_object(Object(size=2, color="green", shape="square"), position=Position(row=1, column=1))
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))

# object is non-blocking for box
world.clear_situation()
world.place_object(Object(size=2, color="green", shape="square"), position=Position(row=1, column=1))
all_positions = world.sample_position_complex(condition="box", box_size=2, sample_one=False)
assert len(all_positions) == 4
all_positions = world.sample_position_complex(condition="box", box_size=1, sample_one=False)
assert len(all_positions) == 9

world.clear_situation()
world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))
world.place_object(Object(size=2, color="green", shape="square"), position=Position(row=1, column=1))
all_positions = world.sample_position_complex(condition="box", box_size=2, sample_one=False)
assert len(all_positions) == 3

world = World(grid_size=6, colors=vocabulary.get_semantic_colors(),
              object_vocabulary=object_vocabulary,
              shapes=vocabulary.get_semantic_shapes(),
              save_directory="./tmp/")

for i in range(20000):
    if i%1000==0:
        print(f"ST Test - Passing rate = {i}/{20000}")
    # try to place an object on to the map
    world.clear_situation()
    # world.place_object(Object(size=4, color="green", shape="box"), position=Position(row=3, column=3))
    world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=2, column=2))
    world.place_object(Object(size=2, color="green", shape="box"), position=Position(row=1, column=1))
    world.place_object(Object(size=3, color="red", shape="cylinder"), position=Position(row=3, column=2))
    world.place_agent_at(Position(row=5, column=2))

    verb = "push"
    adverb = "cautiously"

    # Direct walk.
    action = "walk" # this is definit!
    primitive_command = vocabulary.translate_word(action)
    target_position = Position(row=3, column=2)
    # simulator._world.get_current_situation().to_dict()["target_object"].position
    world.go_to_position(position=target_position, manner=adverb, primitive_command=primitive_command)

    # Object actions.
    if True:
        semantic_action = vocabulary.translate_word(verb)
        world.move_object_to_wall(action=semantic_action, manner=adverb)
    target_commands, target_demonstration = world.get_current_observations()

print("== Shape World Tests Passed ==")

ST Test - Passing rate = 0/20000
ST Test - Passing rate = 1000/20000
ST Test - Passing rate = 2000/20000
ST Test - Passing rate = 3000/20000
ST Test - Passing rate = 4000/20000
ST Test - Passing rate = 5000/20000
ST Test - Passing rate = 6000/20000
ST Test - Passing rate = 7000/20000
ST Test - Passing rate = 8000/20000
ST Test - Passing rate = 9000/20000
ST Test - Passing rate = 10000/20000
ST Test - Passing rate = 11000/20000
ST Test - Passing rate = 12000/20000
ST Test - Passing rate = 13000/20000
ST Test - Passing rate = 14000/20000
ST Test - Passing rate = 15000/20000
ST Test - Passing rate = 16000/20000
ST Test - Passing rate = 17000/20000
ST Test - Passing rate = 18000/20000
ST Test - Passing rate = 19000/20000
== Shape World Tests Passed ==


Simulator

In [10]:
simulator = Simulator(
    object_vocabulary, vocabulary, 
    grid_size=6, 
    n_object_max=10,
)

In [11]:
assert simulator.grid_size == 6
assert simulator.n_object_max == 10

sampled_world = simulator.sample_situations_from_grounded_grammer(
    copy.deepcopy(grammer_pattern), 
    copy.deepcopy(obj_pattern_map), 
    copy.deepcopy(rel_map), 
    copy.deepcopy(obj_map),
    is_plot=False,
    include_relation_distractor=False, 
    include_attribute_distractor=False, 
    include_isomorphism_distractor=False, 
    include_random_distractor=False,
    full_relation_probability=0.5,
    debug=False
)

assert sampled_world
assert "obj_map" in sampled_world.keys()
assert len(sampled_world["obj_map"]) == 3
assert len(sampled_world["obj_pattern_map"]) == 3
assert sampled_world["n_random_distractor"] == -1

# I suggest you need to run these lines
# But it will take ~30 mins to finish!
ST_enabled = False
if ST_enabled:
    # Simulator robustness tests.
    random.shuffle(command_structs)
    simulator = Simulator(
        object_vocabulary, vocabulary, 
        grid_size=6, 
        n_object_max=10,
    )
    count = 0
    for test_struct in command_structs[:10000]:
        count += 1
        if count%100==0:
            print(f"passing rate = {count}/{10000}")
        obj_pattern_map = test_struct["obj_pattern_map"]
        rel_map = test_struct["rel_map"]
        obj_map = test_struct["obj_map"]
        grammer_pattern = test_struct["grammer_pattern"]
        verb = test_struct["verb"]
        adverb = test_struct["adverb"]

        test_unique_find = 0
        for i in range(200):
            sampled_world = simulator.sample_situations_from_grounded_grammer(
                copy.deepcopy(grammer_pattern), 
                copy.deepcopy(obj_pattern_map), 
                copy.deepcopy(rel_map), 
                copy.deepcopy(obj_map),
                is_plot=False,
                include_relation_distractor=True, 
                include_attribute_distractor=True, 
                include_isomorphism_distractor=True, 
                include_random_distractor=True,
                full_relation_probability=0.5, # 0.5 seems to work as well!
                debug=False
            )
            assert len(sampled_world['obj_map']) == len(simulator._world.get_current_situation().to_representation()["placed_objects"])
            continue

print("== Simulator Tests Passed ==")

== Simulator Tests Passed ==


In [14]:
# Very long tests
performance_test_enabled = False
if performance_test_enabled:
    # Simulator robustness tests.
    # random.shuffle(command_structs)
    simulator = Simulator(
        object_vocabulary, vocabulary, 
        grid_size=6, 
        n_object_max=10,
    )
    count = 0
    for test_struct in command_structs[:10000]:
        count += 1
        if count%100==0:
            print(f"ST Test - Passing rate = {count}/{10000}")
        obj_pattern_map = test_struct["obj_pattern_map"]
        rel_map = test_struct["rel_map"]
        obj_map = test_struct["obj_map"]
        grammer_pattern = test_struct["grammer_pattern"]
        verb = test_struct["verb"]
        adverb = test_struct["adverb"]

        test_unique_find = 0
        for i in range(200):
            sampled_world = simulator.sample_situations_from_grounded_grammer(
                copy.deepcopy(grammer_pattern), 
                copy.deepcopy(obj_pattern_map), 
                copy.deepcopy(rel_map), 
                copy.deepcopy(obj_map),
                is_plot=False,
                include_relation_distractor=True, 
                include_attribute_distractor=True, 
                include_isomorphism_distractor=False, 
                include_random_distractor=True,
                full_relation_probability=0.5, # 0.5 seems to work as well!
                debug=False
            )
    print("== Simulator ST Tests Passed ==")
else:
    print("== Simulator ST Tests Skipped ==")

== Simulator ST Tests Skipped ==


Pick out tasks randomly, and visually exame them!