In [1]:
import chess
import chess.svg
import chess.engine
from IPython.display import SVG, display, clear_output
from typing import Tuple, Dict, Hashable, List, Any
import numpy as np
from tqdm import trange
import tensorflow as tf
from abc import ABC, abstractmethod

In [2]:
def get_hash(board: chess.Board) -> Hashable:
    return (board.pawns, board.knights, board.bishops, board.rooks,
            board.queens, board.kings,
            board.occupied_co[chess.WHITE], board.occupied_co[chess.BLACK],
            board.turn, board.clean_castling_rights(),
            board.ep_square if board.has_legal_en_passant() else None)

NUM_PROMOTION_DIRS = 3
NUM_SINGLE_SQUARE_STEPS = 7

ROOK_PROMOTION_START_IDX = 0
BISHOP_PROMOTION_START_IDX = ROOK_PROMOTION_START_IDX + NUM_PROMOTION_DIRS
KNIGHT_PROMOTION_START_IDX = BISHOP_PROMOTION_START_IDX + NUM_PROMOTION_DIRS

HORIZONTAL_MOVE_START_IDX = KNIGHT_PROMOTION_START_IDX + NUM_PROMOTION_DIRS
VERTICAL_MOVE_START_IDX = HORIZONTAL_MOVE_START_IDX + 2 * NUM_SINGLE_SQUARE_STEPS
DIAGONAL_MOVE_START_IDX = VERTICAL_MOVE_START_IDX + 2 * NUM_SINGLE_SQUARE_STEPS
KNIGHT_MOVE_START_IDX = DIAGONAL_MOVE_START_IDX + 4 * NUM_SINGLE_SQUARE_STEPS

def move_to_policy_idx(move: chess.Move, is_black: bool) -> int:
    """Convert a move to an int in [0, 73) representing the index in the policy."""
    rank_diff = chess.square_rank(move.to_square) - chess.square_rank(move.from_square)
    file_diff = chess.square_file(move.to_square) - chess.square_file(move.from_square)
    abs_rank_diff, abs_file_diff = abs(rank_diff), abs(file_diff)

    if is_black:
        rank_diff *= -1
        # file_diff *= -1

    # Under-promotion
    if move.promotion is not None and move.promotion is not chess.QUEEN:
        sub_idx = file_diff % 3 # Straight: 0, Right: 1, Left: 2
        if move.promotion is chess.ROOK:
            return ROOK_PROMOTION_START_IDX + sub_idx
        if move.promotion is chess.BISHOP:
            return BISHOP_PROMOTION_START_IDX + sub_idx
        return KNIGHT_PROMOTION_START_IDX + sub_idx
    
    # Horizontal
    if rank_diff == 0:
        if file_diff < 0:
            return HORIZONTAL_MOVE_START_IDX + -file_diff - 1
        return HORIZONTAL_MOVE_START_IDX + NUM_SINGLE_SQUARE_STEPS + file_diff - 1
    
    # Vertical
    if file_diff == 0:
        if rank_diff < 0:
            return VERTICAL_MOVE_START_IDX + -rank_diff - 1
        return VERTICAL_MOVE_START_IDX + NUM_SINGLE_SQUARE_STEPS + rank_diff - 1
    
    # Diagonal
    if abs_rank_diff == abs_file_diff:
        if file_diff < 0:
            if rank_diff > 0: # Northwest
                return DIAGONAL_MOVE_START_IDX + rank_diff - 1
            else: # Southwest
                return DIAGONAL_MOVE_START_IDX + NUM_SINGLE_SQUARE_STEPS + -rank_diff - 1
        else:
            if rank_diff > 0: # Northeast
                return DIAGONAL_MOVE_START_IDX + 2 * NUM_SINGLE_SQUARE_STEPS + rank_diff - 1
            else: # Southeast
                return DIAGONAL_MOVE_START_IDX + 3 * NUM_SINGLE_SQUARE_STEPS + -rank_diff - 1
    
    # Knight move
    if file_diff < 0:
        if rank_diff > 0: # Northwest
            if abs_rank_diff > abs_file_diff:
                return KNIGHT_MOVE_START_IDX
            else:
                return KNIGHT_MOVE_START_IDX + 1
        else: # Southwest
            if abs_rank_diff > abs_file_diff:
                return KNIGHT_MOVE_START_IDX + 2
            else:
                return KNIGHT_MOVE_START_IDX + 3
    else:
        if rank_diff > 0: # Northeast
            if abs_rank_diff > abs_file_diff:
                return KNIGHT_MOVE_START_IDX + 4
            else:
                return KNIGHT_MOVE_START_IDX + 5
        else: # Southeast
            if abs_rank_diff > abs_file_diff:
                return KNIGHT_MOVE_START_IDX + 6
            else:
                return KNIGHT_MOVE_START_IDX + 7

