Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 42 additions & 43 deletions textworld/generator/chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class ChainingOptions:
subquests:
Whether to also return incomplete quests, which could be extended
without reaching the depth or breadth limits.
independent_chains:
Whether to allow totally independent parallel chains.
create_variables:
Whether new variables may be created during chaining.
fixed_mapping:
Expand All @@ -99,6 +101,7 @@ def __init__(self):
self.min_breadth = 1
self.max_breadth = 1
self.subquests = False
self.independent_chains = False
self.create_variables = False
self.fixed_mapping = data.get_types().constants_mapping
self.rng = None
Expand Down Expand Up @@ -189,30 +192,17 @@ class _Node:
A node in a chain being generated.

Each node is aware of its position (depth, breadth) in the dependency tree
induced by the chain. For generating parallel quests, the backtracks field
holds actions that can be use to go up the dependency tree and start a new
chain.

For example, taking the action node.backtracks[i][j] will produce a new node
at depth (i + 1) and breadth (self.breadth + 1). To avoid duplication, in
trees like this:

root
/ | \
A B C
| | |
.......

A.backtracks[0] will be [B, C], B.backtracks[0] will be [C], and
C.backtracks[0] will be [].
induced by the chain. To avoid duplication when generating parallel chains,
each node stores the actions that have already been used at that depth.
"""

def __init__(self, parent, dep_parent, state, action, backtracks, depth, breadth):
def __init__(self, parent, dep_parent, state, action, rules, used, depth, breadth):
self.parent = parent
self.dep_parent = dep_parent
self.state = state
self.action = action
self.backtracks = backtracks
self.rules = rules
self.used = used
self.depth = depth
self.breadth = breadth

Expand All @@ -235,7 +225,7 @@ def __init__(self, state, options):

def root(self) -> _Node:
"""Create the root node for chaining."""
return _Node(None, None, self.state, None, [], 0, 1)
return _Node(None, None, self.state, None, [], set(), 0, 1)

def chain(self, node: _Node) -> Iterable[_Node]:
"""
Expand All @@ -251,30 +241,21 @@ def chain(self, node: _Node) -> Iterable[_Node]:
if self.rng:
self.rng.shuffle(assignments)

partials = []
actions = []
states = []
used = set()
for partial in assignments:
action = self.try_instantiate(node.state, partial)
if not action:
continue

if not self.check_action(node, action):
if not self.check_action(node, node.state, action):
continue

state = self.apply(node, action)
if not state:
continue

partials.append(partial)
actions.append(action)
states.append(state)

for i, action in enumerate(actions):
# Only allow backtracking into later actions, to avoid duplication
remaining = partials[i+1:]
backtracks = node.backtracks + [remaining]
yield _Node(node, node, states[i], action, backtracks, node.depth + 1, node.breadth)
used = used | {action}
yield _Node(node, node, state, action, rules, used, node.depth + 1, node.breadth)

def backtrack(self, node: _Node) -> Iterable[_Node]:
"""
Expand All @@ -284,21 +265,39 @@ def backtrack(self, node: _Node) -> Iterable[_Node]:
if node.breadth >= self.max_breadth:
return

for i, partials in enumerate(node.backtracks):
backtracks = node.backtracks[:i]

for j, partial in enumerate(partials):
parent = node
parents = []
while parent.dep_parent:
if parent.depth == 1 and not self.options.independent_chains:
break
parents.append(parent)
parent = parent.dep_parent
parents = parents[::-1]

for sibling in parents:
parent = sibling.dep_parent
rules = self.options.get_rules(parent.depth)
assignments = self.all_assignments(node, rules)
if self.rng:
self.rng.shuffle(assignments)

for partial in assignments:
action = self.try_instantiate(node.state, partial)
if not action:
continue

if action in sibling.used:
continue

if not self.check_action(parent, node.state, action):
continue

state = self.apply(node, action)
if not state:
continue

remaining = partials[j+1:]
new_backtracks = backtracks + [remaining]
yield _Node(node, partial.node, state, action, new_backtracks, i + 1, node.breadth + 1)
used = sibling.used | {action}
yield _Node(node, parent, state, action, rules, used, sibling.depth, node.breadth + 1)

def all_assignments(self, node: _Node, rules: Iterable[Rule]) -> Iterable[_PartialAction]:
"""
Expand Down Expand Up @@ -359,7 +358,7 @@ def create_variable(self, state, ph, type_counts):
type_counts[ph.type] += 1
return var

def check_action(self, node: _Node, action: Action) -> bool:
def check_action(self, node: _Node, state: State, action: Action) -> bool:
# Find the last action before a navigation action
# TODO: Fold this behaviour into ChainingOptions.check_action()
nav_parent = node
Expand Down Expand Up @@ -387,7 +386,7 @@ def check_action(self, node: _Node, action: Action) -> bool:
if len(recent.added & relevant) == 0 or len(pre_navigation.added & relevant) == 0:
return False

return self.options.check_action(node.state, action)
return self.options.check_action(state, action)

def _is_navigation(self, action):
return action.name.startswith("go/")
Expand All @@ -405,8 +404,8 @@ def apply(self, node: _Node, action: Action) -> Optional[State]:

new_state.apply(action)

# Some debug checks
assert self.check_state(new_state)
if not self.check_state(new_state):
return None

# Detect cycles
state = new_state.copy()
Expand Down
71 changes: 70 additions & 1 deletion textworld/generator/tests/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,73 @@ def test_parallel_quests():
options.min_breadth = 1
options.create_variables = True
chains = list(get_chains(State(), options))
assert len(chains) == 6
assert len(chains) == 5


def test_parallel_quests_navigation():
logic = GameLogic.parse("""
type P {
}

type I {
}

type r {
rules {
move :: at(P, r) & $free(r, r') -> at(P, r');
}

constraints {
atat :: at(P, r) & at(P, r') -> fail();
}
}

type o {
rules {
take :: $at(P, r) & at(o, r) -> in(o, I);
}

constraints {
inat :: in(o, I) & at(o, r) -> fail();
}
}

type flour : o {
}

type eggs : o {
}

type cake {
rules {
bake :: in(flour, I) & in(eggs, I) -> in(cake, I) & in(flour, cake) & in(eggs, cake);
}

constraints {
inincake :: in(o, I) & in(o, cake) -> fail();
atincake :: at(o, r) & in(o, cake) -> fail();
}
}
""")

state = State([
Proposition.parse("at(P, r3: r)"),
Proposition.parse("free(r2: r, r3: r)"),
Proposition.parse("free(r1: r, r2: r)"),
])

bake = [logic.rules["bake"]]
non_bake = [r for r in logic.rules.values() if r.name != "bake"]

options = ChainingOptions()
options.backward = True
options.create_variables = True
options.min_depth = 3
options.max_depth = 3
options.min_breadth = 2
options.max_breadth = 2
options.logic = logic
options.rules_per_depth = [bake, non_bake, non_bake]
options.restricted_types = {"P", "r"}
chains = list(get_chains(state, options))
assert len(chains) == 2