In [1]:
import math
import random
from decimal import Decimal
from functools import lru_cache
from itertools import count
from collections import defaultdict
from gamestate import (
    KlonState, state_is_win, 
    play_move, get_legal_moves,
    to_pretty_string
)
from benchmarking import random_state
from policies import yan_et_al

In [2]:
CYCLE_REWARD = -10
DEAD_END_REWARD = -20
WIN_REWARD = 100
TIMEOUT_REWARD = -1
NON_TERMINAL = 0


class MCTS:
    
    def __init__(self):
        self.children = dict()
        self.ancestors = dict()
        self.Q = defaultdict(int)
        self.N = defaultdict(int)
        
    def search(tree, node, budget=10):
        for _ in range(budget):
            leaf = tree.tree_policy(node)
            reward = tree.rollout_policy(leaf)
            tree.backup(leaf, reward)
        return tree.uct_best_child(node, exploration=0)
    
    ### tree policy
    
    def tree_policy(tree, node):
        while not node.is_terminal:
            if tree._not_fully_expanded(node):
                result = tree._expand(node)
                if result is not False:
                    return result
                # result=False => terminal node
            else: # fully expanded
                node = tree.uct_best_child(node, exploration=1)
        return node
    
    def _not_fully_expanded(tree, node):
        return node.has_available_child_moves
    
    def _expand(tree, node):
        while node.has_available_child_moves:
            untried = random.choice(list(node.untried_actions))
            child = node.play_move(untried)
            # mark this move as explored
            node.explored_moves.add(untried)
            if tree._contains_state(child):
                # disable this move because it creates a visited child
                node.disabled_moves.add(untried)
            else:
                # valid new child, add to the tree
                if node not in tree.children:
                    tree.children[node] = set()
                tree.children[node].add(child)
                return child
        return False

    def _contains_state(tree, state):
        nodes = set(tree.children.keys())
        children = {c for cs in tree.children.values() for c in cs}
        all_nodes = nodes.union(children)
        return state in all_nodes
    
    #####
    
    def rollout_policy(tree, node, max_depth=5_000):
        visited = set() # states visited during this rollout
        for _ in range(max_depth):
            if node.is_terminal:
                return node.reward()
            visited.add(node)
            actions = node.untried_actions
            a = random.choice(list(actions))
            child = node.play_move(a)
            node.explored_moves.add(a)
            if child in visited or tree._contains_state(child):
                node.disabled_moves.add(a)
            else:
                # since we don't add the child to tree.children
                # it is ephemeral during rollout
                node = child
        return TIMEOUT_REWARD
    
    def backup(tree, node, reward):
        while node is not None:
            tree.Q[node] += reward
            tree.N[node] += 1
            node = node.parent
    
    def uct_best_child(tree, node, exploration=0):
        def uct(c):
            """ see Algorithm 2: UCT in Browne et al (2012) """
            if tree.N[c] == 0:
                return float('-inf') # avoid unvisited nodes
            exploit = tree.Q[c] / tree.N[c]
            explore = math.sqrt(2*math.log(tree.N[node]) / tree.N[c])
            return exploit + exploration*explore
        children = tree.children[node]
        uct_vals = [uct(x) for x in children]
        return max(children, key=uct)
    

class KlonNode(KlonState):
    def __new__(cls, *klonstate, parent=None):
        self = super(KlonNode, cls).__new__(cls, *klonstate)
        
        self.parent = parent
        
        if parent is not None:
            self.ancestors = parent.ancestors.union(frozenset([parent]))
        else:
            self.ancestors = frozenset()

        # these are all the standard moves available
        self.all_legal_moves = get_legal_moves(self)
        # these moves would create repeated child states
        self.disabled_moves = set()
        # these moves were taken to create children
        self.explored_moves = set()
        return self
    
    @property
    def is_terminal(node):
        if node.is_dead_end:
            return True
        if node.is_cycle:
            return True
        if node.is_win:
            return True
        return False
    
    def reward(state):
        if state.is_dead_end:
            return DEAD_END_REWARD
        elif state.is_cycle:
            return CYCLE_REWARD
        elif state.is_win:
            return WIN_REWARD
        return NON_TERMINAL
    
    @property
    def is_win(self):
        return state_is_win(self)
    
    @property
    def is_cycle(self):
        """ this should probably never happen! """
        return self in self.ancestors
    
    @property
    def is_dead_end(node):
        return not node.has_available_child_moves
    
    def play_move(self, move):
        child = play_move(self, move)
        return KlonNode(*child, parent=self)
    
    @property
    def untried_actions(self):
        return self.all_legal_moves - self.disabled_moves - self.explored_moves
    
    @property
    def has_available_child_moves(self):
        return len(self.untried_actions) > 0
    
    def __repr__(self):
        return f"K{hash(self)%999999}"
    
    def to_pretty_string(self):
        return to_pretty_string(self)

In [9]:
# %pdb
tree = MCTS()
random.seed(0)
state = random_state()
node = KlonNode(*state)
visited = set([node])
print(node)
# print(node.to_pretty_string())

for i in range(29): #count(1):
    print(f'{i:3}: search on {node}')
    node = tree.search(node)
#     print(f' got node {node}')
#     print(node.to_pretty_string())
    if node in visited:
        print(f"WTF CYCLE {node} in visited")
    visited.add(node)
#     if i % 10 == 1:
#         print(node.to_pretty_string())

Automatic pdb calling has been turned ON
K708477
  0: search on K708477
  1: search on K871058
  2: search on K521982
  3: search on K238575
  4: search on K329515
  5: search on K371434
  6: search on K821241
  7: search on K817369
  8: search on K111309
  9: search on K807371
 10: search on K977766
 11: search on K954617
 12: search on K33266
 13: search on K326217
 14: search on K91011
 15: search on K399164
 16: search on K827536
 17: search on K828624
 18: search on K748028
 19: search on K159407
 20: search on K252552
 21: search on K305976
 22: search on K739848
 23: search on K91093
 24: search on K805577
 25: search on K830853
 26: search on K690969
 27: search on K12857
 28: search on K17363


KeyError: K17363

> [0;32m<ipython-input-2-8bfc3c5f6c79>[0m(96)[0;36muct_best_child[0;34m()[0m
[0;32m     94 [0;31m            [0mexplore[0m [0;34m=[0m [0mmath[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0;36m2[0m[0;34m*[0m[0mmath[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mtree[0m[0;34m.[0m[0mN[0m[0;34m[[0m[0mnode[0m[0;34m][0m[0;34m)[0m [0;34m/[0m [0mtree[0m[0;34m.[0m[0mN[0m[0;34m[[0m[0mc[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m            [0;32mreturn[0m [0mexploit[0m [0;34m+[0m [0mexploration[0m[0;34m*[0m[0mexplore[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 96 [0;31m        [0mchildren[0m [0;34m=[0m [0mtree[0m[0;34m.[0m[0mchildren[0m[0;34m[[0m[0mnode[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     97 [0;31m        [0muct_vals[0m [0;34m=[0m [0;34m[[0m[0muct[0m[0;34m([0m[0mx[0m[0;34m)[0m [0;32mfor[0m [0mx[0m [0;32min[0m [0mchildren[0m[0;34m][0m[0;34m[0m[0;34m[0m