EPS = 1e-3

class ChessState:
    def __init__(self, board: chess.Board=None, transposition_table: Dict[Hashable, int]=None) -> None:
        if board is None:
            self.board = chess.Board()
        else:
            self.board = board
        
        if transposition_table is None:
            self.transposition_table = {get_hash(self.board): 1}
        else:
            self.transposition_table = transposition_table
    
    def get_next_state(self, action: chess.Move) -> 'ChessState':
        next_board = self.board.copy()
        next_board.push(action)

        next_transposition_table = self.transposition_table.copy()
        next_hash = get_hash(next_board)
        if next_hash not in next_transposition_table:
            next_transposition_table[next_hash] = 1
        else:
            next_transposition_table[next_hash] += 1
        
        return ChessState(next_board, next_transposition_table)
    
    def get_valid_moves(self) -> chess.LegalMoveGenerator:
        return self.board.legal_moves
    
    def get_value_and_terminated(self) -> Tuple[int, bool]:
        """Value is from parent's perspective."""
        outcome = self.board.outcome(claim_draw=True)

        if outcome is None:
            return 0, False
        if outcome.winner is None:
            return 0, True
        return 1, True
    
    def get_player(self) -> chess.Color:
        return self.board.turn
    
    def show_state(self, orientation: chess.Color) -> None:
        display(SVG(
            chess.svg.board(self.board, size=400, orientation=orientation)
        ))

    def get_encoded_state(self) -> np.ndarray:
        if self.get_player() is chess.BLACK:
            board = self.board.mirror()
        else:
            board = self.board
        
        encoding = []

        # Piece positions
        for piece in chess.PIECE_TYPES:
            encoding.append(np.reshape(board.pieces(piece, chess.WHITE).tolist(), (8, 8)).astype(np.float64))
        for piece in chess.PIECE_TYPES:
            encoding.append(np.reshape(board.pieces(piece, chess.BLACK).tolist(), (8, 8)).astype(np.float64))

        # Castling rights
        encoding.append(board.has_kingside_castling_rights(chess.WHITE) * np.ones((8, 8)))
        encoding.append(board.has_queenside_castling_rights(chess.WHITE) * np.ones((8, 8)))
        encoding.append(board.has_kingside_castling_rights(chess.BLACK) * np.ones((8, 8)))
        encoding.append(board.has_queenside_castling_rights(chess.BLACK) * np.ones((8, 8)))

        # Number of occurrences (for 3-fold repetition)
        encoding.append(self.transposition_table[get_hash(self.board)] * np.ones((8, 8)))
        # 50 move rule counter
        encoding.append(board.halfmove_clock * np.ones((8, 8)) / 100) # TODO: Do we floor divide halfmove_clock by 2?
        # Total move counter
        encoding.append(board.fullmove_number * np.ones((8, 8)) / 50)

        return np.transpose(np.array(encoding), (1, 2, 0))
    
    def decode_policy(self, policy: np.ndarray) -> Dict[chess.Move, float]:
        """Convert policy (given as if we were white) of shape (8, 8, 73) to dictionary mapping legal moves to probabilities."""
        is_black = self.get_player() is chess.BLACK
        probs: Dict[chess.Move, float] = {}
        total_prob = 0

        for move in self.get_valid_moves():
            from_rank, from_file = chess.square_rank(move.from_square), chess.square_file(move.from_square)
            if self.get_player() is chess.BLACK:
                from_rank = 7 - from_rank
            prob = policy[from_rank, from_file, move_to_policy_idx(move, is_black)]
            probs[move] = prob
            total_prob += prob

        for move, prob in probs.items():
            probs[move] = prob / (total_prob + EPS)

        return probs
    
    def encode_policy(self, probs: Dict[chess.Move, float]) -> np.ndarray:
        """Convert dictionary mapping legal move to probabilities to policy of shape (8, 8, 73) as if we were white."""
        is_black = self.get_player() is chess.BLACK
        policy = np.zeros((8, 8, 73))

        for move, prob in probs.items():
            from_rank, from_file = chess.square_rank(move.from_square), chess.square_file(move.from_square)
            if is_black:
                from_rank = 7 - from_rank

            policy[from_rank, from_file, move_to_policy_idx(move, is_black)] = prob

        return policy

