In [5]:

# print("CUDA :", torch.cuda.is_available())
# print("number of GPU :", torch.cuda.device_count())
# if torch.cuda.is_available():
#     print("GPU name :", torch.cuda.get_device_name(0))

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import chess

from tqdm.auto import trange

torch.manual_seed(0)

import random


In [44]:
# changement pour chess : action passe de int à tuple (initial_position, final_position)

class TicTacToe:

    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count

    # défini dans les champs de la classe chess 
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    # Fonction move_piece() dans chess, state = self, initial_position : pièce à bouger, final_position : pièce bougée (forme l'action),
    # player est directement dans le champ self.player
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state


    # équivalent : actions()
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    

    # à implémenter dans chess, à l'aide de in_check_possible_moves(), si le retour est vide, alors c'est un échec et mat
    # attention à vérifier l'échec avant avec check_status()

    # check_status()
    # False : action = action()
    # True : action = in_check_possible_moves

    def check_win(self, state, action):

        if action is None:
            return False

        row = action // self.row_count
        column = action % self.column_count
        player = state[row, column]

        # check row
        if np.all(state[row, :] == player):
            return True

        # check column
        if np.all(state[:, column] == player):
            return True
        
        # check diagonal
        if row == column and np.all(np.diag(state) == player):
            return True

        # check anti-diagonal
        if row + column == self.row_count - 1 and np.all(np.diag(np.fliplr(state)) == player):
            return True
        
        return False

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return state * player

    def get_encoded_state(self, state):
        encoded_state = np.stack((state == 1, state == 0, state == -1)).astype(np.float32)

        # check for batch dimension and swap axis
        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)


        return encoded_state


