In [11]:
import chess
import gym_chess
import gym
import random
from IPython.display import display, clear_output
import time
import numpy as np

C = np.sqrt(2)


In [12]:
select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

In [30]:
class Node:
    def __init__(self, board, parent=None):
        self.state = board
        self.utility = 0
        self.n_playthrough = 0
        self.parent = parent
        self.children = list()
        
    def UCB1(self):
        return self.utility/self.n_playthrough + C*np.sqrt(np.log(self.parent.n_playthrough)/self.n_playthrough)
    
    def peek(self):
        return self.state.peek()


In [31]:
class MonteCarloAgent:
    # main function for the Monte Carlo Tree Search
    def monte_carlo_tree_search(self, state):
#         global select_t
#         global expand_t
#         global simulate_t
#         global backprop_t
#         global total_t
        
        tree = Node(state)
        for i in range(2000):
#             start = time.perf_counter()
            leaf = self.select(tree)
#             select_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            child = self.expand(leaf)
#             expand_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            result = self.simulate(child)
#             simulate_t += time.perf_counter() - start
            
#             start = time.perf_counter()
            if not result == state.turn:
                self.back_prop(0, child)
            else:
                self.back_prop(1, child)
#             backprop_t += time.perf_counter() - start
            
        return max(tree.children, key=lambda k: k.n_playthrough).peek()
    
    def select(self, tree):
        while tree.children:
            tree = self.select(max(tree.children, key=lambda k: k.UCB1()))
        
        return tree
    
    def expand(self, leaf):
        assert not leaf.children, 'This guy is not a leaf'
        _board = leaf.state.copy()
        if _board.legal_moves:
            move = random.choice(list(_board.legal_moves))
            _board.push(move) 
            leaf.children.append(Node(_board, leaf))
            return leaf.children[0]        
        # if the current node is checkmated, i.e. no legal moves
        # then return the current node 
        return leaf
    
    def simulate(self, child):
        _board = child.state.copy()
        while not _board.is_game_over():
            move = random.choice(list(_board.legal_moves))
            _board.push(move)
            
            # simulate would take forever without this if statement
            # basically the game would go on forever with 2 kings moving around without this check
            if not _board.has_insufficient_material(_board.turn) or _board.can_claim_fifty_moves():
                return not _board.turn
            
        return _board.outcome().winner
        
    def back_prop(self, result, child):
        while child:
            child.n_playthrough += 1
            child.utility += result
            child = child.parent

In [32]:
import cProfile
import pstats

env = gym.make('Chess-v0')
env.reset()
print(env.render())
done = False
agent = MonteCarloAgent()

select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

# while not done:
clear_output(wait=True)
print(env.render())
print(f'{env._observation().turn}')
if env._observation().turn:
#         start = time.perf_counter()
    with cProfile.Profile() as pr:
        action = agent.monte_carlo_tree_search(env._observation())
#         total_t += time.perf_counter() - start
else:
    action = random.choice(env.legal_moves)
observation, reward, done, _ = env.step(action)

stats = pstats.Stats(pr)
stats.sort_stats(pstats.SortKey.TIME)
stats.print_stats()

# print(reward)
env.close()
    

♜ ♞ ♝ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ♟ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
True
         64620562 function calls (62621562 primitive calls) in 28.941 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  1999000    6.024    0.000    6.024    0.000 C:\Users\engli\AppData\Local\Temp/ipykernel_21408/2438834728.py:9(UCB1)
  4000000    5.716    0.000   17.523    0.000 C:\Users\engli\AppData\Local\Programs\Python\Python38\lib\copy.py:66(copy)
  4000000    3.144    0.000    7.712    0.000 C:\Users\engli\AppData\Local\Programs\Python\Python38\lib\copy.py:258(_reconstruct)
  4002361    2.172    0.000    2.172    0.000 {built-in method __new__ of type object at 0x00007FF8EBD1A810}
  8000000    1.668    0.000    1.668    0.000 {built-in method builtins.getattr}
2001000/2000    1.420    0.000    9.083    0.005 C:\Users\engli\AppData\Local\Temp/ipykernel_21408/3232594914.py:33(select)
  4

In [22]:
f'Time to select: {select_t}'

'Time to select: 0'

In [23]:
f'Time to expand: {expand_t}'

'Time to expand: 0'

In [24]:
f'Time to simulate: {simulate_t}'

'Time to simulate: 0'

In [25]:
f'Time to backprop: {backprop_t}'

'Time to backprop: 0'

In [26]:
f'Total time: {total_t}'

'Total time: 0'