In [10]:
class Node:
    def __init__(self, state: ChessState, args: dict, parent: 'Node'=None, action_taken: chess.Move=None, prior: float=0, visit_count: int=0) -> None:
        self.state = state
        self.args = args
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children: List['Node'] = []

        self.visit_count = visit_count
        self.value_sum = 0

    def is_fully_expanded(self) -> bool:
        return len(self.children) > 0
    
    def select(self) -> 'Node':
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child

        return best_child

    def get_ucb(self, child: 'Node') -> float:
        if child.visit_count == 0:
            q = 0
        else:
            # value_sum / visit_count can be between -1 and 1, we want between 0 and 1 like probability
            # child's perspective is opposite self's perspective so we flip with 1 - q
            q = 1 - (child.value_sum / child.visit_count + 1) / 2
        return q + self.args['C'] * np.sqrt(self.visit_count) / (1 + child.visit_count) * child.prior
    
    def expand(self, policy: Dict[chess.Move, float]):
        for action, prob in policy.items():
            if prob > 0:
                child_state = self.state.board.copy()
                child_state = self.state.get_next_state(action)
                
                child = Node(child_state, self.args, self, action, prob)
                self.children.append(child)
            
    def backprop(self, value: int) -> None:
        self.visit_count += 1
        self.value_sum += value

        if self.parent is not None:
            self.parent.backprop(value * -1)

class Model(ABC):
    @abstractmethod
    def __call__(self, state: ChessState) -> Tuple[Dict[chess.Move, float], float]:
        pass

    @abstractmethod
    def fit(
            self,
            states: List[np.ndarray],
            policies: List[np.ndarray], 
            values: List[np.ndarray],
            epochs: int,
            callbacks: List[tf.keras.callbacks.Callback]=[]
        ) -> Any:
        pass

class MCTS:
    def __init__(self, args: dict, model: Model) -> None:
        self.args = args
        self.model = model

    def search(self, state: ChessState) -> Dict[chess.Move, float]:
        root = Node(state, self.args, visit_count=1)

        policy, _ = self.model(state)
        eps, alpha = self.args['dirichlet_eps'], self.args['dirichlet_alpha']
        rng = np.random.default_rng()
        total_prob = 0
        for move in policy:
            policy[move] = (1 - eps) * policy[move] + eps * rng.dirichlet([alpha], 1)
            total_prob += policy[move]
        for move in policy:
            policy[move] /= total_prob

        root.expand(policy)

        for _ in range(self.args['num_searches']):
            # Selection
            node = root
            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = node.state.get_value_and_terminated()
            # value is returned from node.parent's perspective
            # if node.parent's perspective is same as root's perspective, then flip sign so that node's value is correct
            # i.e., if node's perspective is different from root's presepctive, then flip sign
            if node.state.get_player() is not root.state.get_player():
                value *= -1

            if not is_terminal:
                # policy and value are from node's perspective
                policy, value = self.model(node.state)
                node.expand(policy)

            # Backprop
            node.backprop(value)

        # Return normalized visit counts
        visit_count = {move: 0 for move in root.state.get_valid_moves()}
        total_visit_count = 0
        for child in root.children:
            visit_count[child.action_taken] = child.visit_count
            total_visit_count += child.visit_count

        action_probs = {move: visit_count[move] / total_visit_count for move in visit_count}
        return action_probs

In [4]:
class Stockfish(Model):
    def __init__(self) -> None:
        self.model = chess.engine.SimpleEngine.popen_uci('/opt/homebrew/opt/stockfish/bin/stockfish')

    def __call__(self, state: ChessState) -> Tuple[Dict[chess.Move, float], float]:
        """Policy and value are given from this node's perspective."""
        value = self.model.analyse(state.board, chess.engine.Limit(depth=0))['score'].white().wdl().expectation() * 2 - 1
        if state.get_player() is not chess.WHITE:
            value *= -1

        next_values = {}
        total_value = 0

        for move in state.get_valid_moves():
            next_board = state.board.copy()
            next_board.push(move)
            next_value = self.model.analyse(next_board, chess.engine.Limit(depth=0))['score'].white().wdl().expectation()
            if state.get_player() is not chess.WHITE:
                next_value = 1 - next_value
            
            next_values[move] = next_value
            total_value += next_value

        if total_value == 0:
            return {}, value

        policy = {move: next_values[move] / total_value for move in next_values}

        return policy, value

