From 00ebb8099092ec35340b9f17f1434a95237a2c46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 19 Sep 2018 19:46:20 -0400 Subject: [PATCH 1/2] DependencyTree now support more than one root + GameProgression can track multiple quests. --- tests/test_play_generated_games.py | 4 +- textworld/envs/glulx/git_glulx_ml.py | 2 +- textworld/generator/dependency_tree.py | 184 ++++++----- textworld/generator/game.py | 293 +++++++++++------- textworld/generator/maker.py | 34 ++ .../generator/tests/test_dependency_tree.py | 62 ++-- textworld/generator/tests/test_game.py | 206 ++++++++++-- 7 files changed, 542 insertions(+), 243 deletions(-) diff --git a/tests/test_play_generated_games.py b/tests/test_play_generated_games.py index 96374eec..2be94426 100644 --- a/tests/test_play_generated_games.py +++ b/tests/test_play_generated_games.py @@ -47,10 +47,10 @@ def test_play_generated_games(): game_state, reward, done = env.step(command) if done: - msg = "Finished before playing `max_steps` steps." + msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command) if game_state.has_won: msg += " (winning)" - assert game_state._game_progression.winning_policy == [] + assert len(game_state._game_progression.winning_policy) == 0 if game_state.has_lost: msg += " (losing)" diff --git a/textworld/envs/glulx/git_glulx_ml.py b/textworld/envs/glulx/git_glulx_ml.py index 5fab75c6..ac0572bc 100644 --- a/textworld/envs/glulx/git_glulx_ml.py +++ b/textworld/envs/glulx/git_glulx_ml.py @@ -146,7 +146,7 @@ def init(self, output: str, game=None, """ output = _strip_input_prompt_symbol(output) super().init(output) - self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward) + self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward) self._state_tracking = state_tracking self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0 diff --git a/textworld/generator/dependency_tree.py b/textworld/generator/dependency_tree.py index dddb4da3..cf0dce1a 100644 --- a/textworld/generator/dependency_tree.py +++ b/textworld/generator/dependency_tree.py @@ -3,6 +3,7 @@ import textwrap +from typing import List, Any, Iterable from textworld.utils import uniquify @@ -18,42 +19,51 @@ class DependencyTreeElement: `__str__` accordingly. """ - def __init__(self, value): + def __init__(self, value: Any): self.value = value + self.parent = None - def depends_on(self, other): + def depends_on(self, other: "DependencyTreeElement") -> bool: """ Check whether this element depends on the `other`. """ return self.value > other.value - def is_distinct_from(self, others): + def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool: """ Check whether this element is distinct from `others`. """ return self.value not in [other.value for other in others] - def __str__(self): + def __str__(self) -> str: return str(self.value) class DependencyTree: class _Node: - def __init__(self, element): + def __init__(self, element: DependencyTreeElement): self.element = element self.children = [] + self.parent = None - def push(self, node): + def push(self, node: "DependencyTree._Node") -> bool: if node == self: - return - + return True + + added = False for child in self.children: - child.push(node) + added |= child.push(node) if self.element.depends_on(node.element) and not self.already_added(node): + node = node.copy() self.children.append(node) + node.element.parent = self.element + node.parent = self + return True + + return added - def already_added(self, node): + def already_added(self, node: "DependencyTree._Node") -> bool: # We want to avoid duplicate information about dependencies. if node in self.children: return True @@ -63,14 +73,15 @@ def already_added(self, node): if not node.element.is_distinct_from((child.element for child in self.children)): return True - # for child in self.children: - # # if node.element.value == child.element.value: - # if not node.element.is_distinct_from((child.element): - # return True - return False + + def __iter__(self) -> Iterable["DependencyTree._Node"]: + for child in self.children: + yield from list(child) + + yield self - def __str__(self): + def __str__(self) -> str: node_text = str(self.element) txt = [node_text] @@ -79,85 +90,112 @@ def __str__(self): return "\n".join(txt) - def copy(self): + def copy(self) -> "DependencyTree._Node": node = DependencyTree._Node(self.element) - node.children = [child.copy() for child in self.children] + for child in self.children: + child_ = child.copy() + child_.parent = node + node.children.append(child_) + return node - def __init__(self, element_type=DependencyTreeElement): - self.root = None + def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, trees: Iterable["DependencyTree"] = []): + self.roots = [] self.element_type = element_type - self._update() - - def push(self, value): - element = self.element_type(value) - node = DependencyTree._Node(element) - if self.root is None: - self.root = node - else: - self.root.push(node) + for tree in trees: + self.roots += [root.copy() for root in tree.roots] - # Recompute leaves. self._update() - if element in self.leaves_elements: - return node - return None + def push(self, value: Any, allow_multi_root: bool = False) -> bool: + """ Add a value to this dependency tree. - def pop(self, value): + Adding a value already present in the tree does not modify the tree. + + Args: + value: value to add. + allow_multi_root: if `True`, allow the value to spawn an + additional root if needed. + + """ + element = self.element_type(value) + node = DependencyTree._Node(element) + + added = False + for root in self.roots: + added |= root.push(node) + + if len(self.roots) == 0 or (not added and allow_multi_root): + self.roots.append(node) + added = True + + self._update() # Recompute leaves. + return added + + def remove(self, value: Any) -> None: + """ Remove all leaves having the given value. + + The value to remove needs to belong to at least one leaf in this tree. + Otherwise, the tree remains unchanged. + + Args: + value: value to remove from the tree. + + Returns: + Whether the tree has changed or not. + """ if value not in self.leaves_values: - raise ValueError("That element is not a leaf: {!r}.".format(value)) - - def _visit(node): - for child in list(node.children): - if child.element.value == value: - node.children.remove(child) - - self._postorder(self.root, _visit) - if self.root.element.value == value: - self.root = None - - # Recompute leaves. - self._update() - - def _postorder(self, node, visit): - for child in node.children: - self._postorder(child, visit) - - visit(node) + return False - def _update(self): + root_to_remove = [] + for node in self: + if node.element.value == value: + if node.parent is not None: + node.parent.children.remove(node) + else: + root_to_remove.append(node) + + for node in root_to_remove: + self.roots.remove(node) + + self._update() # Recompute leaves. + return True + + def _update(self) -> None: self._leaves_values = [] - self._leaves_elements = set() + self._leaves_elements = [] - def _visit(node): + for node in self: if len(node.children) == 0: - self._leaves_elements.add(node.element) + self._leaves_elements.append(node.element) self._leaves_values.append(node.element.value) - if self.root is not None: - self._postorder(self.root, _visit) - self._leaves_values = uniquify(self._leaves_values) + self._leaves_elements = uniquify(self._leaves_elements) + + def copy(self) -> "DependencyTree": + tree = type(self)(element_type=self.element_type) + for root in self.roots: + tree.roots.append(root.copy()) + + tree._update() + return tree - def copy(self): - tree = DependencyTree(self.element_type) - if self.root is not None: - tree.root = self.root.copy() - tree._update() + def __iter__(self) -> Iterable["DependencyTree._Node"]: + for root in self.roots: + yield from list(root) - return tree + @property + def values(self) -> List[Any]: + return [node.element.value for node in self] @property - def leaves_elements(self): + def leaves_elements(self) -> List[DependencyTreeElement]: return self._leaves_elements @property - def leaves_values(self): + def leaves_values(self) -> List[Any]: return self._leaves_values - def __str__(self): - if self.root is None: - return "" - - return str(self.root) + def __str__(self) -> str: + return "\n".join(map(str, self.roots)) diff --git a/textworld/generator/game.py b/textworld/generator/game.py index 11c27e3f..01ff17cb 100644 --- a/textworld/generator/game.py +++ b/textworld/generator/game.py @@ -4,7 +4,7 @@ import json -from typing import List, Dict, Optional, Mapping, Any +from typing import List, Dict, Optional, Mapping, Any, Iterable from collections import OrderedDict from textworld.generator import data @@ -39,14 +39,15 @@ class Quest: undertaken with a goal. """ - def __init__(self, actions: Optional[List[Action]], + def __init__(self, actions: Optional[Iterable[Action]] = None, winning_conditions: Optional[Collection[Proposition]] = None, failing_conditions: Optional[Collection[Proposition]] = None, desc: str = "") -> None: """ Args: actions: The actions to be performed to complete the quest. - If `None`, then `winning_conditions` must be provided. + If `None` or an empty list, then `winning_conditions` + must be provided. winning_conditions: Set of propositions that need to be true before marking the quest as completed. Default: postconditions of the last action. @@ -55,9 +56,10 @@ def __init__(self, actions: Optional[List[Action]], Default: can't fail the quest. desc: A text description of the quest. """ - self.actions = actions + self.actions = tuple(actions) if actions else () self.desc = desc self.commands = [] + self.reward = 1 self.win_action = self.set_winning_conditions(winning_conditions) self.fail_action = self.set_failing_conditions(failing_conditions) @@ -72,14 +74,17 @@ def set_winning_conditions(self, winning_conditions: Optional[Collection[Proposi An action that is only applicable when the quest is finished. """ if winning_conditions is None: - if self.actions is None: + if len(self.actions) == 0: raise UnderspecifiedQuestError() # The default winning conditions are the postconditions of the # last action in the quest. winning_conditions = self.actions[-1].postconditions - self.win_action = Action("win", winning_conditions, [Proposition("win")]) + # TODO: Make win propositions distinguishable by adding arguments? + win_fact = Proposition("win") + self.win_action = Action("win", preconditions=winning_conditions, + postconditions=list(winning_conditions) + [win_fact]) return self.win_action def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposition]]) -> Optional[Action]: @@ -95,12 +100,15 @@ def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposi """ self.fail_action = None if failing_conditions is not None: - self.fail_action = Action("fail", failing_conditions, [Proposition("fail")]) + # TODO: Make fail propositions distinguishable by adding arguments? + fail_fact = Proposition("fail") + self.fail_action = Action("fail", preconditions=failing_conditions, + postconditions=list(failing_conditions) + [fail_fact]) return self.fail_action def __hash__(self) -> int: - return hash((tuple(self.actions), + return hash((self.actions, self.win_action, self.fail_action, self.desc, @@ -438,13 +446,62 @@ def is_distinct_from(self, others: List["ActionDependencyTreeElement"]) -> bool: return len(new_facts) > 0 def __lt__(self, other: "ActionDependencyTreeElement") -> bool: - return len(other.action.removed & self.action._pre_set) > 0 + """ Order ActionDependencyTreeElement elements. + + Actions that remove information needed by other actions + should be sorted further in the list. + + Notes: + This is not a proper ordering, i.e. two actions + can mutually removed information needed by each other. + """ + def _required_facts(node): + pre_set = set(node.action._pre_set) + while node.parent is not None: + pre_set |= node.parent.action._pre_set + pre_set -= node.action.added + node = node.parent + + return pre_set + + return len(other.action.removed & _required_facts(self)) > len(self.action.removed & _required_facts(other)) def __str__(self) -> str: params = ", ".join(map(str, self.action.variables)) return "{}({})".format(self.action.name, params) +class ActionDependencyTree(DependencyTree): + + def remove(self, action: Action) -> Optional[Action]: + super().remove(action) + + # The last action might have impacted one of the subquests. + reverse_action = get_reverse_action(action) + if reverse_action is not None: + self.push(reverse_action) + + return reverse_action + + def flatten(self) -> Iterable[Action]: + """ + Generates a flatten representation of this dependency tree. + + Actions are greedily yielded by iteratively popping leaves from + the dependency tree. + """ + tree = self.copy() # Make a copy of the tree to work on. + last_reverse_action = None + while len(tree.roots) > 0: + # Use 'sort' to try leaves that doesn't affect the others first. + for leaf in sorted(tree.leaves_elements): + if leaf.action != last_reverse_action: + break # Choose an action that avoids cycles. + + yield leaf.action + last_reverse_action = tree.remove(leaf.action) + + class QuestProgression: """ QuestProgression keeps track of the completion of a quest. @@ -458,130 +515,153 @@ def __init__(self, quest: Quest) -> None: quest: The quest to keep track of its completion. """ self._quest = quest - self._tree = DependencyTree(element_type=ActionDependencyTreeElement) - self._winning_policy = list(quest.actions) - - # Build a tree representation - for i, action in enumerate(quest.actions[::-1]): + self._completed = False + self._failed = False + self._unfinishable = False + + # Build a tree representation of the quest. + self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + self._tree.push(quest.win_action) + for action in quest.actions[::-1]: self._tree.push(action) - def is_completed(self, state: State) -> bool: - """ Check whether the quest is completed. """ - return state.is_applicable(self._quest.win_action) - - def has_failed(self, state: State) -> bool: - """ Check whether the quest has failed. """ - if self._quest.fail_action is None: - return False - - return state.is_applicable(self._quest.fail_action) + self._winning_policy = quest.actions + (quest.win_action,) @property def winning_policy(self) -> List[Action]: """ Actions to be performed in order to complete the quest. """ - return self._winning_policy + if self.done: + return [] - def _pop_action_from_tree(self, action: Action, tree: DependencyTree) -> Optional[Action]: - # The last action was meaningful for the quest. - tree.pop(action) + return self._winning_policy[:-1] # Discard "win" action. - reverse_action = None - if tree.root is not None: - # The last action might have impacted one of the subquests. - reverse_action = get_reverse_action(action) - if reverse_action is not None: - tree.push(reverse_action) + @property + def done(self) -> bool: + """ Check if the quest is done (i.e. completed, failed or unfinishable). """ + return self.completed or self.failed or self.unfinishable - return reverse_action + @property + def completed(self) -> bool: + """ Check whether the quest is completed. """ + return self._completed + + @property + def failed(self) -> bool: + """ Check whether the quest has failed. """ + return self._failed + + @property + def unfinishable(self) -> bool: + """ Check whether the quest is in an unfinishable state. """ + return self._unfinishable - def _build_policy(self) -> Optional[List[Action]]: - """ Build a policy given the current state of the QuestTree. + def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None: + """ Update quest progression given available information. - The policy is greedily built by iteratively popping leaves from - the dependency tree. + Args: + action: Action potentially affecting the quest progression. + state: Current game state. """ - if self._tree is None: - return None + if self.done: + return # Nothing to do, the quest is already done. - tree = self._tree.copy() # Make a copy of the tree to work on. + if state is not None: + # Check if quest is completed. + if self._quest.win_action is not None: + self._completed = state.is_applicable(self._quest.win_action) - policy = [] - last_reverse_action = None - while tree.root is not None: - # Try leaves that doesn't affect the others first. - for leaf in sorted(tree.leaves_elements): - if leaf.action != last_reverse_action: - break # Choose an action that avoids cycles. + # Check if quest has failed. + if self._quest.fail_action is not None: + self._failed = state.is_applicable(self._quest.fail_action) + + # Try compressing the winning policy given the new game state. + if self.compress_winning_policy(state): + return # A shorter winning policy has been found. - policy.append(leaf.action) - last_reverse_action = self._pop_action_from_tree(leaf.action, tree) + if action is not None: + # Determine if we moved away from the goal or closer to it. + reverse_action = self._tree.remove(action) + if reverse_action is None: # Irreversible action. + self._unfinishable = True # Can't track quest anymore. - return policy + self._winning_policy = tuple(self._tree.flatten()) # Rebuild policy. - def update(self, action: Action, bypass: Optional[List[Action]] = None) -> None: - """ Update the state of the quest after a given action was performed. + def compress_winning_policy(self, state: State) -> bool: + """ Compress the winning policy given a game state. Args: - action: Action affecting the state of the quest. + state: Current game state. + + Returns: + Whether the winning policy was compressed or not. """ - if bypass is not None: - for action in bypass: - self._pop_action_from_tree(action, self._tree) - - self._winning_policy = self._build_policy() - return - - # Determine if we moved away from the goal or closer to it. - if action in self._tree.leaves_values: - # The last action was meaningful for the quest. - self._pop_action_from_tree(action, self._tree) - else: - # The last action must have moved us away from the goal. - # We need to reverse it. - reverse_action = get_reverse_action(action) - if reverse_action is None: - # Irreversible action. - self._tree = None # Can't track quest anymore. - else: - self._tree.push(reverse_action) - - self._winning_policy = self._build_policy() + + def _find_shorter_policy(policy): + for j in range(0, len(policy)): + for i in range(j + 1, len(policy))[::-1]: + shorter_policy = policy[:j] + policy[i:] + if state.is_sequence_applicable(shorter_policy): + self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + for action in shorter_policy[::-1]: + self._tree.push(action) + + return shorter_policy + + return None + + compressed = False + policy = _find_shorter_policy(self._winning_policy) + while policy is not None: + compressed = True + self._winning_policy = policy + policy = _find_shorter_policy(policy) + + return compressed class GameProgression: """ GameProgression keeps track of the progression of a game. - If `tracking_quest` is True, then `winning_policy` will be the list + If `tracking_quests` is True, then `winning_policy` will be the list of Action that need to be applied in order to complete the game. """ - def __init__(self, game: Game, track_quest: bool = True) -> None: + def __init__(self, game: Game, track_quests: bool = True) -> None: """ Args: - game: The game to track progression of. - track_quest: Whether we should track the quest completion. + game: The game for which to track progression. + track_quests: whether quest progressions are being tracked. """ self.game = game self.state = game.state.copy() self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(), self.game._types.constants_mapping)) - self.quest_progression = None - if track_quest and len(game.quests) > 0: - self.quest_progression = QuestProgression(game.quests[0]) + + self.quest_progressions = [] + if track_quests: + self.quest_progressions = [QuestProgression(quest) for quest in game.quests] + for quest_progression in self.quest_progressions: + quest_progression.update(action=None, state=self.state) @property def done(self) -> bool: - """ Whether the quest is completed or has failed. """ - if self.quest_progression is None: - return False + """ Whether all quests are completed or at least one has failed or is unfinishable. """ + if not self.tracking_quests: + return False # There is nothing to be "done". + + all_completed = True + for quest_progression in self.quest_progressions: + if quest_progression.failed or quest_progression.unfinishable: + return True + + all_completed &= quest_progression.completed - return (self.quest_progression.is_completed(self.state) or - self.quest_progression.has_failed(self.state)) + return all_completed @property - def tracking_quest(self) -> bool: - """ Whether the quest is tracked or not. """ - return self.quest_progression is not None + def tracking_quests(self) -> bool: + """ Whether quests are being tracked or not. """ + return len(self.quest_progressions) > 0 @property def valid_actions(self) -> List[Action]: @@ -594,12 +674,21 @@ def winning_policy(self) -> Optional[List[Action]]: Returns: A policy that leads to winning the game. It can be `None` - if `tracking_quest` is `False` or the quest has been failed. + if `tracking_quests` is `False` or the quest has failed. """ - if not self.tracking_quest or self.quest_progression.winning_policy is None: + if not self.tracking_quests: return None - return list(self.quest_progression.winning_policy) + # Check if any quest has failed. + if any(quest.failed or quest.unfinishable for quest in self.quest_progressions): + return None + + # Greedily build a new winning policy by merging all individual quests' tree. + trees = [quest._tree for quest in self.quest_progressions if not quest.done] + master_quest_tree = ActionDependencyTree(element_type=ActionDependencyTreeElement, + trees=trees) + + return tuple(a for a in master_quest_tree.flatten() if a.name != "win") def update(self, action: Action) -> None: """ Update the state of the game given the provided action. @@ -614,14 +703,6 @@ def update(self, action: Action) -> None: self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(), self.game._types.constants_mapping)) - if self.tracking_quest: - if self.state.is_sequence_applicable(self.winning_policy): - pass # The last action didn't impact the quest. - else: - # Check for shortcut. - for i in range(1, len(self.winning_policy)): - if self.state.is_sequence_applicable(self.winning_policy[i:]): - self.quest_progression.update(action, bypass=self.winning_policy[:i]) - return - - self.quest_progression.update(action) + # Update all quest progressions given the last action and new state. + for quest_progression in self.quest_progressions: + quest_progression.update(action, self.state) diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py index 13aba882..19dbec2f 100644 --- a/textworld/generator/maker.py +++ b/textworld/generator/maker.py @@ -589,6 +589,40 @@ def set_quest_from_commands(self, commands: List[str], ask_for_state: bool = Fal self.build() return self._quests[0] + def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None: + """ Create new fact. + + Args: + name: The name of the new fact. + *entities: A list of entities as arguments to the new fact. + """ + args = [entity.var for entity in entities] + return Proposition(name, args) + + def new_quest_using_commands(self, commands: List[str]) -> Quest: + """ Creates a new quest using predefined text commands. + + This launches a `textworld.play` session to execute provided commands. + + Args: + commands: Text commands. + + Returns: + The resulting quest. + """ + with make_temp_directory() as tmpdir: + try: + game_file = self.compile(pjoin(tmpdir, "record_quest")) + recorder = Recorder() + agent = textworld.agents.WalkthroughAgent(commands) + textworld.play(game_file, agent=agent, wrapper=recorder, silent=True) + except textworld.agents.WalkthroughDone: + pass # Quest is done. + + # Skip "None" actions. + actions = [action for action in recorder.actions if action is not None] + return Quest(actions=actions) + def set_quest_from_final_state(self, final_state: Collection[Proposition]) -> Quest: """ Defines the game's quest using a collection of facts. diff --git a/textworld/generator/tests/test_dependency_tree.py b/textworld/generator/tests/test_dependency_tree.py index f62bccd8..02c4f667 100644 --- a/textworld/generator/tests/test_dependency_tree.py +++ b/textworld/generator/tests/test_dependency_tree.py @@ -25,60 +25,66 @@ def depends_on(self, other): class TestDependencyTree(unittest.TestCase): - def test_pop(self): + def test_remove(self): tree = DependencyTree(element_type=CustomDependencyTreeElement) - assert tree.root is None - tree.push("G") - tree.pop("G") - assert tree.root is None - - tree.push("G") - tree.push("F") + assert len(tree.roots) == 0 + assert tree.push("G") + assert tree.remove("G") + assert len(tree.roots) == 0 + assert list(tree) == [] + assert tree.values == [] + + assert tree.push("G") + assert tree.push("F") # Can't pop a non-leaf element. - assert_raises(ValueError, tree.pop, "G") - assert tree.root is not None + assert not tree.remove("G") + assert len(tree.roots) > 0 assert set(tree.leaves_values) == set("F") - tree.pop("F") + assert tree.remove("F") assert set(tree.leaves_values) == set("G") def test_push(self): tree = DependencyTree(element_type=CustomDependencyTreeElement) - assert tree.root is None + assert len(tree.roots) == 0 assert set(tree.leaves_values) == set() - node = tree.push("G") + tree.push("G") assert set(tree.leaves_values) == set(["G"]), tree.leaves_values - assert tree.root.element.value == "G" - assert tree.root is node - assert len(node.children) == 0 + assert tree.roots[0].element.value == "G" + assert len(tree.roots[0].children) == 0 - node = tree.push("F") + tree.push("F") assert set(tree.leaves_values) == set(["F"]) - node = tree.push("C") + tree.push("C") + node = tree.roots[0].children[0].children[0] assert set(tree.leaves_values) == set(["C"]) - assert tree.root.element.value == "G" + assert tree.roots[0].element.value == "G" assert node.element.value == "C" - assert len(tree.root.children) == 1 + assert len(tree.roots[0].children) == 1 assert len(node.children) == 0 # Nothing depends on A at the moment. - node = tree.push("A") + tree.push("A") assert set(tree.leaves_values) == set(["C"]) - node = tree.push("E") + tree_ = tree.copy() + tree.push("E") + assert tree_.values != tree.values assert set(tree.leaves_values) == set(["E", "C"]) # Add the same element twice at the same level doesn't change the tree. - node = tree.push("E") - assert node is None + tree_ = tree.copy() + tree.push("E") + assert tree_.values == tree.values assert set(tree.leaves_values) == set(["E", "C"]) - # Cannot remove a value that hasn't been added to the tree. - assert_raises(ValueError, tree.pop, "Z") + # Removing a value not associated to a leaf, does not change the tree. + assert not tree.remove("Z") + assert tree_.values == tree.values - node = tree.push("A") + tree.push("A") assert set(tree.leaves_values) == set(["A", "C"]) - node = tree.push("B") + tree.push("B") assert set(tree.leaves_values) == set(["B", "A", "C"]) diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py index 7abda58e..7634b0b1 100644 --- a/textworld/generator/tests/test_game.py +++ b/textworld/generator/tests/test_game.py @@ -18,15 +18,30 @@ from textworld.generator import make_small_map, make_grammar, make_game_with from textworld.generator.chaining import ChainingOptions, sample_quest +from textworld.logic import Action, State from textworld.generator.game import Quest, Game from textworld.generator.game import QuestProgression, GameProgression from textworld.generator.game import UnderspecifiedQuestError +from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement from textworld.generator.inform7 import gen_commands_from_actions from textworld.logic import Proposition +def _apply_command(command: str, game_progression: GameProgression) -> None: + """ Apply a text command to a game_progression object. + """ + valid_commands = gen_commands_from_actions(game_progression.valid_actions, game_progression.game.infos) + + for action, cmd in zip(game_progression.valid_actions, valid_commands): + if command == cmd: + game_progression.update(action) + return + + raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands)) + + def test_game_comparison(): rngs = {} rngs['rng_map'] = np.random.RandomState(1) @@ -106,12 +121,12 @@ def test_quest_creation(self): assert quest.win_action.preconditions == self.quest.actions[-1].postconditions assert quest.fail_action is None - quest = Quest(actions=None, winning_conditions=self.quest.actions[-1].postconditions) - assert quest.actions is None + quest = Quest(winning_conditions=self.quest.actions[-1].postconditions) + assert len(quest.actions) == 0 assert quest.win_action == self.quest.win_action assert quest.fail_action is None - npt.assert_raises(UnderspecifiedQuestError, Quest, actions=None, winning_conditions=None) + npt.assert_raises(UnderspecifiedQuestError, Quest, actions=[], winning_conditions=None) quest = Quest(self.quest.actions, failing_conditions=self.failing_conditions) assert quest.fail_action == self.quest.fail_action @@ -242,9 +257,6 @@ class TestQuestProgression(unittest.TestCase): def setUpClass(cls): M = GameMaker() - # The goal - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] - # Create a 'bedroom' room. R1 = M.new_room("bedroom") R2 = M.new_room("kitchen") @@ -262,15 +274,79 @@ def setUpClass(cls): chest.add_property("open") R2.add(chest) + # The goal + commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] cls.quest = M.set_quest_from_commands(commands) + commands = ["open wooden door", "go west", "take carrot", "eat carrot"] + cls.eating_quest = M.new_quest_using_commands(commands) + + cls.game = M.build() + def test_winning_policy(self): quest = QuestProgression(self.quest) assert quest.winning_policy == self.quest.actions - quest.update(self.quest.actions[0]) - assert quest.winning_policy == self.quest.actions[1:] + quest.update(self.quest.actions[0], state=State()) + assert tuple(quest.winning_policy) == self.quest.actions[1:] + + def test_failing_quest(self): + quest = QuestProgression(self.quest) + + state = self.game.state.copy() + for action in self.eating_quest.actions: + state.apply(action) + if action.name.startswith("eat"): + quest.update(action, state) + assert len(quest.winning_policy) == 0 + assert quest.done + assert not quest.completed + assert not quest.failed + assert quest.unfinishable + break + + assert not quest.done + assert len(quest.winning_policy) > 0 + + +class TestGameProgression(unittest.TestCase): + + @classmethod + def setUpClass(cls): + M = GameMaker() + + # The goal + commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] + + # Create a 'bedroom' room. + R1 = M.new_room("bedroom") + R2 = M.new_room("kitchen") + M.set_player(R2) + + path = M.connect(R1.east, R2.west) + path.door = M.new(type='d', name='wooden door') + path.door.add_property("closed") + + carrot = M.new(type='f', name='carrot') + R1.add(carrot) + + # Add a closed chest in R2. + chest = M.new(type='c', name='chest') + chest.add_property("open") + R2.add(chest) + + cls.quest = M.set_quest_from_commands(commands) + cls.game = M.build() - def test_cycle_in_winning_policy(cls): + def test_is_game_completed(self): + game_progress = GameProgression(self.game) + + for action in self.quest.actions: + assert not game_progress.done + game_progress.update(action) + + assert game_progress.done + + def test_cycle_in_winning_policy(self): M = GameMaker() # Create a map. @@ -303,23 +379,12 @@ def test_cycle_in_winning_policy(cls): game = M.build() game_progression = GameProgression(game) - def _apply_command(command, game_progression): - valid_commands = gen_commands_from_actions(game_progression.valid_actions, game.infos) - - for action, cmd in zip(game_progression.valid_actions, valid_commands): - if command == cmd: - game_progression.update(action) - return - - raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands)) - _apply_command("go south", game_progression) expected_commands = ["go north"] + commands winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands) _apply_command("go east", game_progression) - _apply_command("go north", game_progression) expected_commands = ["go south", "go west", "go north"] + commands winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) @@ -332,6 +397,7 @@ def _apply_command(command, game_progression): # Quest where player's has to pick up the carrot first. commands = ["go east", "take apple", "go west", "go north", "drop apple"] + M.set_quest_from_commands(commands) game = M.build() game_progression = GameProgression(game) @@ -351,15 +417,15 @@ def _apply_command(command, game_progression): winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos) assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands) - -class TestGameProgression(unittest.TestCase): - - @classmethod - def setUpClass(cls): + def test_game_with_multiple_quests(self): M = GameMaker() - # The goal - commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"] + # The subgoals (needs to be executed in order). + commands = [["open wooden door", "go west", "take carrot", "go east", "drop carrot"], + # Now, the player is back in the kitchen and the wooden door is open. + ["go west", "take lettuce", "go east", "drop lettuce"], + # Now, the player is back in the kitchen, there are a carrot and a lettuce on the floor. + ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]] # Create a 'bedroom' room. R1 = M.new_room("bedroom") @@ -371,24 +437,75 @@ def setUpClass(cls): path.door.add_property("closed") carrot = M.new(type='f', name='carrot') - R1.add(carrot) + lettuce = M.new(type='f', name='lettuce') + R1.add(carrot, lettuce) # Add a closed chest in R2. chest = M.new(type='c', name='chest') chest.add_property("open") R2.add(chest) - cls.quest = M.set_quest_from_commands(commands) - cls.game = M.build() + quest1 = M.new_quest_using_commands(commands[0]) + quest1.desc = "Fetch the carrot and drop it on the kitchen's ground." + quest2 = M.new_quest_using_commands(commands[0] + commands[1]) + quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground." + quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2]) + winning_facts = [M.new_fact("in", lettuce, chest), + M.new_fact("in", carrot, chest), + M.new_fact("closed", chest),] + quest3.set_winning_conditions(winning_facts) + quest3.desc = "Put the lettuce and the carrot into the chest before closing it." + + M._quests = [quest1, quest2, quest3] + assert len(M._quests) == len(commands) + game = M.build() - def test_is_game_completed(self): - game_progress = GameProgression(self.game) + game_progress = GameProgression(game) + assert len(game_progress.quest_progressions) == len(game.quests) - for action in self.quest.actions: + # Following the actions associated to the last quest actually corresponds + # to solving the whole game. + for action in game_progress.winning_policy: assert not game_progress.done game_progress.update(action) assert game_progress.done + assert all(quest_progression.done for quest_progression in game_progress.quest_progressions) + + # Try solving the game by greedily taking the first action from the current winning policy. + game_progress = GameProgression(game) + while not game_progress.done: + action = game_progress.winning_policy[0] + game_progress.update(action) + # print(action.name, [c.name for c in game_progress.winning_policy]) + + # Try solving the second quest (i.e. bringing back the lettuce) first. + game_progress = GameProgression(game) + for command in ["open wooden door", "go west", "take lettuce", "go east", "drop lettuce"]: + _apply_command(command, game_progress) + + assert not game_progress.quest_progressions[0].done + assert game_progress.quest_progressions[1].done + + for command in ["go west", "take carrot", "go east", "drop carrot"]: + _apply_command(command, game_progress) + + assert game_progress.quest_progressions[0].done + assert game_progress.quest_progressions[1].done + + for command in ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]: + _apply_command(command, game_progress) + + assert game_progress.done + + # Game is done whenever a quest has failed or is unfinishable. + game_progress = GameProgression(game) + + for command in ["open wooden door", "go west", "take carrot", "eat carrot"]: + assert not game_progress.done + _apply_command(command, game_progress) + + assert game_progress.done def test_game_without_a_quest(self): M = GameMaker() @@ -406,3 +523,26 @@ def test_game_without_a_quest(self): action = game_progress.valid_actions[0] game_progress.update(action) assert not game_progress.done + + +class TestActionDependencyTree(unittest.TestCase): + + def test_flatten(self): + action_lock = Action.parse("lock/c :: $at(P, r) & $at(c, r) & $match(k, c) & $in(k, I) & closed(c) -> locked(c)") + action_close = Action.parse("close/c :: $at(P, r) & $at(c, r) & open(c) -> closed(c)") + action_insert1 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o1: o, I) -> in(o1: o, c)") + action_insert2 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o2: o, I) -> in(o2: o, c)") + action_take1 = Action.parse("take :: $at(P, r) & at(o1: o, r) -> in(o1: o, I)") + action_take2 = Action.parse("take :: $at(P, r) & at(o2: o, r) -> in(o2: o, I)") + action_win = Action.parse("win :: $in(o1: o, c) & $in(o2: o, c) & $locked(c) -> win(o1: o, o2: o, c)") + + tree = ActionDependencyTree(element_type=ActionDependencyTreeElement) + tree.push(action_win) + tree.push(action_lock) + tree.push(action_close) + tree.push(action_insert1) + tree.push(action_insert2) + tree.push(action_take1) + tree.push(action_take2) + actions = list(a.name for a in tree.flatten()) + assert actions == ['take', 'insert', 'take', 'insert', 'close/c', 'lock/c', 'win'], actions From 9025cc8b5f4afb7a724c6e5939a712ee575205ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 25 Sep 2018 09:30:42 -0400 Subject: [PATCH 2/2] Address @tavianator --- textworld/generator/dependency_tree.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/textworld/generator/dependency_tree.py b/textworld/generator/dependency_tree.py index cf0dce1a..8f48966e 100644 --- a/textworld/generator/dependency_tree.py +++ b/textworld/generator/dependency_tree.py @@ -49,7 +49,7 @@ def __init__(self, element: DependencyTreeElement): def push(self, node: "DependencyTree._Node") -> bool: if node == self: return True - + added = False for child in self.children: added |= child.push(node) @@ -74,11 +74,11 @@ def already_added(self, node: "DependencyTree._Node") -> bool: return True return False - + def __iter__(self) -> Iterable["DependencyTree._Node"]: for child in self.children: yield from list(child) - + yield self def __str__(self) -> str: @@ -99,7 +99,7 @@ def copy(self) -> "DependencyTree._Node": return node - def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, trees: Iterable["DependencyTree"] = []): + def __init__(self, element_type: type = DependencyTreeElement, trees: Iterable["DependencyTree"] = []): self.roots = [] self.element_type = element_type for tree in trees: @@ -110,17 +110,17 @@ def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, def push(self, value: Any, allow_multi_root: bool = False) -> bool: """ Add a value to this dependency tree. - Adding a value already present in the tree does not modify the tree. - + Adding a value already present in the tree does not modify the tree. + Args: value: value to add. - allow_multi_root: if `True`, allow the value to spawn an + allow_multi_root: if `True`, allow the value to spawn an additional root if needed. - + """ element = self.element_type(value) node = DependencyTree._Node(element) - + added = False for root in self.roots: added |= root.push(node) @@ -128,7 +128,7 @@ def push(self, value: Any, allow_multi_root: bool = False) -> bool: if len(self.roots) == 0 or (not added and allow_multi_root): self.roots.append(node) added = True - + self._update() # Recompute leaves. return added @@ -137,17 +137,17 @@ def remove(self, value: Any) -> None: The value to remove needs to belong to at least one leaf in this tree. Otherwise, the tree remains unchanged. - + Args: value: value to remove from the tree. - + Returns: Whether the tree has changed or not. """ if value not in self.leaves_values: return False - root_to_remove = [] + root_to_remove = [] for node in self: if node.element.value == value: if node.parent is not None: @@ -160,7 +160,7 @@ def remove(self, value: Any) -> None: self._update() # Recompute leaves. return True - + def _update(self) -> None: self._leaves_values = [] self._leaves_elements = [] @@ -177,7 +177,7 @@ def copy(self) -> "DependencyTree": tree = type(self)(element_type=self.element_type) for root in self.roots: tree.roots.append(root.copy()) - + tree._update() return tree