In [None]:
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

from tqdm.notebook import trange
import random

!pip3 install python-chess
import chess

In [116]:
# The wrapper here makes sense because we need special stuff for the neural network, like get_encoded_state


#TODO: implement change_perspective as flipping the board, but use the input plane to indicate the current player.
# TODO: define game.action_size 
# TODO: define number of feature planes
# TODO: implement all the stuff we saw in TicTacToe, will defer lots of it to python-chess, get_encoded_state will need to represent the board the way AlphaZero does.

class ChessEncoder:

    def generate_input_planes(self, board_history):
        """
        Generates all 119 input planes for the AlphaZero chess neural network.
        
        :param board_history: A list of board states, with the most recent state last.
        :param current_player: The current player (always presented as P1)
        :return: A numpy array of shape (119, 8, 8), assuming an 8x8 chess board.
        

        Structure of 119 planes returned:
            - 8x (representing up to 7 historical states, 1 current state), ordered from past to present:
                - 6 P1 Piece Planes
                - 6 P2 Piece Planes
                - 2 Repetition Planes
            - 1 Current player color plane
            - 1 total move count plane
            - 2 P1 Castling Planes, for Kingside and Queenside Castling
            - 2 P2 Castling Planes, for Kingside and Queenside Castling
            - 1 No-progress count plane

        """
        all_planes = []

        max_history_length = 8 # 7 previous board states + 1 current state

        current_board = board_history[-1]
        current_player_color = current_board.turn

        for i in range(-max_history_length, 0):
            if i < -len(board_history):
                # For board states before the start of the game, add empty planes
                all_planes.extend([np.zeros((8, 8)) for _ in range(14)])
            else:
                all_planes.extend(self.encode_state_at_time(board_history[i], current_player_color))
        
        # Add additional game state planes (assuming functions for each)
        # These include player color, total move count, and special rules (castling, no-progress, etc.)
        all_planes.append(self.encode_player_color_plane(current_player_color))
        all_planes.append(self.encode_total_move_count_plane(current_board))
        all_planes.extend(self.encode_castling_planes(current_board))
        all_planes.append(self.encode_no_progress_count_plane(current_board))
        
        # Convert list of planes into a 3D numpy array (for the neural network)
        return np.stack(all_planes)
        

    # Encodes the 14 planes for a specific time step.
    def encode_state_at_time(self, board, current_player_color):
        planes = []

        opponent_player_color = not current_player_color

        # Add 6 planes for P1's pieces (P1 is the current player)
        for piece_type in chess.PIECE_TYPES:
            planes.append(self.encode_piece_plane(board, piece_type, current_player_color))

        # Add 6 planes for P2's pieces (P2 is the opponent)
        for piece_type in chess.PIECE_TYPES:
            planes.append(self.encode_piece_plane(board, piece_type, opponent_player_color))

        # Add 2 repetition planes
        planes.extend(self.encode_repetition_planes(board))

        return planes

    def encode_piece_plane(self, board, piece_type, color):
        plane = np.zeros((8,8))

        for board_idx in board.pieces(piece_type=piece_type, color=color):
            row, col = divmod(board_idx, 8)
            plane[row, col] = 1
        
        return plane


    def encode_repetition_planes(self, board):
        planes = []

        if board.is_repetition(2): # one-fold repitition
            planes.append(np.ones((8,8)))
            if board.is_repetition(3): #two-fold repitition
                planes.append(np.ones((8,8)))
            else:
                planes.append(np.zeros((8,8)))
        else:
            planes.append(np.zeros((8,8)))
            planes.append(np.zeros((8,8)))

        return planes

    # I like the symmetry of using 1 and -1 (from player) here, rather than 1 and 0
    def encode_player_color_plane(self, current_player_color):
        if current_player_color == chess.WHITE:
            player = 1
        elif (current_player_color == chess.BLACK):
            player = -1
        else:
            raise 'Missing board turn'
        
        return np.full((8,8), player)

    def encode_total_move_count_plane(self, board):
        return np.full((8,8), board.fullmove_number)

    def encode_castling_planes(self, board):
        planes = []

        current_player_color = board.turn
        opponent_player_color = not current_player_color

        planes.append(np.full((8,8), int(board.has_kingside_castling_rights(current_player_color))))
        planes.append(np.full((8,8), int(board.has_queenside_castling_rights(current_player_color))))

        planes.append(np.full((8,8), int(board.has_kingside_castling_rights(opponent_player_color))))
        planes.append(np.full((8,8), int(board.has_queenside_castling_rights(opponent_player_color))))

        return planes

    def encode_no_progress_count_plane(self, board):
        return np.full((8,8), board.halfmove_clock)
    

# Input planes:
    # 8x history of:
        # 6 P1 Piece Planes


In [42]:
print(list(chess.PIECE_TYPES))
print(chess.PAWN)
print(chess.KNIGHT)
print(chess.BISHOP)
print(chess.ROOK)
print(chess.QUEEN)
print(chess.KING)
print(chess.WHITE)
print(chess.BLACK)

print(chess.SQUARES)

[1, 2, 3, 4, 5, 6]
1
2
3
4
5
6
True
False
range(0, 64)


