In [1]:
import chess
import re 
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.policies import ActorCriticPolicy
from torch.nn import ReLU

In [421]:
MOVE_LIST_REGEX = r'\((.*?)\)'
MAX_MOVES_IN_ANY_POS = 218

board = chess.Board()
starting_pos = np.zeros((8, 8, 12), dtype=np.uint8)
for square, piece in board.piece_map().items():
    piece_type = piece.piece_type
    piece_color = int(piece.color)
    starting_pos[square // 8][square % 8][piece_type - 1 + 6 * piece_color] = 1

STARTING_POS_TENSOR = starting_pos

In [422]:
def get_king_square(board: chess.Board, white: bool):
    for square, piece in board.piece_map().items():
        if white:
            if piece == chess.KING and piece.color == chess.WHITE:
                return square
        else:
            if piece == chess.KING and piece.color == chess.BLACK:
                return square
            
KNIGHT_DANGER = 0.8
BISHOP_DANGER = 0.7
ROOK_DANGER = 0.5
QUEEN_DANGER = 0.2
def get_king_tropism(board: chess.Board):
    weighted_danger = 0
    num_enemy_pieces = 0
    relevant_pieces_dict = {}
    if board.turn == chess.BLACK: # ie White just made a move
        # Locate white king and black pieces
        w_king_square = get_king_square(board, True)
        for square, piece in board.piece_map().items():
            if piece.color == chess.BLACK:
                match piece:
                    case chess.KNIGHT: # We put the value in a list because there could be multiple knights
                        relevant_pieces_dict[chess.KNIGHT] += [
                            KNIGHT_DANGER * chess.square_knight_distance(square, w_king_square)]
                    case chess.BISHOP:
                        relevant_pieces_dict[chess.BISHOP] += [
                            BISHOP_DANGER * chess.square_distance(square, w_king_square)]
                    case chess.ROOK:
                        relevant_pieces_dict[chess.ROOK] += [
                            ROOK_DANGER * chess.square_distance(square, w_king_square)]
                    case chess.QUEEN:
                        relevant_pieces_dict[chess.QUEEN] += [
                            QUEEN_DANGER * chess.square_distance(square, w_king_square)]

    else:
        b_king_square = get_king_square(board, False)
        for square, piece in board.piece_map().items():
            if piece.color == chess.WHITE:
                match piece:
                    case chess.KNIGHT: # We put the value in a list because there could be multiple pieces of this type
                        relevant_pieces_dict[chess.KNIGHT] += [
                            KNIGHT_DANGER * chess.square_knight_distance(square, b_king_square)]
                    case chess.BISHOP:
                        relevant_pieces_dict[chess.BISHOP] += [
                            BISHOP_DANGER * chess.square_distance(square, b_king_square)]
                    case chess.ROOK:
                        relevant_pieces_dict[chess.ROOK] += [
                            ROOK_DANGER * chess.square_distance(square, b_king_square)]
                    case chess.QUEEN:
                        relevant_pieces_dict[chess.QUEEN] += [
                            QUEEN_DANGER * chess.square_distance(square, b_king_square)]
    
    for _, dist_list in relevant_pieces_dict.items():
        for weighted_distance in dist_list:
            weighted_danger += weighted_distance
            num_enemy_pieces += 1
    
    return weighted_danger / num_enemy_pieces if num_enemy_pieces > 0 else 0

In [423]:
def make_move_list(board: chess.Board):
    moves = [move.strip() for move in re.findall(MOVE_LIST_REGEX, str(board.legal_moves))[0].split(',')]
    move_dict = {i: move for i, move in enumerate(moves)}
    return move_dict

class ChessEnvironment(gym.Env):
    def __init__(self):
        super().__init__()
        # Define the observation and action spaces
        # We represent the board with an 8x8 grid, then imagine one-hot encoding tensors extending
        # into the 3rd dimension of the chessboard, representing what piece is where (both side and color)
        # Since there are 6 pieces per side (pawn, knight, bishop, rook, queen, and king), there are 12 total
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(8, 8, 12), dtype=np.uint8)
        # Initialize the chessboards
        self.board = chess.Board()
        self.last_board = self.board
        # Action space 
        self.move_dict = make_move_list(self.board)
        self.action_space = gym.spaces.MultiDiscrete([8, 8, 73])

        self.piece_value_dict = {
            chess.PAWN: 1,
            chess.KNIGHT: 3,
            chess.BISHOP: 3.25, # Maybe change this
            chess.ROOK: 5,
            chess.QUEEN: 9,
        }
        # Hyperparameters (for reward function)
        self.flex_weight = 0.25
        self.material_weight = 0.5
        self.king_safety_weight = 0.25
        self.move_penalty_base = 0.01
        self.draw_base = 0.01
        self.move_list = []

        # Policy kwargs for A2C model
        self.policy_kwargs = dict(
            net_arch=dict(pi=[128, 128, 64], vf=[128, 128, 64]),
            activation_fn=ReLU,
            #use_sde=True, # Some errors getting 'two values provided' error, look into this later
        )
    
    def flexibility_reward(self): # Reward computer for making moves that limit opponents options
        legal_move_count = -len(list(self.board.legal_moves))
        # Normalizing
        legal_move_count = legal_move_count / MAX_MOVES_IN_ANY_POS

        return legal_move_count * self.flex_weight
    
    def get_net_material(self, board: chess.Board):
        difference = 0
        for _, piece in board.piece_map().items():
            value = self.piece_value_dict.get(piece.piece_type, 0) # default to 0 if no piece on square
            if piece.color == chess.WHITE:
                difference += value
            else:
                difference -= value
        return difference

    def material_reward(self):
        difference = self.get_net_material(self.board) 
        # bring extreme outliers in line with rest of data (since once you're that far ahead it doesn't really matter)
        # Also normalizing the value
        difference = np.clip(difference, -15.0, 15.0) / 15
        return difference * self.material_weight
    
    def king_safety_reward(self):
        # We're going to use something called "King Tropism", where we calculate the distance from enemy pieces to the king
        # Note: this technique I'm just fully yoinking from chessprogramming.org/King_Safety
        # Go check them out if you're trying something similar to this, they are very helpful
        val = get_king_tropism(self.board)
        if val > 0:
            val = 1 / val # get reciprocal since small values mean pieces are close ie danger is high, large values mean pieces are far away ie danger low
        return val * self.king_safety_weight
        
    
    def calculate_reward(self):
        # Technically board.turn is True when it is White's turn, and false when Black's, but when this fn is called, 
        # the move has already been made on the board, and thus the turn player has swapped. This will allow us to
        # evaluate the previous move from the perspective of that side
        evaluatingBlack = bool(self.board.turn)
        game_status = self.board.result()
        if (evaluatingBlack and game_status == '0-1') or (not evaluatingBlack and game_status == '1-0'): # Agent delivered checkmate on the previous move
            # We're gonna normalize the position evaluation heuristics so this is 8x the max value for that
            return 8.0 
        elif game_status == '1/2-1/2': #Draw
            # Check the ply before the draw for the material imbalance--we want to reward draws found while down material
            # while punishing draws while up material
            # Will be positive if white has more material, else negative
            mat_difference = self.get_net_material(self.last_board)
            #Drew with more material, negative reward scaled w how much material
            if (evaluatingBlack and mat_difference < 0) or (not evaluatingBlack and mat_difference > 0):
                return -1 * abs(mat_difference) * self.draw_base
            # Drew with less material, positive reward scaled with the material difference
            elif (evaluatingBlack and mat_difference > 0) or (not evaluatingBlack and mat_difference < 0):
                return abs(mat_difference) * self.draw_base
        # In case where no checkmate has been achieved, use heuristics
        flex_score = self.flexibility_reward()
        material_score = self.material_reward()
        if evaluatingBlack: # Invert value if eval'ing Black so we're not feeding negative into reward val
            material_score *= -1
        king_safety_score = self.king_safety_reward()
        move_penalty = self.board.ply() * self.move_penalty_base
        return (flex_score + material_score - king_safety_score) / 3 - move_penalty
        
            

    def sync_action_space(self):
        self.move_dict = make_move_list(self.board)
        self.action_space = gym.spaces.Discrete(len(self.move_dict))
    
    def reset(self, seed=None, options=None):
        # Reset the chessboard to the starting position
        self.board = chess.Board()
        self.sync_action_space()
        self.move_list = []
        # Return the initial observation
        return STARTING_POS_TENSOR, {}

    def step(self, action):
        move = self.move_dict[action]
        self.move_list.append(move)
        # Save the board as is before making the move for reward calculation
        self.last_board = self.board
        # Execute the specified action on the chessboard
        self.board.push_san(move)
        self.sync_action_space()

        # Convert the board to the observation format
        observation = np.zeros((8, 8, 12), dtype=np.uint8)
        for square, piece in self.board.piece_map().items():
            piece_type = piece.piece_type
            piece_color = int(piece.color)
            observation[square // 8][square % 8][piece_type - 1 + 6 * piece_color] = 1


        # Calculate the reward
        reward = self.calculate_reward()
        # Check if the episode is done
        terminated = (self.board.result() != '*')
        # Return the observation, reward, done flag, and additional info
        return observation, reward, terminated, False, {}

    def render(self):
        out_str = ""
        move_num = 1
        for ind, move in enumerate(self.move_list):
            if ind % 2 == 0:
                out_str += f"{move_num}.{move} "
                move_num += 1
            else:
                out_str += f"{move} "
        
        print(out_str)

In [424]:
env = ChessEnvironment()
model = A2C(ActorCriticPolicy, env, policy_kwargs=env.policy_kwargs, verbose=1)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [425]:
# Manually train the model
total_timesteps = int(1e5)
checkpoint_timestep = int(1e4)
log_interval = 100

for timestep in range(total_timesteps):
    observation, _ = env.reset()
    done = False
    feature_game = timestep % log_interval == 0
    save_checkpoint_model = timestep % checkpoint_timestep == 0 and timestep != 0

    while not done:
        action, _ = model.predict(observation)
        # JANKY SOLUTION: IF ISSUES LOOK HERE!!!
        # For some reason when one side is in check there's indexing issues so
        available_moves = len(env.move_dict)
        while action >= available_moves:
            # Make more predictions until one is in range of the indices
            action, _ = model.predict(observation)
        observation, reward, done, _, _ = env.step(int(action))
    if feature_game:
        env.render()
        print(f"Games completed: {timestep}/{total_timesteps}")
    
    if save_checkpoint_model:
        model.save("a2c_chess_model" + str(timestep) + "_games_exp")

# Save the trained model
model.save("a2c_chess_model_v1")

1.f3 h5 2.g3 g6 3.b3 Bh6 4.Nh3 f6 5.Na3 c6 6.Nb5 Bg5 7.Nc7+ Kf7 8.Nxa8 Qa5 9.Nb6 Nh6 10.Na4 Ke8 11.Nc3 Qc7 12.Bg2 Rf8 13.Na4 Qd8 14.Rg1 Kf7 15.f4 Ng8 16.Nb2 b6 17.Bf1 Bh6 18.e3 Bb7 19.Nf2 Ba8 20.Qg4 b5 21.Qh3 Qa5 22.Nbd3 Rc8 23.Rg2 Bf8 24.Be2 Qb4 25.Qg4 Ke8 26.Ne5 Bh6 27.Nxg6 Bf8 28.Qxd7+ Kf7 29.Qxc8 Nh6 30.Qxf8+ Kxg6 31.Qh8 Nd7 32.Qg8+ Nxg8 33.Bb2 Nb6 34.Bd4 Nd5 35.Bb6 Qa5 36.Bf1 Qa3 37.Kd1 Qxb3 38.Bd8 Nc7 39.Ke2 Qxe3+ 40.Kxe3 Bb7 41.Nd3 Na6 42.Bb6 Bc8 43.Nf2 Nc7 44.Kd4 Kf5 45.Kc3 b4+ 46.Kb2 Nb5 47.Bd3+ Ke6 48.Bc7 Kf7 49.Be2 Nc3 50.Bb6 Nd1+ 51.Kb3 Kg6 52.Nd3 Be6+ 53.c4 Bd7 54.Nc1 Nc3 55.Bc5 Nb5 56.Bxe7 Bg4 57.Bc5 Nc7 58.Bg1 Nb5 59.Bf1 Kg7 60.Kc2 Nc3 61.Kd3 Kh7 62.h3 Be2+ 63.Kc2 Bg4 64.Re2 Nb5 65.Re8 Kh8 66.Bc5 Bf5+ 67.Bd3 a6 68.Bd6 Be4 69.Rd8 a5 70.Rxg8+ Kxg8 71.Ne2 Nd4+ 72.Nxd4 Bg6 73.Bf5 c5 74.Kb2 Bxf5 75.Rg1 Kh7 76.Nc2 b3 77.Ne1 Bxh3 78.Bf8 Kg8 79.Bh6 bxa2 80.Rh1 Kh7 81.g4 a1=R 82.Bf8 Rc1 83.gxh5 Bf5 84.Rg1 Rxe1 85.Bd6 Re2 86.Ka3 Be6 87.Rg3 Bd7 88.Rf3 Bc6 89.Be7 Re4 90.Re3 Kg8 91.R

KeyboardInterrupt: 

In [436]:
model = A2C.load('a2c_chess_model_v1')
old_model = A2C.load('a2c_chess_model10000_games_exp.zip')

In [447]:
def make_move(move, board: chess.Board):
    board.push_san(move)
    return board

def get_obs_from_board(board: chess.Board):
    observation = np.zeros((8, 8, 12), dtype=np.uint8)
    for square, piece in board.piece_map().items():
        piece_type = piece.piece_type
        piece_color = int(piece.color)
        observation[square // 8][square % 8][piece_type - 1 + 6 * piece_color] = 1
    return observation

def play_games(model1, model2, num_games=10):
    wins_model1, wins_model2, draws = 0, 0, 0
    move_counts = []

    for game in range(num_games):
        obs = STARTING_POS_TENSOR
        env = ChessEnvironment()

        done = False
        while not done:
            if env.board.is_checkmate() or env.board.can_claim_draw():
                done = True
                draws += 1
                move_counts.append(env.board.ply() // 2)
                env.render()
                break

            # Model 1 (100k games model)
            env.sync_action_space()
            action_idx, _ = model1.predict(obs)
            available_moves = env.board.legal_moves.count()
            if available_moves == 0:
                # If no available moves, end the game
                env.render()
                done = True
                continue
            while action_idx >= available_moves:
                # Make more predictions until one is in range of the indices
                action_idx, _ = model1.predict(obs)

            white_move = env.move_dict[int(action_idx)]
            #print("Model 1 (100k games) move: ", black_move)
            make_move(white_move, env.board)
            obs = get_obs_from_board(env.board)

            if env.board.is_checkmate():
                done = True
                env.render()
                move_counts.append(env.board.ply() // 2)
                wins_model2 += 1
                break

            # Model 2 (10k games model)
            env.sync_action_space()
            action_idx, _ = model2.predict(obs)
            available_moves = env.board.legal_moves.count()
            while action_idx >= available_moves:
                # Make more predictions until one is in range of the indices
                action_idx, _ = model2.predict(obs)

            black_move = env.move_dict[int(action_idx)]
            #print("Model 2 (10k games) move: ", black_move)
            make_move(black_move, env.board)
            obs = get_obs_from_board(env.board)

            if env.board.is_checkmate():
                done = True
                env.render()
                move_counts.append(env.board.ply() // 2)
                wins_model1 += 1
                break

            env.sync_action_space()

        print(f"Game {game + 1}/{num_games} completed.")

    print("Results:")
    print(f"Model 1 (100k games) wins: {wins_model1}")
    print(f"Model 2 (10k games) wins: {wins_model2}")
    print(f"Draws: {draws}")
    return move_counts

# Load the models
model1 = A2C.load('a2c_chess_model_v1')
model2 = A2C.load('a2c_chess_model10000_games_exp.zip')

# Play the games
counts = play_games(model1, model2, num_games=10)
print(counts)
    


Game 1/10 completed.

Game 2/10 completed.

Game 3/10 completed.

Game 4/10 completed.

Game 5/10 completed.

Game 6/10 completed.

Game 7/10 completed.

Game 8/10 completed.

Game 9/10 completed.

Game 10/10 completed.
Results:
Model 1 (100k games) wins: 0
Model 2 (10k games) wins: 1
Draws: 8
[223, 191, 192, 234, 274, 205, 234, 48, 231]
