In [1]:
import math
import random
from decimal import Decimal
from functools import lru_cache
from itertools import count
from collections import defaultdict
from gamestate import (
    KlonState, state_is_win, 
    play_move, get_legal_moves,
    to_pretty_string
)
from benchmarking import random_state
from policies import yan_et_al

In [2]:
CYCLE_REWARD = -10
DEAD_END_REWARD = -20
WIN_REWARD = 100
TIMEOUT_REWARD = -1
NON_TERMINAL = 0


class MCTS:
    
    def __init__(self):
        self.children = dict()
        self.ancestors = dict()
        self.Q = defaultdict(int)
        self.N = defaultdict(int)
        
    def search(tree, node, budget=10):
        for _ in range(budget):
#             if node.is_terminal:
                # keep this inside the for-loop
                # because node may be marked terminal after just one rollout
#                 return False
            leaf = tree.tree_policy(node)
            reward = tree.rollout_policy(leaf)
            tree.backup(leaf, reward)
#         return
        if tree.children.get(node):
            return tree.uct_best_child(node, exploration=0)
        else:
            raise Exception(f'no children for node {node}')
    
    ### tree policy
    
    def tree_policy(tree, node):
        while not node.is_terminal:
            if tree._not_fully_expanded(node):
                result = tree._expand(node)
                if result is not False:
                    return result
                # result=False => terminal node
            else: # fully expanded
                node = tree.uct_best_child(node, exploration=1)
        return node
    
    def _not_fully_expanded(tree, node):
        return node.has_untried_actions
    
    def _expand(tree, node):
        while node.has_untried_actions:
            untried = random.choice(list(node.untried_actions))
            child = node.play_move(untried)
            # mark this move as explored
            node.explored_moves.add(untried)
            if tree._contains_state(child):
                # disable this move because it creates a visited child
                node.disabled_moves.add(untried)
            else:
                # valid new child, add to the tree
                if node not in tree.children:
                    tree.children[node] = set()
                tree.children[node].add(child)
                node.child_moves.add(untried)
                return child
        return False

    def _contains_state(tree, state):
        nodes = set(tree.children.keys())
        children = {c for cs in tree.children.values() for c in cs}
        all_nodes = nodes.union(children)
        return state in all_nodes
    
    #####
    
    def rollout_policy(tree, node, max_depth=5_000):
        visited = set() # states visited during this rollout
        rollout_disabled_moves = set()
        for _ in range(max_depth):
            if node.is_terminal:
                return node.reward()
            visited.add(node)
            actions = node.untried_actions - rollout_disabled_moves
            if len(actions) == 0:
                return DEAD_END_REWARD
            a = random.choice(list(actions))
            child = node.play_move(a)
            if child in visited or tree._contains_state(child):
                rollout_disabled_moves.add(a)
            else:
                # since we don't add the child to tree.children
                # it is ephemeral during rollout
                node = child
                rollout_disabled_moves = set()
        return TIMEOUT_REWARD
    
    def backup(tree, node, reward):
        while node is not None:
            tree.Q[node] += reward
            tree.N[node] += 1
            node = node.parent
    
    def uct_best_child(tree, node, exploration=0):
        def uct(c):
            """ see Algorithm 2: UCT in Browne et al (2012) """
            if tree.N[c] == 0:
                return float('-inf') # avoid unvisited nodes
            exploit = tree.Q[c] / tree.N[c]
            explore = math.sqrt(2*math.log(tree.N[node]) / tree.N[c])
            return exploit + exploration*explore
        children = tree.children[node]
        uct_vals = [uct(x) for x in children]
        return max(children, key=uct)
    

class KlonNode(KlonState):
    def __new__(cls, *klonstate, parent=None, action=None):
        self = super(KlonNode, cls).__new__(cls, *klonstate)
        
        self.parent = parent
        self.action = action
        
        if parent is not None:
            self.ancestors = parent.ancestors.union(frozenset([parent]))
        else:
            self.ancestors = frozenset()

        # these are all the standard moves available
        self.all_legal_moves = get_legal_moves(self)
        # these moves would create repeated child states
        self.disabled_moves = set()
        # these moves were taken to create children
        self.child_moves = set()
        # these moves were tried
        self.explored_moves = set()
        return self
    
    @property
    def is_terminal(node):
        if node.is_dead_end:
            return True
        if node.is_cycle:
            return True
        if node.is_win:
            return True
        return False
    
    def reward(state):
        if state.is_dead_end:
            return DEAD_END_REWARD
        elif state.is_cycle:
            return CYCLE_REWARD
        elif state.is_win:
            return WIN_REWARD
        return NON_TERMINAL
    
    @property
    def is_win(self):
        return state_is_win(self)
    
    @property
    def is_cycle(self):
        """ this should probably never happen! """
        return self in self.ancestors
    
    @property
    def is_dead_end(self):
        # all legal moves have been disabled because they're in the tree already
        # this would be a dead end
        available = self.all_legal_moves - self.disabled_moves
        return len(available) == 0
    
    def play_move(self, move):
        child = play_move(self, move)
        return KlonNode(*child, parent=self, action=move)
    
    @property
    def has_untried_actions(self):
        return len(self.untried_actions) > 0

    @property
    def untried_actions(self):
        # there have been moves that haven't been tried yet
        return self.all_legal_moves - self.explored_moves
    
    def __repr__(self):
        return f"K{hash(self)%999999}"
    
    def to_pretty_string(self):
        return to_pretty_string(self)

In [3]:
# %pdb
tree = MCTS()
random.seed(0)
state = random_state()
node = KlonNode(*state)
path = [node]
# print(node.to_pretty_string())

for i in count(1):
    if node.is_terminal:
        print(f"{node} is TERMINAL")
        if node.is_win:
            print("WIN!!!")
            print(path)
            break
        else:
            path.pop() # discard terminal node
            node = path.pop()
    print(f'{i:3}: search on {node}')
    node = tree.search(node)
    if node in path:
        print(f"CYCLE {node} in visited")
    path.append(node)

print(node)
print(node.to_pretty_string())

  1: search on K581459
  2: search on K947409
  3: search on K443855
  4: search on K465144
  5: search on K423075
  6: search on K681356
  7: search on K106946
  8: search on K811683
  9: search on K769992
 10: search on K606670
 11: search on K297615
 12: search on K707496
 13: search on K648793
 14: search on K171562
 15: search on K47183
 16: search on K564423
 17: search on K274836
 18: search on K296128
 19: search on K616299
 20: search on K708939
K685468 is TERMINAL
 21: search on K708939
K685468 is TERMINAL
 22: search on K616299
 23: search on K708939
K685468 is TERMINAL
 24: search on K708939
K685468 is TERMINAL
 25: search on K296128
 26: search on K616299
 27: search on K708939
K685468 is TERMINAL
 28: search on K708939
K685468 is TERMINAL
 29: search on K616299
 30: search on K708939
K685468 is TERMINAL
 31: search on K708939
K685468 is TERMINAL
 32: search on K274836
 33: search on K296128
 34: search on K616299
 35: search on K708939
K685468 is TERMINAL
 36: search on K

KeyboardInterrupt: 