## MCTS Test with Stockfish Eval

In [None]:
state = ChessState()
args = {
    # 'C': np.sqrt(2),
    'C': 2,
    'num_searches': 1000,
    'dirichlet_eps': 0.25,
    'dirichlet_alpha': 0.3
}
model = Stockfish()
mcts = MCTS(args, model)
human = chess.BLACK

while True:
    clear_output()
    state.show_state(human)

    if state.get_player() is human:
        while True:
            try:
                action = str(input('Move (algebraic notation): '))
                if action == 'quit':
                    break
                action = state.board.parse_san(action)
            except:
                continue
            break
    else:
        mcts_probs = mcts.search(state)
        action = max(mcts_probs, key=mcts_probs.get)

    state = state.get_next_state(action)

    value, is_terminal = state.get_value_and_terminated()
    if is_terminal:
        clear_output()
        state.show_state(human)
        if value == 0:
            print('Draw')
        else:
            # Winner is the previous player
            winner = 'White' if state.get_player() is not chess.WHITE else 'Black'
            print(f'{winner} won')

        break

In [5]:
class ResNetBlock(tf.keras.Model):
    def __init__(self, num_filters):
        super().__init__()

        self.conv_block = tf.keras.Sequential([
            tf.keras.layers.Conv2D(num_filters, (3, 3), padding='same'),
            tf.keras.layers.BatchNormalization(axis=-1),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Conv2D(num_filters, kernel_size=(3, 3), padding='same'),
            tf.keras.layers.BatchNormalization(axis=-1)
        ])
        self.add = tf.keras.layers.Add()
        self.relu = tf.keras.layers.Activation('relu')

    def call(self, inputs):
        inputs_skip = inputs

        inputs = self.conv_block(inputs)
        inputs = self.add([inputs, inputs_skip])
        inputs = self.relu(inputs)

        return inputs

