In [2]:
import math
from copy import deepcopy
import random
import chess
from ChessWrapper import ChessWrapper
import time
import signal
from treelib import Node, Tree
import chess.pgn
import pandas as pd

In [3]:
!pip install treelib


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
from stockfish import Stockfish

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
class Node:
    # need to kep track of:
    # Q, N
    # parent node (need to know parent's N for UCT)
    # children
    # outcome?
    # value <- to be done using a function, as value can be updated as MCTS continues
    def __init__(self, move, parent, tag=0, explore=math.sqrt(2)):
        self.move = move # move that we played from the parent state to get to this state
        self.parent = parent # parent Node
        
        self.Q = 0 # number of wins
        self.N = 0 # number of rollouts
        self.children = {} # children nodes
        self.tag = tag # unique identifying number for each node. Used to construct trees in treelib
        self.explore = explore
        
        self.properties = {
            'move' : str(self.move),
            'Q' : self.Q,
            'N' : self.N,
            'Value' : float('inf')
        }
    
    # explore = exploration factor in the Upper Confidence bound applied to Trees (UCT)
    def value(self):
        if self.N == 0:
            # if for some reason we do not want any exploration
            if self.explore == 0: 
                return 0
            # TODO: double check, does this make sense?
            # in selection phase, we ensure that MCTS at least gives each child node a chance
            return float('inf') 
        
        # filler value, in practice .value() is never used for the root node
        # only time .value() is called for root node is during testing, when generating trees for Treelib
        if self.parent is None: 
            return self.Q / self.N 
        
        # UCB applied to trees formula
        return self.Q / self.N + self.explore * math.sqrt(math.log(self.parent.N) / self.N)
    
    # self.children is a dictionary mapping move -> resulting child
    # input: list of Nodes
    def add_children(self, children):
        for child in children:
            self.children[child.move] = child
    
    def print_children(self):
        children_str = []
        for move in self.children:
            children_str.append(str(move) + ': [' + str(self.children[move]) + ']')
        print(children_str)
    
    # used for diagnostic purposes. treelib can only print 1 data_property at a time
    # so we use a properties variable aggregating all of the useful information
    def update_properties(self):
        self.properties['Q'] = self.Q
        self.properties['N'] = self.N
        self.properties['Value'] = self.value()
    
    # print out diagnostic information about the current node
    #
    def print(self):
        print("""
        move: {move},
        Q: {Q},
        N: {N},
        value: {V}
        """.format(move=str(self.move), Q=self.Q, N=self.N, V=self.value()))
    
    def __str__(self):
        return "move: {move}, Q: {Q}, N: {N}".format(move=str(self.move), Q=self.Q, N=self.N)

In [7]:
# testing for Node class
n1 = Node(None, None)
n1.Q = 5
n1.N = 10

n2 = Node(chess.Move.from_uci('e2e4'), n1)
n2.Q = 1
n2.N = 2

n3 = Node(chess.Move.from_uci('c2c4'), n1)
n3.Q = 0
n3.N = 1

n1.add_children([n2, n3])

In [8]:
n2.print()


        move: e2e4,
        Q: 1,
        N: 2,
        value: 2.0174271293851467
        


In [9]:
n3.print()


        move: c2c4,
        Q: 0,
        N: 1,
        value: 2.1459660262893476
        


In [10]:
n1.print_children()

['e2e4: [move: e2e4, Q: 1, N: 2]', 'c2c4: [move: c2c4, Q: 0, N: 1]']


In [11]:
n1.update_properties()
n1.properties

{'move': 'None', 'Q': 5, 'N': 10, 'Value': 0.5}

