https://int8.io/monte-carlo-tree-search-beginners-guide/

In [78]:
import math
import random
from functools import lru_cache
from itertools import count
from collections import defaultdict
from gamestate import (
    KlonState, state_is_win, 
    play_move, get_legal_moves,
    to_pretty_string
)
from policies import yan_et_al

WIN_REWARD = 1000
CYCLE_REWARD = -10
TIMEOUT_REWARD = -1


class MCTS:
    """
    See Algorithm 2: The UCT algorithm
    
    Browne, C. B., Powley, E., Whitehouse, D., Lucas, S. M., 
    Cowling, P. I., Rohlfshagen, P., ... Colton, S. (2012). 
    A Survey of Monte Carlo Tree Search Methods. 
    IEEE Transactions on Computational Intelligence and AI in Games, 4(1), 1–43.
    https://doi.org/10.1109/TCIAIG.2012.2186810
    """
    
    def __init__(self, root):
        self.root = root
        
#     def search(self, rollouts=50):
#         for i in range(rollouts):
#             leaf = self.traverse(self.root)
#             reward = self.simulation(leaf)
#             self.backpropagate(leaf, reward)
#         return self.best_child(self.root)
    
    def traverse(self, node):
        print(f'traverse {node}')
        while not node.is_terminal():
            if node.not_fully_expanded():
                return self.expand(node)
            node = self.best_child(node, exploration_term=0.9)
        print(f'returning terminal node {node}')
        return node
    
    def expand(self, node):
#         action = choose untried action from node
        print(f'expanding {node}')
        node.expand_random_child()
        # return v'
        return child
        
    def best_child(self, node, exploration_term=0):
        print(f'best child for node {node}, exp {exploration_term}')
        def uct(c):
            if c.N == 0: return float('-inf') # avoid unseen moves
            exploit = c.Q / c.N
            explore = math.sqrt((2*math.log(node.N)) / c.N)
            return exploit + exploration_term*explore
        best = max(node.children, key=uct)
        print(f'best child {best}')
        return best
    
    def simulation(self, node):
        print(f'simulate on node {node}')
        rollout_depth = 1000
        for _ in range(rollout_depth):
            if node.is_terminal():
                r = node.reward()
                print(f'terminal node reward {r}')
                return r
            node = node.rollout_policy_child()
        print(f'timeout reward {TIMEOUT_REWARD}')
        return TIMEOUT_REWARD

    def backpropagate(self, node, reward):
        while node is not None:
            print(f"  > backpropagate {node}")
            node.Q += reward
            node.N += 1
            node = node.parent
            

class Node(KlonState):
    def __new__(cls, *klonstate, parent=None, action=None):
        self = super(Node, cls).__new__(cls, *klonstate)
        self.parent = parent
        self.action = action
        self.children = set()
        self.tried_actions = set()
        ancestors = frozenset([])
        if self.parent:
            ancestors = parent.ancestors \
                .union(frozenset([parent]))
        self.ancestors = ancestors
        self.Q = 0
        self.N = 0
        return self
    
    def __repr__(self):
        return f"K{hash(self)%99999:05}"
    
    def reward(self):
        if self.is_win:
            return WIN_REWARD
        if self.is_cycled:
            return CYCLE_REWARD
        return 0
    
    def is_terminal(self):
        return self.reward() != 0
    
    def not_fully_expanded(self):
        return len(self.legal_moves) > len(self.children)
    
    @property
    @lru_cache(maxsize=None)
    def is_win(self):
        return state_is_win(self)
    
    @property
    def is_cycled(self):
        return self in self.ancestors

    @property
    @lru_cache(maxsize=None)
    def legal_moves(self):
        return get_legal_moves(self)
    
    def play_move(self, move):
        child = play_move(self, move)
        return Node(*child, parent=self, action=move)
    
    def rollout_policy_move(self):
        # copied from policies.py
        score = lambda mc: (yan_et_al(mc, self), mc)
        moves = sorted(self.legal_moves, key=score, reverse=True)
        return moves[0]
    
    def rollout_policy_child(node):
        move = node.rollout_policy_move()
        return node.play_move(move)
        
    def expand_random_child(node):
        moves = node.legal_moves - node.tried_actions
        move = random.choice(list(moves))
        node.tried_actions.add(move)
        child = node.play_move(move)
        node.children.add(child)
    
    def to_pretty_string(self):
        return to_pretty_string(self)


from benchmarking import random_state
random.seed(0)
state = Node(*random_state())


self = MCTS(state)
#     def search(self, rollouts=50):
for i in range(10):
    leaf = self.traverse(self.root)
    reward = self.simulation(leaf)
    self.backpropagate(leaf, reward)
ret = self.best_child(self.root)
print('final', ret)
print(ret.to_pretty_string())

traverse K25795
expanding K25795
playmove DR7
simulate on node K62559
terminal node reward -10
  > backpropagate K62559
  > backpropagate K25795
traverse K25795
expanding K25795
playmove DR8
simulate on node K67678
terminal node reward -10
  > backpropagate K67678
  > backpropagate K25795
traverse K25795
expanding K25795
playmove 27
simulate on node K34533
terminal node reward -10
  > backpropagate K34533
  > backpropagate K25795
traverse K25795
expanding K25795
playmove DR3
simulate on node K26610
terminal node reward -10
  > backpropagate K26610
  > backpropagate K25795
traverse K25795
expanding K25795
playmove DR2
simulate on node K21664
terminal node reward -10
  > backpropagate K21664
  > backpropagate K25795
traverse K25795
expanding K25795
playmove DR7
simulate on node K62559
terminal node reward -10
  > backpropagate K62559
  > backpropagate K25795
traverse K25795
expanding K25795
playmove 27
simulate on node K34533
terminal node reward -10
  > backpropagate K34533
  > backprop