diff --git a/scripts/tw-make b/scripts/tw-make
index 938886d4..736751b2 100755
--- a/scripts/tw-make
+++ b/scripts/tw-make
@@ -46,6 +46,8 @@ def parse_args():
help="Nb. of objects in the world.")
custom_parser.add_argument("--quest-length", type=int, default=5, metavar="LENGTH",
help="Minimum nb. of actions the quest requires to be completed.")
+ custom_parser.add_argument("--quest-breadth", type=int, default=3, metavar="BREADTH",
+ help="Control how non-linear a quest can be.")
challenge_parser = subparsers.add_parser("challenge", parents=[general_parser],
help='Generate a game for one of the challenges.')
@@ -72,7 +74,7 @@ if __name__ == "__main__":
}
if args.subcommand == "custom":
- game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, grammar_flags,
+ game_file, game = textworld.make(args.world_size, args.nb_objects, args.quest_length, args.quest_breadth, grammar_flags,
seed=args.seed, games_dir=args.output)
elif args.subcommand == "challenge":
@@ -87,7 +89,7 @@ if __name__ == "__main__":
print("Game generated: {}".format(game_file))
if args.verbose:
- print(game.quests[0].desc)
+ print(game.objective)
if args.view:
textworld.render.visualize(game, interactive=True)
diff --git a/scripts/tw-stats b/scripts/tw-stats
index 21145ec9..dbef4559 100755
--- a/scripts/tw-stats
+++ b/scripts/tw-stats
@@ -38,7 +38,7 @@ if __name__ == "__main__":
continue
if len(game.quests) > 0:
- objectives[game_filename] = game.quests[0].desc
+ objectives[game_filename] = game.objective
names |= set(info.name for info in game.infos.values() if info.name is not None)
game_logger.collect(game)
diff --git a/tests/test_make_game.py b/tests/test_make_game.py
index 6a6c6263..05b661a8 100644
--- a/tests/test_make_game.py
+++ b/tests/test_make_game.py
@@ -11,21 +11,22 @@ def test_making_game_with_names_to_exclude():
g_rng.set_seed(42)
with make_temp_directory(prefix="test_render_wrapper") as tmpdir:
- game_file1, game1 = textworld.make(2, 20, 3, {"names_to_exclude": []},
+ game_file1, game1 = textworld.make(2, 20, 3, 3, {"names_to_exclude": []},
seed=123, games_dir=tmpdir)
game1_objects_names = [info.name for info in game1.infos.values() if info.name is not None]
- game_file2, game2 = textworld.make(2, 20, 3, {"names_to_exclude": game1_objects_names},
+ game_file2, game2 = textworld.make(2, 20, 3, 3, {"names_to_exclude": game1_objects_names},
seed=123, games_dir=tmpdir)
game2_objects_names = [info.name for info in game2.infos.values() if info.name is not None]
+ game2.grammar.flags.encode()
assert len(set(game1_objects_names) & set(game2_objects_names)) == 0
def test_making_game_is_reproducible_with_seed():
grammar_flags = {}
with make_temp_directory(prefix="test_render_wrapper") as tmpdir:
- game_file1, game1 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir)
- game_file2, game2 = textworld.make(2, 20, 3, grammar_flags, seed=123, games_dir=tmpdir)
+ game_file1, game1 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir)
+ game_file2, game2 = textworld.make(2, 20, 3, 3, grammar_flags, seed=123, games_dir=tmpdir)
assert game_file1 == game_file2
assert game1 == game2
# Make sure they are not the same Python objects.
diff --git a/tests/test_play_generated_games.py b/tests/test_play_generated_games.py
index 96374eec..0fd050c5 100644
--- a/tests/test_play_generated_games.py
+++ b/tests/test_play_generated_games.py
@@ -16,12 +16,13 @@ def test_play_generated_games():
# Sample game specs.
world_size = rng.randint(1, 10)
nb_objects = rng.randint(0, 20)
- quest_length = rng.randint(1, 10)
+ quest_length = rng.randint(2, 5)
+ quest_breadth = rng.randint(3, 7)
game_seed = rng.randint(0, 65365)
grammar_flags = {} # Default grammar.
with make_temp_directory(prefix="test_play_generated_games") as tmpdir:
- game_file, game = textworld.make(world_size, nb_objects, quest_length, grammar_flags,
+ game_file, game = textworld.make(world_size, nb_objects, quest_length, quest_breadth, grammar_flags,
seed=game_seed, games_dir=tmpdir)
# Solve the game using WalkthroughAgent.
@@ -47,10 +48,10 @@ def test_play_generated_games():
game_state, reward, done = env.step(command)
if done:
- msg = "Finished before playing `max_steps` steps."
+ msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command)
if game_state.has_won:
msg += " (winning)"
- assert game_state._game_progression.winning_policy == []
+ assert len(game_state._game_progression.winning_policy) == 0
if game_state.has_lost:
msg += " (losing)"
diff --git a/tests/test_textworld.py b/tests/test_textworld.py
index a0165e8a..b4fa918f 100644
--- a/tests/test_textworld.py
+++ b/tests/test_textworld.py
@@ -58,7 +58,7 @@ def test_game_walkthrough_agent(self):
agent = textworld.agents.WalkthroughAgent()
env = textworld.start(self.game_file)
env.activate_state_tracking()
- commands = self.game.quests[0].commands
+ commands = self.game.quests[-1].commands
agent.reset(env)
game_state = env.reset()
diff --git a/tests/test_tw_play.py b/tests/test_tw-play.py
similarity index 80%
rename from tests/test_tw_play.py
rename to tests/test_tw-play.py
index 6ff2dfea..1bb38ea9 100644
--- a/tests/test_tw_play.py
+++ b/tests/test_tw-play.py
@@ -7,9 +7,9 @@
from textworld.utils import make_temp_directory
-def test_making_a_custom_game():
+def test_playing_a_game():
with make_temp_directory(prefix="test_tw-play") as tmpdir:
- game_file, _ = textworld.make(5, 10, 5, {}, seed=1234, games_dir=tmpdir)
+ game_file, _ = textworld.make(5, 10, 5, 4, {}, seed=1234, games_dir=tmpdir)
command = ["tw-play", "--max-steps", "100", "--mode", "random", game_file]
assert check_call(command) == 0
diff --git a/textworld/agents/walkthrough.py b/textworld/agents/walkthrough.py
index 7b046b86..9b214d8f 100644
--- a/textworld/agents/walkthrough.py
+++ b/textworld/agents/walkthrough.py
@@ -26,7 +26,7 @@ def reset(self, env):
raise NameError(msg)
# Load command from the generated game.
- self._commands = iter(env.game.quests[0].commands)
+ self._commands = iter(env.game.main_quest.commands)
def act(self, game_state, reward, done):
try:
diff --git a/textworld/envs/glulx/git_glulx_ml.py b/textworld/envs/glulx/git_glulx_ml.py
index 5fab75c6..c37299a6 100644
--- a/textworld/envs/glulx/git_glulx_ml.py
+++ b/textworld/envs/glulx/git_glulx_ml.py
@@ -146,13 +146,10 @@ def init(self, output: str, game=None,
"""
output = _strip_input_prompt_symbol(output)
super().init(output)
- self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward)
+ self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward)
self._state_tracking = state_tracking
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0
-
- self._objective = ""
- if len(game.quests) > 0:
- self._objective = game.quests[0].desc
+ self._objective = game.objective
def view(self) -> "GlulxGameState":
"""
@@ -317,6 +314,7 @@ def intermediate_reward(self):
@property
def score(self):
+ # XXX: Should the score reflect the sum of all subquests' reward?
if self.has_won:
return 1
elif self.has_lost:
@@ -326,6 +324,7 @@ def score(self):
@property
def max_score(self):
+ # XXX: Should the score reflect the sum of all subquests' reward?
return 1
@property
diff --git a/textworld/envs/wrappers/tests/test_viewer.py b/textworld/envs/wrappers/tests/test_viewer.py
index ac40fbb0..32a4c4fe 100644
--- a/textworld/envs/wrappers/tests/test_viewer.py
+++ b/textworld/envs/wrappers/tests/test_viewer.py
@@ -17,7 +17,7 @@ def test_html_viewer():
num_items = 10
g_rng.set_seed(1234)
grammar_flags = {"theme": "house", "include_adj": True}
- game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, grammar_flags=grammar_flags)
+ game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, quest_breadth=1, grammar_flags=grammar_flags)
game_name = "test_html_viewer_wrapper"
with make_temp_directory(prefix=game_name) as tmpdir:
diff --git a/textworld/generator/__init__.py b/textworld/generator/__init__.py
index 7db5d7da..8496cece 100644
--- a/textworld/generator/__init__.py
+++ b/textworld/generator/__init__.py
@@ -147,7 +147,7 @@ def make_game_with(world, quests=None, grammar=None):
return game
-def make_game(world_size: int, nb_objects: int, quest_length: int,
+def make_game(world_size: int, nb_objects: int, quest_length: int, quest_breadth: int,
grammar_flags: Mapping = {},
rngs: Optional[Dict[str, RandomState]] = None
) -> Game:
@@ -158,6 +158,7 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
world_size: Number of rooms in the world.
nb_objects: Number of objects in the world.
quest_length: Minimum nb. of actions the quest requires to be completed.
+ quest_breadth: How many branches the quest can have.
grammar_flags: Options for the grammar.
Returns:
@@ -175,14 +176,34 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
world = make_world(world_size, nb_objects=0, rngs=rngs)
# Sample a quest according to quest_length.
- options = ChainingOptions()
+ class Options(ChainingOptions):
+
+ def get_rules(self, depth):
+ if depth == 0:
+ # Last action should not be "go
".
+ return data.get_rules().get_matching("^(?!go.*).*")
+ else:
+ return super().get_rules(depth)
+
+ options = Options()
options.backward = True
+ options.min_depth = 1
options.max_depth = quest_length
+ options.min_breadth = 1
+ options.max_breadth = quest_breadth
options.create_variables = True
options.rng = rngs['rng_quest']
options.restricted_types = {"r", "d"}
chain = sample_quest(world.state, options)
+
+ subquests = []
+ for i in range(1, len(chain.nodes)):
+ if chain.nodes[i].breadth != chain.nodes[i - 1].breadth:
+ quest = Quest(chain.actions[:i])
+ subquests.append(quest)
+
quest = Quest(chain.actions)
+ subquests.append(quest)
# Set the initial state required for the quest.
world.state = chain.initial_state
@@ -191,7 +212,9 @@ def make_game(world_size: int, nb_objects: int, quest_length: int,
world.populate(nb_objects, rng=rngs['rng_objects'])
grammar = make_grammar(grammar_flags, rng=rngs['rng_grammar'])
- game = make_game_with(world, [quest], grammar)
+ game = make_game_with(world, subquests, grammar)
+ game.change_grammar(grammar)
+
return game
diff --git a/textworld/generator/data/logic/door.twl b/textworld/generator/data/logic/door.twl
index c8a56384..54613ab5 100644
--- a/textworld/generator/data/logic/door.twl
+++ b/textworld/generator/data/logic/door.twl
@@ -34,7 +34,13 @@ type d : t {
link3 :: link(r, d, r') & link(r, d', r') -> fail();
# There cannot be more than four doors in a room.
- dr2 :: link(r, d1: d, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
+ too_many_doors :: link(r, d1: d, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
+
+ # There cannot be more than four doors in a room.
+ dr1 :: free(r, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
+ dr2 :: free(r, r1: r) & free(r, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
+ dr3 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
+ dr4 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & free(r, r4: r) & link(r, d5: d, r5: r) -> fail();
free1 :: link(r, d, r') & free(r, r') & closed(d) -> fail();
free2 :: link(r, d, r') & free(r, r') & locked(d) -> fail();
diff --git a/textworld/generator/data/text_grammars/house_instruction.twg b/textworld/generator/data/text_grammars/house_instruction.twg
index e88605e8..7888718d 100644
--- a/textworld/generator/data/text_grammars/house_instruction.twg
+++ b/textworld/generator/data/text_grammars/house_instruction.twg
@@ -112,7 +112,7 @@ action_seperator_go/north: #afterhave# gone north, ;#emptyinstruction1#;#emptyin
action_seperator_go/east: #afterhave# gone east, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10#
action_seperator_go/west: #afterhave# gone west, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10#
action_separator_close: #afterhave# #closed# the #close_open_types#, ; #after# #closing# the #close_open_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10#
-action_separator_drop: #afterhave# #dropped# #obj_types#, ; #after# #dropping# #obj_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10#
+action_separator_drop: #afterhave# #dropped# the #obj_types#, ; #after# #dropping# the #obj_types#, ;#emptyinstruction1#;#emptyinstruction2#;#emptyinstruction3#;#emptyinstruction4#;#emptyinstruction5#;#emptyinstruction6#;#emptyinstruction7#;#emptyinstruction8#;#emptyinstruction9#;#emptyinstruction10#
#Separator Symbols
afterhave:After you have;Having;Once you have;If you have
havetaken:taken;got;picked up
diff --git a/textworld/generator/dependency_tree.py b/textworld/generator/dependency_tree.py
index dddb4da3..cf0dce1a 100644
--- a/textworld/generator/dependency_tree.py
+++ b/textworld/generator/dependency_tree.py
@@ -3,6 +3,7 @@
import textwrap
+from typing import List, Any, Iterable
from textworld.utils import uniquify
@@ -18,42 +19,51 @@ class DependencyTreeElement:
`__str__` accordingly.
"""
- def __init__(self, value):
+ def __init__(self, value: Any):
self.value = value
+ self.parent = None
- def depends_on(self, other):
+ def depends_on(self, other: "DependencyTreeElement") -> bool:
"""
Check whether this element depends on the `other`.
"""
return self.value > other.value
- def is_distinct_from(self, others):
+ def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool:
"""
Check whether this element is distinct from `others`.
"""
return self.value not in [other.value for other in others]
- def __str__(self):
+ def __str__(self) -> str:
return str(self.value)
class DependencyTree:
class _Node:
- def __init__(self, element):
+ def __init__(self, element: DependencyTreeElement):
self.element = element
self.children = []
+ self.parent = None
- def push(self, node):
+ def push(self, node: "DependencyTree._Node") -> bool:
if node == self:
- return
-
+ return True
+
+ added = False
for child in self.children:
- child.push(node)
+ added |= child.push(node)
if self.element.depends_on(node.element) and not self.already_added(node):
+ node = node.copy()
self.children.append(node)
+ node.element.parent = self.element
+ node.parent = self
+ return True
+
+ return added
- def already_added(self, node):
+ def already_added(self, node: "DependencyTree._Node") -> bool:
# We want to avoid duplicate information about dependencies.
if node in self.children:
return True
@@ -63,14 +73,15 @@ def already_added(self, node):
if not node.element.is_distinct_from((child.element for child in self.children)):
return True
- # for child in self.children:
- # # if node.element.value == child.element.value:
- # if not node.element.is_distinct_from((child.element):
- # return True
-
return False
+
+ def __iter__(self) -> Iterable["DependencyTree._Node"]:
+ for child in self.children:
+ yield from list(child)
+
+ yield self
- def __str__(self):
+ def __str__(self) -> str:
node_text = str(self.element)
txt = [node_text]
@@ -79,85 +90,112 @@ def __str__(self):
return "\n".join(txt)
- def copy(self):
+ def copy(self) -> "DependencyTree._Node":
node = DependencyTree._Node(self.element)
- node.children = [child.copy() for child in self.children]
+ for child in self.children:
+ child_ = child.copy()
+ child_.parent = node
+ node.children.append(child_)
+
return node
- def __init__(self, element_type=DependencyTreeElement):
- self.root = None
+ def __init__(self, element_type: DependencyTreeElement = DependencyTreeElement, trees: Iterable["DependencyTree"] = []):
+ self.roots = []
self.element_type = element_type
- self._update()
-
- def push(self, value):
- element = self.element_type(value)
- node = DependencyTree._Node(element)
- if self.root is None:
- self.root = node
- else:
- self.root.push(node)
+ for tree in trees:
+ self.roots += [root.copy() for root in tree.roots]
- # Recompute leaves.
self._update()
- if element in self.leaves_elements:
- return node
- return None
+ def push(self, value: Any, allow_multi_root: bool = False) -> bool:
+ """ Add a value to this dependency tree.
- def pop(self, value):
+ Adding a value already present in the tree does not modify the tree.
+
+ Args:
+ value: value to add.
+ allow_multi_root: if `True`, allow the value to spawn an
+ additional root if needed.
+
+ """
+ element = self.element_type(value)
+ node = DependencyTree._Node(element)
+
+ added = False
+ for root in self.roots:
+ added |= root.push(node)
+
+ if len(self.roots) == 0 or (not added and allow_multi_root):
+ self.roots.append(node)
+ added = True
+
+ self._update() # Recompute leaves.
+ return added
+
+ def remove(self, value: Any) -> None:
+ """ Remove all leaves having the given value.
+
+ The value to remove needs to belong to at least one leaf in this tree.
+ Otherwise, the tree remains unchanged.
+
+ Args:
+ value: value to remove from the tree.
+
+ Returns:
+ Whether the tree has changed or not.
+ """
if value not in self.leaves_values:
- raise ValueError("That element is not a leaf: {!r}.".format(value))
-
- def _visit(node):
- for child in list(node.children):
- if child.element.value == value:
- node.children.remove(child)
-
- self._postorder(self.root, _visit)
- if self.root.element.value == value:
- self.root = None
-
- # Recompute leaves.
- self._update()
-
- def _postorder(self, node, visit):
- for child in node.children:
- self._postorder(child, visit)
-
- visit(node)
+ return False
- def _update(self):
+ root_to_remove = []
+ for node in self:
+ if node.element.value == value:
+ if node.parent is not None:
+ node.parent.children.remove(node)
+ else:
+ root_to_remove.append(node)
+
+ for node in root_to_remove:
+ self.roots.remove(node)
+
+ self._update() # Recompute leaves.
+ return True
+
+ def _update(self) -> None:
self._leaves_values = []
- self._leaves_elements = set()
+ self._leaves_elements = []
- def _visit(node):
+ for node in self:
if len(node.children) == 0:
- self._leaves_elements.add(node.element)
+ self._leaves_elements.append(node.element)
self._leaves_values.append(node.element.value)
- if self.root is not None:
- self._postorder(self.root, _visit)
-
self._leaves_values = uniquify(self._leaves_values)
+ self._leaves_elements = uniquify(self._leaves_elements)
+
+ def copy(self) -> "DependencyTree":
+ tree = type(self)(element_type=self.element_type)
+ for root in self.roots:
+ tree.roots.append(root.copy())
+
+ tree._update()
+ return tree
- def copy(self):
- tree = DependencyTree(self.element_type)
- if self.root is not None:
- tree.root = self.root.copy()
- tree._update()
+ def __iter__(self) -> Iterable["DependencyTree._Node"]:
+ for root in self.roots:
+ yield from list(root)
- return tree
+ @property
+ def values(self) -> List[Any]:
+ return [node.element.value for node in self]
@property
- def leaves_elements(self):
+ def leaves_elements(self) -> List[DependencyTreeElement]:
return self._leaves_elements
@property
- def leaves_values(self):
+ def leaves_values(self) -> List[Any]:
return self._leaves_values
- def __str__(self):
- if self.root is None:
- return ""
-
- return str(self.root)
+ def __str__(self) -> str:
+ return "\n".join(map(str, self.roots))
diff --git a/textworld/generator/game.py b/textworld/generator/game.py
index 11c27e3f..41d01824 100644
--- a/textworld/generator/game.py
+++ b/textworld/generator/game.py
@@ -4,7 +4,7 @@
import json
-from typing import List, Dict, Optional, Mapping, Any
+from typing import List, Dict, Optional, Mapping, Any, Iterable
from collections import OrderedDict
from textworld.generator import data
@@ -32,6 +32,23 @@ def __init__(self):
super().__init__(msg)
+def gen_commands_from_actions(actions):
+ def _get_name_mapping(action):
+ mapping = data.get_rules()[action.name].match(action)
+ return {ph.name: var.name for ph, var in mapping.items()}
+
+ commands = []
+ for action in actions:
+ command = "None"
+ if action is not None:
+ command = data.INFORM7_COMMANDS[action.name]
+ command = command.format(**_get_name_mapping(action))
+
+ commands.append(command)
+
+ return commands
+
+
class Quest:
""" Quest presentation in TextWorld.
@@ -39,14 +56,15 @@ class Quest:
undertaken with a goal.
"""
- def __init__(self, actions: Optional[List[Action]],
+ def __init__(self, actions: Optional[Iterable[Action]] = None,
winning_conditions: Optional[Collection[Proposition]] = None,
failing_conditions: Optional[Collection[Proposition]] = None,
- desc: str = "") -> None:
+ desc: Optional[str] = None) -> None:
"""
Args:
actions: The actions to be performed to complete the quest.
- If `None`, then `winning_conditions` must be provided.
+ If `None` or an empty list, then `winning_conditions`
+ must be provided.
winning_conditions: Set of propositions that need to be true
before marking the quest as completed.
Default: postconditions of the last action.
@@ -55,9 +73,10 @@ def __init__(self, actions: Optional[List[Action]],
Default: can't fail the quest.
desc: A text description of the quest.
"""
- self.actions = actions
+ self.actions = tuple(actions) if actions else ()
self.desc = desc
- self.commands = []
+ self.commands = gen_commands_from_actions(self.actions)
+ self.reward = 1
self.win_action = self.set_winning_conditions(winning_conditions)
self.fail_action = self.set_failing_conditions(failing_conditions)
@@ -72,14 +91,16 @@ def set_winning_conditions(self, winning_conditions: Optional[Collection[Proposi
An action that is only applicable when the quest is finished.
"""
if winning_conditions is None:
- if self.actions is None:
+ if len(self.actions) == 0:
raise UnderspecifiedQuestError()
# The default winning conditions are the postconditions of the
# last action in the quest.
winning_conditions = self.actions[-1].postconditions
- self.win_action = Action("win", winning_conditions, [Proposition("win")])
+ from textworld.utils import uniquify
+ arguments = uniquify([a for c in winning_conditions for a in c.arguments])
+ self.win_action = Action("win", winning_conditions, [Proposition("win", arguments)] + list(winning_conditions))
return self.win_action
def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposition]]) -> Optional[Action]:
@@ -100,7 +121,7 @@ def set_failing_conditions(self, failing_conditions: Optional[Collection[Proposi
return self.fail_action
def __hash__(self) -> int:
- return hash((tuple(self.actions),
+ return hash((self.actions,
self.win_action,
self.fail_action,
self.desc,
@@ -112,6 +133,7 @@ def __eq__(self, other: Any) -> bool:
self.win_action == other.win_action and
self.fail_action == other.fail_action and
self.desc == other.desc and
+ self.reward == other.reward and
self.commands == other.commands)
@classmethod
@@ -132,6 +154,7 @@ def deserialize(cls, data: Mapping) -> "Quest":
desc = data["desc"]
quest = cls(actions, win_action.preconditions, failing_conditions, desc=desc)
quest.commands = data["commands"]
+ quest.reward = data.get("reward", 1)
return quest
def serialize(self) -> Mapping:
@@ -142,6 +165,7 @@ def serialize(self) -> Mapping:
"""
data = {}
data["desc"] = self.desc
+ data["reward"] = self.reward
data["commands"] = self.commands
data["actions"] = [action.serialize() for action in self.actions]
data["win_action"] = self.win_action.serialize()
@@ -240,15 +264,26 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None,
"""
self.world = world
self.state = world.state.copy() # Current state of the game.
- self.grammar = grammar
self.quests = [] if quests is None else quests
self.metadata = {}
+ self._objective = None
self._infos = self._build_infos()
self._rules = data.get_rules()
self._types = data.get_types()
- # TODO:
- # self.change_names()
- # self.change_descriptions()
+ self.change_grammar(grammar)
+
+ self._main_quest = None
+
+ @property
+ def main_quest(self):
+ if self._main_quest is None:
+ from textworld.generator import inform7
+ from textworld.generator.text_generation import assign_description_to_quest
+ self._main_quest = Quest(actions=GameProgression(self).winning_policy)
+ self._main_quest.desc = assign_description_to_quest(self._main_quest, self, self.grammar)
+ self._main_quest.commands = inform7.gen_commands_from_actions(self._main_quest.actions, self.infos)
+
+ return self._main_quest
@property
def infos(self) -> Dict[str, EntityInfo]:
@@ -270,6 +305,7 @@ def copy(self) -> "Game":
game.state = self.state.copy()
game._rules = self._rules
game._types = self._types
+ game._objective = self._objective
return game
def change_grammar(self, grammar: Grammar) -> None:
@@ -277,16 +313,15 @@ def change_grammar(self, grammar: Grammar) -> None:
from textworld.generator import inform7
from textworld.generator.text_generation import generate_text_from_grammar
self.grammar = grammar
+ if self.grammar is None:
+ return
+
generate_text_from_grammar(self, self.grammar)
for quest in self.quests:
# TODO: should have a generic way of generating text commands from actions
- # insteaf of relying on inform7 convention.
+ # instead of relying on inform7 convention.
quest.commands = inform7.gen_commands_from_actions(quest.actions, self.infos)
- # TODO
- # self.change_names()
- # self.change_descriptions()
-
def save(self, filename: str) -> None:
""" Saves the serialized data of this game to a file. """
with open(filename, 'w') as f:
@@ -307,11 +342,11 @@ def deserialize(cls, data: Mapping) -> "Game":
`Game` object.
"""
world = World.deserialize(data["world"])
- grammar = None
+ game = cls(world)
if "grammar" in data:
- grammar = Grammar(data["grammar"])
- quests = [Quest.deserialize(d) for d in data["quests"]]
- game = cls(world, grammar, quests)
+ game.grammar = Grammar(data["grammar"])
+
+ game.quests = [Quest.deserialize(d) for d in data["quests"]]
game._infos = {k: EntityInfo.deserialize(v)
for k, v in data["infos"]}
game.state = State.deserialize(data["state"])
@@ -319,6 +354,7 @@ def deserialize(cls, data: Mapping) -> "Game":
for k, v in data["rules"]}
game._types = VariableTypeTree.deserialize(data["types"])
game.metadata = data.get("metadata", {})
+ game._objective = data.get("objective", None)
return game
@@ -332,24 +368,27 @@ def serialize(self) -> Mapping:
data["world"] = self.world.serialize()
data["state"] = self.state.serialize()
if self.grammar is not None:
- data["grammar"] = self.grammar.flags
+ data["grammar"] = self.grammar.flags.serialize()
data["quests"] = [quest.serialize() for quest in self.quests]
data["infos"] = [(k, v.serialize()) for k, v in self._infos.items()]
data["rules"] = [(k, v.serialize()) for k, v in self._rules.items()]
data["types"] = self._types.serialize()
data["metadata"] = self.metadata
+ data["objective"] = self._objective
return data
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Game) and
self.world == other.world and
self.infos == other.infos and
- self.quests == other.quests)
+ self.quests == other.quests and
+ self._objective == other._objective)
def __hash__(self) -> int:
state = (self.world,
frozenset(self.quests),
- frozenset(self.infos.items()))
+ frozenset(self.infos.items()),
+ self._objective)
return hash(state)
@@ -395,6 +434,21 @@ def win_condition(self) -> List[Collection[Proposition]]:
""" All win conditions, one for each quest. """
return [q.winning_conditions for q in self.quests]
+ @property
+ def objective(self) -> str:
+ if self._objective is not None:
+ return self._objective
+
+ if len(self.quests) == 0:
+ return ""
+
+ self._objective = self.main_quest.desc
+ return self._objective
+
+ @objective.setter
+ def objective(self, value: str):
+ self._objective = value
+
class ActionDependencyTreeElement(DependencyTreeElement):
""" Representation of an `Action` in the dependency tree.
@@ -438,13 +492,76 @@ def is_distinct_from(self, others: List["ActionDependencyTreeElement"]) -> bool:
return len(new_facts) > 0
def __lt__(self, other: "ActionDependencyTreeElement") -> bool:
- return len(other.action.removed & self.action._pre_set) > 0
+ """ Order ActionDependencyTreeElement elements.
+
+ Actions that remove information needed by other actions
+ should be sorted further in the list.
+
+ Notes:
+ This is not a proper ordering, i.e. two actions
+ can mutually removed information needed by each other.
+ """
+ def _required_facts(node):
+ pre_set = set(node.action._pre_set)
+ while node.parent is not None:
+ pre_set |= node.parent.action._pre_set
+ pre_set -= node.action.added
+ node = node.parent
+
+ return pre_set
+
+ return len(other.action.removed & _required_facts(self)) > len(self.action.removed & _required_facts(other))
+ #return len(other.action.removed & self.action._pre_set) > 0
def __str__(self) -> str:
params = ", ".join(map(str, self.action.variables))
return "{}({})".format(self.action.name, params)
+class ActionDependencyTree(DependencyTree):
+
+ def remove(self, action: Action) -> Optional[Action]:
+ super().remove(action)
+
+ # The last action might have impacted one of the subquests.
+ reverse_action = get_reverse_action(action)
+ if reverse_action is not None:
+ self.push(reverse_action)
+
+ return reverse_action
+
+ def flatten(self) -> Iterable[Action]:
+ """
+ Generates a flatten representation of this dependency tree.
+
+ Actions are greedily yielded by iteratively popping leaves from
+ the dependency tree.
+ """
+ tree = self.copy() # Make a copy of the tree to work on.
+
+ actions = []
+ last_reverse_action = None
+ while len(tree.roots) > 0:
+ # Try leaves that doesn't affect the others first.
+ for leaf in sorted(tree.leaves_elements):
+ if leaf.action != last_reverse_action:
+ break # Choose an action that avoids cycles.
+
+ yield leaf.action
+ last_reverse_action = tree.remove(leaf.action)
+
+ def compress(self):
+ for node in self:
+ if node.parent is None:
+ continue
+
+ if len(node.children) == 1:
+ r_action = get_reverse_action(node.element.action)
+ if r_action == node.children[0].element.action:
+ node.parent.children.remove(node)
+ node.parent.children += node.children[0].children
+
+
class QuestProgression:
""" QuestProgression keeps track of the completion of a quest.
@@ -458,130 +575,153 @@ def __init__(self, quest: Quest) -> None:
quest: The quest to keep track of its completion.
"""
self._quest = quest
- self._tree = DependencyTree(element_type=ActionDependencyTreeElement)
- self._winning_policy = list(quest.actions)
-
- # Build a tree representation
- for i, action in enumerate(quest.actions[::-1]):
+ self._completed = False
+ self._failed = False
+ self._unfinishable = False
+
+ # Build a tree representation of the quest.
+ self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement)
+ self._tree.push(quest.win_action)
+ for action in quest.actions[::-1]:
self._tree.push(action)
- def is_completed(self, state: State) -> bool:
- """ Check whether the quest is completed. """
- return state.is_applicable(self._quest.win_action)
-
- def has_failed(self, state: State) -> bool:
- """ Check whether the quest has failed. """
- if self._quest.fail_action is None:
- return False
-
- return state.is_applicable(self._quest.fail_action)
+ self._winning_policy = quest.actions + (quest.win_action,)
@property
def winning_policy(self) -> List[Action]:
""" Actions to be performed in order to complete the quest. """
- return self._winning_policy
+ if self.done:
+ return []
+
+ return self._winning_policy[:-1] # Discard "win" action.
- def _pop_action_from_tree(self, action: Action, tree: DependencyTree) -> Optional[Action]:
- # The last action was meaningful for the quest.
- tree.pop(action)
+ @property
+ def done(self) -> bool:
+ """ Check if the quest is done (i.e. completed, failed or unfinishable). """
+ return self.completed or self.failed or self.unfinishable
- reverse_action = None
- if tree.root is not None:
- # The last action might have impacted one of the subquests.
- reverse_action = get_reverse_action(action)
- if reverse_action is not None:
- tree.push(reverse_action)
+ @property
+ def completed(self) -> bool:
+ """ Check whether the quest is completed. """
+ return self._completed
- return reverse_action
+ @property
+ def failed(self) -> bool:
+ """ Check whether the quest has failed. """
+ return self._failed
- def _build_policy(self) -> Optional[List[Action]]:
- """ Build a policy given the current state of the QuestTree.
+ @property
+ def unfinishable(self) -> bool:
+ """ Check whether the quest is in an unfinishable state. """
+ return self._unfinishable
- The policy is greedily built by iteratively popping leaves from
- the dependency tree.
+ def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None:
+ """ Update quest progression given available information.
+
+ Args:
+ action: Action potentially affecting the quest progression.
+ state: Current game state.
"""
- if self._tree is None:
- return None
+ if self.done:
+ return # Nothing to do, the quest is already done.
- tree = self._tree.copy() # Make a copy of the tree to work on.
+ if state is not None:
+ # Check if quest is completed.
+ if self._quest.win_action is not None:
+ self._completed = state.is_applicable(self._quest.win_action)
- policy = []
- last_reverse_action = None
- while tree.root is not None:
- # Try leaves that doesn't affect the others first.
- for leaf in sorted(tree.leaves_elements):
- if leaf.action != last_reverse_action:
- break # Choose an action that avoids cycles.
+ # Check if quest has failed.
+ if self._quest.fail_action is not None:
+ self._failed = state.is_applicable(self._quest.fail_action)
+
+ # Try compressing the winning policy given the new game state.
+ if self.compress_winning_policy(state):
+ return # A shorter winning policy has been found.
- policy.append(leaf.action)
- last_reverse_action = self._pop_action_from_tree(leaf.action, tree)
+ if action is not None:
+ # Determine if we moved away from the goal or closer to it.
+ reverse_action = self._tree.remove(action)
+ if reverse_action is None: # Irreversible action.
+ self._unfinishable = True # Can't track quest anymore.
- return policy
+ self._winning_policy = tuple(self._tree.flatten()) # Rebuild policy.
- def update(self, action: Action, bypass: Optional[List[Action]] = None) -> None:
- """ Update the state of the quest after a given action was performed.
+ def compress_winning_policy(self, state: State) -> bool:
+ """ Compress the winning policy given a game state.
Args:
- action: Action affecting the state of the quest.
+ state: Current game state.
+
+ Returns:
+ Whether the winning policy was compressed or not.
"""
- if bypass is not None:
- for action in bypass:
- self._pop_action_from_tree(action, self._tree)
- self._winning_policy = self._build_policy()
- return
+ def _find_shorter_policy(policy):
+ for j in range(0, len(policy)):
+ for i in range(j + 1, len(policy))[::-1]:
+ shorter_policy = policy[:j] + policy[i:]
+ if state.is_sequence_applicable(shorter_policy):
+ self._tree = ActionDependencyTree(element_type=ActionDependencyTreeElement)
+ for action in shorter_policy[::-1]:
+ self._tree.push(action)
+
+ return shorter_policy
- # Determine if we moved away from the goal or closer to it.
- if action in self._tree.leaves_values:
- # The last action was meaningful for the quest.
- self._pop_action_from_tree(action, self._tree)
- else:
- # The last action must have moved us away from the goal.
- # We need to reverse it.
- reverse_action = get_reverse_action(action)
- if reverse_action is None:
- # Irreversible action.
- self._tree = None # Can't track quest anymore.
- else:
- self._tree.push(reverse_action)
+ return None
- self._winning_policy = self._build_policy()
+ compressed = False
+ policy = _find_shorter_policy(self._winning_policy)
+ while policy is not None:
+ compressed = True
+ self._winning_policy = policy
+ policy = _find_shorter_policy(policy)
+
+ return compressed
class GameProgression:
""" GameProgression keeps track of the progression of a game.
- If `tracking_quest` is True, then `winning_policy` will be the list
+ If `tracking_quests` is True, then `winning_policy` will be the list
of Action that need to be applied in order to complete the game.
"""
- def __init__(self, game: Game, track_quest: bool = True) -> None:
+ def __init__(self, game: Game, track_quests: bool = True) -> None:
"""
Args:
- game: The game to track progression of.
- track_quest: Whether we should track the quest completion.
+ game: The game for which to track progression.
+ track_quests: whether quest progressions are being tracked.
"""
self.game = game
self.state = game.state.copy()
self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(),
self.game._types.constants_mapping))
- self.quest_progression = None
- if track_quest and len(game.quests) > 0:
- self.quest_progression = QuestProgression(game.quests[0])
+
+ self.quest_progressions = []
+ if track_quests:
+ self.quest_progressions = [QuestProgression(quest) for quest in game.quests]
+ for quest_progression in self.quest_progressions:
+ quest_progression.update(action=None, state=self.state)
@property
def done(self) -> bool:
- """ Whether the quest is completed or has failed. """
- if self.quest_progression is None:
- return False
+ """ Whether all quests are completed or at least one has failed or is unfinishable. """
+ if not self.tracking_quests:
+ return False # There is nothing to be "done".
- return (self.quest_progression.is_completed(self.state) or
- self.quest_progression.has_failed(self.state))
+ all_completed = True
+ for quest_progression in self.quest_progressions:
+ if quest_progression.failed or quest_progression.unfinishable:
+ return True
+
+ all_completed &= quest_progression.completed
+
+ return all_completed
@property
- def tracking_quest(self) -> bool:
- """ Whether the quest is tracked or not. """
- return self.quest_progression is not None
+ def tracking_quests(self) -> bool:
+ """ Whether quests are being tracked or not. """
+ return len(self.quest_progressions) > 0
@property
def valid_actions(self) -> List[Action]:
@@ -594,12 +734,21 @@ def winning_policy(self) -> Optional[List[Action]]:
Returns:
A policy that leads to winning the game. It can be `None`
- if `tracking_quest` is `False` or the quest has been failed.
+ if `tracking_quests` is `False` or the quest has failed.
"""
- if not self.tracking_quest or self.quest_progression.winning_policy is None:
+ if not self.tracking_quests:
return None
- return list(self.quest_progression.winning_policy)
+ # Check if any quest has failed.
+ if any(quest.failed or quest.unfinishable for quest in self.quest_progressions):
+ return None
+
+ # Greedily build a new winning policy by merging all individual quests' tree.
+ trees = [quest._tree for quest in self.quest_progressions if not quest.done]
+ master_quest_tree = ActionDependencyTree(element_type=ActionDependencyTreeElement,
+ trees=trees)
+
+ return tuple(a for a in master_quest_tree.flatten() if a.name != "win")
def update(self, action: Action) -> None:
""" Update the state of the game given the provided action.
@@ -614,14 +763,6 @@ def update(self, action: Action) -> None:
self._valid_actions = list(self.state.all_applicable_actions(self.game._rules.values(),
self.game._types.constants_mapping))
- if self.tracking_quest:
- if self.state.is_sequence_applicable(self.winning_policy):
- pass # The last action didn't impact the quest.
- else:
- # Check for shortcut.
- for i in range(1, len(self.winning_policy)):
- if self.state.is_sequence_applicable(self.winning_policy[i:]):
- self.quest_progression.update(action, bypass=self.winning_policy[:i])
- return
-
- self.quest_progression.update(action)
+ # Update all quest progressions given the last action and new state.
+ for quest_progression in self.quest_progressions:
+ quest_progression.update(action, self.state)
diff --git a/textworld/generator/inform7/world2inform7.py b/textworld/generator/inform7/world2inform7.py
index e4858744..2223e373 100644
--- a/textworld/generator/inform7/world2inform7.py
+++ b/textworld/generator/inform7/world2inform7.py
@@ -156,7 +156,6 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
quests = game.quests
source = ""
- source += "Use scoring. The maximum score is 1.\n"
source += "When play begins, seed the random-number generator with {}.\n\n".format(seed)
source += define_inform7_kinds()
# Mention that rooms have a special text attribute called 'internal name'.
@@ -221,22 +220,30 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
# Place the player.
source += "The player is in {}.\n\n".format(var_infos[world.player_room.id].id)
- quest = None
- if len(quests) > 0:
- quest = quests[0] # TODO: randomly sample a quest.
+ objective = game.objective
+ maximum_score = 0
+ for quest_id, quest in enumerate(quests):
commands = gen_commands_from_actions(quest.actions, var_infos)
quest.commands = commands
+ maximum_score += quest.reward
+
+ quest_completed = textwrap.dedent("""\
+ The quest{quest_id} completed is a truth state that varies.
+ The quest{quest_id} completed is usually false.
+ """)
+ source += quest_completed.format(quest_id=quest_id)
- walkthrough = '\nTest me with "{}"\n\n'.format(" / ".join(commands))
+ walkthrough = '\nTest quest{} with "{}"\n\n'.format(quest_id, " / ".join(commands))
source += walkthrough
- # Add winning and losing conditions.
- ending_condition = """\
+ # Add winning and losing conditions for quest.
+ quest_ending_condition = """\
Every turn:
- if {}:
+ if {losing_tests}:
end the story; [Lost]
- else if {}:
- end the story finally; [Win]
+ else if quest{quest_id} completed is false and {winning_tests}:
+ increase the score by {reward}; [Quest completed]
+ Now the quest{quest_id} completed is true.
"""
@@ -246,8 +253,41 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
if quest.fail_action is not None:
losing_tests = gen_source_for_conditions(quest.fail_action.preconditions)
- ending_condition = ending_condition.format(losing_tests, winning_tests)
- source += textwrap.dedent(ending_condition)
+ quest_ending_condition = quest_ending_condition.format(losing_tests=losing_tests,
+ winning_tests=winning_tests,
+ reward=quest.reward,
+ quest_id=quest_id)
+ source += textwrap.dedent(quest_ending_condition)
+
+ # Enable scoring is at least one quest has nonzero reward.
+ if maximum_score != 0:
+ source += "Use scoring. The maximum score is {}.\n".format(maximum_score)
+
+ # Build test condition for winning the game.
+ game_winning_test = "1 is 0 [always false]"
+ if len(quests) > 0:
+ test_template = "quest{} completed is true"
+ game_winning_test = " and ".join(test_template.format(i) for i in range(len(quests)))
+
+ # Remove square bracket when printing score increases. Square brackets are conflicting with
+ # Inform7's events parser in git_glulx_ml.py.
+ # And add winning conditions for the game.
+ source += textwrap.dedent("""\
+ This is the simpler notify score changes rule:
+ If the score is not the last notified score:
+ let V be the score - the last notified score;
+ say "Your score has just gone up by [V in words] ";
+ if V > 1:
+ say "points.";
+ else:
+ say "point.";
+ Now the last notified score is the score;
+ if {game_winning_test}:
+ end the story finally; [Win]
+
+ The simpler notify score changes rule substitutes for the notify score changes rule.
+
+ """.format(game_winning_test=game_winning_test))
if not use_i7_description:
# Remove Inform7 listing of nondescript items.
@@ -292,7 +332,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
Rule for printing the banner text:
say "{objective}[line break]".
- """.format(objective=quest.desc if quest is not None else ""))
+ """.format(objective=objective))
# Simply display *** The End *** when game ends.
source += textwrap.dedent("""\
@@ -300,7 +340,6 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
Rule for printing the player's obituary:
if story has ended finally:
- increase score by 1;
center "*** The End ***";
else:
center "*** You lost! ***";
@@ -451,7 +490,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
source += textwrap.dedent("""\
An objective is some text that varies. The objective is "{objective}".
- """.format(objective=quest.desc if quest is not None else ""))
+ """.format(objective=objective))
# Special command to print the objective of the game, if any.
source += textwrap.dedent("""\
@@ -618,6 +657,7 @@ def generate_inform7_source(game, seed=1234, use_i7_description=False):
Turning on the restrict commands option is an action applying to nothing.
Carry out turning on the restrict commands option:
+ Decrease turn count by 1;
Now the restrict commands option is true.
Understand "restrict commands" as turning on the restrict commands option.
diff --git a/textworld/generator/maker.py b/textworld/generator/maker.py
index 13aba882..debeabec 100644
--- a/textworld/generator/maker.py
+++ b/textworld/generator/maker.py
@@ -3,6 +3,7 @@
import os
+import itertools
from os.path import join as pjoin
from typing import List, Iterable, Union, Optional
@@ -588,6 +589,44 @@ def set_quest_from_commands(self, commands: List[str], ask_for_state: bool = Fal
# Calling build will generate the description for the quest.
self.build()
return self._quests[0]
+
+ def new_fact(self, name: str, *entities: List["WorldEntity"]) -> None:
+ """ Create new fact.
+
+ Args:
+ name: The name of the new fact.
+ *entities: A list of entities as arguments to the new fact.
+ """
+ args = [entity.var for entity in entities]
+ return Proposition(name, args)
+
+ def new_quest_using_commands(self, commands: List[str], winning_facts: Optional[Iterable[Proposition]] = None) -> Quest:
+ """ Add a new quest using predefined text commands.
+
+ This launches a `textworld.play` session.
+
+ Args:
+ commands: Text commands.
+ winning_facts: set of facts that should be true in order to
+ consider the quest as completed. If `None`,
+ the last action's preconditions will be used.
+
+ Returns:
+ The resulting quest.
+ """
+ with make_temp_directory() as tmpdir:
+ try:
+ game_file = self.compile(pjoin(tmpdir, "record_quest"))
+ recorder = Recorder()
+ agent = textworld.agents.WalkthroughAgent(commands)
+ textworld.play(game_file, agent=agent, wrapper=recorder, silent=True)
+ except textworld.agents.WalkthroughDone:
+ pass # Quest is done.
+
+ # Skip "None" actions.
+ actions = [action for action in recorder.actions if action is not None]
+ quest = Quest(actions=actions, winning_conditions=winning_facts)
+ return quest
def set_quest_from_final_state(self, final_state: Collection[Proposition]) -> Quest:
""" Defines the game's quest using a collection of facts.
diff --git a/textworld/generator/tests/test_dependency_tree.py b/textworld/generator/tests/test_dependency_tree.py
index f62bccd8..02c4f667 100644
--- a/textworld/generator/tests/test_dependency_tree.py
+++ b/textworld/generator/tests/test_dependency_tree.py
@@ -25,60 +25,66 @@ def depends_on(self, other):
class TestDependencyTree(unittest.TestCase):
- def test_pop(self):
+ def test_remove(self):
tree = DependencyTree(element_type=CustomDependencyTreeElement)
- assert tree.root is None
- tree.push("G")
- tree.pop("G")
- assert tree.root is None
-
- tree.push("G")
- tree.push("F")
+ assert len(tree.roots) == 0
+ assert tree.push("G")
+ assert tree.remove("G")
+ assert len(tree.roots) == 0
+ assert list(tree) == []
+ assert tree.values == []
+
+ assert tree.push("G")
+ assert tree.push("F")
# Can't pop a non-leaf element.
- assert_raises(ValueError, tree.pop, "G")
- assert tree.root is not None
+ assert not tree.remove("G")
+ assert len(tree.roots) > 0
assert set(tree.leaves_values) == set("F")
- tree.pop("F")
+ assert tree.remove("F")
assert set(tree.leaves_values) == set("G")
def test_push(self):
tree = DependencyTree(element_type=CustomDependencyTreeElement)
- assert tree.root is None
+ assert len(tree.roots) == 0
assert set(tree.leaves_values) == set()
- node = tree.push("G")
+ tree.push("G")
assert set(tree.leaves_values) == set(["G"]), tree.leaves_values
- assert tree.root.element.value == "G"
- assert tree.root is node
- assert len(node.children) == 0
+ assert tree.roots[0].element.value == "G"
+ assert len(tree.roots[0].children) == 0
- node = tree.push("F")
+ tree.push("F")
assert set(tree.leaves_values) == set(["F"])
- node = tree.push("C")
+ tree.push("C")
+ node = tree.roots[0].children[0].children[0]
assert set(tree.leaves_values) == set(["C"])
- assert tree.root.element.value == "G"
+ assert tree.roots[0].element.value == "G"
assert node.element.value == "C"
- assert len(tree.root.children) == 1
+ assert len(tree.roots[0].children) == 1
assert len(node.children) == 0
# Nothing depends on A at the moment.
- node = tree.push("A")
+ tree.push("A")
assert set(tree.leaves_values) == set(["C"])
- node = tree.push("E")
+ tree_ = tree.copy()
+ tree.push("E")
+ assert tree_.values != tree.values
assert set(tree.leaves_values) == set(["E", "C"])
# Add the same element twice at the same level doesn't change the tree.
- node = tree.push("E")
- assert node is None
+ tree_ = tree.copy()
+ tree.push("E")
+ assert tree_.values == tree.values
assert set(tree.leaves_values) == set(["E", "C"])
- # Cannot remove a value that hasn't been added to the tree.
- assert_raises(ValueError, tree.pop, "Z")
+ # Removing a value not associated to a leaf, does not change the tree.
+ assert not tree.remove("Z")
+ assert tree_.values == tree.values
- node = tree.push("A")
+ tree.push("A")
assert set(tree.leaves_values) == set(["A", "C"])
- node = tree.push("B")
+ tree.push("B")
assert set(tree.leaves_values) == set(["B", "A", "C"])
diff --git a/textworld/generator/tests/test_game.py b/textworld/generator/tests/test_game.py
index 7abda58e..d732a3e5 100644
--- a/textworld/generator/tests/test_game.py
+++ b/textworld/generator/tests/test_game.py
@@ -18,42 +18,55 @@
from textworld.generator import make_small_map, make_grammar, make_game_with
from textworld.generator.chaining import ChainingOptions, sample_quest
+from textworld.logic import Action, State
from textworld.generator.game import Quest, Game
from textworld.generator.game import QuestProgression, GameProgression
from textworld.generator.game import UnderspecifiedQuestError
+from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement
from textworld.generator.inform7 import gen_commands_from_actions
from textworld.logic import Proposition
+def _apply_command(command: str, game_progression: GameProgression) -> None:
+ """ Apply a text command to a game_progression object.
+ """
+ valid_commands = gen_commands_from_actions(game_progression.valid_actions, game_progression.game.infos)
+
+ for action, cmd in zip(game_progression.valid_actions, valid_commands):
+ if command == cmd:
+ game_progression.update(action)
+ return
+
+ raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands))
+
+
def test_game_comparison():
rngs = {}
rngs['rng_map'] = np.random.RandomState(1)
rngs['rng_objects'] = np.random.RandomState(2)
rngs['rng_quest'] = np.random.RandomState(3)
rngs['rng_grammar'] = np.random.RandomState(4)
- game1 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs)
+ game1 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs)
rngs['rng_map'] = np.random.RandomState(1)
rngs['rng_objects'] = np.random.RandomState(2)
rngs['rng_quest'] = np.random.RandomState(3)
rngs['rng_grammar'] = np.random.RandomState(4)
- game2 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs)
+ game2 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs)
assert game1 == game2 # Test __eq__
assert game1 in {game2} # Test __hash__
- game3 = make_game(world_size=5, nb_objects=5, quest_length=2, grammar_flags={}, rngs=rngs)
+ game3 = make_game(world_size=5, nb_objects=5, quest_length=2, quest_breadth=2, grammar_flags={}, rngs=rngs)
assert game1 != game3
-
-
def test_variable_infos(verbose=False):
g_rng.set_seed(1234)
grammar_flags = {"theme": "house", "include_adj": True}
- game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, grammar_flags=grammar_flags)
+ game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, quest_breadth=2, grammar_flags=grammar_flags)
for var_id, var_infos in game.infos.items():
if var_id not in ["P", "I"]:
@@ -106,12 +119,12 @@ def test_quest_creation(self):
assert quest.win_action.preconditions == self.quest.actions[-1].postconditions
assert quest.fail_action is None
- quest = Quest(actions=None, winning_conditions=self.quest.actions[-1].postconditions)
- assert quest.actions is None
+ quest = Quest(winning_conditions=self.quest.actions[-1].postconditions)
+ assert len(quest.actions) == 0
assert quest.win_action == self.quest.win_action
assert quest.fail_action is None
- npt.assert_raises(UnderspecifiedQuestError, Quest, actions=None, winning_conditions=None)
+ npt.assert_raises(UnderspecifiedQuestError, Quest, actions=[], winning_conditions=None)
quest = Quest(self.quest.actions, failing_conditions=self.failing_conditions)
assert quest.fail_action == self.quest.fail_action
@@ -242,9 +255,6 @@ class TestQuestProgression(unittest.TestCase):
def setUpClass(cls):
M = GameMaker()
- # The goal
- commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"]
-
# Create a 'bedroom' room.
R1 = M.new_room("bedroom")
R2 = M.new_room("kitchen")
@@ -262,15 +272,79 @@ def setUpClass(cls):
chest.add_property("open")
R2.add(chest)
+ # The goal
+ commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"]
cls.quest = M.set_quest_from_commands(commands)
+ commands = ["open wooden door", "go west", "take carrot", "eat carrot"]
+ cls.eating_quest = M.new_quest_using_commands(commands)
+
+ cls.game = M.build()
+
def test_winning_policy(self):
quest = QuestProgression(self.quest)
assert quest.winning_policy == self.quest.actions
- quest.update(self.quest.actions[0])
- assert quest.winning_policy == self.quest.actions[1:]
+ quest.update(self.quest.actions[0], state=State())
+ assert tuple(quest.winning_policy) == self.quest.actions[1:]
+
+ def test_failing_quest(self):
+ quest = QuestProgression(self.quest)
+
+ state = self.game.state.copy()
+ for action in self.eating_quest.actions:
+ state.apply(action)
+ if action.name.startswith("eat"):
+ quest.update(action, state)
+ assert len(quest.winning_policy) == 0
+ assert quest.done
+ assert not quest.completed
+ assert not quest.failed
+ assert quest.unfinishable
+ break
+
+ assert not quest.done
+ assert len(quest.winning_policy) > 0
+
+
+class TestGameProgression(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ M = GameMaker()
+
+ # The goal
+ commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"]
+
+ # Create a 'bedroom' room.
+ R1 = M.new_room("bedroom")
+ R2 = M.new_room("kitchen")
+ M.set_player(R2)
+
+ path = M.connect(R1.east, R2.west)
+ path.door = M.new(type='d', name='wooden door')
+ path.door.add_property("closed")
+
+ carrot = M.new(type='f', name='carrot')
+ R1.add(carrot)
+
+ # Add a closed chest in R2.
+ chest = M.new(type='c', name='chest')
+ chest.add_property("open")
+ R2.add(chest)
+
+ cls.quest = M.set_quest_from_commands(commands)
+ cls.game = M.build()
+
+ def test_is_game_completed(self):
+ game_progress = GameProgression(self.game)
- def test_cycle_in_winning_policy(cls):
+ for action in self.quest.actions:
+ assert not game_progress.done
+ game_progress.update(action)
+
+ assert game_progress.done
+
+ def test_cycle_in_winning_policy(self):
M = GameMaker()
# Create a map.
@@ -303,23 +377,12 @@ def test_cycle_in_winning_policy(cls):
game = M.build()
game_progression = GameProgression(game)
- def _apply_command(command, game_progression):
- valid_commands = gen_commands_from_actions(game_progression.valid_actions, game.infos)
-
- for action, cmd in zip(game_progression.valid_actions, valid_commands):
- if command == cmd:
- game_progression.update(action)
- return
-
- raise ValueError("Not a valid command: {}. Expected: {}".format(command, valid_commands))
-
_apply_command("go south", game_progression)
expected_commands = ["go north"] + commands
winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos)
assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands)
_apply_command("go east", game_progression)
-
_apply_command("go north", game_progression)
expected_commands = ["go south", "go west", "go north"] + commands
winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos)
@@ -332,6 +395,7 @@ def _apply_command(command, game_progression):
# Quest where player's has to pick up the carrot first.
commands = ["go east", "take apple", "go west", "go north", "drop apple"]
+
M.set_quest_from_commands(commands)
game = M.build()
game_progression = GameProgression(game)
@@ -351,15 +415,15 @@ def _apply_command(command, game_progression):
winning_commands = gen_commands_from_actions(game_progression.winning_policy, game.infos)
assert winning_commands == expected_commands, "{} != {}".format(winning_commands, expected_commands)
-
-class TestGameProgression(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
+ def test_game_with_multiple_quests(self):
M = GameMaker()
- # The goal
- commands = ["open wooden door", "go west", "take carrot", "go east", "drop carrot"]
+ # The subgoals (needs to be executed in order).
+ commands = [["open wooden door", "go west", "take carrot", "go east", "drop carrot"],
+ # Now, the player is back in the kitchen and the wooden door is open.
+ ["go west", "take lettuce", "go east", "drop lettuce"],
+ # Now, the player is back in the kitchen, there are a carrot and a lettuce on the floor.
+ ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]]
# Create a 'bedroom' room.
R1 = M.new_room("bedroom")
@@ -371,24 +435,74 @@ def setUpClass(cls):
path.door.add_property("closed")
carrot = M.new(type='f', name='carrot')
- R1.add(carrot)
+ lettuce = M.new(type='f', name='lettuce')
+ R1.add(carrot, lettuce)
# Add a closed chest in R2.
chest = M.new(type='c', name='chest')
chest.add_property("open")
R2.add(chest)
- cls.quest = M.set_quest_from_commands(commands)
- cls.game = M.build()
+ quest1 = M.new_quest_using_commands(commands[0])
+ quest1.desc = "Fetch the carrot and drop it on the kitchen's ground."
+ quest2 = M.new_quest_using_commands(commands[0] + commands[1])
+ quest2.desc = "Fetch the lettuce and drop it on the kitchen's ground."
+ winning_facts = [M.new_fact("in", lettuce, chest),
+ M.new_fact("in", carrot, chest),
+ M.new_fact("closed", chest),]
+ quest3 = M.new_quest_using_commands(commands[0] + commands[1] + commands[2], winning_facts=winning_facts)
+ quest3.desc = "Put the lettuce and the carrot into the chest before closing it."
+
+ M._quests = [quest1, quest2, quest3]
+ assert len(M._quests) == len(commands)
+ game = M.build()
- def test_is_game_completed(self):
- game_progress = GameProgression(self.game)
+ game_progress = GameProgression(game)
+ assert len(game_progress.quest_progressions) == len(game.quests)
- for action in self.quest.actions:
+ # Following the actions associated to the last quest actually corresponds
+ # to solving the whole game.
+ for action in game_progress.winning_policy:
assert not game_progress.done
game_progress.update(action)
assert game_progress.done
+ assert all(quest_progression.done for quest_progression in game_progress.quest_progressions)
+
+ # Try solving the game by greedily taking the first action from the current winning policy.
+ game_progress = GameProgression(game)
+ while not game_progress.done:
+ action = game_progress.winning_policy[0]
+ game_progress.update(action)
+ # print(action.name, [c.name for c in game_progress.winning_policy])
+
+ # Try solving the second quest (i.e. bringing back the lettuce) first.
+ game_progress = GameProgression(game)
+ for command in ["open wooden door", "go west", "take lettuce", "go east", "drop lettuce"]:
+ _apply_command(command, game_progress)
+
+ assert not game_progress.quest_progressions[0].done
+ assert game_progress.quest_progressions[1].done
+
+ for command in ["go west", "take carrot", "go east", "drop carrot"]:
+ _apply_command(command, game_progress)
+
+ assert game_progress.quest_progressions[0].done
+ assert game_progress.quest_progressions[1].done
+
+ for command in ["take lettuce", "take carrot", "insert carrot into chest", "insert lettuce into chest", "close chest"]:
+ _apply_command(command, game_progress)
+
+ assert game_progress.done
+
+ # Game is done whenever a quest has failed or is unfinishable.
+ game_progress = GameProgression(game)
+
+ for command in ["open wooden door", "go west", "take carrot", "eat carrot"]:
+ assert not game_progress.done
+ _apply_command(command, game_progress)
+
+ assert game_progress.done
def test_game_without_a_quest(self):
M = GameMaker()
@@ -406,3 +520,48 @@ def test_game_without_a_quest(self):
action = game_progress.valid_actions[0]
game_progress.update(action)
assert not game_progress.done
+
+
+class TestActionDependencyTree(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ pass
+
+ def test_tolist(self):
+ action_lock = Action.parse("lock/c :: $at(P, r) & $at(c, r) & $match(k, c) & $in(k, I) & closed(c) -> locked(c)")
+ action_close = Action.parse("close/c :: $at(P, r) & $at(c, r) & open(c) -> closed(c)")
+ action_open = Action.parse("open/c :: $at(P, r) & $at(c, r) & closed(c) -> open(c)")
+ action_insert1 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o1: o, I) -> in(o1: o, c)")
+ action_insert2 = Action.parse("insert :: $at(P, r) & $at(c, r) & $open(c) & in(o2: o, I) -> in(o2: o, c)")
+ action_take1 = Action.parse("take :: $at(P, r) & at(o1: o, r) -> in(o1: o, I)")
+ action_take2 = Action.parse("take :: $at(P, r) & at(o2: o, r) -> in(o2: o, I)")
+ action_win = Action.parse("win :: $in(o1: o, c) & $in(o2: o, c) & $locked(c) -> win(o1: o, o2: o, c)")
+
+ tree = ActionDependencyTree(element_type=ActionDependencyTreeElement)
+ tree.push(action_win)
+ tree.push(action_lock)
+ tree.push(action_close)
+ tree.push(action_insert1)
+ tree.push(action_insert2)
+ tree.push(action_take1)
+ tree.push(action_take2)
+ tree.push(action_open)
+ tree.push(action_close)
+ tree.compress()
+ actions = list(a.name for a in tree.flatten())
+ assert actions == ['take', 'insert', 'take', 'insert', 'close/c', 'lock/c', 'win'], actions
+
+
+ def test_compress(self):
+ tree = ActionDependencyTree(element_type=ActionDependencyTreeElement)
+
+ action_close = Action.parse("close/d :: $at(P, r) & $at(d, r) & open(d) -> closed(d)")
+ action_open = Action.parse("open/d :: $at(P, r) & $at(d, r) & closed(d) -> open(d)")
+ tree.push(action_close)
+ tree.push(action_open)
+ tree.push(action_close)
+ assert len(list(tree)) == 3
+ tree.compress()
+ assert len(list(tree)) == 1
+
diff --git a/textworld/generator/tests/test_logger.py b/textworld/generator/tests/test_logger.py
index 2a5d1ae0..9d5eed4f 100644
--- a/textworld/generator/tests/test_logger.py
+++ b/textworld/generator/tests/test_logger.py
@@ -18,7 +18,7 @@ def test_logger():
for _ in range(10):
seed = rng.randint(65635)
g_rng.set_seed(seed)
- game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3)
+ game = textworld.generator.make_game(world_size=5, nb_objects=10, quest_length=3, quest_breadth=3)
game_logger.collect(game)
with make_temp_directory(prefix="textworld_tests") as tests_folder:
diff --git a/textworld/generator/tests/test_text_generation.py b/textworld/generator/tests/test_text_generation.py
index 1a34c4c5..ef5f7fee 100644
--- a/textworld/generator/tests/test_text_generation.py
+++ b/textworld/generator/tests/test_text_generation.py
@@ -77,8 +77,10 @@ def test_blend_instructions(verbose=False):
grammar2 = textworld.generator.make_grammar(flags={"blend_instructions": True},
rng=np.random.RandomState(42))
+ quest.desc = None
game.change_grammar(grammar1)
quest1 = quest.copy()
+ quest.desc = None
game.change_grammar(grammar2)
quest2 = quest.copy()
assert len(quest1.desc) > len(quest2.desc)
diff --git a/textworld/generator/tests/test_text_grammar.py b/textworld/generator/tests/test_text_grammar.py
index 3942727f..d78f2207 100644
--- a/textworld/generator/tests/test_text_grammar.py
+++ b/textworld/generator/tests/test_text_grammar.py
@@ -5,6 +5,7 @@
import unittest
from textworld.generator.text_grammar import Grammar
+from textworld.generator.text_grammar import GrammarFlags
class ContainsEveryObjectContainer:
@@ -12,6 +13,13 @@ def __contains__(self, item):
return True
+class TestGrammarFlags(unittest.TestCase):
+ def test_serialization(self):
+ flags = GrammarFlags()
+ data = flags.serialize()
+ flags2 = GrammarFlags.deserialize(data)
+ assert flags == flags2
+
class GrammarTest(unittest.TestCase):
def test_grammar_eq(self):
grammar = Grammar()
@@ -20,7 +28,7 @@ def test_grammar_eq(self):
def test_grammar_eq2(self):
grammar = Grammar()
- grammar2 = Grammar(flags={'unused': 'flag'})
+ grammar2 = Grammar(flags={'theme': 'something'})
self.assertNotEqual(grammar, grammar2, "Testing two different grammar files are not equal")
def test_grammar_get_random_expansion_fail(self):
diff --git a/textworld/generator/text_generation.py b/textworld/generator/text_generation.py
index 89b63c16..2ba3e81b 100644
--- a/textworld/generator/text_generation.py
+++ b/textworld/generator/text_generation.py
@@ -6,6 +6,7 @@
from collections import OrderedDict
from textworld.generator import data
+from textworld.generator.game import Quest
from textworld.generator.text_grammar import Grammar
from textworld.generator.text_grammar import fix_determinant
@@ -21,7 +22,7 @@ def __getitem__(self, item):
return super().__getitem__(item)
-def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, exclude=[]):
+def assign_new_matching_names(obj1_infos, obj2_infos, grammar, exclude=[]):
tag = "#({}<->{})_match#".format(obj1_infos.type, obj2_infos.type)
if not grammar.has_tag(tag):
return False
@@ -31,8 +32,8 @@ def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, excl
result = grammar.expand(tag)
first, second = result.split("<->") # Matching arguments are separated by '<->'.
- name1, adj1, noun1 = grammar.split_name_adj_noun(first.strip(), include_adj)
- name2, adj2, noun2 = grammar.split_name_adj_noun(second.strip(), include_adj)
+ name1, adj1, noun1 = grammar.split_name_adj_noun(first.strip(), grammar.flags.include_adj)
+ name2, adj2, noun2 = grammar.split_name_adj_noun(second.strip(), grammar.flags.include_adj)
if name1 not in exclude and name2 not in exclude and name1 != name2:
found_matching_names = True
break
@@ -52,7 +53,7 @@ def assign_new_matching_names(obj1_infos, obj2_infos, grammar, include_adj, excl
return True
-def assign_name_to_object(obj, grammar, game_infos, include_adj):
+def assign_name_to_object(obj, grammar, game_infos):
"""
Assign a name to an object (if needed).
"""
@@ -66,20 +67,19 @@ def assign_name_to_object(obj, grammar, game_infos, include_adj):
# Check if the object should match another one (i.e. same adjective).
if obj.matching_entity_id is not None:
other_obj_infos = game_infos[obj.matching_entity_id]
- success = assign_new_matching_names(obj_infos, other_obj_infos, grammar, include_adj, exclude)
+ success = assign_new_matching_names(obj_infos, other_obj_infos, grammar, exclude)
if success:
return
# Try swapping the objects around i.e. match(o2, o1).
- success = assign_new_matching_names(other_obj_infos, obj_infos, grammar, include_adj, exclude)
+ success = assign_new_matching_names(other_obj_infos, obj_infos, grammar, exclude)
if success:
return
# TODO: Should we enforce it?
# Fall back on generating unmatching object name.
- values = grammar.generate_name(obj.type, room_type=obj_infos.room_type,
- include_adj=include_adj, exclude=exclude)
+ values = grammar.generate_name(obj.type, room_type=obj_infos.room_type, exclude=exclude)
obj_infos.name, obj_infos.adj, obj_infos.noun = values
grammar.used_names.add(obj_infos.name)
@@ -103,19 +103,13 @@ def assign_description_to_object(obj, grammar, game_infos):
def generate_text_from_grammar(game, grammar: Grammar):
- include_adj = grammar.flags.get("include_adj", False)
- only_last_action = grammar.flags.get("only_last_action", False)
- blend_instructions = grammar.flags.get("blend_instructions", False)
- blend_descriptions = grammar.flags.get("blend_descriptions", False)
- ambiguous_instructions = grammar.flags.get("ambiguous_instructions", False)
-
# Assign a specific room type and name to our rooms
for room in game.world.rooms:
# First, generate a unique roomtype and name from the grammar
if game.infos[room.id].room_type is None and grammar.has_tag("#room_type#"):
game.infos[room.id].room_type = grammar.expand("#room_type#")
- assign_name_to_object(room, grammar, game.infos, include_adj)
+ assign_name_to_object(room, grammar, game.infos)
# Next, assure objects contained in a room must have the same room type
for obj in game.world.get_all_objects_in(room):
@@ -127,38 +121,35 @@ def generate_text_from_grammar(game, grammar: Grammar):
if game.infos[obj.id].room_type is None and grammar.has_tag("#room_type#"):
game.infos[obj.id].room_type = grammar.expand("#room_type#")
- # We have to "count" all the adj/noun/types in the world
- # This is important for using "unique" but abstracted references to objects
- counts = OrderedDict()
- counts["adj"] = CountOrderedDict()
- counts["noun"] = CountOrderedDict()
- counts["type"] = CountOrderedDict()
-
# Assign name and description to objects.
for obj in game.world.objects:
if obj.type in ["I", "P"]:
continue
- obj_infos = game.infos[obj.id]
- assign_name_to_object(obj, grammar, game.infos, include_adj)
+ assign_name_to_object(obj, grammar, game.infos)
assign_description_to_object(obj, grammar, game.infos)
- counts['adj'][obj_infos.adj] += 1
- counts['noun'][obj_infos.noun] += 1
- counts['type'][obj.type] += 1
-
# Generate the room descriptions.
for room in game.world.rooms:
- assign_description_to_room(room, game, grammar, blend_descriptions)
+ if game.infos[room.id].desc is None: # Skip rooms which already have a description.
+ game.infos[room.id].desc = assign_description_to_room(room, game, grammar)
# Generate the instructions.
for quest in game.quests:
- assign_description_to_quest(quest, game, grammar, counts, only_last_action, blend_instructions, ambiguous_instructions)
+ if quest.desc is None: # Skip quests which already have a description.
+ quest.desc = assign_description_to_quest(quest, game, grammar)
+
+ if grammar.flags.only_last_action and len(game.quests) > 1:
+ main_quest = Quest(actions=[quest.actions[-1] for quest in game.quests])
+ only_last_action_bkp = grammar.flags.only_last_action
+ grammar.flags.only_last_action = False
+ game.objective = assign_description_to_quest(main_quest, game, grammar)
+ grammar.flags.only_last_action = only_last_action_bkp
return game
-def assign_description_to_room(room, game, grammar, blend_descriptions):
+def assign_description_to_room(room, game, grammar):
"""
Assign a descripton to a room.
"""
@@ -188,7 +179,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions):
obj_infos = game.infos[obj.id]
adj, noun = obj_infos.adj, obj_infos.noun
- if blend_descriptions:
+ if grammar.flags.blend_descriptions:
found = False
for type in ["noun", "adj"]:
group_filt = []
@@ -248,7 +239,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions):
exits_desc = []
# Describing exits with door.
- if blend_descriptions and len(exits_with_closed_door) > 1:
+ if grammar.flags.blend_descriptions and len(exits_with_closed_door) > 1:
dirs, door_objs = zip(*exits_with_closed_door)
e_desc = grammar.expand("#room_desc_doors_closed#")
e_desc = replace_num(e_desc, len(door_objs))
@@ -263,7 +254,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions):
d_desc = d_desc.replace("(dir)", dir_)
exits_desc.append(d_desc)
- if blend_descriptions and len(exits_with_open_door) > 1:
+ if grammar.flags.blend_descriptions and len(exits_with_open_door) > 1:
dirs, door_objs = zip(*exits_with_open_door)
e_desc = grammar.expand("#room_desc_doors_open#")
e_desc = replace_num(e_desc, len(door_objs))
@@ -279,7 +270,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions):
exits_desc.append(d_desc)
# Describing exits without door.
- if blend_descriptions and len(exits_without_door) > 1:
+ if grammar.flags.blend_descriptions and len(exits_without_door) > 1:
e_desc = grammar.expand("#room_desc_exits#").replace("(dir)", list_to_string(exits_without_door, False))
e_desc = repl_sing_plur(e_desc, len(exits_without_door))
exits_desc.append(e_desc)
@@ -291,7 +282,7 @@ def assign_description_to_room(room, game, grammar, blend_descriptions):
room_desc += " ".join(exits_desc)
# Finally, set the description
- game.infos[room.id].desc = fix_determinant(room_desc)
+ return fix_determinant(room_desc)
class MergeAction:
@@ -308,7 +299,7 @@ def __init__(self):
self.end = None
-def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_instructions):
+def generate_instruction(action, grammar, game_infos, world, counts):
"""
Generate text instruction for a specific action.
"""
@@ -360,7 +351,7 @@ def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_i
obj = world.find_object_by_id(var.name)
obj_infos = game_infos[obj.id]
- if ambiguous_instructions:
+ if grammar.flags.ambiguous_instructions:
assert False, "not tested"
choices = []
@@ -393,29 +384,46 @@ def generate_instruction(action, grammar, game_infos, world, counts, ambiguous_i
return desc, separator
-def assign_description_to_quest(quest, game, grammar, counts, only_last_action, blend_instructions, ambiguous_instructions):
+def assign_description_to_quest(quest, game, grammar):
"""
Assign a descripton to a quest.
"""
+ # We have to "count" all the adj/noun/types in the world
+ # This is important for using "unique" but abstracted references to objects
+ counts = OrderedDict()
+ counts["adj"] = CountOrderedDict()
+ counts["noun"] = CountOrderedDict()
+ counts["type"] = CountOrderedDict()
+
+ # Assign name and description to objects.
+ for obj in game.world.objects:
+ if obj.type in ["I", "P"]:
+ continue
+
+ obj_infos = game.infos[obj.id]
+ counts['adj'][obj_infos.adj] += 1
+ counts['noun'][obj_infos.noun] += 1
+ counts['type'][obj.type] += 1
+
if len(quest.actions) == 0:
# We don't need to say anything if the quest is empty
- quest.desc = "Choose your own adventure!"
+ quest_desc = "Choose your own adventure!"
else:
# Generate a description for either the last, or all commands
- if only_last_action:
- actions_desc, _ = generate_instruction(quest.actions[-1], grammar, game.infos, game.world, counts, ambiguous_instructions)
+ if grammar.flags.only_last_action:
+ actions_desc, _ = generate_instruction(quest.actions[-1], grammar, game.infos, game.world, counts)
only_one_action = True
else:
actions_desc = ""
# Decide if we blend instructions together or not
- if blend_instructions:
+ if grammar.flags.blend_instructions:
instructions = get_action_chains(quest.actions, grammar, game.infos)
else:
instructions = quest.actions
only_one_action = len(instructions) < 2
for c in instructions:
- desc, separator = generate_instruction(c, grammar, game.infos, game.world, counts, ambiguous_instructions)
+ desc, separator = generate_instruction(c, grammar, game.infos, game.world, counts)
actions_desc += desc
if c != instructions[-1] and len(separator) > 0:
actions_desc += separator
@@ -428,7 +436,9 @@ def assign_description_to_quest(quest, game, grammar, counts, only_last_action,
quest_tag = grammar.get_random_expansion("#quest#")
quest_tag = quest_tag.replace("(list_of_actions)", actions_desc.strip())
- quest.desc = grammar.expand(quest_tag)
+ quest_desc = grammar.expand(quest_tag)
+
+ return quest_desc
def get_action_chains(actions, grammar, game_infos):
diff --git a/textworld/generator/text_grammar.py b/textworld/generator/text_grammar.py
index 1417a060..9637e584 100644
--- a/textworld/generator/text_grammar.py
+++ b/textworld/generator/text_grammar.py
@@ -36,7 +36,7 @@ def fix_determinant(var):
class GrammarFlags:
- __slots__ = ['theme', 'include_adj', 'blend_descriptions',
+ __slots__ = ['theme', 'names_to_exclude', 'include_adj', 'blend_descriptions',
'ambiguous_instructions', 'only_last_action',
'blend_instructions',
'allowed_variables_numbering', 'unique_expansion']
@@ -45,6 +45,7 @@ def __init__(self, flags=None, **kwargs):
flags = flags or kwargs
self.theme = flags.get("theme", "house")
+ self.names_to_exclude = flags.get("names_to_exclude", [])
self.allowed_variables_numbering = flags.get("allowed_variables_numbering", False)
self.unique_expansion = flags.get("unique_expansion", False)
self.include_adj = flags.get("include_adj", False)
@@ -53,18 +54,37 @@ def __init__(self, flags=None, **kwargs):
self.blend_descriptions = flags.get("blend_descriptions", False)
self.ambiguous_instructions = flags.get("ambiguous_instructions", False)
- def encode(self):
+ def serialize(self) -> Mapping:
+ return {slot: getattr(self, slot) for slot in self.__slots__}
+
+ @classmethod
+ def deserialize(cls, data: Mapping) -> "GrammarFlags":
+ return cls(data)
+
+ def __eq__(self, other) -> bool:
+ return (isinstance(other, GrammarFlags) and
+ all(getattr(self, slot) == getattr(other, slot) for slot in self.__slots__))
+
+ def encode(self) -> str:
""" Generate UUID for this set of grammar flags.
"""
- values = [int(getattr(self, s)) for s in self.__slots__[1:]]
+ def _unsigned(n):
+ return n & 0xFFFFFFFFFFFFFFFF
+
+ # Skip theme and names_to_exclude.
+ values = [int(getattr(self, s)) for s in self.__slots__[2:]]
flag = "".join(map(str, values))
from hashids import Hashids
hashids = Hashids(salt="TextWorld")
+ if len(self.names_to_exclude) > 0:
+ names_to_exclude_hash = _unsigned(hash(frozenset(self.names_to_exclude)))
+ return self.theme + "-" + hashids.encode(names_to_exclude_hash) + "-" + hashids.encode(int(flag))
+
return self.theme + "-" + hashids.encode(int(flag))
-def encode_flags(flags):
+def encode_flags(flags: Mapping) -> str:
return GrammarFlags(flags).encode()
@@ -83,19 +103,19 @@ def __init__(self, flags: Mapping = {}, rng: Optional[RandomState] = None):
:param rng:
Random generator used for sampling tag expansions.
"""
- self.flags = flags
+ self.flags = GrammarFlags(flags)
self.grammar = OrderedDict()
self.rng = g_rng.next() if rng is None else rng
- self.allowed_variables_numbering = self.flags.get("allowed_variables_numbering", False)
- self.unique_expansion = self.flags.get("unique_expansion", False)
+ self.allowed_variables_numbering = self.flags.allowed_variables_numbering
+ self.unique_expansion = self.flags.unique_expansion
self.all_expansions = defaultdict(list)
# The current used symbols
self.overflow_dict = OrderedDict()
- self.used_names = set(self.flags.get("names_to_exclude", []))
+ self.used_names = set(self.flags.names_to_exclude)
# Load the grammar associated to the provided theme.
- self.theme = self.flags.get("theme", "house")
+ self.theme = self.flags.theme
grammar_contents = []
# Load the object names file
@@ -111,7 +131,7 @@ def __eq__(self, other):
return (isinstance(other, Grammar) and
self.overflow_dict == other.overflow_dict and
self.grammar == other.grammar and
- self.flags == other.flags and
+ self.flags.encode() == other.flags.encode() and
self.used_names == other.used_names)
def _parse(self, lines: List[str]):
@@ -174,7 +194,6 @@ def get_random_expansion(self, tag: str, rng: Optional[RandomState] = None) -> s
self.all_expansions[tag].append(expansion)
return expansion
-
def expand(self, text: str, rng: Optional[RandomState] = None) -> str:
"""
Expand some text until there is no more tag to expand.
@@ -235,7 +254,7 @@ def split_name_adj_noun(self, candidate: str, include_adj: bool) -> Optional[Tup
return name, adj, noun
def generate_name(self, obj_type: str, room_type: str = "",
- include_adj: bool = True, exclude: Container[str] = []) -> Tuple[str, str, str]:
+ include_adj: Optional[bool] = None, exclude: Container[str] = []) -> Tuple[str, str, str]:
"""
Generate a name given an object type and the type room it belongs to.
@@ -248,6 +267,7 @@ def generate_name(self, obj_type: str, room_type: str = "",
include_adj : optional
If True, the name can contain a generated adjective.
If False, any generated adjective will be discarded.
+ Default: use value grammar.flags.include_adj
exclude : optional
List of names we should avoid generating.
@@ -260,6 +280,8 @@ def generate_name(self, obj_type: str, room_type: str = "",
noun :
The noun part of the name.
"""
+ if include_adj is None:
+ include_adj = self.flags.include_adj
# Get room-specialized name, if possible.
symbol = "#{}_({})#".format(room_type, obj_type)
diff --git a/textworld/helpers.py b/textworld/helpers.py
index 6b5b3620..4c4ee7a7 100644
--- a/textworld/helpers.py
+++ b/textworld/helpers.py
@@ -119,7 +119,7 @@ def play(game_file: str, agent: Optional[Agent] = None, max_nb_steps: int = 1000
print(msg)
-def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2,
+def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2, quest_breadth: int = 1,
grammar_flags: Mapping = {}, seed: int = None,
games_dir: str = "./gen_games/") -> Tuple[str, Game]:
""" Makes a text-based game.
@@ -128,6 +128,7 @@ def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2,
world_size: Number of rooms in the world.
nb_objects: Number of objects in the world.
quest_length: Minimum number of actions the quest requires to be completed.
+ quest_breadth: Control how nonlinear a quest can be (1: linear).
grammar_flags: Grammar options.
seed: Random seed for the game generation process.
games_dir: Path to the directory where the game will be saved.
@@ -137,6 +138,6 @@ def make(world_size: int = 1, nb_objects: int = 5, quest_length: int = 2,
"""
g_rng.set_seed(seed)
game_name = "game_{}".format(seed)
- game = make_game(world_size, nb_objects, quest_length, grammar_flags)
+ game = make_game(world_size, nb_objects, quest_length, quest_breadth, grammar_flags)
game_file = compile_game(game, game_name, games_folder=games_dir, force_recompile=True)
return game_file, game
diff --git a/textworld/render/render.py b/textworld/render/render.py
index 50da8ee0..458e77c5 100644
--- a/textworld/render/render.py
+++ b/textworld/render/render.py
@@ -215,18 +215,18 @@ def used_pos():
edges.append((room.name, target.name, room.doors.get(exit)))
# temp_viz(nodes, edges, pos, color=[world.player_room.name])
- pos = {game_infos[k].name: v for k, v in pos.items()}
rooms = {}
player_room = world.player_room
if game_infos is None:
new_game = Game(world, [])
game_infos = new_game.infos
- game_infos["objective"] = new_game.quests[0].desc
for k, v in game_infos.items():
if v.name is None:
v.name = k
+ pos = {game_infos[k].name: v for k, v in pos.items()}
+
for room in world.rooms:
rooms[room.id] = GraphRoom(game_infos[room.id].name, room)
@@ -354,9 +354,7 @@ def visualize(world: Union[Game, State, GlulxGameState, World],
if isinstance(world, Game):
game = world
state = load_state(game.world, game.infos)
- state["objective"] = ""
- if len(game.quests) > 0:
- state["objective"] = game.quests[0].desc
+ state["objective"] = game.objective
elif isinstance(world, GlulxGameState):
state = load_state_from_game_state(game_state=world)
elif isinstance(world, World):