In [45]:
class ChessRL:
    def __init__(self):
        self.board = chess.Board()
        self.row_count = 8
        self.column_count = 8
        self.all_moves = self._generate_all_moves()
        self.action_size = len(self.all_moves)  # Taille du vecteur one-hot
        self.move_to_index = {move: i for i, move in enumerate(self.all_moves)}
        self.index_to_move = {i: move for i, move in enumerate(self.all_moves)}


    def _generate_all_moves(self):
        """Génère tous les coups UCI possibles aux échecs, indépendamment de l'état du plateau."""
        all_moves = set()
        board = chess.Board.empty()  # Échiquier vide pour générer tous les coups possibles

        # Liste de toutes les cases
        squares = list(chess.SQUARES)

        # Déplacements possibles pour chaque type de pièce
        piece_moves = {
            chess.PAWN: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if abs(from_sq - to_sq) in [8, 16, 7, 9]],
            chess.KNIGHT: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if abs(from_sq // 8 - to_sq // 8) * abs(from_sq % 8 - to_sq % 8) == 2],
            chess.BISHOP: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if abs(from_sq // 8 - to_sq // 8) == abs(from_sq % 8 - to_sq % 8)],
            chess.ROOK: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if (from_sq // 8 == to_sq // 8 or from_sq % 8 == to_sq % 8)],
            chess.QUEEN: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if (from_sq // 8 == to_sq // 8 or from_sq % 8 == to_sq % 8 or abs(from_sq // 8 - to_sq // 8) == abs(from_sq % 8 - to_sq % 8))],
            chess.KING: [chess.Move(from_sq, to_sq) for from_sq in squares for to_sq in squares if max(abs(from_sq // 8 - to_sq // 8), abs(from_sq % 8 - to_sq % 8)) == 1]
        }

        # Ajoute tous les coups non spécifiques aux règles
        for move_list in piece_moves.values():
            for move in move_list:
                all_moves.add(move.uci())

        # Ajoute les promotions de pions (blancs et noirs)
        promotion_pieces = ['q', 'r', 'b', 'n']
        for file in range(8):
            for piece in promotion_pieces:
                all_moves.add(f"{chr(97 + file)}7{chr(97 + file)}8{piece}")  # Blancs
                all_moves.add(f"{chr(97 + file)}2{chr(97 + file)}1{piece}")  # Noirs

        # Ajoute les roques
        all_moves.update(["e1g1", "e1c1", "e8g8", "e8c8"])  # Petit et grand roque

        # Ajoute les prises en passant (théoriquement possibles)
        for file in range(8):
            all_moves.add(f"{chr(97 + file)}5{chr(97 + file + (-1 if file > 0 else 1))}6")  # Blancs en passant
            all_moves.add(f"{chr(97 + file)}4{chr(97 + file + (-1 if file > 0 else 1))}3")  # Noirs en passant

        return sorted(all_moves)  # Trie pour assurer un ordre fixe

    def get_initial_state(self):
        return self.board.fen()

    def get_next_state(self, state, action, player):
        board = chess.Board(state)
        move = self.decode_move(action)
        if move in board.legal_moves:
            board.push(move)
        return board.fen()

    def get_valid_moves(self, state):
        board = chess.Board(state)
        valid_moves = np.zeros(self.action_size, dtype=np.uint8)
        for i, move in enumerate(board.legal_moves):
            idx = np.where(self.encode_move(move) == 1)[0][0]
            valid_moves[idx] = 1
        return valid_moves

    def check_win(self, state, action):
        board = chess.Board(state)
        if board.is_checkmate():
            return True
        return False

    def get_value_and_terminated(self, state, action):
        board = chess.Board(state)
        if board.is_checkmate():
            return 1, True
        if board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves():
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return self.flip_board(state)

    def get_encoded_state(self, state):
        board = chess.Board(state)
        encoded_state = np.zeros((12, 8, 8), dtype=np.float32)
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                encoded_state[piece.piece_type - 1, square // 8, square % 8] = 1 if piece.color else -1
        return encoded_state

    def flip_board(self, state):
        board = chess.Board(state)
        board.apply_mirror()
        return board.fen()

    def encode_move(self, move):
        """Encode un coup UCI en one-hot (index unique)."""

        move_uci = move.uci()
        if move_uci not in self.move_to_index:
            raise ValueError(f"Mouvement UCI inconnu : {move_uci}")
        
        one_hot = np.zeros(self.action_size, dtype=np.uint8)
        one_hot[self.move_to_index[move_uci]] = 1
        return one_hot

    def decode_move(self, index):
        """Décodage d'un index en coup UCI."""
        if index < 0 or index >= self.action_size:
            raise ValueError(f"Index de mouvement hors limite : {index}")
        
        move_uci = self.index_to_move[index]
        return chess.Move.from_uci(move_uci)


In [46]:
board = ChessRL()

state = board.get_initial_state()

print(state)

valid_moves = board.get_valid_moves(state)
move = np.where(valid_moves == 1)[0][1]


next_state = board.get_next_state(state, move, 1)

change_perspective = board.change_perspective(next_state, -1)
print(change_perspective)

print(board.get_valid_moves(change_perspective).sum())


rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
rnbqkbnr/1ppppppp/8/p7/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
20


In [47]:
# Adapté à 22 channels, tout est encodé 

class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super(ResNet, self).__init__()

        self.device = device
        self.to(device)

        self.startBlock = nn.Sequential(
            nn.Conv2d(12, num_hidden, kernel_size = 3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
            )

        self.backBone = nn.ModuleList([ResBlock(num_hidden) for i in range(num_resBlocks)])

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32*game.row_count * game.column_count, game.action_size)
            )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size = 3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3*game.row_count*game.column_count, 1),
            nn.Tanh()
            )
    
    def forward(self, x):
        x = self.startBlock(x)
        for block in self.backBone:
            x = block(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value


class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size = 3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size = 3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return self.relu(out + x)

In [72]:

class Node:
    def __init__(self, game, args, state, parent= None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior # probability of selecting this node (from the policy)

        self.children = []

        self.value_sum = 0
        self.visit_count = visit_count
    
    def is_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.calculate_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        
        return best_child

    def calculate_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2 
        return q_value + self.args['C'] * np.sqrt((self.visit_count) / (child.visit_count + 1)) * child.prior

    def expand(self, policy):

        for action, prob in enumerate(policy):
            if prob > 0:

                board = chess.Board(self.state)

                # print(board)

                board_copy = board.copy()
                child_state = board_copy.fen()

                # action = np.zeros(self.game.action_size)
                # action[action] = 1

                # print(action)

                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)
                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return child
    
    
    def backpropagation(self, value):
        node = self
        while node is not None:
            node.visit_count += 1
            node.value_sum += value

            value = self.game.get_opponent_value(value)
            node = node.parent

class MCTSParallel:

    def __init__(self, game, args,model):
        self.game = game
        self.args = args
        self.model = model.to(model.device)
    
    # Use for prediction, not training
    @torch.no_grad()
    def search(self, states, spGames):

        liste = []

        for state in states:
            liste.append(self.game.get_encoded_state(state))
        
        liste = np.array(liste)
        states_tensor = torch.tensor(
            liste,
            dtype=torch.float32
        ).to(self.model.device)


        # states_tensor = torch.tensor(
        #     self.game.get_encoded_state(states),
        #     dtype=torch.float32
        # ).to(self.model.device)
        
        policy, _ = self.model(states_tensor)
        policy = torch.softmax(policy, axis = 1).detach().cpu().numpy()

        # add some noise to the policy to encourage exploration (dirichlet noise)
        policy = (1 - self.args['epsilon']) * policy + self.args['epsilon'] * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size = policy.shape[0])


        for i, spg in enumerate(spGames):
            spg_policy = policy[i]

            valid_moves = self.game.get_valid_moves(states[i])


            # print("spg_policy", spg_policy)
            # print("valid_moves", valid_moves.sum())

            # renormalize
            spg_policy = spg_policy / np.sum(spg_policy)


            spg.root = Node(self.game, self.args, states[i], visit_count=1)

            print("spg_policy", spg_policy.sum())
            spg.root.expand(spg_policy)


        for search in range(self.args['num_searches']):
            for i, spg in enumerate(spGames):
                spg.node = None
                node = spg.root

                while node.is_expanded():

                    node = node.select()
                
                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)

                if is_terminal:
                    node.backpropagation(value)
                else: 
                    spg.node = node

            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]

            if len(expandable_spGames) > 0:
                states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])

                # Apply get_encoded_state to all states (mapping=)
                liste = []
                for state in states:
                    liste.append(self.game.get_encoded_state(state))
                liste = np.array(liste)

                states_tensor = torch.tensor(
                    liste,
                    dtype=torch.float32
                ).to(self.model.device)

                policy, value = self.model(
                    states_tensor
                )
                
                # policy, value = self.model(
                #     torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
                # )

                policy = torch.softmax(policy, axis=1).cpu().numpy()
                value = value.cpu().numpy()
            
            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                spg_policy, spg_value = policy[i], value[i]
                
                valid_moves = self.game.get_valid_moves(node.state)
                spg_policy *= valid_moves
                spg_policy /= np.sum(spg_policy)

                node.expand(spg_policy)
                node.backpropagation(spg_value)