In [117]:
# Setting up basic self play so I get encodings working
import chess.svg
from IPython.display import display, HTML, clear_output
import numpy as np

encoder = ChessEncoder()

# Initialize the chess board
board = chess.Board()

def display_board(board, use_svg=True):
    if use_svg:
        return display(HTML(chess.svg.board(board=board, size=400)))
    else:
        print(board)

        
# def play_move_interactive():
#     display_board(board)
#     move = input("Enter your move: ")
#     try:
#         board.push_san(move)
#     except ValueError as e:
#         print(f"Invalid move: {e}")
#     clear_output(wait=True)
#     display_board(board)

# play_move_interactive()

# board = board.transform(chess.flip_vertical)
# board = board.transform(chess.flip_vertical)
# display_board(board)

# Confirmed, board.transform(chess.flip_vertical) works.


board = chess.Board()

encoded_state = encoder.generate_input_planes([board])

print(encoded_state[0])

In [None]:
#TODO: use this cell to get it doing 1v1 and possibly pure MCTS play... seems maybe better than jumping right into the NN part of it?

In [None]:
#TODO: this needs updates.  board_size is not moves like in TicTacToe or Chess.  Need to actually think about the State we're passing in.. proly matching the alphazero.
    # TODO: define feature_planes, which the initial convolution depends on.
    # TODO: verify that we're passing in 8x8 sized feature planes, then we can just use game.row_count, game.column_count

class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_filters, device):
        super().__init__()
        
        self.device = device
        feature_planes = 3 # This will differ once we move beyond tic tac toe and connect four into chess
        board_size = game.row_count * game.column_count
        self.startBlock = nn.Sequential(
            nn.Conv2d(feature_planes, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_filters) for i in range(num_resBlocks)]
        )

        # Note, I changed the output channels on policy head from 32->2 compared to the code.  Also removed padding.
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_filters, 2, kernel_size=1),
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2 * board_size, game.action_size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_filters, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(board_size, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value

class ResBlock(nn.Module):
    def __init__(self, num_filters):
        super().__init__()
        self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [None]:
# QSTN: Will we pass along states using FEN?
# QSTN: do we want to change the game perspective like we've been doing?  So that the player always thinks they're p1 (white?).
# TODO: we'll need to translate between the outputs of the neural network and specific chess actions

class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior

        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0

    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            # We want minimal q_value for child, because it's our opponent.
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1) # The board is always from perspective of P1 moving
                child_state = self.game.change_perspective(child_state)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return child
    
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        value = self.game.get_opponent_value(value)
        if self.parent is not None: # True outside root node
            self.parent.backpropagate(value)

class MCTSParallel:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
    
    @torch.no_grad() # Don't use MCTS for training neural network parameters
    def search(self, states, spGames):
        # Increased temp for start policy
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()
        # Dirichlet noise
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])
        
        for i, spg in enumerate(spGames):
            spg_policy = policy[i]
            valid_moves = self.game.get_valid_moves(states[i])
            spg_policy *= valid_moves
            spg_policy /= np.sum(spg_policy)

            spg.root = Node(self.game, self.args, states[i], visit_count=1)
            spg.root.expand(spg_policy)

        for _ in range(self.args['num_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root

                # Selection
                while node.is_fully_expanded():
                    node = node.select()
                
                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value) # The value is from the perspective of the opponent of the person who made the move.

                if is_terminal:
                    # Back Propagation
                    node.backpropagate(value)
                else:
                    spg.node = node

            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]

            if len(expandable_spGames) > 0:
                states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])
                
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
                )
                policy = torch.softmax(policy, axis=1).cpu().numpy()

            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                spg_policy, spg_value = policy[i], value[i]

                valid_moves = self.game.get_valid_moves(node.state)
                spg_policy *= valid_moves # only consider legal moves
                spg_policy /= np.sum(spg_policy)

                # Expansion
                node.expand(spg_policy)

                node.backpropagate(spg_value)

In [None]:
class AlphaZeroParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTSParallel(game, args, model)
    
    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SelfPlayGame(self.game) for spg in range(self.args['num_parallel_games'])]

        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])

            neutral_states = self.game.change_perspective(states, player)
            self.mcts.search(neutral_states, spGames)

            for i in range(len(spGames))[::-1]:
                spg = spGames[i]

                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                spg.memory.append((spg.root.state, action_probs, player))

                # Introduce temperature into probs
                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                spg.state = self.game.get_next_state(spg.state, action, player)

                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    for hist_neutral_state, hist_action_probs, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome
                        ))
                    del spGames[i]
            
            player = self.game.get_opponent(player)

            return return_memory

    def train(self, memory):
        # Want to shuffle training data to avoid getting same batches all the time
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) -1,batchIdx + self.args['batch_size'])]
            state, policy_targets, value_targets = zip(*sample)

            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(state)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)

            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()



    # Main method, runs self-play, and uses that data for training.
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                memory += self.selfPlay()

            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

class SelfPlayGame:
    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None

In [None]:
# Eval, for when the time comes...

game = Chess()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZeroParallel(model, optimizer, game, args)
alphaZero.learn()