In [167]:
import chess
import re
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv

In [168]:
MOVE_LIST_REGEX = r'\((.*?)\)'

board = chess.Board()
starting_pos = np.zeros((8, 8, 12), dtype=np.float32)
for square, piece in board.piece_map().items():
    piece_type = piece.piece_type
    piece_color = int(piece.color)
    starting_pos[square // 8][square % 8][piece_type - 1 + 6 * piece_color] = 1

STARTING_POS_TENSOR = starting_pos

In [170]:
def calculate_reward(board: chess.Board):
    result = board.result()
    if result == '1-0':
        return 1  # Agent wins
    elif result == '0-1':
        return -1  # Agent loses
    else:
        return 0
    
def make_move_list(board: chess.Board):
    moves = [move.strip() for move in re.findall(MOVE_LIST_REGEX, str(board.legal_moves))[0].split(',')]
    move_dict = {i: move for i, move in enumerate(moves)}
    return move_dict

class ChessEnvironment(gym.Env):
    def __init__(self):
        super().__init__()
        # Define the observation and action spaces
        # We represent the board with an 8x8 grid, then imagine one-hot encoding tensors extending
        # into the 3rd dimension of the chessboard, representing what piece is where (both side and color)
        # Since there are 6 pieces per side (pawn, knight, bishop, rook, queen, and king), there are 12 total
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(8, 8, 12), dtype=np.float32)
        # Initialize the chessboard
        self.board = chess.Board()
        # Action space 
        self.move_dict = make_move_list(self.board)
        self.action_space = gym.spaces.Discrete(len(self.move_dict))

    def sync_action_space(self):
        self.move_dict = make_move_list(self.board)
        self.action_space = gym.spaces.Discrete(len(self.move_dict))
    
    def reset(self, seed=None, options=None):
        # Reset the chessboard to the starting position
        self.board = chess.Board()
        self.sync_action_space()
        # Return the initial observation
        return STARTING_POS_TENSOR, {}

    def step(self, action):
        move = self.move_dict[action]
        # Execute the specified action on the chessboard
        self.board.push_san(move)
        self.render()
        self.sync_action_space()
        print(self.move_dict)
        print(self.action_space)

        # Convert the board to the observation format
        observation = np.zeros((8, 8, 12), dtype=np.float32)
        for square, piece in self.board.piece_map().items():
            piece_type = piece.piece_type
            piece_color = int(piece.color)
            observation[square // 8][square % 8][piece_type - 1 + 6 * piece_color] = 1


        # Calculate the reward
        reward = calculate_reward(self.board)
        # Check if the episode is done
        terminated = (self.board.result() != '*')
        # Return the observation, reward, done flag, and additional info
        return observation, reward, terminated, False, {}

    def render(self):
        print(self.board, "\n")

In [171]:
def make_chess_env():
    return ChessEnvironment()

env = DummyVecEnv([make_chess_env])

#env.action_space = env.envs[0].action_space
model = A2C("MlpPolicy", env, verbose=1)

Using cpu device


In [172]:
# Manually train the model
total_timesteps = int(1e5)
log_interval = 1000

for timestep in range(total_timesteps):
    observation = env.reset()
    done = False

    while not done:
        action, _ = model.predict(observation)
        print(f"The next action index is {action}")
        observation, reward, done, _ = env.step(action)

    if timestep % log_interval == 0:
        print(f"Timestep: {timestep}/{total_timesteps}")

# Save the trained model
model.save("a2c_chess_model")

The next action index is [12]
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . P
. . . . . . . .
P P P P P P P .
R N B Q K B N R 

{0: 'Nh6', 1: 'Nf6', 2: 'Nc6', 3: 'Na6', 4: 'h6', 5: 'g6', 6: 'f6', 7: 'e6', 8: 'd6', 9: 'c6', 10: 'b6', 11: 'a6', 12: 'h5', 13: 'g5', 14: 'f5', 15: 'e5', 16: 'd5', 17: 'c5', 18: 'b5', 19: 'a5'}
Discrete(20)
The next action index is [2]
r . b q k b n r
p p p p p p p p
. . n . . . . .
. . . . . . . .
. . . . . . . P
. . . . . . . .
P P P P P P P .
R N B Q K B N R 

{0: 'Rh3', 1: 'Rh2', 2: 'Nh3', 3: 'Nf3', 4: 'Nc3', 5: 'Na3', 6: 'h5', 7: 'g3', 8: 'f3', 9: 'e3', 10: 'd3', 11: 'c3', 12: 'b3', 13: 'a3', 14: 'g4', 15: 'f4', 16: 'e4', 17: 'd4', 18: 'c4', 19: 'b4', 20: 'a4'}
Discrete(21)
The next action index is [1]
r . b q k b n r
p p p p p p p p
. . n . . . . .
. . . . . . . .
. . . . . . . P
. . . . . . . .
P P P P P P P R
R N B Q K B N . 

{0: 'Nh6', 1: 'Nf6', 2: 'Rb8', 3: 'Nb8', 4: 'Ne5', 5: 'Na5', 6: 'Nd4', 7: 'Nb4', 8: 'h6', 9: '

KeyError: 8