Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_play_generated_games.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def test_play_generated_games():
game_state, reward, done = env.step(command)

if done:
msg = "Finished before playing `max_steps` steps."
msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command)
if game_state.has_won:
msg += " (winning)"
assert game_state._game_progression.winning_policy == []
assert len(game_state._game_progression.winning_policy) == 0

if game_state.has_lost:
msg += " (losing)"
Expand Down
2 changes: 1 addition & 1 deletion textworld/envs/glulx/git_glulx_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def init(self, output: str, game=None,
"""
output = _strip_input_prompt_symbol(output)
super().init(output)
self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward)
self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward)
self._state_tracking = state_tracking
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0

Expand Down
166 changes: 102 additions & 64 deletions textworld/generator/dependency_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


import textwrap
from typing import List, Any, Iterable

from textworld.utils import uniquify

Expand All @@ -18,42 +19,51 @@ class DependencyTreeElement:
`__str__` accordingly.
"""

def __init__(self, value):
def __init__(self, value: Any):
self.value = value
self.parent = None

def depends_on(self, other):
def depends_on(self, other: "DependencyTreeElement") -> bool:
"""
Check whether this element depends on the `other`.
"""
return self.value > other.value

def is_distinct_from(self, others):
def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool:
"""
Check whether this element is distinct from `others`.
"""
return self.value not in [other.value for other in others]

def __str__(self):
def __str__(self) -> str:
return str(self.value)


class DependencyTree:
class _Node:
def __init__(self, element):
def __init__(self, element: DependencyTreeElement):
self.element = element
self.children = []
self.parent = None

def push(self, node):
def push(self, node: "DependencyTree._Node") -> bool:
if node == self:
return
return True

added = False
for child in self.children:
child.push(node)
added |= child.push(node)

if self.element.depends_on(node.element) and not self.already_added(node):
node = node.copy()
self.children.append(node)
node.element.parent = self.element
node.parent = self
return True

def already_added(self, node):
return added

def already_added(self, node: "DependencyTree._Node") -> bool:
# We want to avoid duplicate information about dependencies.
if node in self.children:
return True
Expand All @@ -63,14 +73,15 @@ def already_added(self, node):
if not node.element.is_distinct_from((child.element for child in self.children)):
return True

# for child in self.children:
# # if node.element.value == child.element.value:
# if not node.element.is_distinct_from((child.element):
# return True

return False

def __str__(self):
def __iter__(self) -> Iterable["DependencyTree._Node"]:
for child in self.children:
yield from list(child)

yield self

def __str__(self) -> str:
node_text = str(self.element)

txt = [node_text]
Expand All @@ -79,85 +90,112 @@ def __str__(self):

return "\n".join(txt)

def copy(self):
def copy(self) -> "DependencyTree._Node":
node = DependencyTree._Node(self.element)
node.children = [child.copy() for child in self.children]
for child in self.children:
child_ = child.copy()
child_.parent = node
node.children.append(child_)

return node

def __init__(self, element_type=DependencyTreeElement):
self.root = None
def __init__(self, element_type: type = DependencyTreeElement, trees: Iterable["DependencyTree"] = []):
self.roots = []
self.element_type = element_type
for tree in trees:
self.roots += [root.copy() for root in tree.roots]

self._update()

def push(self, value):
def push(self, value: Any, allow_multi_root: bool = False) -> bool:
""" Add a value to this dependency tree.

Adding a value already present in the tree does not modify the tree.

Args:
value: value to add.
allow_multi_root: if `True`, allow the value to spawn an
additional root if needed.

"""
element = self.element_type(value)
node = DependencyTree._Node(element)
if self.root is None:
self.root = node
else:
self.root.push(node)

# Recompute leaves.
self._update()
if element in self.leaves_elements:
return node
added = False
for root in self.roots:
added |= root.push(node)

return None
if len(self.roots) == 0 or (not added and allow_multi_root):
self.roots.append(node)
added = True

def pop(self, value):
if value not in self.leaves_values:
raise ValueError("That element is not a leaf: {!r}.".format(value))
self._update() # Recompute leaves.
return added

def _visit(node):
for child in list(node.children):
if child.element.value == value:
node.children.remove(child)
def remove(self, value: Any) -> None:
""" Remove all leaves having the given value.

self._postorder(self.root, _visit)
if self.root.element.value == value:
self.root = None
The value to remove needs to belong to at least one leaf in this tree.
Otherwise, the tree remains unchanged.

# Recompute leaves.
self._update()
Args:
value: value to remove from the tree.

Returns:
Whether the tree has changed or not.
"""
if value not in self.leaves_values:
return False

root_to_remove = []
for node in self:
if node.element.value == value:
if node.parent is not None:
node.parent.children.remove(node)
else:
root_to_remove.append(node)

def _postorder(self, node, visit):
for child in node.children:
self._postorder(child, visit)
for node in root_to_remove:
self.roots.remove(node)

visit(node)
self._update() # Recompute leaves.
return True

def _update(self):
def _update(self) -> None:
self._leaves_values = []
self._leaves_elements = set()
self._leaves_elements = []

def _visit(node):
for node in self:
if len(node.children) == 0:
self._leaves_elements.add(node.element)
self._leaves_elements.append(node.element)
self._leaves_values.append(node.element.value)

if self.root is not None:
self._postorder(self.root, _visit)

self._leaves_values = uniquify(self._leaves_values)
self._leaves_elements = uniquify(self._leaves_elements)

def copy(self):
tree = DependencyTree(self.element_type)
if self.root is not None:
tree.root = self.root.copy()
tree._update()
def copy(self) -> "DependencyTree":
tree = type(self)(element_type=self.element_type)
for root in self.roots:
tree.roots.append(root.copy())

tree._update()
return tree

def __iter__(self) -> Iterable["DependencyTree._Node"]:
for root in self.roots:
yield from list(root)

@property
def leaves_elements(self):
def values(self) -> List[Any]:
return [node.element.value for node in self]

@property
def leaves_elements(self) -> List[DependencyTreeElement]:
return self._leaves_elements

@property
def leaves_values(self):
def leaves_values(self) -> List[Any]:
return self._leaves_values

def __str__(self):
if self.root is None:
return ""

return str(self.root)
def __str__(self) -> str:
return "\n".join(map(str, self.roots))
Loading