-
Notifications
You must be signed in to change notification settings - Fork 85
Description
Hello!
The pseudocode of LATS in your paper unifies three phases of MCTS: expansion, simulation and selection.
In this setting, if you set n=1, you will still be able to have nodes with more than one children at some point, because during the unified e-s-e phase, you can extend the nodes that have already been extended, probably adding one more child node to it.
However, regarding the code you provided in hotpot/lats.py (below), the selection, expansion and simulation (i.e. rollout) phases are independent.
There are 2 major differences:
- During the selection phase, no actions will be generated. Hence no new nodes will be added.
- During the simulation phase, the new nodes will not be saved (i.e. be stored in the
.childrenof its parent node) but instead discarded.
So, here's the problem. Suppose k = 50 but n = 1 (the setting of hotpot in your paper), and the depth limit L = 10, after 10 iterations, the tree will be extended to its maximum depth, and hence can't add more nodes anymore (you can't expand the expanded nodes again, which is Difference 1) and can't do more simulation as well (since the maximum depth is reached). The next 40 steps will just be selecting all the way down the depth limit and go over again, which is meaningless.
But following your pseudocode, you have the chance to add even more nodes (in the case of n=1, one node) to nodes that have already been extended. So exploration is possible in all 50 steps.
I guess the setting k = 50, n = 1 (with a depth less than 50) is meaningful in your paper, in the sense that every iteration could mean more exploration. But the code you provided here in Github, seems to say otherwise.
def lats_search(args, task, idx, iterations=30, to_print=True):
# ...
for i in range(iterations):
logging.info(f"Iteration {i + 1}...")
node = select_node(root)
while node is None or (node.is_terminal and node.reward != 1):
logging.info(f"Need to backtrack or terminal node with reward 0 found at iteration {i + 1}, reselecting...")
node = select_node(root)
if node is None:
logging.info("All paths lead to terminal nodes with reward 0. Ending search.")
break
if node.is_terminal and node.reward == 1:
logging.info(f"Terminal node with reward 1 found at iteration {i + 1}")
return node.state, node.value, all_nodes, node.reward, node.em
expand_node(node, args, task)
while node.is_terminal or not node.children:
logging.info(f"Depth limit node found at iteration {i + 1}, reselecting...")
node = select_node(root)
expand_node(node, args, task)
value = evaluate_node(node, args, task)
# Find the child with the highest value
reward, terminal_node = rollout(max(node.children, key=lambda child: child.value), args, task, idx, max_depth=4)
terminal_nodes.append(terminal_node)
if terminal_node.reward == 1:
logging.info("SUCCESSFUL TRAJECTORY FOUND DURING SIMULATION")
return terminal_node.state, terminal_node.value, [], terminal_node.reward, terminal_node.em
backpropagate(terminal_node, reward)
# ...