In [6]:
import chess
import gym_chess
import gym
import random
from IPython.display import display, clear_output
import time
import numpy as np

C = np.sqrt(2)


In [7]:
select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

In [8]:
class Node:
    def __init__(self, board, parent=None):
        self.state = board
        self.utility = 0
        self.n_playthrough = 0
        self.parent = parent
        self.children = list()
    
    def peek(self):
        return self.state.peek()


In [14]:

class MonteCarloAgent:
    from functools import lru_cache
    def __init__(self, state):
        self.tree = None
    
    # main function for the Monte Carlo Tree Search
    def monte_carlo_tree_search(self, state):
#         global select_t
#         global expand_t
#         global simulate_t
#         global backprop_t
#         global total_t
        
        self.tree = Node(state)
        for i in range(1000):
#             start = time.perf_counter()
            leaf = self.select(tree)
#             select_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            child = self.expand(leaf)
#             expand_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            result = self.simulate(child)
#             simulate_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            if not result == state.turn:
                self.back_prop(0, child)
            else:
                self.back_prop(1, child)
#             backprop_t += time.perf_counter() - start
            
        return max(tree.children, key=lambda k: k.n_playthrough).peek()
    
    def select(self, tree):
        while tree.children:
            tree = max(tree.children, key=lambda k: self.UCB1(k.utility, k.n_playthrough, k.parent.n_playthrough))
        
        return tree
    
    @lru_cache(maxsize=10000)
    def UCB1(self, U_n, N_n, P_n):
        return U_n/N_n + C*np.sqrt(np.log(P_n)/N_n)
    
    def expand(self, leaf):
        assert not leaf.children, 'This guy is not a leaf'
        _board = leaf.state.copy()
        if _board.legal_moves:
            move = random.choice(list(_board.legal_moves))
            _board.push(move) 
            leaf.children.append(Node(_board, leaf))
            return leaf.children[0]        
        # if the current node is checkmated, i.e. no legal moves
        # then return the current node 
        return leaf
    
    def simulate(self, child):
        _board = child.state.copy(stack=False)
        while not _board.is_game_over():
            move = random.choice(list(_board.legal_moves))
            _board.push(move)
            
            # simulate would take forever without this if statement
            # basically the game would go on forever with 2 kings moving around without this check
            if not _board.has_insufficient_material(_board.turn) or _board.can_claim_fifty_moves():
                return not _board.turn
            
        return _board.outcome().winner
        
    def back_prop(self, result, child):
        while child:
            child.n_playthrough += 1
            child.utility += result
            child = child.parent

In [15]:
import cProfile
import pstats

env = gym.make('Chess-v0')
env.reset()
print(env.render())
done = False
agent = MonteCarloAgent()

select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

with cProfile.Profile() as pr:
    while not done:
        clear_output(wait=True)
        print(env.render())
        print(f'{env._observation().turn}')
        if env._observation().turn:
        #         start = time.perf_counter()
                action = agent.monte_carlo_tree_search(env._observation())
        #         total_t += time.perf_counter() - start
        else:
            action = random.choice(env.legal_moves)
        observation, reward, done, _ = env.step(action)

stats = pstats.Stats(pr)
stats.sort_stats(pstats.SortKey.TIME)
stats.print_stats()

print(reward)
env.close()
    

⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ♖ ♔
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ♚ ⭘ ⭘ ⭘ ⭘
False
         2782081674 function calls in 1146.035 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
175850982  270.581    0.000  856.606    0.000 C:\Users\engli\anaconda3\lib\copy.py:66(copy)
175850982  164.552    0.000  423.847    0.000 C:\Users\engli\anaconda3\lib\copy.py:258(_reconstruct)
176130979  152.148    0.000  152.148    0.000 {built-in method __new__ of type object at 0x00007FF8EBD0B810}
351701964   66.283    0.000   66.283    0.000 {built-in method builtins.getattr}
115651297   48.497    0.000   87.680    0.000 <ipython-input-14-7c8b5f54c8a9>:39(<lambda>)
175850982   47.232    0.000  199.328    0.000 C:\Users\engli\anaconda3\lib\copyreg.py:90(__newobj__)
115652032   47.185    0.000  134.865    0.000 {built-in method builtins.max}
175850982   46.222    0.000   46.222    0.000 {method '_

In [None]:
f'Time to select: {select_t}'

In [None]:
f'Time to expand: {expand_t}'

In [None]:
f'Time to simulate: {simulate_t}'

In [None]:
f'Time to backprop: {backprop_t}'

In [None]:
f'Total time: {total_t}'

In [None]:
stats = pstats.Stats(pr)
stats.sort_stats(pstats.SortKey.TIME)
stats.print_stats()