In [1]:
import chess
import gym_chess
import gym
import random
from IPython.display import display, clear_output
import time
import numpy as np

C = np.sqrt(2)


In [2]:
select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

In [3]:
class Node:
    def __init__(self, board, parent=None):
        self.state = board
        self.utility = 0
        self.n_playthrough = 0
        self.parent = parent
        self.children = list()
        
    def UCB1(self):
        return self.utility/self.n_playthrough + C*np.sqrt(np.log(self.parent.n_playthrough)/self.n_playthrough)
    
    def peek(self):
        return self.state.peek()


In [4]:
class MonteCarloAgent:
    # main function for the Monte Carlo Tree Search
    def monte_carlo_tree_search(self, state):
        global select_t
        global expand_t
        global simulate_t
        global backprop_t
        global total_t
        
        tree = Node(state)
        for i in range(2000):
            start = time.perf_counter()
            leaf = self.select(tree)
            end = time.perf_counter()
            select_t += end - start
            
            start = time.perf_counter()
            child = self.expand(leaf)
            end = time.perf_counter()
            expand_t += end - start
            
            start = time.perf_counter()
            result = self.simulate(child)
            end = time.perf_counter()
            simulate_t += end - start
            
            start = time.perf_counter()
            if not result == state.turn:
                self.back_prop(0, child)
            else:
                self.back_prop(1, child)
            end = time.perf_counter()
            backprop_t += end - start
            
        best_child = max(tree.children, key=lambda k: k.n_playthrough)
        return best_child.peek(), best_child.UCB1()
    
    def select(self, tree):
        while tree.children:
            tree = self.select(max(tree.children, key=lambda k: k.UCB1()))
        
        return tree
    
    def expand(self, leaf):
        assert not leaf.children, 'This guy is not a leaf'
        _board = leaf.state.copy()
        if _board.legal_moves:
            move = random.choice(list(_board.legal_moves))
            _board.push(move) 
            leaf.children.append(Node(_board, leaf))
            return leaf.children[0]        
        # if the current node is checkmated, i.e. no legal moves
        # then return the current node 
        return leaf
    
    def simulate(self, child):
        _board = child.state.copy()
        while not _board.is_game_over():
            move = random.choice(list(_board.legal_moves))
            _board.push(move)
            
            # simulate would take forever without this if statement
            # basically the game would go on forever with 2 kings moving around without this check
            if not _board.has_insufficient_material(_board.turn) or _board.can_claim_fifty_moves():
                return not _board.turn
            
        return _board.outcome().winner
        
    def back_prop(self, result, child):
        while child:
            child.n_playthrough += 1
            child.utility += result
            child = child.parent

In [5]:
env = gym.make('Chess-v0')
env.reset()
print(env.render())
done = False
agent = MonteCarloAgent()

select_t = 0
expand_t = 0
simulate_t = 0
backprop_t = 0
total_t = 0

while not done:
    clear_output(wait=True)
    print(env.render())
    print(f'{env._observation().turn}')
    if env._observation().turn:
        start = time.perf_counter()
        action, score = agent.monte_carlo_tree_search(env._observation())
        end = time.perf_counter()
        total_t += end - start
    else:
        action = random.choice(env.legal_moves)
    observation, reward, done, _ = env.step(action)
    
print(reward)
env.close()
    

⭘ ⭘ ♕ ⭘ ⭘ ⭘ ♝ ⭘
♖ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ♙ ⭘ ♘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ♚ ⭘ ⭘ ⭘ ⭘
⭘ ♜ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ♔ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
True


KeyboardInterrupt: 

In [6]:
f'Time to select: {select_t}'

'Time to select: 683.1487228999497'

In [7]:
f'Time to expand: {expand_t}'

'Time to expand: 450.38253839997424'

In [8]:
f'Time to simulate: {simulate_t}'

'Time to simulate: 475.9871828999956'

In [9]:
f'Time to backprop: {backprop_t}'

'Time to backprop: 52.94526730004248'

In [10]:
f'Total time: {total_t}'

'Total time: 1645.0892389000007'