In [5]:
import chess
import gym_chess
import gym
import random
from IPython.display import display, clear_output
import time
import numpy as np

C = np.sqrt(2)

In [13]:
class Node:
    def __init__(self, board, parent=None):
        self.state = board
        self.utility = 0
        self.n_playthrough = 0
        self.parent = parent
        self.children = list()
        
    def UCB1(self):
        if self.n_playthrough == 0:
            return 0
        return self.utility/self.n_playthrough + C + np.sqrt(np.log(parent.n_playthrough)/self.n_playthrough)
    

In [19]:
class MonteCarloAgent:
    
    # main function for the Monte Carlo Tree Search
    def monte_carlo_tree_search(self, state):
        tree = Node(state)
        for i in range(10):
            leaf = self.select(tree)
            child = self.expand(leaf)
            result = self.simulate(child)
            
            if result.winner is state.turn:                
                self.back_prop(1, child)
            else:
                self.back_prop(0, child)
                
        return self.select(tree).state.peek()
    
    def select(self, tree):
        if tree.children:
            return max(tree.children, key=lambda k: k.UCB1())
        else:
            return self.expand(tree)
    
    def expand(self, leaf):
        for move in leaf.state.legal_moves:
            leaf.state.push(move)
            leaf.children.append(Node(leaf.state.copy()))
            leaf.state.pop()
        return random.choice(leaf.children)
    
    def simulate(self, child):
        _board = child.state.copy()
        while not _board.is_game_over():
            move = random.choice(list(_board.legal_moves))
            _board.push(move)
        return _board.outcome()
        
    def back_prop(self, result, child):
        if not child:
            return
        child.n_playthrough += 1
        child.utility += result
        self.back_prop(result, child.parent)

In [25]:
env = gym.make('Chess-v0')
env.reset()
print(env.render())
done = False
agent = MonteCarloAgent()

while not done:
    clear_output(wait=True)
    action = agent.monte_carlo_tree_search(env._observation())
    observation, reward, done, _ = env.step(action)   
    print(env.render())
    

♜ ♞ ♝ ♛ ♚ ♘ ⭘ ♜
♟ ♟ ♟ ♟ ♟ ♟ ♟ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ♞
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ⭘ ♖