In [12]:
class MCTS:
    
    def __init__(self, state, max_depth, explore=math.sqrt(2)):
        self.explore = explore
        
        self.root_node = Node(None, None, explore=self.explore)
        self.root_state = deepcopy(state)
        self.max_depth = max_depth
        
        
        self.num_nodes = 0 # use current number of nodes as a tag for the nodes that go into the treelib Tree
        self.Tree = Tree() # to show the tree using treelib
        self.Tree.create_node(str(self.root_node), self.num_nodes, data=self.root_node)

        # statistics
        self.total_rollouts = 0
        self.total_time = 0
        self.total_depth = 0
    
    # start from the root node, and traverse down the tree by best value order
    def select(self):
        curr_node = self.root_node
        curr_state = deepcopy(self.root_state)
        
        # continue until we hit a leaf
        while len(curr_node.children) != 0:
            children = curr_node.children.values()
            
            # find the max value, and then find the node candidates who have max value
            max_val = -1
            for child in children:
                max_val = max(max_val, child.value())
                
            # unsure: maybe it's possible for ties? seems unlikely but might as well account for it
            max_val_children = [] 
            for child in children:
                if abs(child.value() - max_val) <= 0.001:
                    max_val_children.append(child)

            if len(max_val_children) == 0:
                max_val_children = list(children)
                    
            curr_node = random.choice(max_val_children)
            
            curr_state.push(curr_node.move) 
            
            # unexplored, has never done a rollout
            if curr_node.N == 0: 
                return curr_node, curr_state
        
        return curr_node, curr_state
        
            
    # once we select, we try to add all of the possible children (using legal moves)
    # if this is possible, we return True
    # otherwise, we return False
    
    def expand(self, leaf_node, leaf_state, use_sf=False):
        if leaf_state.is_game_over(): # TODO: is_game_over must be part of the interface
            return None, None
        
        #next_legal_moves = leaf_state.get_legal_moves() # ORIGINAL: don't add any move valuation
        prob, next_legal_moves = leaf_state.eval_legal_moves() # try using an informed view of the legal moves?
        
        # FOR TESTING PURPOSES ONLY: get stockfish's top move and worst move
        
        def get_sf_best_worst(leaf_state):
            sf = Stockfish('/opt/homebrew/Cellar/stockfish/16/bin/stockfish')
            sf.set_fen_position(leaf_state.fen())
            lm = list(leaf_state.legal_moves)
            sf_top = sf.get_top_moves(len(lm))
            return [chess.Move.from_uci(sf_top[0]['Move']), chess.Move.from_uci(sf_top[len(lm) - 1]['Move'])]
        
        if use_sf:
            next_legal_moves = get_sf_best_worst(leaf_state) # FOR TESTING PURPOSES ONLY
            
        
        # Node(move, parent). the leaf node is the parent, and the next move gets us to the child
        
        children = [] # add the children to the leaf node
        for next_move in next_legal_moves:
            self.num_nodes += 1
            children.append(Node(next_move, leaf_node, self.num_nodes, explore=self.explore))        
        
        leaf_node.add_children(children)
        
        # try to add children to Treelib representation
        for child in leaf_node.children.values():
            self.Tree.create_node(str(child.move), child.tag, parent=leaf_node.tag, data=child)
            
        # ORIGINAL: expand using unif random policy    
        #node = random.choice(children) # pick random child to simulate
        
        # REMOVE FOR TESTING PURPOSES ONLY
        if not use_sf:
            node = random.choices(children, weights=prob)[0] # pick weighted random child to simulate based on how "good" a move is
        else:
            node = random.choices(children, weights=[.75, .25])[0] # pick best move w .75% prob, worst move w .25% prob

        node = random.choice(children)
        leaf_state.push(node.move)
        
        return node, leaf_state
    
    # if eval=True, try to evaluate the "goodness" of moves
    # otherwise, just do simple random sampling
    def rollout_policy(self, state, eval=True):
        if eval == False:
            return random.choice(state.get_legal_moves())
        else:
            prob, moves = state.eval_legal_moves()
            return random.choices(moves, weights=prob)[0]
    
    # if limit_depth=True, then we artificially terminate the rollout and do an evaluation
    def rollout(self, state, limit_depth=True):
        depth = 0
        state_copy = deepcopy(state) # IMPORTANT: deepcopy the state!
        
        # limit_depth = False: run vanilla MCTS until completion. no eval() function required
        if limit_depth == False:
            while not state_copy.is_game_over():
                state_copy.push(self.rollout_policy(state_copy))
                depth += 1
            # update the number of total rollouts performed so we can calculate average rollouts per second
            self.total_rollouts += 1
            # udpate the number of total depth searched so we can calculate average depth per rollout
            self.total_depth += depth
            return state_copy.outcome() #TODO: define state.outcome() in interface
        else:
            while depth < self.max_depth and not state_copy.is_game_over():
                state_copy.push(self.rollout_policy(state_copy))
                depth += 1
            self.total_rollouts += 1
            self.total_depth += depth
            
            static_eval = state_copy.eval() # TODO: allow for eval() as part of API
            
            #print(static_eval)
            #print(state_copy)
            
            # unsure: is this the right idea? or should i have some sort of threshold (ex: 100 centipawns?)
            # if static_eval > 100:
            #     return chess.WHITE
            # elif static_eval < 100:
            #     return chess.BLACK
            # else:
            #     return None

            if static_eval > 0:
                return (chess.WHITE, static_eval/100.0)
            elif static_eval < 0:
                return (chess.BLACK, static_eval/100.0)
            else:
                return (None, 0)

    def backprop(self, node, turn, outcome):
        # reward = 0 if outcome == turn else 1
        winner, static_eval = outcome
        reward = -abs(static_eval) if winner == turn else abs(static_eval)
        
        node_path = []
        
        # if outcome is None: # draw
        #     reward = 0.5 # TODO: should this be like 0.5 or smth??

        if winner is None:
            reward = 0
        
        while node is not None:
            node_path.append(node)
            node.N += 1
            node.Q += reward
            node = node.parent
            
            # True = white wins, False = black wins, None = draw
            # if outcome is None:
            #     reward = 0.5
            # else:
            #     reward = 1 - reward
            if winner is None:
                reward = 0
            else:
                reward = -reward
        node_path.reverse()
        for node in node_path:
            for child in node.children.values():
                child.update_properties()
            node.update_properties()
    
    # set a time limit on the search
    def search(self, time_limit, operation_limit=float('inf')):
        def handle_timeout(signum, frame):
            raise TimeoutError
        signal.signal(signal.SIGALRM, handle_timeout)
        signal.alarm(time_limit) # when the limit limit is reached, break out
        
        try:
            operations = 0
            while True and operations < operation_limit:
                leaf_node, leaf_state = self.select()
                node, state = self.expand(leaf_node, leaf_state)

                if node is None or state is None: # the game is terminated
                    node, state = leaf_node, leaf_state # TODO: does this make sense??
                #outcome = self.rollout(state) # TODO: fix this!
                outcome = self.rollout(state, limit_depth=True)
                self.backprop(node, state.get_turn(), outcome) #TODO: define to_play in the interface
                operations += 1
            
        except TimeoutError:
            self.total_time += time_limit
            #print("search completed, total: {t} seconds".format(t=time_limit))
            #print("rollouts/sec: {rps}, depth/rollout: {dpr}".format(rps=self.total_rollouts/self.total_time, dpr=self.total_depth/self.total_rollouts))
            #print("total num nodes: {num_nodes}".format(num_nodes=self.num_nodes))
        finally:
            signal.alarm(0)
        
    def move(self, move):
        if move in self.root_node.children:
            self.root_state.push(move) # move forward in the MCTS
            self.root_node = self.root_node.children[move]
            return
        else:
            self.root_state.push(move)
            self.root_node = Node(None, None, explore=self.explore)
    
    # choose the action leading to the state w the highest N (most explored node, probably most worthwhile)
    def find_best_move(self):
        if self.root_state.is_game_over():
            raise Exception("Cannot find best move, game has already terminated")
        
        children = self.root_node.children.values()
            
        # find the max value, and then find the node candidates who have max value
        max_N = -1
        for child in children:
            max_N = max(max_N, child.N)
        
        max_N_children = []

        for child in children:
            if child.N == max_N:
                max_N_children.append(child)

        return random.choice(max_N_children).move