class ResNet(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.torso = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, (3, 3), padding='same', activation='relu', input_shape=(8, 8, 19)),
            ResNetBlock(256)
        ])
        self.policy_head = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, (1, 1), activation='relu'),
            tf.keras.layers.Conv2D(73, (1, 1), activation=None),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Activation('softmax')
        ])
        self.value_head = tf.keras.Sequential([
            tf.keras.layers.Conv2D(1, (1, 1), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(1, activation='tanh')
        ])

    def call(self, inputs):
        torso_output = self.torso(inputs)
        policy = self.policy_head(torso_output)
        value = self.value_head(torso_output)
        return {'policy_output': policy, 'value_output': value}
    

In [6]:
model = ResNet()
model.build((None, 8, 8, 19))
model.summary()

2023-07-04 18:16:34.910207: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2023-07-04 18:16:34.910231: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-07-04 18:16:34.910238: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-07-04 18:16:34.910275: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-07-04 18:16:34.910295: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Model: "res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_1 (Sequential)   (None, 8, 8, 256)         1226240   
                                                                 
 sequential_2 (Sequential)   (None, 4672)              84553     
                                                                 
 sequential_3 (Sequential)   (None, 1)                 17154     
                                                                 
Total params: 1327947 (5.07 MB)
Trainable params: 1326923 (5.06 MB)
Non-trainable params: 1024 (4.00 KB)
_________________________________________________________________


In [7]:
class ResNetModel(Model):
    def __init__(self) -> None:
        self.model = ResNet()
        self.model.compile(
            optimizer='Adam',
            loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'},
            metrics={'policy_output': 'categorical_accuracy', 'value_output': 'mean_squared_error'}
        )

    def __call__(self, state: ChessState) -> Tuple[Dict[chess.Move, float], float]:
        """Policy and value are given from this node's perspective."""
        outputs = self.model(np.array([state.get_encoded_state()]))
        policy, value = outputs['policy_output'], outputs['value_output']
        policy = np.reshape(policy, (8, 8, -1))
        return state.decode_policy(policy), value.numpy().item()
    
    def fit(
            self,
            states: List[np.ndarray],
            policies: List[np.ndarray], 
            values: List[np.ndarray],
            epochs: int,
            callbacks: List[tf.keras.callbacks.Callback]=[]
        ) -> Any:
        
        return self.model.fit(
            x=np.array(states),
            y={'policy_output': np.array(policies).reshape((len(policies), -1)), 'value_output': np.array(values)},
            epochs=epochs,
            validation_split=0,
            shuffle=True,
            callbacks=callbacks,
            verbose=1
        )

In [8]:
class Engine:
    def __init__(self, model: Model, args) -> None:
        self.model = model
        self.args = args
        self.mcts = MCTS(args, model)

    def self_play(self) -> Tuple[List[np.ndarray], List[np.ndarray], List[int]]:
        state_history: List[ChessState] = []
        policy_history: List[np.ndarray] = []

        state = ChessState()
        while True:
            action_probs = self.mcts.search(state)

            state_history.append(state)
            policy_history.append(state.encode_policy(action_probs).flatten())

            # Higher temperature => squishes probabilities together => encourages more exploration
            temperature_action_probs = np.array(action_probs.values()) ** (1 / self.args['temperature'])
            action = np.random.choice(list(action_probs.keys()), p=temperature_action_probs)
            state = state.get_next_state(action)
            value, is_terminal = state.get_value_and_terminated()

            if is_terminal:
                return (
                    [past_state.get_encoded_state() for past_state in state_history],
                    policy_history,
                    [value if past_state.get_player() is state.get_player() else -value for past_state in state_history]
                )

    def learn(self):
        for i in range(self.args['num_learn_iters']):
            state_memory: List[np.ndarray] = []
            policy_memory: List[np.ndarray] = []
            value_memory: List[int] = []

            # Self-play
            for _ in trange(self.args['num_self_play_iters'], desc='Self-play'):
                state_history: List[ChessState] = []
                policy_history: List[np.ndarray] = []

                state = ChessState()
                while True:
                    action_probs = self.mcts.search(state)

                    state_history.append(state)
                    policy_history.append(state.encode_policy(action_probs))

                    action = np.random.choice(list(action_probs.keys()), p=list(action_probs.values()))
                    state = state.get_next_state(action)
                    value, is_terminal = state.get_value_and_terminated()

                    if is_terminal:
                        encoded_state_history = [past_state.get_encoded_state() for past_state in state_history]
                        value_history = [value if past_state.get_player() is state.get_player() else -value for past_state in state_history]
                        break

                state_memory += encoded_state_history
                policy_memory += policy_history
                value_memory += value_history

            # Train
            cp_callback = tf.keras.callbacks.ModelCheckpoint(f'learning/cp{i}.ckpt')
            self.model.fit(state_memory, policy_memory, value_memory, self.args['num_epochs'], [cp_callback])

In [11]:
model = ResNetModel()
args = {
    'C': 2,
    'num_searches': 3,
    'num_learn_iters': 3,
    'num_self_play_iters': 3,
    'num_epochs': 3,
    'temperature': 1.25,
    'dirichlet_eps': 0.25,
    'dirichlet_alpha': 0.3
}

engine = Engine(model, args)
engine.learn()

Self-play: 100%|██████████| 3/3 [00:13<00:00,  4.52s/it]

Epoch 1/3



2023-07-04 18:19:00.033846: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




INFO:tensorflow:Assets written to: learning/cp0.ckpt/assets


Epoch 2/3


INFO:tensorflow:Assets written to: learning/cp0.ckpt/assets


Epoch 3/3


INFO:tensorflow:Assets written to: learning/cp0.ckpt/assets




Self-play: 100%|██████████| 3/3 [01:04<00:00, 21.41s/it]

Epoch 1/3
 3/24 [==>...........................] - ETA: 0s - loss: 2.9893 - policy_output_loss: 2.8304 - value_output_loss: 0.1590 - policy_output_categorical_accuracy: 0.3542 - value_output_mean_squared_error: 0.1590






INFO:tensorflow:Assets written to: learning/cp1.ckpt/assets


Epoch 2/3


INFO:tensorflow:Assets written to: learning/cp1.ckpt/assets


Epoch 3/3


INFO:tensorflow:Assets written to: learning/cp1.ckpt/assets




Self-play: 100%|██████████| 3/3 [00:07<00:00,  2.40s/it]

Epoch 1/3
1/6 [====>.........................] - ETA: 0s - loss: 1.5122 - policy_output_loss: 1.4811 - value_output_loss: 0.0310 - policy_output_categorical_accuracy: 0.5312 - value_output_mean_squared_error: 0.0310






INFO:tensorflow:Assets written to: learning/cp2.ckpt/assets


Epoch 2/3


INFO:tensorflow:Assets written to: learning/cp2.ckpt/assets


Epoch 3/3


INFO:tensorflow:Assets written to: learning/cp2.ckpt/assets