In [70]:
class AlphaZeroParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTSParallel(game, args, model)
    
    # Play multiple games at the same time
    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_processes'])]

        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])

            liste = []
            for state in states:
                liste.append(self.game.change_perspective(state, player))
            liste = np.array(liste)

            neutral_states = liste
            
            # neutral_states = self.game.change_perspective(states, player)

            self.mcts.search(neutral_states, spGames)


            for i in range(len(spGames))[::-1]:

                spg = spGames[i]

                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs = action_probs / np.sum(action_probs)


                spg.memory.append([spg.root.state, action_probs, player])

                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)

                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                spg.state = self.game.get_next_state(spg.state, action, player)

                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    for hist_neutral_state, hist_action_probs, hist_player in spg.memory:

                        # adapted, more general and work for 1 player game
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        
                        return_memory.append([
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome
                        ])
                    del spGames[i]
                
            player = self.game.get_opponent(player)
        
        return return_memory




    def train(self, memory):
        
        random.shuffle(memory)

        for batchidx in range(0, len(memory), self.args['batch_size']):
            batch = memory[batchidx:min(len(memory) + 1,batchidx+self.args['batch_size'])]

            # list of list, convert to tensor, these are the targets for policy and value
            states, action_probs, outcomes = zip(*batch)

            states, action_probs, outcomes = np.array(states), np.array(action_probs), np.array(outcomes).reshape(-1, 1)

            states = torch.tensor(states, dtype=torch.float32).to(self.model.device)
            action_probs = torch.tensor(action_probs, dtype=torch.float32).to(self.model.device)
            outcomes = torch.tensor(outcomes, dtype=torch.float32).to(self.model.device)
            
            self.optimizer.zero_grad()

            policy, value = self.model(states)

            value_loss = F.mse_loss(value, outcomes)
            policy_loss = F.cross_entropy(policy, action_probs)

            loss = value_loss + policy_loss

            loss.backward()
            self.optimizer.step()


    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()

            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_processes']):
                memory += self.selfPlay()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f'save\model_{iteration}.pth')
            torch.save(self.optimizer.state_dict(), f'save\optimizer_{iteration}.pth')

# self play game
class SPG:
    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None

In [73]:
game = ChessRL()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 4, 64, device)

# temperature ->  eploitation / exploration tradeoff, same role as gamma in Q-learning. High temperature -> more exploration (rd distribution), low temperature -> more exploitation (peak distribution)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
 
args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 10,
    'num_processes': 10,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZeroParallel(model, optimizer, game, args)
alphaZero.learn()



  0%|          | 0/1 [00:00<?, ?it/s]

spg_policy 1.0


InvalidMoveError: invalid uci (use 0000 for null moves): 'a2a2'