From a27254140cf8b756543487cc415e3312c0c89a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:30:53 -0400 Subject: [PATCH 1/9] Adding the quest_breadth option when making a game --- scripts/tw-make | 6 ++++-- tests/test_make_game.py | 8 ++++---- tests/test_play_generated_games.py | 7 ++++--- textworld/envs/wrappers/tests/test_viewer.py | 2 +- textworld/generator/tests/test_logger.py | 2 +- textworld/helpers.py | 5 +++-- 6 files changed, 17 insertions(+), 13 deletions(-) diff --git a/scripts/tw-make b/scripts/tw-make index 938886d4..736751b2 100755 --- a/scripts/tw-make +++ b/scripts/tw-make @@ -46,6 +46,8 @@ def parse_args(): help="Nb. of objects in the world.") custom_parser.add_argument("--quest-length", type=int, default=5, metavar="LENGTH", help="Minimum nb. of actions the quest requires to be completed.") + custom_parser.add_argument("--quest-breadth", type=int, default=3, metavar="BREADTH", + help="Control how non-linear a quest can be.") challenge_parser = subparsers.add_parser("challenge", parents=[general_parser], help='Generate a game for one of the challenges.') @@ -72,7 +74,7 @@ if __name__ == "__main__": } if args.subcommand == "custom": - game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, grammar_flags, + game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, args.quest_breadth, grammar_flags, seed=args.seed, games_dir=args.output) elif args.subcommand == "challenge": @@ -87,7 +89,7 @@ if __name__ == "__main__": print("Game generated: {}".format(game_file)) if args.verbose: - print(game.quests[0].desc) + print(game.objective) if args.view: textworld.render.visualize(game, interactive=True) diff --git a/tests/test_make_game.py b/tests/test_make_game.py index 6a6c6263..b8b9be69 100644 --- a/tests/test_make_game.py +++ b/tests/test_make_game.py @@ -11,11 +11,11 @@ def test_making_game_with_names_to_exclude(): g_rng.set_seed(42) with make_temp_directory(prefix="test_render_wrapper") as tmpdir: - game_file1, game1 = textworld.make(2, 20, 3, {"names_to_exclude": []}, + game_file1, game1 = textworld.make(2, 20, 3, 3, {"names_to_exclude": []}, seed=123, games_dir=tmpdir) game1_objects_names = [info.name for info in game1.infos.values() if info.name is not None] - game_file2, game2 = textworld.make(2, 20, 3, {"names_to_exclude": game1_objects_names}, + game_file2, game2 = textworld.make(2, 20, 3, 3, {"names_to_exclude": game1_objects_names}, seed=123, games_dir=tmpdir) game2_objects_names = [info.name for info in game2.infos.values() if info.name is not None] assert len(set(game1_objects_names) & set(game2_objects_names)) == 0 @@ -24,8 +24,8 @@ def test_making_game_with_names_to_exclude(): def test_making_game_is_reproducible_with_seed(): grammar_flags = {} with make_temp_directory(prefix="test_render_wrapper") as tmpdir: - game_file1, game1 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir) - game_file2, game2 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir) + game_file1, game1 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir) + game_file2, game2 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir) assert game_file1 == game_file2 assert game1 == game2 # Make sure they are not the same Python objects. diff --git a/tests/test_play_generated_games.py b/tests/test_play_generated_games.py index 96374eec..562892f1 100644 --- a/tests/test_play_generated_games.py +++ b/tests/test_play_generated_games.py @@ -16,12 +16,13 @@ def test_play_generated_games(): # Sample game specs. world_size = rng.randint(1, 10) nb_objects = rng.randint(0, 20) - quest_length = rng.randint(1, 10) + quest_length = rng.randint(1, 5) + quest_breadth = rng.randint(3, 7) game_seed = rng.randint(0, 65365) grammar_flags = {} # Default grammar. with make_temp_directory(prefix="test_play_generated_games") as tmpdir: - game_file, game = textworld.make(world_size, nb_objects, quest_length, grammar_flags, + game_file, game = textworld.make(world_size, nb_objects, quest_length, quest_breadth, grammar_flags, seed=game_seed, games_dir=tmpdir) # Solve the game using WalkthroughAgent. @@ -47,7 +48,7 @@ def test_play_generated_games(): game_state, reward, done = env.step(command) if done: - msg = "Finished before playing `max_steps` steps." + msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command) if game_state.has_won: msg += " (winning)" assert game_state._game_progression.winning_policy == [] diff --git a/textworld/envs/wrappers/tests/test_viewer.py b/textworld/envs/wrappers/tests/test_viewer.py index ac40fbb0..32a4c4fe 100644 --- a/textworld/envs/wrappers/tests/test_viewer.py +++ b/textworld/envs/wrappers/tests/test_viewer.py @@ -17,7 +17,7 @@ def test_html_viewer(): num_items = 10 g_rng.set_seed(1234) grammar_flags = {"theme": "house", "include_adj": True} - game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, grammar_flags=grammar_flags) + game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, quest_breadth=1, grammar_flags=grammar_flags) game_name = "test_html_viewer_wrapper" with make_temp_directory(prefix=game_name) as tmpdir: diff --git a/textworld/generator/tests/test_logger.py b/textworld/generator/tests/test_logger.py index 2a5d1ae0..9d5eed4f 100644 --- a/textworld/generator/tests/test_logger.py +++ b/textworld/generator/tests/test_logger.py @@ -18,7 +18,7 @@ def test_logger(): for _ in range(10): seed = rng.randint(65635) g_rng.set_seed(seed) - game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3) + game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, quest_breadth=3) game_logger.collect(game) with make_temp_directory(prefix="textworld_tests") as tests_folder: diff --git a/textworld/helpers.py b/textworld/helpers.py index 6b5b3620..4c4ee7a7 100644 --- a/textworld/helpers.py +++ b/textworld/helpers.py @@ -119,7 +119,7 @@ def play(game_file: str, agent: Optional[Agent] = None, max_nb_steps: int = 1000 print(msg) -def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2, +def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2, quest_breadth: int = 1, grammar_flags: Mapping = {}, seed: int = None, games_dir: str = "./gen_games/") -> Tuple[str, Game]: """ Makes a text-based game. @@ -128,6 +128,7 @@ def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2, world_size: Number of rooms in the world. nb_objects: Number of objects in the world. quest_length: Minimum number of actions the quest requires to be completed. + quest_breadth: Control how nonlinear a quest can be (1: linear). grammar_flags: Grammar options. seed: Random seed for the game generation process. games_dir: Path to the directory where the game will be saved. @@ -137,6 +138,6 @@ def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2, """ g_rng.set_seed(seed) game_name = "game_{}".format(seed) - game = make_game(world_size, nb_objects, quest_length, grammar_flags) + game = make_game(world_size, nb_objects, quest_length, quest_breadth, grammar_flags) game_file = compile_game(game, game_name, games_folder=games_dir, force_recompile=True) return game_file, game From eaa504d00e32e28149466d258947e50dc45dac67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:32:49 -0400 Subject: [PATCH 2/9] Typo in text grammar. --- textworld/generator/data/text_grammars/house_instruction.twg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/textworld/generator/data/text_grammars/house_instruction.twg b/textworld/generator/data/text_grammars/house_instruction.twg index e88605e8..7888718d 100644 --- a/textworld/generator/data/text_grammars/house_instruction.twg +++ b/textworld/generator/data/text_grammars/house_instruction.twg @@ -112,7 +112,7 @@ action_seperator_go/north: #afterhave# gone north, ;#emptyinstruction1#;#emptyin action_seperator_go/east: #afterhave# gone east, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10# action_seperator_go/west: #afterhave# gone west, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10# action_separator_close: #afterhave# #closed# the #close_open_types#, ; #after# #closing# the #close_open_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10# -action_separator_drop: #afterhave# #dropped# #obj_types#, ; #after# #dropping# #obj_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10# +action_separator_drop: #afterhave# #dropped# the #obj_types#, ; #after# #dropping# the #obj_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10# #Separator Symbols afterhave:After you have;Having;Once you have;If you have havetaken:taken;got;picked up From c89140028386f8f9cacf3c2066bcb668a970f307 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:34:19 -0400 Subject: [PATCH 3/9] Add contraints to prevent adding a door when there is no free exit --- textworld/generator/data/logic/door.twl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/textworld/generator/data/logic/door.twl b/textworld/generator/data/logic/door.twl index c8a56384..54613ab5 100644 --- a/textworld/generator/data/logic/door.twl +++ b/textworld/generator/data/logic/door.twl @@ -34,7 +34,13 @@ type d : t { link3 :: link(r, d, r') & link(r, d', r') -> fail(); # There cannot be more than four doors in a room. - dr2 :: link(r, d1: d, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail(); + too_many_doors :: link(r, d1: d, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail(); + + # There cannot be more than four doors in a room. + dr1 :: free(r, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail(); + dr2 :: free(r, r1: r) & free(r, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail(); + dr3 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail(); + dr4 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & free(r, r4: r) & link(r, d5: d, r5: r) -> fail(); free1 :: link(r, d, r') & free(r, r') & closed(d) -> fail(); free2 :: link(r, d, r') & free(r, r') & locked(d) -> fail(); From 5af220a5bc4397c44e4532c336aaf141c0ce724a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:38:20 -0400 Subject: [PATCH 4/9] Objective is now part of the Game object. --- scripts/tw-stats | 2 +- textworld/envs/glulx/git_glulx_ml.py | 7 +++---- textworld/generator/text_generation.py | 6 ++++++ textworld/render/render.py | 5 +---- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/scripts/tw-stats b/scripts/tw-stats index 21145ec9..dbef4559 100755 --- a/scripts/tw-stats +++ b/scripts/tw-stats @@ -38,7 +38,7 @@ if __name__ == "__main__": continue if len(game.quests) > 0: - objectives[game_filename] = game.quests[0].desc + objectives[game_filename] = game.objective names |= set(info.name for info in game.infos.values() if info.name is not None) game_logger.collect(game) diff --git a/textworld/envs/glulx/git_glulx_ml.py b/textworld/envs/glulx/git_glulx_ml.py index 5fab75c6..9bd3a4b4 100644 --- a/textworld/envs/glulx/git_glulx_ml.py +++ b/textworld/envs/glulx/git_glulx_ml.py @@ -149,10 +149,7 @@ def init(self, output: str, game=None, self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward) self._state_tracking = state_tracking self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0 - - self._objective = "" - if len(game.quests) > 0: - self._objective = game.quests[0].desc + self._objective = game.objective def view(self) -> "GlulxGameState": """ @@ -317,6 +314,7 @@ def intermediate_reward(self): @property def score(self): + # XXX: Should the score reflect the sum of all subquests' reward? if self.has_won: return 1 elif self.has_lost: @@ -326,6 +324,7 @@ def score(self): @property def max_score(self): + # XXX: Should the score reflect the sum of all subquests' reward? return 1 @property diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py index 89b63c16..6d594435 100644 --- a/textworld/generator/text_generation.py +++ b/textworld/generator/text_generation.py @@ -6,6 +6,7 @@ from collections import OrderedDict from textworld.generator import data +from textworld.generator.game import Quest from textworld.generator.text_grammar import Grammar from textworld.generator.text_grammar import fix_determinant @@ -155,6 +156,11 @@ def generate_text_from_grammar(game, grammar: Grammar): for quest in game.quests: assign_description_to_quest(quest, game, grammar, counts, only_last_action, blend_instructions, ambiguous_instructions) + if only_last_action and len(game.quests) > 1: + main_quest = Quest(actions=[quest.actions[-1] for quest in game.quests]) + assign_description_to_quest(main_quest, game, grammar, counts, False, blend_instructions, ambiguous_instructions) + game.objective = main_quest.desc + return game diff --git a/textworld/render/render.py b/textworld/render/render.py index 50da8ee0..58e63ab9 100644 --- a/textworld/render/render.py +++ b/textworld/render/render.py @@ -222,7 +222,6 @@ def used_pos(): if game_infos is None: new_game = Game(world, []) game_infos = new_game.infos - game_infos["objective"] = new_game.quests[0].desc for k, v in game_infos.items(): if v.name is None: v.name = k @@ -354,9 +353,7 @@ def visualize(world: Union[Game, State, GlulxGameState, World], if isinstance(world, Game): game = world state = load_state(game.world, game.infos) - state["objective"] = "" - if len(game.quests) > 0: - state["objective"] = game.quests[0].desc + state["objective"] = game.objective elif isinstance(world, GlulxGameState): state = load_state_from_game_state(game_state=world) elif isinstance(world, World): From 7b987254940883603493c7b4b18cbc95c1f2aac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:41:00 -0400 Subject: [PATCH 5/9] Support game.quests with more than one quest --- tests/test_textworld.py | 2 +- textworld/agents/walkthrough.py | 2 +- textworld/generator/__init__.py | 29 +++++++- textworld/generator/inform7/world2inform7.py | 70 +++++++++++++++----- 4 files changed, 83 insertions(+), 20 deletions(-) diff --git a/tests/test_textworld.py b/tests/test_textworld.py index a0165e8a..b4fa918f 100644 --- a/tests/test_textworld.py +++ b/tests/test_textworld.py @@ -58,7 +58,7 @@ def test_game_walkthrough_agent(self): agent = textworld.agents.WalkthroughAgent() env = textworld.start(self.game_file) env.activate_state_tracking() - commands = self.game.quests[0].commands + commands = self.game.quests[-1].commands agent.reset(env) game_state = env.reset() diff --git a/textworld/agents/walkthrough.py b/textworld/agents/walkthrough.py index 7b046b86..e75ed08c 100644 --- a/textworld/agents/walkthrough.py +++ b/textworld/agents/walkthrough.py @@ -26,7 +26,7 @@ def reset(self, env): raise NameError(msg) # Load command from the generated game. - self._commands = iter(env.game.quests[0].commands) + self._commands = iter(env.game.quests[-1].commands) def act(self, game_state, reward, done): try: diff --git a/textworld/generator/__init__.py b/textworld/generator/__init__.py index 7db5d7da..514ae963 100644 --- a/textworld/generator/__init__.py +++ b/textworld/generator/__init__.py @@ -147,7 +147,7 @@ def make_game_with(world, quests=None, grammar=None): return game -def make_game(world_size: int, nb_objects: int, quest_length: int, +def make_game(world_size: int, nb_objects: int, quest_length: int, quest_breadth: int, grammar_flags: Mapping = {}, rngs: Optional[Dict[str, RandomState]] = None ) -> Game: @@ -158,6 +158,7 @@ def make_game(world_size: int, nb_objects: int, quest_length: int, world_size: Number of rooms in the world. nb_objects: Number of objects in the world. quest_length: Minimum nb. of actions the quest requires to be completed. + quest_breadth: How many branches the quest can have. grammar_flags: Options for the grammar. Returns: @@ -175,14 +176,34 @@ def make_game(world_size: int, nb_objects: int, quest_length: int, world = make_world(world_size, nb_objects=0, rngs=rngs) # Sample a quest according to quest_length. - options = ChainingOptions() + class Options(ChainingOptions): + + def get_rules(self, depth): + if depth == 0: + # Last action should not be "go ". + return data.get_rules().get_matching("^(?!go.*).*") + else: + return super().get_rules(depth) + + options = Options() options.backward = True + options.min_depth = quest_length options.max_depth = quest_length + options.min_breadth = quest_breadth + options.max_breadth = quest_breadth options.create_variables = True options.rng = rngs['rng_quest'] options.restricted_types = {"r", "d"} chain = sample_quest(world.state, options) + + subquests = [] + for i in range(1, len(chain.nodes)): + if chain.nodes[i].breadth != chain.nodes[i - 1].breadth: + quest = Quest(chain.actions[:i]) + subquests.append(quest) + quest = Quest(chain.actions) + subquests.append(quest) # Set the initial state required for the quest. world.state = chain.initial_state @@ -191,7 +212,9 @@ def make_game(world_size: int, nb_objects: int, quest_length: int, world.populate(nb_objects, rng=rngs['rng_objects']) grammar = make_grammar(grammar_flags, rng=rngs['rng_grammar']) - game = make_game_with(world, [quest], grammar) + game = make_game_with(world, subquests, grammar) + game.change_grammar(grammar) + return game diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py index e4858744..2223e373 100644 --- a/textworld/generator/inform7/world2inform7.py +++ b/textworld/generator/inform7/world2inform7.py @@ -156,7 +156,6 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): quests = game.quests source = "" - source += "Use scoring. The maximum score is 1.\n" source += "When play begins, seed the random-number generator with {}.\n\n".format(seed) source += define_inform7_kinds() # Mention that rooms have a special text attribute called 'internal name'. @@ -221,22 +220,30 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): # Place the player. source += "The player is in {}.\n\n".format(var_infos[world.player_room.id].id) - quest = None - if len(quests) > 0: - quest = quests[0] # TODO: randomly sample a quest. + objective = game.objective + maximum_score = 0 + for quest_id, quest in enumerate(quests): commands = gen_commands_from_actions(quest.actions, var_infos) quest.commands = commands + maximum_score += quest.reward + + quest_completed = textwrap.dedent("""\ + The quest{quest_id} completed is a truth state that varies. + The quest{quest_id} completed is usually false. + """) + source += quest_completed.format(quest_id=quest_id) - walkthrough = '\nTest me with "{}"\n\n'.format(" / ".join(commands)) + walkthrough = '\nTest quest{} with "{}"\n\n'.format(quest_id, " / ".join(commands)) source += walkthrough - # Add winning and losing conditions. - ending_condition = """\ + # Add winning and losing conditions for quest. + quest_ending_condition = """\ Every turn: - if {}: + if {losing_tests}: end the story; [Lost] - else if {}: - end the story finally; [Win] + else if quest{quest_id} completed is false and {winning_tests}: + increase the score by {reward}; [Quest completed] + Now the quest{quest_id} completed is true. """ @@ -246,8 +253,41 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): if quest.fail_action is not None: losing_tests = gen_source_for_conditions(quest.fail_action.preconditions) - ending_condition = ending_condition.format(losing_tests, winning_tests) - source += textwrap.dedent(ending_condition) + quest_ending_condition = quest_ending_condition.format(losing_tests=losing_tests, + winning_tests=winning_tests, + reward=quest.reward, + quest_id=quest_id) + source += textwrap.dedent(quest_ending_condition) + + # Enable scoring is at least one quest has nonzero reward. + if maximum_score != 0: + source += "Use scoring. The maximum score is {}.\n".format(maximum_score) + + # Build test condition for winning the game. + game_winning_test = "1 is 0 [always false]" + if len(quests) > 0: + test_template = "quest{} completed is true" + game_winning_test = " and ".join(test_template.format(i) for i in range(len(quests))) + + # Remove square bracket when printing score increases. Square brackets are conflicting with + # Inform7's events parser in git_glulx_ml.py. + # And add winning conditions for the game. + source += textwrap.dedent("""\ + This is the simpler notify score changes rule: + If the score is not the last notified score: + let V be the score - the last notified score; + say "Your score has just gone up by [V in words] "; + if V > 1: + say "points."; + else: + say "point."; + Now the last notified score is the score; + if {game_winning_test}: + end the story finally; [Win] + + The simpler notify score changes rule substitutes for the notify score changes rule. + + """.format(game_winning_test=game_winning_test)) if not use_i7_description: # Remove Inform7 listing of nondescript items. @@ -292,7 +332,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): Rule for printing the banner text: say "{objective}[line break]". - """.format(objective=quest.desc if quest is not None else "")) + """.format(objective=objective)) # Simply display *** The End *** when game ends. source += textwrap.dedent("""\ @@ -300,7 +340,6 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): Rule for printing the player's obituary: if story has ended finally: - increase score by 1; center "*** The End ***"; else: center "*** You lost! ***"; @@ -451,7 +490,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): source += textwrap.dedent("""\ An objective is some text that varies. The objective is "{objective}". - """.format(objective=quest.desc if quest is not None else "")) + """.format(objective=objective)) # Special command to print the objective of the game, if any. source += textwrap.dedent("""\ @@ -618,6 +657,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False): Turning on the restrict commands option is an action applying to nothing. Carry out turning on the restrict commands option: + Decrease turn count by 1; Now the restrict commands option is true. Understand "restrict commands" as turning on the restrict commands option. From c99177918931fcefe39e4751b933a1d15f6a86dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:41:53 -0400 Subject: [PATCH 6/9] DependencyTree can now support more than one root --- textworld/generator/dependency_tree.py | 95 ++++++++++++------- .../generator/tests/test_dependency_tree.py | 39 ++++---- 2 files changed, 82 insertions(+), 52 deletions(-) diff --git a/textworld/generator/dependency_tree.py b/textworld/generator/dependency_tree.py index dddb4da3..e03fd72b 100644 --- a/textworld/generator/dependency_tree.py +++ b/textworld/generator/dependency_tree.py @@ -3,6 +3,7 @@ import textwrap +from typing import List, Any from textworld.utils import uniquify @@ -45,13 +46,17 @@ def __init__(self, element): def push(self, node): if node == self: - return - + return True + + added = False for child in self.children: - child.push(node) + added |= child.push(node) if self.element.depends_on(node.element) and not self.already_added(node): self.children.append(node) + return True + + return added def already_added(self, node): # We want to avoid duplicate information about dependencies. @@ -63,11 +68,6 @@ def already_added(self, node): if not node.element.is_distinct_from((child.element for child in self.children)): return True - # for child in self.children: - # # if node.element.value == child.element.value: - # if not node.element.is_distinct_from((child.element): - # return True - return False def __str__(self): @@ -84,25 +84,37 @@ def copy(self): node.children = [child.copy() for child in self.children] return node - def __init__(self, element_type=DependencyTreeElement): - self.root = None + def __init__(self, element_type=DependencyTreeElement, trees=[]): + self.roots = [] self.element_type = element_type + for tree in trees: + self.roots += [root.copy() for root in tree.roots] + self._update() - def push(self, value): + def push(self, value: Any, allow_multi_root: bool = False): + """ Add a value to this dependency tree. + + Adding a value already present in the tree does not modify the tree. + + Args: + value: value to add. + allow_multi_root: if `True`, allow the value to spawn an + additional root if needed. + + """ element = self.element_type(value) node = DependencyTree._Node(element) - if self.root is None: - self.root = node - else: - self.root.push(node) + + added = False + for root in self.roots: + added |= root.push(node) + + if len(self.roots) == 0 or (not added and allow_multi_root): + self.roots.append(node) # Recompute leaves. self._update() - if element in self.leaves_elements: - return node - - return None def pop(self, value): if value not in self.leaves_values: @@ -113,9 +125,14 @@ def _visit(node): if child.element.value == value: node.children.remove(child) - self._postorder(self.root, _visit) - if self.root.element.value == value: - self.root = None + root_to_remove = [] + for i, root in enumerate(self.roots): + self._postorder(root, _visit) + if root.element.value == value: + root_to_remove.append(i) + + for i in root_to_remove[::-1]: + del self.roots[i] # Recompute leaves. self._update() @@ -128,26 +145,38 @@ def _postorder(self, node, visit): def _update(self): self._leaves_values = [] - self._leaves_elements = set() + self._leaves_elements = [] def _visit(node): if len(node.children) == 0: - self._leaves_elements.add(node.element) + self._leaves_elements.append(node.element) self._leaves_values.append(node.element.value) - if self.root is not None: - self._postorder(self.root, _visit) + for root in self.roots: + self._postorder(root, _visit) self._leaves_values = uniquify(self._leaves_values) + self._leaves_elements = uniquify(self._leaves_elements) def copy(self): - tree = DependencyTree(self.element_type) - if self.root is not None: - tree.root = self.root.copy() - tree._update() - + tree = type(self)(element_type=self.element_type) + for root in self.roots: + tree.roots.append(root.copy()) + + tree._update() return tree + def tolist(self) -> List[Any]: + values = [] + + def _visit(node): + values.append(node.element.value) + + for root in self.roots: + self._postorder(root, _visit) + + return values + @property def leaves_elements(self): return self._leaves_elements @@ -157,7 +186,5 @@ def leaves_values(self): return self._leaves_values def __str__(self): - if self.root is None: - return "" + return "\n".join(map(str, self.roots)) - return str(self.root) diff --git a/textworld/generator/tests/test_dependency_tree.py b/textworld/generator/tests/test_dependency_tree.py index f62bccd8..6e05f1e1 100644 --- a/textworld/generator/tests/test_dependency_tree.py +++ b/textworld/generator/tests/test_dependency_tree.py @@ -27,16 +27,16 @@ class TestDependencyTree(unittest.TestCase): def test_pop(self): tree = DependencyTree(element_type=CustomDependencyTreeElement) - assert tree.root is None + assert len(tree.roots) == 0 tree.push("G") tree.pop("G") - assert tree.root is None + assert len(tree.roots) == 0 tree.push("G") tree.push("F") # Can't pop a non-leaf element. assert_raises(ValueError, tree.pop, "G") - assert tree.root is not None + assert len(tree.roots) > 0 assert set(tree.leaves_values) == set("F") tree.pop("F") @@ -44,41 +44,44 @@ def test_pop(self): def test_push(self): tree = DependencyTree(element_type=CustomDependencyTreeElement) - assert tree.root is None + assert len(tree.roots) == 0 assert set(tree.leaves_values) == set() - node = tree.push("G") + tree.push("G") assert set(tree.leaves_values) == set(["G"]), tree.leaves_values - assert tree.root.element.value == "G" - assert tree.root is node - assert len(node.children) == 0 + assert tree.roots[0].element.value == "G" + assert len(tree.roots[0].children) == 0 - node = tree.push("F") + tree.push("F") assert set(tree.leaves_values) == set(["F"]) - node = tree.push("C") + tree.push("C") + node = tree.roots[0].children[0].children[0] assert set(tree.leaves_values) == set(["C"]) - assert tree.root.element.value == "G" + assert tree.roots[0].element.value == "G" assert node.element.value == "C" - assert len(tree.root.children) == 1 + assert len(tree.roots[0].children) == 1 assert len(node.children) == 0 # Nothing depends on A at the moment. - node = tree.push("A") + tree.push("A") assert set(tree.leaves_values) == set(["C"]) - node = tree.push("E") + tree_ = tree.copy() + tree.push("E") + assert tree_.tolist() != tree.tolist() assert set(tree.leaves_values) == set(["E", "C"]) # Add the same element twice at the same level doesn't change the tree. - node = tree.push("E") - assert node is None + tree_ = tree.copy() + tree.push("E") + assert tree_.tolist() == tree.tolist() assert set(tree.leaves_values) == set(["E", "C"]) # Cannot remove a value that hasn't been added to the tree. assert_raises(ValueError, tree.pop, "Z") - node = tree.push("A") + tree.push("A") assert set(tree.leaves_values) == set(["A", "C"]) - node = tree.push("B") + tree.push("B") assert set(tree.leaves_values) == set(["B", "A", "C"]) From 1c4a1a685a42daedca77984b9515fc27eac47cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 28 Aug 2018 22:44:49 -0400 Subject: [PATCH 7/9] Add necessary attributes for subquests support + add proper multi-quests tracking --- textworld/generator/game.py | 250 +++++++++++++++++-------- textworld/generator/maker.py | 39 ++++ textworld/generator/tests/test_game.py | 160 +++++++++++++--- 3 files changed, 342 insertions(+), 107 deletions(-) diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 11c27e3f..95105659 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -58,6 +58,7 @@ def __init__(self, actions: Optional[List[Action]], self.actions = actions self.desc = desc self.commands = [] + self.reward = 1 self.win_action = self.set_winning_conditions(winning_conditions) self.fail_action = self.set_failing_conditions(failing_conditions) @@ -79,7 +80,9 @@ def set_winning_conditions(self, winning_conditions: Optional[Collection[Proposi # last action in the quest. winning_conditions = self.actions[-1].postconditions - self.win_action = Action("win", winning_conditions, [Proposition("win")]) + from textworld.utils import uniquify + arguments = uniquify([a for c in winning_conditions for a in c.arguments]) + self.win_action = Action("win", winning_conditions, [Proposition("win", arguments)] + list(winning_conditions)) return self.win_action def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposition]]) -> Optional[Action]: @@ -112,6 +115,7 @@ def __eq__(self, other: Any) -> bool: self.win_action == other.win_action and self.fail_action == other.fail_action and self.desc == other.desc and + self.reward == other.reward and self.commands == other.commands) @classmethod @@ -132,6 +136,7 @@ def deserialize(cls, data: Mapping) -> "Quest": desc = data["desc"] quest = cls(actions, win_action.preconditions, failing_conditions, desc=desc) quest.commands = data["commands"] + quest.reward = data.get("reward", 1) return quest def serialize(self) -> Mapping: @@ -142,6 +147,7 @@ def serialize(self) -> Mapping: """ data = {} data["desc"] = self.desc + data["reward"] = self.reward data["commands"] = self.commands data["actions"] = [action.serialize() for action in self.actions] data["win_action"] = self.win_action.serialize() @@ -199,7 +205,7 @@ def __hash__(self) -> int: def __str__(self) -> str: return "Info({}: {} | {})".format(self.name, self.adj, self.noun) - + @classmethod def deserialize(cls, data: Mapping) -> "EntityInfo": """ Creates a `EntityInfo` from serialized data. @@ -243,6 +249,7 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None, self.grammar = grammar self.quests = [] if quests is None else quests self.metadata = {} + self._objective = None self._infos = self._build_infos() self._rules = data.get_rules() self._types = data.get_types() @@ -270,6 +277,7 @@ def copy(self) -> "Game": game.state = self.state.copy() game._rules = self._rules game._types = self._types + game._objective = self._objective return game def change_grammar(self, grammar: Grammar) -> None: @@ -319,6 +327,7 @@ def deserialize(cls, data: Mapping) -> "Game": for k, v in data["rules"]} game._types = VariableTypeTree.deserialize(data["types"]) game.metadata = data.get("metadata", {}) + game._objective = data.get("objective", None) return game @@ -338,18 +347,21 @@ def serialize(self) -> Mapping: data["rules"] = [(k, v.serialize()) for k, v in self._rules.items()] data["types"] = self._types.serialize() data["metadata"] = self.metadata + data["objective"] = self._objective return data def __eq__(self, other: Any) -> bool: return (isinstance(other, Game) and self.world == other.world and self.infos == other.infos and - self.quests == other.quests) + self.quests == other.quests and + self._objective == other._objective) def __hash__(self) -> int: state = (self.world, frozenset(self.quests), - frozenset(self.infos.items())) + frozenset(self.infos.items()), + self._objective) return hash(state) @@ -395,6 +407,21 @@ def win_condition(self) -> List[Collection[Proposition]]: """ All win conditions, one for each quest. """ return [q.winning_conditions for q in self.quests] + @property + def objective(self) -> str: + if self._objective is not None: + return self._objective + + if len(self.quests) == 0: + return "" + + # We assume the last quest includes all actions needed to solve the game. + return self.quests[-1].desc + + @objective.setter + def objective(self, value: str): + self._objective = value + class ActionDependencyTreeElement(DependencyTreeElement): """ Representation of an `Action` in the dependency tree. @@ -438,6 +465,15 @@ def is_distinct_from(self, others: List["ActionDependencyTreeElement"]) -> bool: return len(new_facts) > 0 def __lt__(self, other: "ActionDependencyTreeElement") -> bool: + """ Order ActionDependencyTreeElement elements. + + Actions that remove information needed by other actions + should be sorted further in the list. + + Notes: + This is not a proper ordering, i.e. two actions + can mutually removed information needed by the other. + """ return len(other.action.removed & self.action._pre_set) > 0 def __str__(self) -> str: @@ -445,6 +481,48 @@ def __str__(self) -> str: return "{}({})".format(self.action.name, params) +class ActionDependencyTree(DependencyTree): + + def pop(self, action: Action) -> Optional[Action]: + super().pop(action) + + reverse_action = None + # The last action might have impacted one of the subquests. + reverse_action = get_reverse_action(action) + if reverse_action is not None: + self.push(reverse_action) + + return reverse_action + + def tolist(self) -> Optional[List[Action]]: + """ Builds a list with the actions contained in this dependency tree. + + The list is greedily built by iteratively popping leaves from + the dependency tree. + """ + tree = self.copy() # Make a copy of the tree to work on. + + actions = [] + last_reverse_action = None + while len(tree.roots) > 0: + # Try leaves that doesn't affect the others first. + for leaf in sorted(tree.leaves_elements): + if leaf.action != last_reverse_action: + break # Choose an action that avoids cycles. + + actions.append(leaf.action) + last_reverse_action = tree.pop(leaf.action) + + return actions + + def compress(self): + tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + for action in self.tolist()[::-1]: + tree.push(action) + + self.roots = tree.roots + + class QuestProgression: """ QuestProgression keeps track of the completion of a quest. @@ -458,13 +536,40 @@ def __init__(self, quest: Quest) -> None: quest: The quest to keep track of its completion. """ self._quest = quest - self._tree = DependencyTree(element_type=ActionDependencyTreeElement) - self._winning_policy = list(quest.actions) + self._winning_policy = None - # Build a tree representation - for i, action in enumerate(quest.actions[::-1]): + # Build a tree representation of the quest. + self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + self._tree.push(quest.win_action) + for action in quest.actions[::-1]: self._tree.push(action) + # We compress the tree since quest's actions might not be optimal. + # e.g. go west > go east > go west cycles + self._tree.compress() + self._rebuild_policy() + + def _rebuild_policy(self): + self._winning_policy = None + if self._tree is not None: + self._winning_policy = self._tree.tolist() + + @property + def winning_policy(self) -> List[Action]: + """ Actions to be performed in order to complete the quest. """ + if self._winning_policy is None: + return None + + return self._winning_policy[:-1] # Discard "win" action. + + @property + def done(self): + """ Check whether the quest is done. """ + if self.winning_policy is None: + return True + + return len(self.winning_policy) == 0 + def is_completed(self, state: State) -> bool: """ Check whether the quest is completed. """ return state.is_applicable(self._quest.win_action) @@ -476,65 +581,16 @@ def has_failed(self, state: State) -> bool: return state.is_applicable(self._quest.fail_action) - @property - def winning_policy(self) -> List[Action]: - """ Actions to be performed in order to complete the quest. """ - return self._winning_policy - - def _pop_action_from_tree(self, action: Action, tree: DependencyTree) -> Optional[Action]: - # The last action was meaningful for the quest. - tree.pop(action) - - reverse_action = None - if tree.root is not None: - # The last action might have impacted one of the subquests. - reverse_action = get_reverse_action(action) - if reverse_action is not None: - tree.push(reverse_action) - - return reverse_action - - def _build_policy(self) -> Optional[List[Action]]: - """ Build a policy given the current state of the QuestTree. - - The policy is greedily built by iteratively popping leaves from - the dependency tree. - """ - if self._tree is None: - return None - - tree = self._tree.copy() # Make a copy of the tree to work on. - - policy = [] - last_reverse_action = None - while tree.root is not None: - # Try leaves that doesn't affect the others first. - for leaf in sorted(tree.leaves_elements): - if leaf.action != last_reverse_action: - break # Choose an action that avoids cycles. - - policy.append(leaf.action) - last_reverse_action = self._pop_action_from_tree(leaf.action, tree) - - return policy - - def update(self, action: Action, bypass: Optional[List[Action]] = None) -> None: + def update(self, action: Action) -> None: """ Update the state of the quest after a given action was performed. Args: action: Action affecting the state of the quest. """ - if bypass is not None: - for action in bypass: - self._pop_action_from_tree(action, self._tree) - - self._winning_policy = self._build_policy() - return - # Determine if we moved away from the goal or closer to it. if action in self._tree.leaves_values: # The last action was meaningful for the quest. - self._pop_action_from_tree(action, self._tree) + self._tree.pop(action) else: # The last action must have moved us away from the goal. # We need to reverse it. @@ -545,7 +601,22 @@ def update(self, action: Action, bypass: Optional[List[Action]] = None) -> None: else: self._tree.push(reverse_action) - self._winning_policy = self._build_policy() + self._rebuild_policy() + + def compress_winning_policy(self, state: State): + for j in range(0, len(self._winning_policy)): + for i in range(j + 1, len(self._winning_policy))[::-1]: + if state.is_sequence_applicable(self._winning_policy[:j] + self._winning_policy[i:]): + for action in self._winning_policy[:i]: + self._tree.pop(action) + + for action in self._winning_policy[:j][::-1]: + self._tree.push(action) + + self._rebuild_policy() + return True + + return False class GameProgression: @@ -558,30 +629,39 @@ class GameProgression: def __init__(self, game: Game, track_quest: bool = True) -> None: """ Args: - game: The game to track progression of. - track_quest: Whether we should track the quest completion. + game: The gaquest_progressionogression of. + track_quest:quest_progressionould track the quest completion. """ self.game = game self.state = game.state.copy() self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(), self.game._types.constants_mapping)) - self.quest_progression = None + self.quest_progressions = None if track_quest and len(game.quests) > 0: - self.quest_progression = QuestProgression(game.quests[0]) + self.quest_progressions = [QuestProgression(quest) for quest in game.quests] + for quest_progression in self.quest_progressions: + while quest_progression.compress_winning_policy(self.state): + pass @property def done(self) -> bool: """ Whether the quest is completed or has failed. """ - if self.quest_progression is None: + if self.quest_progressions is None: return False - return (self.quest_progression.is_completed(self.state) or - self.quest_progression.has_failed(self.state)) + all_completed = True + for quest_progression in self.quest_progressions: + if quest_progression.has_failed(self.state): + return True + + all_completed &= quest_progression.done + + return all_completed @property def tracking_quest(self) -> bool: """ Whether the quest is tracked or not. """ - return self.quest_progression is not None + return self.quest_progressions is not None @property def valid_actions(self) -> List[Action]: @@ -594,12 +674,23 @@ def winning_policy(self) -> Optional[List[Action]]: Returns: A policy that leads to winning the game. It can be `None` - if `tracking_quest` is `False` or the quest has been failed. + if `tracking_quest` is `False` or the quest has failed. """ - if not self.tracking_quest or self.quest_progression.winning_policy is None: + if not self.tracking_quest: return None - - return list(self.quest_progression.winning_policy) + + # Check if any quest has failed. + if any(quest_progression.winning_policy is None for quest_progression in self.quest_progressions): + return None + + # Greedily build a new winning policy by merging all individual quests' tree. + trees = [qp._tree for qp in self.quest_progressions if not qp.done] + master_quest_tree = ActionDependencyTree(element_type=ActionDependencyTreeElement, + trees=trees) + + # print(master_quest_tree) + winning_policy = master_quest_tree.tolist() + return [a for a in winning_policy if a.name != "win"] def update(self, action: Action) -> None: """ Update the state of the game given the provided action. @@ -618,10 +709,11 @@ def update(self, action: Action) -> None: if self.state.is_sequence_applicable(self.winning_policy): pass # The last action didn't impact the quest. else: - # Check for shortcut. - for i in range(1, len(self.winning_policy)): - if self.state.is_sequence_applicable(self.winning_policy[i:]): - self.quest_progression.update(action, bypass=self.winning_policy[:i]) - return - - self.quest_progression.update(action) + for quest_progression in self.quest_progressions: + if quest_progression.done: + continue + + # Try compressing the winning policy for the quest, + # otherwise update its progression given the new action. + if not quest_progression.compress_winning_policy(self.state): + quest_progression.update(action) diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 13aba882..debeabec 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -3,6 +3,7 @@ import os +import itertools from os.path import join as pjoin from typing import List, Iterable, Union, Optional @@ -588,6 +589,44 @@ def set_quest_from_commands(self, commands: List[str], ask_for_state: bool = Fal # Calling build will generate the description for the quest. self.build() return self._quests[0] + + def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + """ Create new fact. + + Args: + name: The name of the new fact. + *entities: A list of entities as arguments to the new fact. + """ + args = [entity.var for entity in entities] + return Proposition(name, args) + + def new_quest_using_commands(self, commands: List[str], winning_facts: Optional[Iterable[Proposition]] = None) -> Quest: + """ Add a new quest using predefined text commands. + + This launches a `textworld.play` session. + + Args: + commands: Text commands. + winning_facts: set of facts that should be true in order to + consider the quest as completed. If `None`, + the last action's preconditions will be used. + + Returns: + The resulting quest. + """ + with make_temp_directory() as tmpdir: + try: + game_file = self.compile(pjoin(tmpdir, "record_quest")) + recorder = Recorder() + agent = textworld.agents.WalkthroughAgent(commands) + textworld.play(game_file, agent=agent, wrapper=recorder, silent=True) + except textworld.agents.WalkthroughDone: + pass # Quest is done. + + # Skip "None" actions. + actions = [action for action in recorder.actions if action is not None] + quest = Quest(actions=actions, winning_conditions=winning_facts) + return quest def set_quest_from_final_state(self, final_state: Collection[Proposition]) -> Quest: """ Defines the game's quest using a collection of facts. diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index 7abda58e..edd9ee44 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -27,24 +27,38 @@ from textworld.logic import Proposition + +def _apply_command(command: str, game_progression: GameProgression): + """ Apply a text command to a game_progression object. + """ + valid_commands = gen_commands_from_actions(game_progression.valid_actions, game_progression.game.infos) + + for action, cmd in zip(game_progression.valid_actions, valid_commands): + if command == cmd: + game_progression.update(action) + return + + raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands)) + + def test_game_comparison(): rngs = {} rngs['rng_map'] = np.random.RandomState(1) rngs['rng_objects'] = np.random.RandomState(2) rngs['rng_quest'] = np.random.RandomState(3) rngs['rng_grammar'] = np.random.RandomState(4) - game1 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs) + game1 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs) rngs['rng_map'] = np.random.RandomState(1) rngs['rng_objects'] = np.random.RandomState(2) rngs['rng_quest'] = np.random.RandomState(3) rngs['rng_grammar'] = np.random.RandomState(4) - game2 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs) + game2 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs) assert game1 == game2 # Test __eq__ assert game1 in {game2} # Test __hash__ - game3 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs) + game3 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs) assert game1 != game3 @@ -53,7 +67,7 @@ def test_game_comparison(): def test_variable_infos(verbose=False): g_rng.set_seed(1234) grammar_flags = {"theme": "house", "include_adj": True} - game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, grammar_flags=grammar_flags) + game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, quest_breadth=2, grammar_flags=grammar_flags) for var_id, var_infos in game.infos.items(): if var_id not in ["P", "I"]: @@ -270,7 +284,46 @@ def test_winning_policy(self): quest.update(self.quest.actions[0]) assert quest.winning_policy == self.quest.actions[1:] - def test_cycle_in_winning_policy(cls): + +class TestGameProgression(unittest.TestCase): + + @classmethod + def setUpClass(cls): + M = GameMaker() + + # The goal + commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] + + # Create a 'bedroom' room. + R1 = M.new_room("bedroom") + R2 = M.new_room("kitchen") + M.set_player(R2) + + path = M.connect(R1.east, R2.west) + path.door = M.new(type='d', name='wooden door') + path.door.add_property("closed") + + carrot = M.new(type='f', name='carrot') + R1.add(carrot) + + # Add a closed chest in R2. + chest = M.new(type='c', name='chest') + chest.add_property("open") + R2.add(chest) + + cls.quest = M.set_quest_from_commands(commands) + cls.game = M.build() + + def test_is_game_completed(self): + game_progress = GameProgression(self.game) + + for action in self.quest.actions: + assert not game_progress.done + game_progress.update(action) + + assert game_progress.done + + def test_cycle_in_winning_policy(self): M = GameMaker() # Create a map. @@ -303,16 +356,6 @@ def test_cycle_in_winning_policy(cls): game = M.build() game_progression = GameProgression(game) - def _apply_command(command, game_progression): - valid_commands = gen_commands_from_actions(game_progression.valid_actions, game.infos) - - for action, cmd in zip(game_progression.valid_actions, valid_commands): - if command == cmd: - game_progression.update(action) - return - - raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands)) - _apply_command("go south", game_progression) expected_commands = ["go north"] + commands winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) @@ -351,15 +394,26 @@ def _apply_command(command, game_progression): winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands) - -class TestGameProgression(unittest.TestCase): - - @classmethod - def setUpClass(cls): + def test_game_with_multiple_quests(self): M = GameMaker() - # The goal - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] + # The subgoals (needs to be executed in order). + # commands = [["open wooden door"], + # ["go west", "take carrot"], + # ["go east", "drop carrot"], + # # Now, the player is back in the kitchen and the wooden door is open. + # ["go west", "take lettuce"], + # ["go east", "drop lettuce"], + # # Now, the player is back in the kitchen, there are a carrot and a lettuce on the floor. + # ["take lettuce"], + # ["take carrot"], + # ["insert carrot into chest"], + # ["insert lettuce into chest", "close chest"]] + commands = [["open wooden door", "go west", "take carrot", "go east", "drop carrot"], + # Now, the player is back in the kitchen and the wooden door is open. + ["go west", "take lettuce", "go east", "drop lettuce"], + # Now, the player is back in the kitchen, there are a carrot and a lettuce on the floor. + ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -371,22 +425,72 @@ def setUpClass(cls): path.door.add_property("closed") carrot = M.new(type='f', name='carrot') - R1.add(carrot) + lettuce = M.new(type='f', name='lettuce') + R1.add(carrot, lettuce) # Add a closed chest in R2. chest = M.new(type='c', name='chest') chest.add_property("open") R2.add(chest) - cls.quest = M.set_quest_from_commands(commands) - cls.game = M.build() + quest1 = M.new_quest_using_commands(commands[0]) + quest1.desc = "Fetch the carrot and drop it on the kitchen's ground." + quest2 = M.new_quest_using_commands(commands[0] + commands[1]) + quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground." + winning_facts = [M.new_fact("in", lettuce, chest), + M.new_fact("in", carrot, chest), + M.new_fact("closed", chest),] + quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], winning_facts=winning_facts) + # quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) + quest3.desc = "Put the lettuce and the carrot into the chest before closing it." + + M._quests = [quest1, quest2, quest3] + assert len(M._quests) == len(commands) + game = M.build() - def test_is_game_completed(self): - game_progress = GameProgression(self.game) + game_progress = GameProgression(game) + assert len(game_progress.quest_progressions) == len(game.quests) - for action in self.quest.actions: + # for i, qp in enumerate(game_progress.quest_progressions): + # print() + # print(qp._tree) + # print(i, [c.name for c in qp.winning_policy]) + + # print() + # print([c.name for c in game_progress.winning_policy]) + + # Following the actions associated to the last quest actually corresponds + # to solving the whole game. + for action in game_progress.winning_policy: assert not game_progress.done game_progress.update(action) + # print(action.name, [c.name for c in game_progress.winning_policy]) + + assert game_progress.done + assert all(quest_progression.done for quest_progression in game_progress.quest_progressions) + + # Try solving the game by greedily taking the first action from the current winning policy. + game_progress = GameProgression(game) + while not game_progress.done: + action = game_progress.winning_policy[0] + game_progress.update(action) + + # Try solving the second quest (i.e. bringing back the lettuce) first. + game_progress = GameProgression(game) + for command in ["open wooden door", "go west", "take lettuce", "go east", "drop lettuce"]: + _apply_command(command, game_progress) + + assert not game_progress.quest_progressions[0].done + assert game_progress.quest_progressions[1].done + + for command in ["go west", "take carrot", "go east", "drop carrot"]: + _apply_command(command, game_progress) + + assert game_progress.quest_progressions[0].done + assert game_progress.quest_progressions[1].done + + for command in ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]: + _apply_command(command, game_progress) assert game_progress.done From 0bd83c7891134a3916e82be666de90d8395399fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 4 Sep 2018 15:35:18 -0400 Subject: [PATCH 8/9] Fixing bugs + refactoring --- tests/test_play_generated_games.py | 2 +- tests/{test_tw_play.py => test_tw-play.py} | 4 +- textworld/agents/walkthrough.py | 8 +- textworld/envs/glulx/git_glulx_ml.py | 2 +- textworld/generator/__init__.py | 6 +- textworld/generator/dependency_tree.py | 127 ++++----- textworld/generator/game.py | 240 ++++++++++-------- .../generator/tests/test_dependency_tree.py | 25 +- textworld/generator/tests/test_game.py | 117 ++++++--- textworld/render/render.py | 3 +- 10 files changed, 317 insertions(+), 217 deletions(-) rename tests/{test_tw_play.py => test_tw-play.py} (80%) diff --git a/tests/test_play_generated_games.py b/tests/test_play_generated_games.py index 562892f1..a70ac2a7 100644 --- a/tests/test_play_generated_games.py +++ b/tests/test_play_generated_games.py @@ -16,7 +16,7 @@ def test_play_generated_games(): # Sample game specs. world_size = rng.randint(1, 10) nb_objects = rng.randint(0, 20) - quest_length = rng.randint(1, 5) + quest_length = rng.randint(2, 5) quest_breadth = rng.randint(3, 7) game_seed = rng.randint(0, 65365) grammar_flags = {} # Default grammar. diff --git a/tests/test_tw_play.py b/tests/test_tw-play.py similarity index 80% rename from tests/test_tw_play.py rename to tests/test_tw-play.py index 6ff2dfea..1bb38ea9 100644 --- a/tests/test_tw_play.py +++ b/tests/test_tw-play.py @@ -7,9 +7,9 @@ from textworld.utils import make_temp_directory -def test_making_a_custom_game(): +def test_playing_a_game(): with make_temp_directory(prefix="test_tw-play") as tmpdir: - game_file, _ = textworld.make(5, 10, 5, {}, seed=1234, games_dir=tmpdir) + game_file, _ = textworld.make(5, 10, 5, 4, {}, seed=1234, games_dir=tmpdir) command = ["tw-play", "--max-steps", "100", "--mode", "random", game_file] assert check_call(command) == 0 diff --git a/textworld/agents/walkthrough.py b/textworld/agents/walkthrough.py index e75ed08c..319ca197 100644 --- a/textworld/agents/walkthrough.py +++ b/textworld/agents/walkthrough.py @@ -26,7 +26,13 @@ def reset(self, env): raise NameError(msg) # Load command from the generated game. - self._commands = iter(env.game.quests[-1].commands) + from textworld.generator.game import GameProgression, Quest + from textworld.generator.inform7 import gen_commands_from_actions + + game_progression = GameProgression(env.game) + main_quest = Quest(actions=game_progression.winning_policy) + commands = gen_commands_from_actions(main_quest.actions, env.game.infos) + self._commands = iter(commands) def act(self, game_state, reward, done): try: diff --git a/textworld/envs/glulx/git_glulx_ml.py b/textworld/envs/glulx/git_glulx_ml.py index 9bd3a4b4..c37299a6 100644 --- a/textworld/envs/glulx/git_glulx_ml.py +++ b/textworld/envs/glulx/git_glulx_ml.py @@ -146,7 +146,7 @@ def init(self, output: str, game=None, """ output = _strip_input_prompt_symbol(output) super().init(output) - self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward) + self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward) self._state_tracking = state_tracking self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0 self._objective = game.objective diff --git a/textworld/generator/__init__.py b/textworld/generator/__init__.py index 514ae963..8496cece 100644 --- a/textworld/generator/__init__.py +++ b/textworld/generator/__init__.py @@ -180,16 +180,16 @@ class Options(ChainingOptions): def get_rules(self, depth): if depth == 0: - # Last action should not be "go ". + # Last action should not be "go ". return data.get_rules().get_matching("^(?!go.*).*") else: return super().get_rules(depth) options = Options() options.backward = True - options.min_depth = quest_length + options.min_depth = 1 options.max_depth = quest_length - options.min_breadth = quest_breadth + options.min_breadth = 1 options.max_breadth = quest_breadth options.create_variables = True options.rng = rngs['rng_quest'] diff --git a/textworld/generator/dependency_tree.py b/textworld/generator/dependency_tree.py index e03fd72b..cf0dce1a 100644 --- a/textworld/generator/dependency_tree.py +++ b/textworld/generator/dependency_tree.py @@ -3,7 +3,7 @@ import textwrap -from typing import List, Any +from typing import List, Any, Iterable from textworld.utils import uniquify @@ -19,32 +19,34 @@ class DependencyTreeElement: `__str__` accordingly. """ - def __init__(self, value): + def __init__(self, value: Any): self.value = value + self.parent = None - def depends_on(self, other): + def depends_on(self, other: "DependencyTreeElement") -> bool: """ Check whether this element depends on the `other`. """ return self.value > other.value - def is_distinct_from(self, others): + def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool: """ Check whether this element is distinct from `others`. """ return self.value not in [other.value for other in others] - def __str__(self): + def __str__(self) -> str: return str(self.value) class DependencyTree: class _Node: - def __init__(self, element): + def __init__(self, element: DependencyTreeElement): self.element = element self.children = [] + self.parent = None - def push(self, node): + def push(self, node: "DependencyTree._Node") -> bool: if node == self: return True @@ -53,12 +55,15 @@ def push(self, node): added |= child.push(node) if self.element.depends_on(node.element) and not self.already_added(node): + node = node.copy() self.children.append(node) + node.element.parent = self.element + node.parent = self return True return added - def already_added(self, node): + def already_added(self, node: "DependencyTree._Node") -> bool: # We want to avoid duplicate information about dependencies. if node in self.children: return True @@ -69,8 +74,14 @@ def already_added(self, node): return True return False + + def __iter__(self) -> Iterable["DependencyTree._Node"]: + for child in self.children: + yield from list(child) + + yield self - def __str__(self): + def __str__(self) -> str: node_text = str(self.element) txt = [node_text] @@ -79,12 +90,16 @@ def __str__(self): return "\n".join(txt) - def copy(self): + def copy(self) -> "DependencyTree._Node": node = DependencyTree._Node(self.element) - node.children = [child.copy() for child in self.children] + for child in self.children: + child_ = child.copy() + child_.parent = node + node.children.append(child_) + return node - def __init__(self, element_type=DependencyTreeElement, trees=[]): + def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, trees: Iterable["DependencyTree"] = []): self.roots = [] self.element_type = element_type for tree in trees: @@ -92,7 +107,7 @@ def __init__(self, element_type=DependencyTreeElement, trees=[]): self._update() - def push(self, value: Any, allow_multi_root: bool = False): + def push(self, value: Any, allow_multi_root: bool = False) -> bool: """ Add a value to this dependency tree. Adding a value already present in the tree does not modify the tree. @@ -112,53 +127,53 @@ def push(self, value: Any, allow_multi_root: bool = False): if len(self.roots) == 0 or (not added and allow_multi_root): self.roots.append(node) + added = True + + self._update() # Recompute leaves. + return added - # Recompute leaves. - self._update() + def remove(self, value: Any) -> None: + """ Remove all leaves having the given value. - def pop(self, value): + The value to remove needs to belong to at least one leaf in this tree. + Otherwise, the tree remains unchanged. + + Args: + value: value to remove from the tree. + + Returns: + Whether the tree has changed or not. + """ if value not in self.leaves_values: - raise ValueError("That element is not a leaf: {!r}.".format(value)) - - def _visit(node): - for child in list(node.children): - if child.element.value == value: - node.children.remove(child) + return False root_to_remove = [] - for i, root in enumerate(self.roots): - self._postorder(root, _visit) - if root.element.value == value: - root_to_remove.append(i) - - for i in root_to_remove[::-1]: - del self.roots[i] - - # Recompute leaves. - self._update() - - def _postorder(self, node, visit): - for child in node.children: - self._postorder(child, visit) - - visit(node) - - def _update(self): + for node in self: + if node.element.value == value: + if node.parent is not None: + node.parent.children.remove(node) + else: + root_to_remove.append(node) + + for node in root_to_remove: + self.roots.remove(node) + + self._update() # Recompute leaves. + return True + + def _update(self) -> None: self._leaves_values = [] self._leaves_elements = [] - def _visit(node): + for node in self: if len(node.children) == 0: self._leaves_elements.append(node.element) self._leaves_values.append(node.element.value) - for root in self.roots: - self._postorder(root, _visit) - self._leaves_values = uniquify(self._leaves_values) self._leaves_elements = uniquify(self._leaves_elements) - def copy(self): + def copy(self) -> "DependencyTree": tree = type(self)(element_type=self.element_type) for root in self.roots: tree.roots.append(root.copy()) @@ -166,25 +181,21 @@ def copy(self): tree._update() return tree - def tolist(self) -> List[Any]: - values = [] - - def _visit(node): - values.append(node.element.value) - + def __iter__(self) -> Iterable["DependencyTree._Node"]: for root in self.roots: - self._postorder(root, _visit) - - return values + yield from list(root) @property - def leaves_elements(self): + def values(self) -> List[Any]: + return [node.element.value for node in self] + + @property + def leaves_elements(self) -> List[DependencyTreeElement]: return self._leaves_elements @property - def leaves_values(self): + def leaves_values(self) -> List[Any]: return self._leaves_values - def __str__(self): + def __str__(self) -> str: return "\n".join(map(str, self.roots)) - diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 95105659..6d8381d9 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -205,7 +205,7 @@ def __hash__(self) -> int: def __str__(self) -> str: return "Info({}: {} | {})".format(self.name, self.adj, self.noun) - + @classmethod def deserialize(cls, data: Mapping) -> "EntityInfo": """ Creates a `EntityInfo` from serialized data. @@ -414,9 +414,9 @@ def objective(self) -> str: if len(self.quests) == 0: return "" - + # We assume the last quest includes all actions needed to solve the game. - return self.quests[-1].desc + return self.quests[-1].desc @objective.setter def objective(self, value: str): @@ -472,9 +472,19 @@ def __lt__(self, other: "ActionDependencyTreeElement") -> bool: Notes: This is not a proper ordering, i.e. two actions - can mutually removed information needed by the other. + can mutually removed information needed by each other. """ - return len(other.action.removed & self.action._pre_set) > 0 + def _required_facts(node): + pre_set = set(node.action._pre_set) + while node.parent is not None: + pre_set |= node.parent.action._pre_set + pre_set -= node.action.added + node = node.parent + + return pre_set + + return len(other.action.removed & _required_facts(self)) > len(self.action.removed & _required_facts(other)) + #return len(other.action.removed & self.action._pre_set) > 0 def __str__(self) -> str: params = ", ".join(map(str, self.action.variables)) @@ -483,10 +493,9 @@ def __str__(self) -> str: class ActionDependencyTree(DependencyTree): - def pop(self, action: Action) -> Optional[Action]: - super().pop(action) + def remove(self, action: Action) -> Optional[Action]: + super().remove(action) - reverse_action = None # The last action might have impacted one of the subquests. reverse_action = get_reverse_action(action) if reverse_action is not None: @@ -511,16 +520,20 @@ def tolist(self) -> Optional[List[Action]]: break # Choose an action that avoids cycles. actions.append(leaf.action) - last_reverse_action = tree.pop(leaf.action) + last_reverse_action = tree.remove(leaf.action) return actions - + def compress(self): - tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) - for action in self.tolist()[::-1]: - tree.push(action) - - self.roots = tree.roots + for node in self: + if node.parent is None: + continue + + if len(node.children) == 1: + r_action = get_reverse_action(node.element.action) + if r_action == node.children[0].element.action: + node.parent.children.remove(node) + node.parent.children += node.children[0].children class QuestProgression: @@ -536,7 +549,9 @@ def __init__(self, quest: Quest) -> None: quest: The quest to keep track of its completion. """ self._quest = quest - self._winning_policy = None + self._completed = False + self._failed = False + self._unfinishable = False # Build a tree representation of the quest. self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) @@ -544,124 +559,143 @@ def __init__(self, quest: Quest) -> None: for action in quest.actions[::-1]: self._tree.push(action) - # We compress the tree since quest's actions might not be optimal. - # e.g. go west > go east > go west cycles - self._tree.compress() - self._rebuild_policy() - - def _rebuild_policy(self): - self._winning_policy = None - if self._tree is not None: - self._winning_policy = self._tree.tolist() - + self._winning_policy = quest.actions + [quest.win_action] + @property def winning_policy(self) -> List[Action]: """ Actions to be performed in order to complete the quest. """ - if self._winning_policy is None: - return None + if self.done: + return [] return self._winning_policy[:-1] # Discard "win" action. @property - def done(self): - """ Check whether the quest is done. """ - if self.winning_policy is None: - return True - - return len(self.winning_policy) == 0 - - def is_completed(self, state: State) -> bool: + def done(self) -> bool: + """ Check if the quest is done (i.e. completed, failed or unfinishable). """ + return self.completed or self.failed or self.unfinishable + + @property + def completed(self) -> bool: """ Check whether the quest is completed. """ - return state.is_applicable(self._quest.win_action) + return self._completed - def has_failed(self, state: State) -> bool: + @property + def failed(self) -> bool: """ Check whether the quest has failed. """ - if self._quest.fail_action is None: - return False + return self._failed - return state.is_applicable(self._quest.fail_action) + @property + def unfinishable(self) -> bool: + """ Check whether the quest is in an unfinishable state. """ + return self._unfinishable - def update(self, action: Action) -> None: - """ Update the state of the quest after a given action was performed. + def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + """ Update quest progression given available information. + + Args: + action: Action potentially affecting the quest progression. + state: Current game state. + """ + if self.done: + return # Nothing to do, the quest is already done. + + if state is not None: + # Check if quest is completed. + if self._quest.win_action is not None: + self._completed = state.is_applicable(self._quest.win_action) + + # Check if quest has failed. + if self._quest.fail_action is not None: + self._failed = state.is_applicable(self._quest.fail_action) + + # Try compressing the winning policy given the new game state. + if self.compress_winning_policy(state): + return # A shorter winning policy has been found. + + if action is not None: + # Determine if we moved away from the goal or closer to it. + reverse_action = self._tree.remove(action) + if reverse_action is None: # Irreversible action. + self._unfinishable = True # Can't track quest anymore. + + self._winning_policy = self._tree.tolist() # Rebuild policy. + + def compress_winning_policy(self, state: State) -> bool: + """ Compress the winning policy given a game state. Args: - action: Action affecting the state of the quest. + state: Current game state. + + Returns: + Whether the winning policy was compressed or not. """ - # Determine if we moved away from the goal or closer to it. - if action in self._tree.leaves_values: - # The last action was meaningful for the quest. - self._tree.pop(action) - else: - # The last action must have moved us away from the goal. - # We need to reverse it. - reverse_action = get_reverse_action(action) - if reverse_action is None: - # Irreversible action. - self._tree = None # Can't track quest anymore. - else: - self._tree.push(reverse_action) - - self._rebuild_policy() - - def compress_winning_policy(self, state: State): - for j in range(0, len(self._winning_policy)): - for i in range(j + 1, len(self._winning_policy))[::-1]: - if state.is_sequence_applicable(self._winning_policy[:j] + self._winning_policy[i:]): - for action in self._winning_policy[:i]: - self._tree.pop(action) - - for action in self._winning_policy[:j][::-1]: - self._tree.push(action) - - self._rebuild_policy() - return True - - return False + + def _find_shorter_policy(policy): + for j in range(0, len(policy)): + for i in range(j + 1, len(policy))[::-1]: + shorter_policy = policy[:j] + policy[i:] + if state.is_sequence_applicable(shorter_policy): + self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + for action in shorter_policy[::-1]: + self._tree.push(action) + + return shorter_policy + + return None + + compressed = False + policy = _find_shorter_policy(self._winning_policy) + while policy is not None: + compressed = True + self._winning_policy = policy + policy = _find_shorter_policy(policy) + + return compressed class GameProgression: """ GameProgression keeps track of the progression of a game. - If `tracking_quest` is True, then `winning_policy` will be the list + If `tracking_quests` is True, then `winning_policy` will be the list of Action that need to be applied in order to complete the game. """ - def __init__(self, game: Game, track_quest: bool = True) -> None: + def __init__(self, game: Game, track_quests: bool = True) -> None: """ Args: - game: The gaquest_progressionogression of. - track_quest:quest_progressionould track the quest completion. + game: The game for which to track progression. + track_quests: whether quest progressions are being tracked. """ self.game = game self.state = game.state.copy() self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(), self.game._types.constants_mapping)) - self.quest_progressions = None - if track_quest and len(game.quests) > 0: + + self.quest_progressions = [] + if track_quests: self.quest_progressions = [QuestProgression(quest) for quest in game.quests] for quest_progression in self.quest_progressions: - while quest_progression.compress_winning_policy(self.state): - pass + quest_progression.update(action=None, state=self.state) @property def done(self) -> bool: - """ Whether the quest is completed or has failed. """ - if self.quest_progressions is None: - return False + """ Whether all quests are completed or at least one has failed or is unfinishable. """ + if not self.tracking_quests: + return False # There is nothing to be "done". all_completed = True for quest_progression in self.quest_progressions: - if quest_progression.has_failed(self.state): + if quest_progression.failed or quest_progression.unfinishable: return True - all_completed &= quest_progression.done + all_completed &= quest_progression.completed return all_completed @property - def tracking_quest(self) -> bool: - """ Whether the quest is tracked or not. """ - return self.quest_progressions is not None + def tracking_quests(self) -> bool: + """ Whether quests are being tracked or not. """ + return len(self.quest_progressions) > 0 @property def valid_actions(self) -> List[Action]: @@ -674,21 +708,20 @@ def winning_policy(self) -> Optional[List[Action]]: Returns: A policy that leads to winning the game. It can be `None` - if `tracking_quest` is `False` or the quest has failed. + if `tracking_quests` is `False` or the quest has failed. """ - if not self.tracking_quest: + if not self.tracking_quests: return None - + # Check if any quest has failed. if any(quest_progression.winning_policy is None for quest_progression in self.quest_progressions): return None - + # Greedily build a new winning policy by merging all individual quests' tree. trees = [qp._tree for qp in self.quest_progressions if not qp.done] master_quest_tree = ActionDependencyTree(element_type=ActionDependencyTreeElement, trees=trees) - - # print(master_quest_tree) + winning_policy = master_quest_tree.tolist() return [a for a in winning_policy if a.name != "win"] @@ -705,15 +738,6 @@ def update(self, action: Action) -> None: self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(), self.game._types.constants_mapping)) - if self.tracking_quest: - if self.state.is_sequence_applicable(self.winning_policy): - pass # The last action didn't impact the quest. - else: - for quest_progression in self.quest_progressions: - if quest_progression.done: - continue - - # Try compressing the winning policy for the quest, - # otherwise update its progression given the new action. - if not quest_progression.compress_winning_policy(self.state): - quest_progression.update(action) + # Update all quest progressions given the last action and new state. + for quest_progression in self.quest_progressions: + quest_progression.update(action, self.state) diff --git a/textworld/generator/tests/test_dependency_tree.py b/textworld/generator/tests/test_dependency_tree.py index 6e05f1e1..02c4f667 100644 --- a/textworld/generator/tests/test_dependency_tree.py +++ b/textworld/generator/tests/test_dependency_tree.py @@ -25,21 +25,23 @@ def depends_on(self, other): class TestDependencyTree(unittest.TestCase): - def test_pop(self): + def test_remove(self): tree = DependencyTree(element_type=CustomDependencyTreeElement) assert len(tree.roots) == 0 - tree.push("G") - tree.pop("G") + assert tree.push("G") + assert tree.remove("G") assert len(tree.roots) == 0 + assert list(tree) == [] + assert tree.values == [] - tree.push("G") - tree.push("F") + assert tree.push("G") + assert tree.push("F") # Can't pop a non-leaf element. - assert_raises(ValueError, tree.pop, "G") + assert not tree.remove("G") assert len(tree.roots) > 0 assert set(tree.leaves_values) == set("F") - tree.pop("F") + assert tree.remove("F") assert set(tree.leaves_values) == set("G") def test_push(self): @@ -69,16 +71,17 @@ def test_push(self): tree_ = tree.copy() tree.push("E") - assert tree_.tolist() != tree.tolist() + assert tree_.values != tree.values assert set(tree.leaves_values) == set(["E", "C"]) # Add the same element twice at the same level doesn't change the tree. tree_ = tree.copy() tree.push("E") - assert tree_.tolist() == tree.tolist() + assert tree_.values == tree.values assert set(tree.leaves_values) == set(["E", "C"]) - # Cannot remove a value that hasn't been added to the tree. - assert_raises(ValueError, tree.pop, "Z") + # Removing a value not associated to a leaf, does not change the tree. + assert not tree.remove("Z") + assert tree_.values == tree.values tree.push("A") assert set(tree.leaves_values) == set(["A", "C"]) diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index edd9ee44..f0be0d6c 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -18,17 +18,18 @@ from textworld.generator import make_small_map, make_grammar, make_game_with from textworld.generator.chaining import ChainingOptions, sample_quest +from textworld.logic import Action, State from textworld.generator.game import Quest, Game from textworld.generator.game import QuestProgression, GameProgression from textworld.generator.game import UnderspecifiedQuestError +from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement from textworld.generator.inform7 import gen_commands_from_actions from textworld.logic import Proposition - -def _apply_command(command: str, game_progression: GameProgression): +def _apply_command(command: str, game_progression: GameProgression) -> None: """ Apply a text command to a game_progression object. """ valid_commands = gen_commands_from_actions(game_progression.valid_actions, game_progression.game.infos) @@ -62,8 +63,6 @@ def test_game_comparison(): assert game1 != game3 - - def test_variable_infos(verbose=False): g_rng.set_seed(1234) grammar_flags = {"theme": "house", "include_adj": True} @@ -256,9 +255,6 @@ class TestQuestProgression(unittest.TestCase): def setUpClass(cls): M = GameMaker() - # The goal - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] - # Create a 'bedroom' room. R1 = M.new_room("bedroom") R2 = M.new_room("kitchen") @@ -276,14 +272,39 @@ def setUpClass(cls): chest.add_property("open") R2.add(chest) + # The goal + commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] cls.quest = M.set_quest_from_commands(commands) + commands = ["open wooden door", "go west", "take carrot", "eat carrot"] + cls.eating_quest = M.new_quest_using_commands(commands) + + cls.game = M.build() + def test_winning_policy(self): quest = QuestProgression(self.quest) assert quest.winning_policy == self.quest.actions - quest.update(self.quest.actions[0]) + quest.update(self.quest.actions[0], state=State()) assert quest.winning_policy == self.quest.actions[1:] + def test_failing_quest(self): + quest = QuestProgression(self.quest) + + state = self.game.state.copy() + for action in self.eating_quest.actions: + state.apply(action) + if action.name.startswith("eat"): + quest.update(action, state) + assert len(quest.winning_policy) == 0 + assert quest.done + assert not quest.completed + assert not quest.failed + assert quest.unfinishable + break + + assert not quest.done + assert len(quest.winning_policy) > 0 + class TestGameProgression(unittest.TestCase): @@ -362,7 +383,6 @@ def test_cycle_in_winning_policy(self): assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands) _apply_command("go east", game_progression) - _apply_command("go north", game_progression) expected_commands = ["go south", "go west", "go north"] + commands winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) @@ -375,6 +395,7 @@ def test_cycle_in_winning_policy(self): # Quest where player's has to pick up the carrot first. commands = ["go east", "take apple", "go west", "go north", "drop apple"] + M.set_quest_from_commands(commands) game = M.build() game_progression = GameProgression(game) @@ -398,17 +419,6 @@ def test_game_with_multiple_quests(self): M = GameMaker() # The subgoals (needs to be executed in order). - # commands = [["open wooden door"], - # ["go west", "take carrot"], - # ["go east", "drop carrot"], - # # Now, the player is back in the kitchen and the wooden door is open. - # ["go west", "take lettuce"], - # ["go east", "drop lettuce"], - # # Now, the player is back in the kitchen, there are a carrot and a lettuce on the floor. - # ["take lettuce"], - # ["take carrot"], - # ["insert carrot into chest"], - # ["insert lettuce into chest", "close chest"]] commands = [["open wooden door", "go west", "take carrot", "go east", "drop carrot"], # Now, the player is back in the kitchen and the wooden door is open. ["go west", "take lettuce", "go east", "drop lettuce"], @@ -441,7 +451,6 @@ def test_game_with_multiple_quests(self): M.new_fact("in", carrot, chest), M.new_fact("closed", chest),] quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], winning_facts=winning_facts) - # quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) quest3.desc = "Put the lettuce and the carrot into the chest before closing it." M._quests = [quest1, quest2, quest3] @@ -451,29 +460,21 @@ def test_game_with_multiple_quests(self): game_progress = GameProgression(game) assert len(game_progress.quest_progressions) == len(game.quests) - # for i, qp in enumerate(game_progress.quest_progressions): - # print() - # print(qp._tree) - # print(i, [c.name for c in qp.winning_policy]) - - # print() - # print([c.name for c in game_progress.winning_policy]) - # Following the actions associated to the last quest actually corresponds # to solving the whole game. for action in game_progress.winning_policy: assert not game_progress.done game_progress.update(action) - # print(action.name, [c.name for c in game_progress.winning_policy]) assert game_progress.done assert all(quest_progression.done for quest_progression in game_progress.quest_progressions) - + # Try solving the game by greedily taking the first action from the current winning policy. game_progress = GameProgression(game) while not game_progress.done: action = game_progress.winning_policy[0] game_progress.update(action) + # print(action.name, [c.name for c in game_progress.winning_policy]) # Try solving the second quest (i.e. bringing back the lettuce) first. game_progress = GameProgression(game) @@ -494,6 +495,15 @@ def test_game_with_multiple_quests(self): assert game_progress.done + # Game is done whenever a quest has failed or is unfinishable. + game_progress = GameProgression(game) + + for command in ["open wooden door", "go west", "take carrot", "eat carrot"]: + assert not game_progress.done + _apply_command(command, game_progress) + + assert game_progress.done + def test_game_without_a_quest(self): M = GameMaker() @@ -510,3 +520,48 @@ def test_game_without_a_quest(self): action = game_progress.valid_actions[0] game_progress.update(action) assert not game_progress.done + + +class TestActionDependencyTree(unittest.TestCase): + + @classmethod + def setUpClass(cls): + pass + + def test_tolist(self): + action_lock = Action.parse("lock/c :: $at(P, r) & $at(c, r) & $match(k, c) & $in(k, I) & closed(c) -> locked(c)") + action_close = Action.parse("close/c :: $at(P, r) & $at(c, r) & open(c) -> closed(c)") + action_open = Action.parse("open/c :: $at(P, r) & $at(c, r) & closed(c) -> open(c)") + action_insert1 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o1: o, I) -> in(o1: o, c)") + action_insert2 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o2: o, I) -> in(o2: o, c)") + action_take1 = Action.parse("take :: $at(P, r) & at(o1: o, r) -> in(o1: o, I)") + action_take2 = Action.parse("take :: $at(P, r) & at(o2: o, r) -> in(o2: o, I)") + action_win = Action.parse("win :: $in(o1: o, c) & $in(o2: o, c) & $locked(c) -> win(o1: o, o2: o, c)") + + tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + tree.push(action_win) + tree.push(action_lock) + tree.push(action_close) + tree.push(action_insert1) + tree.push(action_insert2) + tree.push(action_take1) + tree.push(action_take2) + tree.push(action_open) + tree.push(action_close) + tree.compress() + actions = [a.name for a in tree.tolist()] + assert actions == ['take', 'insert', 'take', 'insert', 'close/c', 'lock/c', 'win'], actions + + + def test_compress(self): + tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + + action_close = Action.parse("close/d :: $at(P, r) & $at(d, r) & open(d) -> closed(d)") + action_open = Action.parse("open/d :: $at(P, r) & $at(d, r) & closed(d) -> open(d)") + tree.push(action_close) + tree.push(action_open) + tree.push(action_close) + assert len(list(tree)) == 3 + tree.compress() + assert len(list(tree)) == 1 + diff --git a/textworld/render/render.py b/textworld/render/render.py index 58e63ab9..458e77c5 100644 --- a/textworld/render/render.py +++ b/textworld/render/render.py @@ -215,7 +215,6 @@ def used_pos(): edges.append((room.name, target.name, room.doors.get(exit))) # temp_viz(nodes, edges, pos, color=[world.player_room.name]) - pos = {game_infos[k].name: v for k, v in pos.items()} rooms = {} player_room = world.player_room @@ -226,6 +225,8 @@ def used_pos(): if v.name is None: v.name = k + pos = {game_infos[k].name: v for k, v in pos.items()} + for room in world.rooms: rooms[room.id] = GraphRoom(game_infos[room.id].name, room) From c6655010ed418d1961c529f9359aa29479ae385f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 19 Sep 2018 15:50:41 -0400 Subject: [PATCH 9/9] Refactored text_generation --- tests/test_make_game.py | 1 + tests/test_play_generated_games.py | 2 +- textworld/agents/walkthrough.py | 8 +- textworld/generator/game.py | 97 ++++++++++------- textworld/generator/tests/test_game.py | 26 ++--- .../generator/tests/test_text_generation.py | 2 + .../generator/tests/test_text_grammar.py | 10 +- textworld/generator/text_generation.py | 100 +++++++++--------- textworld/generator/text_grammar.py | 46 +++++--- 9 files changed, 174 insertions(+), 118 deletions(-) diff --git a/tests/test_make_game.py b/tests/test_make_game.py index b8b9be69..05b661a8 100644 --- a/tests/test_make_game.py +++ b/tests/test_make_game.py @@ -18,6 +18,7 @@ def test_making_game_with_names_to_exclude(): game_file2, game2 = textworld.make(2, 20, 3, 3, {"names_to_exclude": game1_objects_names}, seed=123, games_dir=tmpdir) game2_objects_names = [info.name for info in game2.infos.values() if info.name is not None] + game2.grammar.flags.encode() assert len(set(game1_objects_names) & set(game2_objects_names)) == 0 diff --git a/tests/test_play_generated_games.py b/tests/test_play_generated_games.py index a70ac2a7..0fd050c5 100644 --- a/tests/test_play_generated_games.py +++ b/tests/test_play_generated_games.py @@ -51,7 +51,7 @@ def test_play_generated_games(): msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command) if game_state.has_won: msg += " (winning)" - assert game_state._game_progression.winning_policy == [] + assert len(game_state._game_progression.winning_policy) == 0 if game_state.has_lost: msg += " (losing)" diff --git a/textworld/agents/walkthrough.py b/textworld/agents/walkthrough.py index 319ca197..9b214d8f 100644 --- a/textworld/agents/walkthrough.py +++ b/textworld/agents/walkthrough.py @@ -26,13 +26,7 @@ def reset(self, env): raise NameError(msg) # Load command from the generated game. - from textworld.generator.game import GameProgression, Quest - from textworld.generator.inform7 import gen_commands_from_actions - - game_progression = GameProgression(env.game) - main_quest = Quest(actions=game_progression.winning_policy) - commands = gen_commands_from_actions(main_quest.actions, env.game.infos) - self._commands = iter(commands) + self._commands = iter(env.game.main_quest.commands) def act(self, game_state, reward, done): try: diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 6d8381d9..41d01824 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -4,7 +4,7 @@ import json -from typing import List, Dict, Optional, Mapping, Any +from typing import List, Dict, Optional, Mapping, Any, Iterable from collections import OrderedDict from textworld.generator import data @@ -32,6 +32,23 @@ def __init__(self): super().__init__(msg) +def gen_commands_from_actions(actions): + def _get_name_mapping(action): + mapping = data.get_rules()[action.name].match(action) + return {ph.name: var.name for ph, var in mapping.items()} + + commands = [] + for action in actions: + command = "None" + if action is not None: + command = data.INFORM7_COMMANDS[action.name] + command = command.format(**_get_name_mapping(action)) + + commands.append(command) + + return commands + + class Quest: """ Quest presentation in TextWorld. @@ -39,14 +56,15 @@ class Quest: undertaken with a goal. """ - def __init__(self, actions: Optional[List[Action]], + def __init__(self, actions: Optional[Iterable[Action]] = None, winning_conditions: Optional[Collection[Proposition]] = None, failing_conditions: Optional[Collection[Proposition]] = None, - desc: str = "") -> None: + desc: Optional[str] = None) -> None: """ Args: actions: The actions to be performed to complete the quest. - If `None`, then `winning_conditions` must be provided. + If `None` or an empty list, then `winning_conditions` + must be provided. winning_conditions: Set of propositions that need to be true before marking the quest as completed. Default: postconditions of the last action. @@ -55,9 +73,9 @@ def __init__(self, actions: Optional[List[Action]], Default: can't fail the quest. desc: A text description of the quest. """ - self.actions = actions + self.actions = tuple(actions) if actions else () self.desc = desc - self.commands = [] + self.commands = gen_commands_from_actions(self.actions) self.reward = 1 self.win_action = self.set_winning_conditions(winning_conditions) self.fail_action = self.set_failing_conditions(failing_conditions) @@ -73,7 +91,7 @@ def set_winning_conditions(self, winning_conditions: Optional[Collection[Proposi An action that is only applicable when the quest is finished. """ if winning_conditions is None: - if self.actions is None: + if len(self.actions) == 0: raise UnderspecifiedQuestError() # The default winning conditions are the postconditions of the @@ -103,7 +121,7 @@ def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposi return self.fail_action def __hash__(self) -> int: - return hash((tuple(self.actions), + return hash((self.actions, self.win_action, self.fail_action, self.desc, @@ -246,16 +264,26 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None, """ self.world = world self.state = world.state.copy() # Current state of the game. - self.grammar = grammar self.quests = [] if quests is None else quests self.metadata = {} self._objective = None self._infos = self._build_infos() self._rules = data.get_rules() self._types = data.get_types() - # TODO: - # self.change_names() - # self.change_descriptions() + self.change_grammar(grammar) + + self._main_quest = None + + @property + def main_quest(self): + if self._main_quest is None: + from textworld.generator import inform7 + from textworld.generator.text_generation import assign_description_to_quest + self._main_quest = Quest(actions=GameProgression(self).winning_policy) + self._main_quest.desc = assign_description_to_quest(self._main_quest, self, self.grammar) + self._main_quest.commands = inform7.gen_commands_from_actions(self._main_quest.actions, self.infos) + + return self._main_quest @property def infos(self) -> Dict[str, EntityInfo]: @@ -285,16 +313,15 @@ def change_grammar(self, grammar: Grammar) -> None: from textworld.generator import inform7 from textworld.generator.text_generation import generate_text_from_grammar self.grammar = grammar + if self.grammar is None: + return + generate_text_from_grammar(self, self.grammar) for quest in self.quests: # TODO: should have a generic way of generating text commands from actions - # insteaf of relying on inform7 convention. + # instead of relying on inform7 convention. quest.commands = inform7.gen_commands_from_actions(quest.actions, self.infos) - # TODO - # self.change_names() - # self.change_descriptions() - def save(self, filename: str) -> None: """ Saves the serialized data of this game to a file. """ with open(filename, 'w') as f: @@ -315,11 +342,11 @@ def deserialize(cls, data: Mapping) -> "Game": `Game` object. """ world = World.deserialize(data["world"]) - grammar = None + game = cls(world) if "grammar" in data: - grammar = Grammar(data["grammar"]) - quests = [Quest.deserialize(d) for d in data["quests"]] - game = cls(world, grammar, quests) + game.grammar = Grammar(data["grammar"]) + + game.quests = [Quest.deserialize(d) for d in data["quests"]] game._infos = {k: EntityInfo.deserialize(v) for k, v in data["infos"]} game.state = State.deserialize(data["state"]) @@ -341,7 +368,7 @@ def serialize(self) -> Mapping: data["world"] = self.world.serialize() data["state"] = self.state.serialize() if self.grammar is not None: - data["grammar"] = self.grammar.flags + data["grammar"] = self.grammar.flags.serialize() data["quests"] = [quest.serialize() for quest in self.quests] data["infos"] = [(k, v.serialize()) for k, v in self._infos.items()] data["rules"] = [(k, v.serialize()) for k, v in self._rules.items()] @@ -415,8 +442,8 @@ def objective(self) -> str: if len(self.quests) == 0: return "" - # We assume the last quest includes all actions needed to solve the game. - return self.quests[-1].desc + self._objective = self.main_quest.desc + return self._objective @objective.setter def objective(self, value: str): @@ -503,10 +530,11 @@ def remove(self, action: Action) -> Optional[Action]: return reverse_action - def tolist(self) -> Optional[List[Action]]: - """ Builds a list with the actions contained in this dependency tree. + def flatten(self) -> Iterable[Action]: + """ + Generates a flatten representation of this dependency tree. - The list is greedily built by iteratively popping leaves from + Actions are greedily yielded by iteratively popping leaves from the dependency tree. """ tree = self.copy() # Make a copy of the tree to work on. @@ -519,11 +547,9 @@ def tolist(self) -> Optional[List[Action]]: if leaf.action != last_reverse_action: break # Choose an action that avoids cycles. - actions.append(leaf.action) + yield leaf.action last_reverse_action = tree.remove(leaf.action) - return actions - def compress(self): for node in self: if node.parent is None: @@ -559,7 +585,7 @@ def __init__(self, quest: Quest) -> None: for action in quest.actions[::-1]: self._tree.push(action) - self._winning_policy = quest.actions + [quest.win_action] + self._winning_policy = quest.actions + (quest.win_action,) @property def winning_policy(self) -> List[Action]: @@ -618,7 +644,7 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None) if reverse_action is None: # Irreversible action. self._unfinishable = True # Can't track quest anymore. - self._winning_policy = self._tree.tolist() # Rebuild policy. + self._winning_policy = tuple(self._tree.flatten()) # Rebuild policy. def compress_winning_policy(self, state: State) -> bool: """ Compress the winning policy given a game state. @@ -714,16 +740,15 @@ def winning_policy(self) -> Optional[List[Action]]: return None # Check if any quest has failed. - if any(quest_progression.winning_policy is None for quest_progression in self.quest_progressions): + if any(quest.failed or quest.unfinishable for quest in self.quest_progressions): return None # Greedily build a new winning policy by merging all individual quests' tree. - trees = [qp._tree for qp in self.quest_progressions if not qp.done] + trees = [quest._tree for quest in self.quest_progressions if not quest.done] master_quest_tree = ActionDependencyTree(element_type=ActionDependencyTreeElement, trees=trees) - winning_policy = master_quest_tree.tolist() - return [a for a in winning_policy if a.name != "win"] + return tuple(a for a in master_quest_tree.flatten() if a.name != "win") def update(self, action: Action) -> None: """ Update the state of the game given the provided action. diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index f0be0d6c..d732a3e5 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -119,12 +119,12 @@ def test_quest_creation(self): assert quest.win_action.preconditions == self.quest.actions[-1].postconditions assert quest.fail_action is None - quest = Quest(actions=None, winning_conditions=self.quest.actions[-1].postconditions) - assert quest.actions is None + quest = Quest(winning_conditions=self.quest.actions[-1].postconditions) + assert len(quest.actions) == 0 assert quest.win_action == self.quest.win_action assert quest.fail_action is None - npt.assert_raises(UnderspecifiedQuestError, Quest, actions=None, winning_conditions=None) + npt.assert_raises(UnderspecifiedQuestError, Quest, actions=[], winning_conditions=None) quest = Quest(self.quest.actions, failing_conditions=self.failing_conditions) assert quest.fail_action == self.quest.fail_action @@ -285,7 +285,7 @@ def test_winning_policy(self): quest = QuestProgression(self.quest) assert quest.winning_policy == self.quest.actions quest.update(self.quest.actions[0], state=State()) - assert quest.winning_policy == self.quest.actions[1:] + assert tuple(quest.winning_policy) == self.quest.actions[1:] def test_failing_quest(self): quest = QuestProgression(self.quest) @@ -395,7 +395,7 @@ def test_cycle_in_winning_policy(self): # Quest where player's has to pick up the carrot first. commands = ["go east", "take apple", "go west", "go north", "drop apple"] - + M.set_quest_from_commands(commands) game = M.build() game_progression = GameProgression(game) @@ -452,7 +452,7 @@ def test_game_with_multiple_quests(self): M.new_fact("closed", chest),] quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], winning_facts=winning_facts) quest3.desc = "Put the lettuce and the carrot into the chest before closing it." - + M._quests = [quest1, quest2, quest3] assert len(M._quests) == len(commands) game = M.build() @@ -461,14 +461,14 @@ def test_game_with_multiple_quests(self): assert len(game_progress.quest_progressions) == len(game.quests) # Following the actions associated to the last quest actually corresponds - # to solving the whole game. + # to solving the whole game. for action in game_progress.winning_policy: assert not game_progress.done game_progress.update(action) - + assert game_progress.done assert all(quest_progression.done for quest_progression in game_progress.quest_progressions) - + # Try solving the game by greedily taking the first action from the current winning policy. game_progress = GameProgression(game) while not game_progress.done: @@ -480,7 +480,7 @@ def test_game_with_multiple_quests(self): game_progress = GameProgression(game) for command in ["open wooden door", "go west", "take lettuce", "go east", "drop lettuce"]: _apply_command(command, game_progress) - + assert not game_progress.quest_progressions[0].done assert game_progress.quest_progressions[1].done @@ -497,7 +497,7 @@ def test_game_with_multiple_quests(self): # Game is done whenever a quest has failed or is unfinishable. game_progress = GameProgression(game) - + for command in ["open wooden door", "go west", "take carrot", "eat carrot"]: assert not game_progress.done _apply_command(command, game_progress) @@ -537,7 +537,7 @@ def test_tolist(self): action_take1 = Action.parse("take :: $at(P, r) & at(o1: o, r) -> in(o1: o, I)") action_take2 = Action.parse("take :: $at(P, r) & at(o2: o, r) -> in(o2: o, I)") action_win = Action.parse("win :: $in(o1: o, c) & $in(o2: o, c) & $locked(c) -> win(o1: o, o2: o, c)") - + tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) tree.push(action_win) tree.push(action_lock) @@ -549,7 +549,7 @@ def test_tolist(self): tree.push(action_open) tree.push(action_close) tree.compress() - actions = [a.name for a in tree.tolist()] + actions = list(a.name for a in tree.flatten()) assert actions == ['take', 'insert', 'take', 'insert', 'close/c', 'lock/c', 'win'], actions diff --git a/textworld/generator/tests/test_text_generation.py b/textworld/generator/tests/test_text_generation.py index 1a34c4c5..ef5f7fee 100644 --- a/textworld/generator/tests/test_text_generation.py +++ b/textworld/generator/tests/test_text_generation.py @@ -77,8 +77,10 @@ def test_blend_instructions(verbose=False): grammar2 = textworld.generator.make_grammar(flags={"blend_instructions": True}, rng=np.random.RandomState(42)) + quest.desc = None game.change_grammar(grammar1) quest1 = quest.copy() + quest.desc = None game.change_grammar(grammar2) quest2 = quest.copy() assert len(quest1.desc) > len(quest2.desc) diff --git a/textworld/generator/tests/test_text_grammar.py b/textworld/generator/tests/test_text_grammar.py index 3942727f..d78f2207 100644 --- a/textworld/generator/tests/test_text_grammar.py +++ b/textworld/generator/tests/test_text_grammar.py @@ -5,6 +5,7 @@ import unittest from textworld.generator.text_grammar import Grammar +from textworld.generator.text_grammar import GrammarFlags class ContainsEveryObjectContainer: @@ -12,6 +13,13 @@ def __contains__(self, item): return True +class TestGrammarFlags(unittest.TestCase): + def test_serialization(self): + flags = GrammarFlags() + data = flags.serialize() + flags2 = GrammarFlags.deserialize(data) + assert flags == flags2 + class GrammarTest(unittest.TestCase): def test_grammar_eq(self): grammar = Grammar() @@ -20,7 +28,7 @@ def test_grammar_eq(self): def test_grammar_eq2(self): grammar = Grammar() - grammar2 = Grammar(flags={'unused': 'flag'}) + grammar2 = Grammar(flags={'theme': 'something'}) self.assertNotEqual(grammar, grammar2, "Testing two different grammar files are not equal") def test_grammar_get_random_expansion_fail(self): diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py index 6d594435..2ba3e81b 100644 --- a/textworld/generator/text_generation.py +++ b/textworld/generator/text_generation.py @@ -22,7 +22,7 @@ def __getitem__(self, item): return super().__getitem__(item) -def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, exclude=[]): +def assign_new_matching_names(obj1_infos, obj2_infos, grammar, exclude=[]): tag = "#({}<->{})_match#".format(obj1_infos.type, obj2_infos.type) if not grammar.has_tag(tag): return False @@ -32,8 +32,8 @@ def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, excl result = grammar.expand(tag) first, second = result.split("<->") # Matching arguments are separated by '<->'. - name1, adj1, noun1 = grammar.split_name_adj_noun(first.strip(), include_adj) - name2, adj2, noun2 = grammar.split_name_adj_noun(second.strip(), include_adj) + name1, adj1, noun1 = grammar.split_name_adj_noun(first.strip(), grammar.flags.include_adj) + name2, adj2, noun2 = grammar.split_name_adj_noun(second.strip(), grammar.flags.include_adj) if name1 not in exclude and name2 not in exclude and name1 != name2: found_matching_names = True break @@ -53,7 +53,7 @@ def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, excl return True -def assign_name_to_object(obj, grammar, game_infos, include_adj): +def assign_name_to_object(obj, grammar, game_infos): """ Assign a name to an object (if needed). """ @@ -67,20 +67,19 @@ def assign_name_to_object(obj, grammar, game_infos, include_adj): # Check if the object should match another one (i.e. same adjective). if obj.matching_entity_id is not None: other_obj_infos = game_infos[obj.matching_entity_id] - success = assign_new_matching_names(obj_infos, other_obj_infos, grammar, include_adj, exclude) + success = assign_new_matching_names(obj_infos, other_obj_infos, grammar, exclude) if success: return # Try swapping the objects around i.e. match(o2, o1). - success = assign_new_matching_names(other_obj_infos, obj_infos, grammar, include_adj, exclude) + success = assign_new_matching_names(other_obj_infos, obj_infos, grammar, exclude) if success: return # TODO: Should we enforce it? # Fall back on generating unmatching object name. - values = grammar.generate_name(obj.type, room_type=obj_infos.room_type, - include_adj=include_adj, exclude=exclude) + values = grammar.generate_name(obj.type, room_type=obj_infos.room_type, exclude=exclude) obj_infos.name, obj_infos.adj, obj_infos.noun = values grammar.used_names.add(obj_infos.name) @@ -104,19 +103,13 @@ def assign_description_to_object(obj, grammar, game_infos): def generate_text_from_grammar(game, grammar: Grammar): - include_adj = grammar.flags.get("include_adj", False) - only_last_action = grammar.flags.get("only_last_action", False) - blend_instructions = grammar.flags.get("blend_instructions", False) - blend_descriptions = grammar.flags.get("blend_descriptions", False) - ambiguous_instructions = grammar.flags.get("ambiguous_instructions", False) - # Assign a specific room type and name to our rooms for room in game.world.rooms: # First, generate a unique roomtype and name from the grammar if game.infos[room.id].room_type is None and grammar.has_tag("#room_type#"): game.infos[room.id].room_type = grammar.expand("#room_type#") - assign_name_to_object(room, grammar, game.infos, include_adj) + assign_name_to_object(room, grammar, game.infos) # Next, assure objects contained in a room must have the same room type for obj in game.world.get_all_objects_in(room): @@ -128,43 +121,35 @@ def generate_text_from_grammar(game, grammar: Grammar): if game.infos[obj.id].room_type is None and grammar.has_tag("#room_type#"): game.infos[obj.id].room_type = grammar.expand("#room_type#") - # We have to "count" all the adj/noun/types in the world - # This is important for using "unique" but abstracted references to objects - counts = OrderedDict() - counts["adj"] = CountOrderedDict() - counts["noun"] = CountOrderedDict() - counts["type"] = CountOrderedDict() - # Assign name and description to objects. for obj in game.world.objects: if obj.type in ["I", "P"]: continue - obj_infos = game.infos[obj.id] - assign_name_to_object(obj, grammar, game.infos, include_adj) + assign_name_to_object(obj, grammar, game.infos) assign_description_to_object(obj, grammar, game.infos) - counts['adj'][obj_infos.adj] += 1 - counts['noun'][obj_infos.noun] += 1 - counts['type'][obj.type] += 1 - # Generate the room descriptions. for room in game.world.rooms: - assign_description_to_room(room, game, grammar, blend_descriptions) + if game.infos[room.id].desc is None: # Skip rooms which already have a description. + game.infos[room.id].desc = assign_description_to_room(room, game, grammar) # Generate the instructions. for quest in game.quests: - assign_description_to_quest(quest, game, grammar, counts, only_last_action, blend_instructions, ambiguous_instructions) + if quest.desc is None: # Skip quests which already have a description. + quest.desc = assign_description_to_quest(quest, game, grammar) - if only_last_action and len(game.quests) > 1: + if grammar.flags.only_last_action and len(game.quests) > 1: main_quest = Quest(actions=[quest.actions[-1] for quest in game.quests]) - assign_description_to_quest(main_quest, game, grammar, counts, False, blend_instructions, ambiguous_instructions) - game.objective = main_quest.desc + only_last_action_bkp = grammar.flags.only_last_action + grammar.flags.only_last_action = False + game.objective = assign_description_to_quest(main_quest, game, grammar) + grammar.flags.only_last_action = only_last_action_bkp return game -def assign_description_to_room(room, game, grammar, blend_descriptions): +def assign_description_to_room(room, game, grammar): """ Assign a descripton to a room. """ @@ -194,7 +179,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions): obj_infos = game.infos[obj.id] adj, noun = obj_infos.adj, obj_infos.noun - if blend_descriptions: + if grammar.flags.blend_descriptions: found = False for type in ["noun", "adj"]: group_filt = [] @@ -254,7 +239,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions): exits_desc = [] # Describing exits with door. - if blend_descriptions and len(exits_with_closed_door) > 1: + if grammar.flags.blend_descriptions and len(exits_with_closed_door) > 1: dirs, door_objs = zip(*exits_with_closed_door) e_desc = grammar.expand("#room_desc_doors_closed#") e_desc = replace_num(e_desc, len(door_objs)) @@ -269,7 +254,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions): d_desc = d_desc.replace("(dir)", dir_) exits_desc.append(d_desc) - if blend_descriptions and len(exits_with_open_door) > 1: + if grammar.flags.blend_descriptions and len(exits_with_open_door) > 1: dirs, door_objs = zip(*exits_with_open_door) e_desc = grammar.expand("#room_desc_doors_open#") e_desc = replace_num(e_desc, len(door_objs)) @@ -285,7 +270,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions): exits_desc.append(d_desc) # Describing exits without door. - if blend_descriptions and len(exits_without_door) > 1: + if grammar.flags.blend_descriptions and len(exits_without_door) > 1: e_desc = grammar.expand("#room_desc_exits#").replace("(dir)", list_to_string(exits_without_door, False)) e_desc = repl_sing_plur(e_desc, len(exits_without_door)) exits_desc.append(e_desc) @@ -297,7 +282,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions): room_desc += " ".join(exits_desc) # Finally, set the description - game.infos[room.id].desc = fix_determinant(room_desc) + return fix_determinant(room_desc) class MergeAction: @@ -314,7 +299,7 @@ def __init__(self): self.end = None -def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_instructions): +def generate_instruction(action, grammar, game_infos, world, counts): """ Generate text instruction for a specific action. """ @@ -366,7 +351,7 @@ def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_i obj = world.find_object_by_id(var.name) obj_infos = game_infos[obj.id] - if ambiguous_instructions: + if grammar.flags.ambiguous_instructions: assert False, "not tested" choices = [] @@ -399,29 +384,46 @@ def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_i return desc, separator -def assign_description_to_quest(quest, game, grammar, counts, only_last_action, blend_instructions, ambiguous_instructions): +def assign_description_to_quest(quest, game, grammar): """ Assign a descripton to a quest. """ + # We have to "count" all the adj/noun/types in the world + # This is important for using "unique" but abstracted references to objects + counts = OrderedDict() + counts["adj"] = CountOrderedDict() + counts["noun"] = CountOrderedDict() + counts["type"] = CountOrderedDict() + + # Assign name and description to objects. + for obj in game.world.objects: + if obj.type in ["I", "P"]: + continue + + obj_infos = game.infos[obj.id] + counts['adj'][obj_infos.adj] += 1 + counts['noun'][obj_infos.noun] += 1 + counts['type'][obj.type] += 1 + if len(quest.actions) == 0: # We don't need to say anything if the quest is empty - quest.desc = "Choose your own adventure!" + quest_desc = "Choose your own adventure!" else: # Generate a description for either the last, or all commands - if only_last_action: - actions_desc, _ = generate_instruction(quest.actions[-1], grammar, game.infos, game.world, counts, ambiguous_instructions) + if grammar.flags.only_last_action: + actions_desc, _ = generate_instruction(quest.actions[-1], grammar, game.infos, game.world, counts) only_one_action = True else: actions_desc = "" # Decide if we blend instructions together or not - if blend_instructions: + if grammar.flags.blend_instructions: instructions = get_action_chains(quest.actions, grammar, game.infos) else: instructions = quest.actions only_one_action = len(instructions) < 2 for c in instructions: - desc, separator = generate_instruction(c, grammar, game.infos, game.world, counts, ambiguous_instructions) + desc, separator = generate_instruction(c, grammar, game.infos, game.world, counts) actions_desc += desc if c != instructions[-1] and len(separator) > 0: actions_desc += separator @@ -434,7 +436,9 @@ def assign_description_to_quest(quest, game, grammar, counts, only_last_action, quest_tag = grammar.get_random_expansion("#quest#") quest_tag = quest_tag.replace("(list_of_actions)", actions_desc.strip()) - quest.desc = grammar.expand(quest_tag) + quest_desc = grammar.expand(quest_tag) + + return quest_desc def get_action_chains(actions, grammar, game_infos): diff --git a/textworld/generator/text_grammar.py b/textworld/generator/text_grammar.py index 1417a060..9637e584 100644 --- a/textworld/generator/text_grammar.py +++ b/textworld/generator/text_grammar.py @@ -36,7 +36,7 @@ def fix_determinant(var): class GrammarFlags: - __slots__ = ['theme', 'include_adj', 'blend_descriptions', + __slots__ = ['theme', 'names_to_exclude', 'include_adj', 'blend_descriptions', 'ambiguous_instructions', 'only_last_action', 'blend_instructions', 'allowed_variables_numbering', 'unique_expansion'] @@ -45,6 +45,7 @@ def __init__(self, flags=None, **kwargs): flags = flags or kwargs self.theme = flags.get("theme", "house") + self.names_to_exclude = flags.get("names_to_exclude", []) self.allowed_variables_numbering = flags.get("allowed_variables_numbering", False) self.unique_expansion = flags.get("unique_expansion", False) self.include_adj = flags.get("include_adj", False) @@ -53,18 +54,37 @@ def __init__(self, flags=None, **kwargs): self.blend_descriptions = flags.get("blend_descriptions", False) self.ambiguous_instructions = flags.get("ambiguous_instructions", False) - def encode(self): + def serialize(self) -> Mapping: + return {slot: getattr(self, slot) for slot in self.__slots__} + + @classmethod + def deserialize(cls, data: Mapping) -> "GrammarFlags": + return cls(data) + + def __eq__(self, other) -> bool: + return (isinstance(other, GrammarFlags) and + all(getattr(self, slot) == getattr(other, slot) for slot in self.__slots__)) + + def encode(self) -> str: """ Generate UUID for this set of grammar flags. """ - values = [int(getattr(self, s)) for s in self.__slots__[1:]] + def _unsigned(n): + return n & 0xFFFFFFFFFFFFFFFF + + # Skip theme and names_to_exclude. + values = [int(getattr(self, s)) for s in self.__slots__[2:]] flag = "".join(map(str, values)) from hashids import Hashids hashids = Hashids(salt="TextWorld") + if len(self.names_to_exclude) > 0: + names_to_exclude_hash = _unsigned(hash(frozenset(self.names_to_exclude))) + return self.theme + "-" + hashids.encode(names_to_exclude_hash) + "-" + hashids.encode(int(flag)) + return self.theme + "-" + hashids.encode(int(flag)) -def encode_flags(flags): +def encode_flags(flags: Mapping) -> str: return GrammarFlags(flags).encode() @@ -83,19 +103,19 @@ def __init__(self, flags: Mapping = {}, rng: Optional[RandomState] = None): :param rng: Random generator used for sampling tag expansions. """ - self.flags = flags + self.flags = GrammarFlags(flags) self.grammar = OrderedDict() self.rng = g_rng.next() if rng is None else rng - self.allowed_variables_numbering = self.flags.get("allowed_variables_numbering", False) - self.unique_expansion = self.flags.get("unique_expansion", False) + self.allowed_variables_numbering = self.flags.allowed_variables_numbering + self.unique_expansion = self.flags.unique_expansion self.all_expansions = defaultdict(list) # The current used symbols self.overflow_dict = OrderedDict() - self.used_names = set(self.flags.get("names_to_exclude", [])) + self.used_names = set(self.flags.names_to_exclude) # Load the grammar associated to the provided theme. - self.theme = self.flags.get("theme", "house") + self.theme = self.flags.theme grammar_contents = [] # Load the object names file @@ -111,7 +131,7 @@ def __eq__(self, other): return (isinstance(other, Grammar) and self.overflow_dict == other.overflow_dict and self.grammar == other.grammar and - self.flags == other.flags and + self.flags.encode() == other.flags.encode() and self.used_names == other.used_names) def _parse(self, lines: List[str]): @@ -174,7 +194,6 @@ def get_random_expansion(self, tag: str, rng: Optional[RandomState] = None) -> s self.all_expansions[tag].append(expansion) return expansion - def expand(self, text: str, rng: Optional[RandomState] = None) -> str: """ Expand some text until there is no more tag to expand. @@ -235,7 +254,7 @@ def split_name_adj_noun(self, candidate: str, include_adj: bool) -> Optional[Tup return name, adj, noun def generate_name(self, obj_type: str, room_type: str = "", - include_adj: bool = True, exclude: Container[str] = []) -> Tuple[str, str, str]: + include_adj: Optional[bool] = None, exclude: Container[str] = []) -> Tuple[str, str, str]: """ Generate a name given an object type and the type room it belongs to. @@ -248,6 +267,7 @@ def generate_name(self, obj_type: str, room_type: str = "", include_adj : optional If True, the name can contain a generated adjective. If False, any generated adjective will be discarded. + Default: use value grammar.flags.include_adj exclude : optional List of names we should avoid generating. @@ -260,6 +280,8 @@ def generate_name(self, obj_type: str, room_type: str = "", noun : The noun part of the name. """ + if include_adj is None: + include_adj = self.flags.include_adj # Get room-specialized name, if possible. symbol = "#{}_({})#".format(room_type, obj_type)