In [13]:
def play_chess(time_limit, max_depth=10, explore=math.sqrt(2)):
    state = ChessWrapper()
    mcts = MCTS(state, max_depth, explore=explore)

    while not state.is_game_over():
        print("Current evaluation:")
        print(state.eval())
        
        print("Current state:")
        state.print()
        

        input_move = input('What is the next move? \n')
        user_move = state.parse_san(input_move)
        
        while user_move not in state.get_legal_moves():
            print("Cannot make illegal move")
            
            input_move = input('What is the next move? \n')
            user_move = state.parse_san(input_move)

        state.push(user_move)
        mcts.move(user_move)

        state.print()

        if state.is_game_over():
            print("Player one wins")
            break

        print("Searching for best move")

        mcts.search(time_limit)
        move = mcts.find_best_move()

        print("MCTS's best move: ", state.san(move))

        state.push(move)
        mcts.move(move)

        if state.is_game_over():
            print("Player two wins")
            break

In [14]:
#play_chess(10, max_depth=0, explore=10)

In [17]:
# wtm = True means alpha beta moves first, and is white
def vs_sf(elo=1600, max_depth = 0, time_limit=5, explore=math.sqrt(2), wtm=True, fen=None):
    
    
    if fen is None:
        b = ChessWrapper()

        sf = Stockfish('/opt/homebrew/Cellar/stockfish/16/bin/stockfish')
        sf.set_elo_rating(elo)
        
        mcts = MCTS(b, max_depth=max_depth, explore=explore)
    else:
        b = ChessWrapper(fen)
        
        sf = Stockfish('/opt/homebrew/Cellar/stockfish/16/bin/stockfish')
        sf.set_fen_position(fen)
        sf.set_elo_rating(elo)
        
        mcts = MCTS(b, max_depth=max_depth, explore=explore)
    
    game = chess.pgn.Game()
    game.headers["Event"] = "sf elo: " + str(elo) + ', max_depth: ' + str(max_depth) + ', time_limit: ' + str(time_limit) + ', explore: ' + str(explore)
    
    if fen is not None:
        game.setup(fen)
        
    node = None

    while True:
        if b.is_game_over():
            break
        if wtm:
            mcts.search(time_limit)
            m = mcts.find_best_move()
            #print(m)
        else:
            m = chess.Move.from_uci(sf.get_best_move())
            #print(m)

        b.push(m)
        mcts.move(m)
        sf.make_moves_from_current_position([str(m)])
        
        if node is None:
            node = game.add_variation(m)
        else:
            node = node.add_variation(m)
        
        if wtm:
            print('mcts move')
        else:
            print('sf move')

        print(b)
        print('-' * 15)
        wtm = not wtm
    return b.outcome(), game

