Skip to content

Commit

Permalink
Add a solver to the python MCTS.
Browse files Browse the repository at this point in the history
Fixes #81, at least in python.

PiperOrigin-RevId: 272641652
Change-Id: I18259a715c461ab94cb9ca9bdac7bb00751e70b7
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Oct 3, 2019
1 parent 976feb7 commit 4ecd616
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 25 deletions.
111 changes: 87 additions & 24 deletions open_spiel/python/algorithms/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,34 +81,56 @@ class SearchNode(object):
total_reward: The sum of rewards of rollouts through this node, from the
parent node's perspective. The average reward of this node is
`total_reward / explore_count`
outcome: The rewards for all players if this is a terminal node or the
subtree has been proven, otherwise None.
children: A list of SearchNodes representing the possible actions from this
node, along with their expected rewards.
"""
__slots__ = [
"action", "player", "explore_count", "total_reward", "children"
"action", "player", "explore_count", "total_reward", "outcome",
"children",
]

def __init__(self, action, player):
self.action = action
self.player = player
self.explore_count = 0
self.total_reward = 0.0
self.outcome = None
self.children = []

def uct_value(self, parent_explore_count, uct_c, child_default_value):
"""Returns the UCT value of child."""
if self.outcome is not None:
return self.outcome[self.player]

if self.explore_count == 0:
return child_default_value

return self.total_reward / self.explore_count + uct_c * math.sqrt(
math.log(parent_explore_count) / self.explore_count)

def most_visited_child(self):
"""Returns the most visited action from this node."""
return max(self.children, key=lambda c: c.explore_count)
def best_child(self):
"""Returns the best action from this node, either proven or most visited.
This ordering leads to choosing:
- Highest proven score > 0 over anything else, including a promising but
unproven action.
- A proven draw only if it has higher exploration than others that are
uncertain, or the others are losses.
- Uncertain action with most exploration over loss of any difficulty
- Hardest loss if everything is a loss
- Highest expected reward if explore counts are equal (unlikely).
- Longest win, if multiple are proven (unlikely due to early stopping).
"""
def key(c):
return (0 if c.outcome is None else c.outcome[c.player],
c.explore_count, c.total_reward)

return max(self.children, key=key)

def children_str(self, state=None):
"""Returns the string representation of all children.
"""Returns the string representation of this node's children.
They are ordered in decreasing explore count.
Expand All @@ -122,11 +144,10 @@ def children_str(self, state=None):
])

def to_str(self, state=None):
"""Returns the string repr.
"""Returns the string representation of this node.
of this node children.
Looks like: "d4h: player: 1, value: 244.0 / 2017 = 0.121, 20 children"
Looks like:
"d4h: player: 1, value: 244.0 / 2017 = 0.121, outcome: 1.0, 20 children"
Args:
state: A `pyspiel.State` object, to be used to convert the action id into
Expand All @@ -135,9 +156,11 @@ def to_str(self, state=None):
action = (state.action_to_string(state.current_player(), self.action)
if state else str(self.action))
return ("{:>3}: player: {}, value: {:6.1f} / {:4d} = {:6.3f}, "
"{:3d} children").format(
"outcome: {}, {:3d} children").format(
action, self.player, self.total_reward, self.explore_count,
self.explore_count and self.total_reward / self.explore_count,
("{:4.1f}".format(self.outcome[self.player])
if self.outcome else "none"),
len(self.children))

def __str__(self):
Expand All @@ -153,6 +176,7 @@ def __init__(self,
uct_c,
max_simulations,
evaluator,
solve=True,
random_state=None,
verbose=False):
"""Initializes a MCTS Search algorithm in the form of a bot.
Expand All @@ -170,6 +194,7 @@ def __init__(self,
search tree should be evaluated. This is correlated with memory size and
tree depth.
evaluator: A `Evaluator` object to use to evaluate a leaf node.
solve: Whether to back up solved states.
random_state: An optional numpy RandomState to make it deterministic.
verbose: Whether to print information about the search tree before
returning the action. Useful for confirming the search is working
Expand All @@ -191,11 +216,27 @@ def __init__(self,
self.child_default_value = float("inf")
self.player = player
self.verbose = verbose
self.solve = solve
self.max_utility = game.max_utility()
self._random_state = random_state or np.random.RandomState()

def step(self, state):
"""Returns bot's policy and action at given state."""
mcts_action = self.mcts_search(state)
root = self.mcts_search(state)

best = root.best_child()

if self.verbose:
print("Root:", root.to_str())
print("Children:")
print(root.children_str(state))
print("Children of chosen:")
chosen_state = state.clone()
chosen_state.apply_action(best.action)
print(best.children_str(chosen_state))

mcts_action = best.action

policy = [(action, (1.0 if action == mcts_action else 0.0))
for action in state.legal_actions(self.player_id())]

Expand Down Expand Up @@ -295,22 +336,44 @@ def mcts_search(self, state):
visit_path, working_state = self._apply_tree_policy(root, state)
if working_state.is_terminal():
returns = working_state.returns()
visit_path[-1].outcome = returns
solved = self.solve
else:
returns = self.evaluator.evaluate(working_state, self._random_state)
solved = False

for node in visit_path:
for node in reversed(visit_path):
node.total_reward += returns[node.player]
node.explore_count += 1

most_visited = root.most_visited_child()

if self.verbose:
print("Root:", root.to_str())
print("Children:")
print(root.children_str(working_state))
print("Children of chosen:")
chosen_state = state.clone()
chosen_state.apply_action(most_visited.action)
print(most_visited.children_str(chosen_state))

return most_visited.action
if solved and node.children:
player = node.children[0].player
if player == pyspiel.PlayerId.CHANCE:
# Only back up chance nodes if all have the same outcome.
# An alternative would be to back up the weighted average of
# outcomes if all children are solved, but that is less clear.
outcome = node.children[0].outcome
if (outcome is not None and
all(np.array_equal(c.outcome, outcome) for c in node.children)):
node.outcome = outcome
else:
solved = False
else:
# If any have max utility (won?), or all children are solved,
# choose the one best for the player choosing.
best = None
all_solved = True
for child in node.children:
if child.outcome is None:
all_solved = False
elif best is None or child.outcome[player] > best.outcome[player]:
best = child
if (best is not None and
(all_solved or best.outcome[player] == self.max_utility)):
node.outcome = best.outcome
else:
solved = False
if root.outcome is not None:
break

return root
132 changes: 131 additions & 1 deletion open_spiel/python/algorithms/mcts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,48 @@
from __future__ import print_function

import math
from absl.testing import absltest
import random

from absl.testing import absltest
import numpy as np

from open_spiel.python.algorithms import evaluate_bots
from open_spiel.python.algorithms import mcts
import pyspiel


def _get_action(state, action_str):
for action in state.legal_actions():
if action_str == state.action_to_string(state.current_player(), action):
return action
raise ValueError("invalid action string: {}".format(action_str))


def search_tic_tac_toe_state(initial_actions):
game = pyspiel.load_game("tic_tac_toe")
state = game.new_initial_state()
for action_str in initial_actions.split(" "):
state.apply_action(_get_action(state, action_str))
bot = mcts.MCTSBot(
game, player=state.current_player(), uct_c=math.sqrt(2),
max_simulations=10000, solve=True,
evaluator=mcts.RandomRolloutEvaluator(n_rollouts=20))
return bot.mcts_search(state), state


def make_node(action, player=0, **kwargs):
node = mcts.SearchNode(action, player)
for k, v in kwargs.items():
setattr(node, k, v)
return node


class MctsBotTest(absltest.TestCase):

def assertTTTStateStr(self, state, expected):
expected = expected.replace(" ", "").strip()
self.assertEqual(str(state), expected)

def test_can_play_tic_tac_toe(self):
game = pyspiel.load_game("tic_tac_toe")
uct_c = math.sqrt(2)
Expand Down Expand Up @@ -70,6 +101,105 @@ def test_can_play_three_player_game(self):
v = evaluate_bots.evaluate_bots(game.new_initial_state(), bots, np.random)
self.assertEqual(sum(v), 0)

def test_solve_draw(self):
root, state = search_tic_tac_toe_state("x(1,1) o(0,0) x(2,2)")
self.assertTTTStateStr(state, """
o..
.x.
..x
""")
self.assertEqual(root.outcome[root.player], 0)
for c in root.children:
self.assertLessEqual(c.outcome[c.player], 0) # No winning moves.

best = root.best_child()
self.assertEqual(best.outcome[best.player], 0)
self.assertIn(state.action_to_string(best.player, best.action),
("o(0,2)", "o(2,0)")) # All others lose.

def test_solve_loss(self):
root, state = search_tic_tac_toe_state("x(1,1) o(0,0) x(2,2) o(1,0) x(2,0)")
self.assertTTTStateStr(state, """
oox
.x.
..x
""")
self.assertEqual(root.outcome[root.player], -1)
for c in root.children:
self.assertEqual(c.outcome[c.player], -1) # All losses.

def test_solve_win(self):
root, state = search_tic_tac_toe_state("x(1,0) o(2,2)")
self.assertTTTStateStr(state, """
.x.
...
..o
""")
self.assertEqual(root.outcome[root.player], 1)
best = root.best_child()
self.assertEqual(best.outcome[best.player], 1)
self.assertEqual(state.action_to_string(best.player, best.action), "x(2,0)")

def assertBestChild(self, choice, children):
# If this causes flakiness, the key in `SearchNode.best_child` is bad.
random.shuffle(children)
root = make_node(-1, children=children)
self.assertEqual(root.best_child().action, choice)

def test_choose_most_visited_when_not_solved(self):
self.assertBestChild(0, [
make_node(0, explore_count=50, total_reward=30),
make_node(1, explore_count=40, total_reward=40),
])

def test_choose_win_over_most_visited(self):
self.assertBestChild(1, [
make_node(0, explore_count=50, total_reward=30),
make_node(1, explore_count=40, total_reward=40, outcome=[1]),
])

def test_choose_best_over_good(self):
self.assertBestChild(1, [
make_node(0, explore_count=50, total_reward=30, outcome=[0.5]),
make_node(1, explore_count=40, total_reward=40, outcome=[0.8]),
])

def test_choose_bad_over_worst(self):
self.assertBestChild(0, [
make_node(0, explore_count=50, total_reward=30, outcome=[-0.5]),
make_node(1, explore_count=40, total_reward=40, outcome=[-0.8]),
])

def test_choose_positive_reward_over_promising(self):
self.assertBestChild(1, [
make_node(0, explore_count=50, total_reward=40), # more promising
make_node(1, explore_count=10, total_reward=1, outcome=[0.1]), # solved
])

def test_choose_most_visited_over_loss(self):
self.assertBestChild(0, [
make_node(0, explore_count=50, total_reward=30),
make_node(1, explore_count=40, total_reward=40, outcome=[-1]),
])

def test_choose_most_visited_over_draw(self):
self.assertBestChild(0, [
make_node(0, explore_count=50, total_reward=30),
make_node(1, explore_count=40, total_reward=40, outcome=[0]),
])

def test_choose_uncertainty_over_most_visited_loss(self):
self.assertBestChild(1, [
make_node(0, explore_count=50, total_reward=30, outcome=[-1]),
make_node(1, explore_count=40, total_reward=40),
])

def test_choose_slowest_loss(self):
self.assertBestChild(1, [
make_node(0, explore_count=50, total_reward=10, outcome=[-1]),
make_node(1, explore_count=60, total_reward=15, outcome=[-1]),
])


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions open_spiel/python/examples/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
flags.DEFINE_integer("rollout_count", 10, "How many rollouts to do.")
flags.DEFINE_integer("max_simulations", 10000, "How many nodes to expand.")
flags.DEFINE_integer("num_games", 1, "How many games to play.")
flags.DEFINE_bool("solve", True, "Whether to use MCTS-Solver.")
flags.DEFINE_bool("quiet", False, "Don't show the moves as they're played.")
flags.DEFINE_bool("verbose", False, "Show the MCTS stats of possible moves.")

Expand All @@ -61,6 +62,7 @@ def _init_bot(bot_type, game, player_id):
FLAGS.uct_c,
FLAGS.max_simulations,
evaluator,
solve=FLAGS.solve,
verbose=FLAGS.verbose)
if bot_type == "random":
return uniform_random.UniformRandomBot(game, player_id, np.random)
Expand Down

0 comments on commit 4ecd616

Please sign in to comment.