In [22]:
#def vs_sf(elo=1600, max_depth = 0, time_limit=5, explore=math.sqrt(2), wtm=True, fen=None):

def est_elo(max_depth=0, time_limit=5, explore=math.sqrt(2)):
    
    # sicilian (open), queen's indian, reti transposed to english
    openings = [
        'rnbqkb1r/pp2pppp/3p1n2/8/3NP3/8/PPP2PPP/RNBQKB1R w KQkq - 1 5',
        'rn1qkb1r/p1pp1ppp/bp2pn2/8/2PP4/5NP1/PP2PP1P/RNBQKB1R w KQkq - 1 5',
        'rnbqk2r/ppp1ppbp/3p1np1/8/2P1P3/2N2N2/PP1P1PPP/R1BQKB1R w KQkq - 0 5'
    ]
    
    random.shuffle(openings)
    
    lo = 100
    
    df = pd.DataFrame(columns = ['max_depth', 'time_limit', 'explore', 'elo', 'outcome', 'game'])
    
    while True:
        mcts_score = 0
        sf_score = 0
        
        # play 3 pairs
        for i in range(3):
            
            print('-' * 15)
            print(max_depth, time_limit, explore, lo)
            # alpha beta plays white
            mcts_white_oc, mcts_white_game = vs_sf(elo=lo, max_depth=max_depth, time_limit=time_limit, explore=explore, wtm=True, fen=openings[i])
            
            if mcts_white_oc.winner == chess.WHITE:
                mcts_score += 1
            elif mcts_white_oc.winner == chess.BLACK:
                sf_score += 1
            else:
                mcts_score += 0.5
                sf_score += 0.5
                
            new_row = {'max_depth': max_depth, 'time_limit': time_limit, 'explore': explore, 'elo': lo, 'outcome': mcts_white_oc.winner != chess.BLACK, 'game': mcts_white_game}
            df = df._append(new_row, ignore_index=True)
            print(mcts_white_oc)
            
            # if either player cuts off early, don't need to play the rest
            if mcts_score >= 3 or sf_score >= 3.5:
                break
                
            print('-' * 15)
            print(max_depth, time_limit, explore, lo)
                
            # alpha beta plays black
            mcts_black_oc, mcts_black_game = vs_sf(elo=lo, max_depth=max_depth, time_limit=time_limit, explore=explore, wtm=False, fen=openings[i])
            
            if mcts_black_oc.winner == chess.BLACK:
                mcts_score += 1
            elif mcts_black_oc.winner == chess.WHITE:
                sf_score += 1
            else:
                mcts_score += 0.5
                sf_score += 0.5
            new_row = {'max_depth': max_depth, 'time_limit': time_limit, 'explore': explore, 'elo': lo, 'outcome': mcts_black_oc.winner != chess.WHITE, 'game': mcts_black_game}

            df = df._append(new_row, ignore_index=True)
            print(mcts_black_oc)
            
            # if either player cuts off early, don't need to play the rest
            if mcts_score >= 3 or sf_score >= 3.5:
                break
                
        if mcts_score >= 3:
            lo += 100
        else:
            break
            
    return df, lo

In [None]:
# fix q search depth, try alpha-beta depth 1, 2, 3, 4

df_results2 = pd.DataFrame(columns = ['max_depth', 'time_limit', 'explore', 'elo'])
df_games2 = pd.DataFrame(columns = ['max_depth', 'time_limit', 'explore', 'elo', 'outcome', 'game'])

for d in range(1, 10):
    df, elo = est_elo(max_depth=0, time_limit=180, explore=d)
    new_row = {'max_depth': 0, 'time_limit': 10, 'explore': d, 'elo': elo}
    df_results2 = df_results2._append(new_row, ignore_index=True)
    df_games2 = pd.concat([df_games2, df])
    print(df_results2)
    print(df_games2)

---------------
0 180 1 100
mcts move
r n b q k . . r
p p p . p p b p
. . . p . n p .
. . . . . . . .
. . P P P . . .
. . N . . N . .
P P . . . P P P
R . B Q K B . R
---------------
sf move
r n b q k . . r
. p p . p p b p
. . . p . n p .
p . . . . . . .
. . P P P . . .
. . N . . N . .
P P . . . P P P
R . B Q K B . R
---------------
mcts move
r n b q k . . r
. p p . p p b p
. . . p . n p .
p . . . . . . .
. . P P P . . .
. . N B . N . .
P P . . . P P P
R . B Q K . . R
---------------
sf move
r n . q k . . r
. p p . p p b p
. . . p . n p .
p . . . . . . .
. . P P P . b .
. . N B . N . .
P P . . . P P P
R . B Q K . . R
---------------
mcts move
r n . q k . . r
. p p . p p b p
. . . p . n p .
p . . . . . . .
Q . P P P . b .
. . N B . N . .
P P . . . P P P
R . B . K . . R
---------------
sf move
r n . q k . . r
. p p b p p b p
. . . p . n p .
p . . . . . . .
Q . P P P . . .
. . N B . N . .
P P . . . P P P
R . B . K . . R
---------------
mcts move
r n . q k . . r
. p p b p p b p
. . . p . n 