## **AlphaZero Algorithm**

#### Imports

In [None]:
import numpy as np

import math

import torch

import torch.nn as nn

import torch.nn.functional as F

torch.manual_seed(0) # set the seed as 0

from tqdm.notebook import trange

import random

import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter


In [None]:
%run MCTS_Attax.ipynb
%run NeuralNetwork_Attax.ipynb 

#### AlphaZero Class

##### When evaluating an AlphaZero algorithm it's important to focus on both the performance of the algorithm in learning and playing the game, as well as the efficiency and effectiveness of its learning process. The performance evalutaion metric used was the follow: Loss Metrics: Monitor the policy loss and value loss during training to understand how well the model is learning and generalizing from the self-play data. For analyzing and visualing data it was created a TensorBoard Session.


The graphic representation contains:

 **Policy Loss:**
   - This loss measures how well the policy head of your neural network predicts the correct action to take at each step. In the context of games, the policy generally represents the probability distribution over possible moves.
   - The policy loss is typically calculated using a cross-entropy loss function between the predicted probabilities and the actual distribution of moves from the self-play data.
   - In AlphaZero, the policy network guides the search by providing a prior probability to the Monte Carlo Tree Search (MCTS) algorithm.

 **Value Loss:**
   - The value loss measures how accurately the value head of the neural network estimates the expected outcome (win, loss, or draw) from a given board state.
   - It is usually calculated using mean squared error (MSE) loss, which compares the predicted value to the actual game outcome.
   - In AlphaZero, the value estimate helps the MCTS evaluate board states without having to simulate all the way to the end of the game.

**Total Loss:**
   - The total loss is the sum of the policy loss and the value loss. It represents the overall performance of the neural network in both predicting the next best move (policy) and estimating the game's outcome from the current position (value).
   - Minimizing the total loss is the goal of the training process, as it leads to improvements in the model's policy and value predictions, which should translate into stronger gameplay performance.


- The model's predictions are obtained by passing the current state through the model (`self.model(state)`).
- The `policy_loss` is computed by comparing the model's policy output to the target policy probabilities using cross-entropy loss.
- The `value_loss` is computed by comparing the model's value output to the actual game outcome using mean squared error loss.
- The `loss` variable is the sum of `policy_loss` and `value_loss`, and represents the total loss that needs to be minimized.
- The `zero_grad()` method clears old gradients, the `loss.backward()` computes the gradient of the loss with respect to the model parameters, and `optimizer.step()` updates the model parameters to minimize the loss.

The losses are appended to lists (`policy_losses`, `value_losses`, `total_losses`) which can be used for visualization, to monitor the training process. 

Early stopping and validation in machine learning are techniques used to prevent overfitting, which occurs when a model learns the training data too well, including noise and details that do not generalize to unseen data. Here's how they work in our AlphaZero implementation:

### Early Stopping:

Early stopping monitors the model's performance on a validation set during training. A validation set is a portion of the data that the model has never seen during training and is used solely to assess the model's generalization performance.

Here's the flow of early stopping in our code:

It started by defining the best validation loss as infinity and setting patience, which is the number of epochs to wait for an improvement in the validation loss before stopping the training.

During each epoch of training, after updating the model with the training data (the self-play games in your case), it was calculated the loss on the validation set.

The `calculate_validation_loss` function does a forward pass of the validation data through the model and computes the loss without performing any backpropagation or updating the model's weights.

It was compared the current validation loss to the best validation loss seen so far. If the current validation loss is lower, it means the model is performing better on the validation set, and then is updated the best validation loss with this new value and reset the patience counter.
If the current validation loss doesn't improve, is incremented a counter. If this counter reaches the patience threshold set earlier, is triggered early stopping, meaning to stop training to avoid overfitting.

### Validation:

Validation is performed by setting aside a part of the dataset that the model does not learn from during the training process. This set is used to evaluate the model's performance and to ensure that it can generalize well to new, unseen data.

In our code, it was splited the collected self-play memory into training and validation sets. The model learns from the training set but is evaluated on the validation set. By evaluating the model on data it has not learned from, it can assess how well the model is expected to perform in a real-world scenario or when playing new games.

Together, early stopping and validation help ensure that your AlphaZero model doesn't just memorize the training data but learns patterns that are generalizable to new data it hasn't seen before, which is crucial for creating a robust AI player.

In [None]:
class AlphaZero:

    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.game = game
        self.optimizer = optimizer
        self.args = args
        self.mcts = MCTS(game, args, 1, model)
        self.current_state = self.game.get_initial_state()

        #to start the tensorboard use this command in the terminal: tensorboard --logdir=runs
        self.writer = SummaryWriter('runs/alphazero_experiment')      


    def selfPlay(self):
        memory = []
        player = 1
        
        state = self.game.get_initial_state()

        while True:
            neutral_state = self.game.change_perspective(state, player)
            
            action_probs = self.mcts.search(neutral_state)
            
            memory.append((neutral_state, action_probs, player))
            
            if not action_probs: # if it is not empty which means if the game is not over
                action = np.random.choice(list(action_probs.keys()), p=list(action_probs.values()))
                state = self.game.get_next_state(state, action)
                value = 0
                is_terminal = False

            else:
                value = 1
                is_terminal = True
            

            if is_terminal:

                returnMemory = []

                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    # print("ciclo for")
                    hist_outcome = value if hist_player == player else -value

                    returnMemory.append((
                        self.game.get_encoded_state(neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))


                return returnMemory
            
            player = -player

    def update_state(self, move_opponent):

        # Assumindo que a jogada do oponente é uma string no formato "x1y1_x2y2"

        # Dividir a string para obter as coordenadas iniciais e finais
        coords = move_opponent.split('_')
        xi, yi = int(coords[0][0]), int(coords[0][1])
        xf, yf = int(coords[1][0]), int(coords[1][1])

        # Criar um objeto de movimento
        opponent_move = self.game.movement(xi, yi, xf, yf, -1, 0)  # -1 para oponente, tipo 0 inicialmente

        # Verificar se o movimento é válido e atualizar o tipo de movimento
        if self.game.is_valid_move(self.current_state, opponent_move):
            # Aplicar o movimento ao estado atual do jogo
            self.current_state = self.game.get_next_state(self.current_state, opponent_move)
        else:
            # Lidar com movimentos inválidos (por exemplo, ignorar ou registrar um erro)
            pass
        

    def train(self, memory):

        global_step = 0
        #shuffle our trainning data
        random.shuffle(memory)

        # Initialize lists to store losses for visualization
        policy_losses = []
        value_losses = []
        total_losses = []

        #loop over batches
        for batchIdx in range(0, len(memory), self.args['batch_size']):

            sample = memory[batchIdx:min(len(memory)-1, batchIdx + self.args['batch_size'])]
            state, policy_targets, value_targets = zip(*sample) #transpose our sample arround

            statee, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device = self.model.device)
    
            policy_values = [list(policy_dict.values()) for policy_dict in policy_targets]
           
            policy_targets = torch.tensor(policy_values, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device = self.model.device)

            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            # Log losses to TensorBoard
            self.writer.add_scalar('Loss/Policy', policy_loss.item(), global_step)
            self.writer.add_scalar('Loss/Value', value_loss.item(), global_step)
            self.writer.add_scalar('Loss/Total', loss.item(), global_step)


            # Log losses
            policy_losses.append(policy_loss.item())
            value_losses.append(value_loss.item())
            total_losses.append(loss.item())
            
            #minimize the loss by backpropagating it
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

             # Increment the global step count
            global_step += 1

        return policy_losses, value_losses, total_losses

    def visualize_losses(self, policy_losses, value_losses, total_losses):
        plt.figure(figsize=(10, 5))
        plt.plot(policy_losses, label='Policy Loss')
        plt.plot(value_losses, label='Value Loss')
        plt.plot(total_losses, label='Total Loss')
        plt.xlabel('Training Steps')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.show()

  
    
    def calculate_validation_loss(self, validation_memory):
        self.model.eval()  # Set the model to evaluation mode

        total_loss = 0.0
        total_batches = 0

        # Define your loss functions based on your model's output and targets
        value_criterion = torch.nn.MSELoss(reduction='sum')
        policy_criterion = torch.nn.CrossEntropyLoss(reduction='sum')

        with torch.no_grad():  # No need to track gradients for validation
            # Ensure validation_memory is a list of batches
            for batch in validation_memory:  # Unpack each batch directly
                
                states, policy_targets, value_targets = zip(*batch)


                # Convert to tensors
                states_tensor = torch.tensor(states, dtype=torch.float32, device=self.model.device)
                policy_values = [list(policy_dict.values()) for policy_dict in policy_targets]
                policy_targets_tensor = torch.tensor(policy_values, dtype=torch.float32, device=self.model.device)
                value_targets_tensor = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

                # Forward pass
                out_policy, out_value = self.model(states_tensor)

                # Calculate loss
                value_loss = value_criterion(out_value, value_targets_tensor)
                policy_loss = policy_criterion(out_policy, policy_targets_tensor)
                loss = value_loss + policy_loss

                # Accumulate the loss
                total_loss += loss.item()
                total_batches += 1

        average_loss = total_loss / total_batches if total_batches > 0 else 0
        return average_loss



    def learn(self):

        all_policy_losses = []
        all_value_losses = []
        all_total_losses = []

         # Early stopping criteria
        best_validation_loss = float('inf')
        patience = 10  
        patience_counter = 0

        for iterations in range(self.args['num_iterations']):
            memory = []
            self.model.eval()

            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']): #trange was used so that we can visualise the progress bars
                
                memory += self.selfPlay()

            # Shuffle and split memory for training and validation
            random.shuffle(memory)
            split_index = int(len(memory) * 0.9)
            train_memory = memory[:split_index]
            validation_memory_unbatched = memory[split_index:]
            validation_memory = [validation_memory_unbatched[i:i + self.args['batch_size']] for i in range(0, len(validation_memory_unbatched), self.args['batch_size'])]


            self.model.train() 

            for epoch in trange(self.args['num_epochs']):


                # Train on the memory from self-play
                policy_losses, value_losses, total_losses = self.train(train_memory)
                
                all_policy_losses.extend(policy_losses)
                all_value_losses.extend(value_losses)
                all_total_losses.extend(total_losses)

                # Calculate validation loss
                current_validation_loss = self.calculate_validation_loss(validation_memory)

                self.writer.add_scalar('Loss/Validation', current_validation_loss, iterations)


                # Early stopping check
                if current_validation_loss < best_validation_loss:
                    best_validation_loss = current_validation_loss
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print("Early stopping triggered")
                    break
        


                # At the end of all epochs, visualize the losses, win rates
            self.visualize_losses(all_policy_losses, all_value_losses, all_total_losses)


            torch.save(self.model.state_dict(), f"model{iterations}Attax{self.game.N}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer{iterations}Attax{self.game.N}.pt")


        self.writer.close()  # Close the TensorBoard writer


### Versões antigas

In [None]:
# class AlphaZero:
#     def __init__(self, model, optimizer, game, args):
#         self.model = model
#         self.game = game
#         self.optimizer = optimizer
#         self.args = args
#         self.mcts = MCTS(game, args, 1, model)

#     def selfPlay(self):
#         memory = []
#         player = 1
#         state = self.game.get_initial_state()

#         while True:
#             neutral_state = self.game.change_perspective(state, player)
            
#             action_probs = self.mcts.search(neutral_state)
            
#             memory.append((neutral_state, action_probs, player))
            
#             if not action_probs: # if it is not empty which means if the game is not over
#                 action = np.random.choice(list(action_probs.keys()), p=list(action_probs.values()))
#                 state = self.game.get_next_state(state, action)
#                 value = 0
#                 is_terminal = False

#             else:
#                 value = 1
#                 is_terminal = True

            
#             if is_terminal:
#                 returnMemory = []
#                 for hist_neutral_state, hist_action_probs, hist_player in memory:
#                     # print("ciclo for")
#                     hist_outcome = value if hist_player == player else -value
#                     returnMemory.append((
#                         self.game.get_encoded_state(neutral_state),
#                         hist_action_probs,
#                         hist_outcome
#                     ))
#                 return returnMemory
            
#             player = -player

#     def train(self, memory):
#         #shuffle our trainning data
#         random.shuffle(memory)

#         #loop over batches
#         for batchIdx in range(0, len(memory), self.args['batch_size']):
#             sample = memory[batchIdx:min(len(memory)-1, batchIdx + self.args['batch_size'])]
#             state, policy_targets, value_targets = zip(*sample) #transpose our sample arround

#             state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

#             state = torch.tensor(state, dtype=torch.float32, device = self.model.device)
#             # print(policy_targets)
#             # policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device = self.model.device)
#             policy_values = [list(policy_dict.values()) for policy_dict in policy_targets]
#             policy_targets = torch.tensor(policy_values, dtype=torch.float32, device=self.model.device)
#             value_targets = torch.tensor(value_targets, dtype=torch.float32, device = self.model.device)

#             out_policy, out_value = self.model(state)

#             policy_loss = F.cross_entropy(out_policy, policy_targets)
#             value_loss = F.mse_loss(out_value, value_targets)
#             loss = policy_loss + value_loss

#             #minimize the loss by backpropagating it
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()

#     def learn(self):
#         for iterations in range(self.args['num_iterations']):
#             # print("iteration number: ", iterations)
#             memory = []

#             self.model.eval()

#             for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']): #trange was used so that we can visualise the progress bars
#                 # print("selfplay iterations: ", selfPlay_iteration)
#                 memory += self.selfPlay()

#             self.model.train() 
#             for epoch in trange(self.args['num_epochs']):
#                 # print("epoch", epoch)
#                 self.train(memory)

#             #store the weights of the model
#             torch.save(self.model.state_dict(), f"model_{iterations}_Attax.pt")
#             torch.save(self.optimizer.state_dict(), f"optimizer_{iterations}_Attax.pt")

In [None]:
# class AlphaZero:
#     def __init__(self, model, optimizer, game, args):
#         self.model = model
#         self.game = game
#         self.optimizer = optimizer
#         self.args = args
#         self.mcts = MCTS(game, args, 1, model)

#     def selfPlay(self):
#         memory = []
#         player = 1
#         state = self.game.get_initial_state()

#         while True:
#             neutral_state = self.game.change_perspective(state, player)
#             action_probs = self.mcts.search(neutral_state)
#             print(action_probs)
#             memory.append((neutral_state, action_probs, player))

#             action = np.random.choice(list(action_probs.keys()), p=list(action_probs.values()))

#             state = self.game.get_next_state(state, action)

#             value,_, is_terminal = self.game.get_value_and_terminated(state, player)

#             if is_terminal:
#                 returnMemory = []
#                 for hist_neutral_state, hist_action_probs, hist_player in memory:
#                     hist_outcome = value if hist_player == player else -value
#                     returnMemory.append((
#                         self.game.get_encoded_state(neutral_state),
#                         hist_action_probs,
#                         hist_outcome
#                     ))
#                 return returnMemory
            
#             player = -player

#     def train(self, memory):
#         #shuffle our trainning data
#         random.shuffle(memory)

#         #loop over batches
#         for batchIdx in range(0, len(memory), self.args['batch_size']):
#             sample = memory[batchIdx:min(len(memory)-1, batchIdx + self.args['batch_size'])]
#             state, policy_targets, value_targets = zip(*sample) #transpose our sample arround

#             state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

#             state = torch.tensor(state, dtype=torch.float32, device = self.model.device)
#             policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device = self.model.device)
#             value_targets = torch.tensor(value_targets, dtype=torch.float32, device = self.model.device)

#             out_policy, out_value = self.model(state)

#             policy_loss = F.cross_entropy(out_policy, policy_targets)
#             value_loss = F.mse_loss(out_value, value_targets)
#             loss = policy_loss + value_loss

#             #minimize the loss by backpropagating it
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()

#     def learn(self):
#         for iterations in range(self.args['num_iterations']):
#             memory = []

#             self.model.eval()

#             for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']): #trange was used so that we can visualise the progress bars
#                 memory += self.selfPlay()

#             self.model.train() 
#             for epoch in trange(self.args['num_epochs']):
#                 self.train(memory)

#             #store the weights of the model
#             torch.save(self.model.state_dict(), f"model_{iterations}_{self.game}.pt")
#             torch.save(self.optimizer.state_dict(), f"optimizer_{iterations}_{self.game}.pt")

##### Código antigo do tensorboard

In [None]:
# To use TensorBoard with your AlphaZero implementation, you'll need to make use of the TensorBoard logging through `SummaryWriter` from the `torch.utils.tensorboard` module. Below is a step-by-step guide to integrate TensorBoard into your existing AlphaZero code.

# 1. **Install TensorBoard**:
#    Ensure you have TensorBoard installed in your environment. If not, you can install it via pip:

#     ```bash
#     pip install tensorboard
#     ```

# 2. **Import SummaryWriter**:
#    At the beginning of your script, import the necessary TensorBoard class:

#     ```python
#     from torch.utils.tensorboard import SummaryWriter
#     ```

# 3. **Initialize SummaryWriter**:
#    Create a `SummaryWriter` instance at the beginning of your training script. This object will be used to write logs into a directory that TensorBoard will later read from.

#     ```python
#     writer = SummaryWriter('runs/alphazero_experiment_1')
#     ```

# 4. **Log Data**:
#    Throughout your training loop and other functions, use the `writer` to log data, such as loss, accuracy, or custom metrics. Here's how you might incorporate it into your AlphaZero training loop:

#     ```python
#     def train(self, memory):
#         # ... [your existing code] ...
        
#         # Loop over batches
#         for batchIdx in range(0, len(memory), self.args['batch_size']):
#             # ... [your existing code] ...
            
#             # After optimizer step, log the losses
#             writer.add_scalar('Loss/policy', policy_loss.item(), global_step)
#             writer.add_scalar('Loss/value', value_loss.item(), global_step)
#             writer.add_scalar('Loss/total', loss.item(), global_step)
            
#             # Increment your global step counter
#             global_step += 1
#     ```

# 5. **Log Custom Metrics and Visualizations**:
#    Besides scalar values, you might want to log histograms of parameters, images of the game board, or distributions of move probabilities:

#     ```python
#     # Log parameters (histograms)
#     for name, param in self.model.named_parameters():
#         writer.add_histogram(name, param, global_step)

#     # Log example game states as images
#     # Convert your game state to an image (assuming you have a function for this)
#     img = game_state_to_image(state)  
#     writer.add_image('Game/Board', img, global_step)
#     ```

# 6. **Start TensorBoard**:
#    Once your script is running and logging data, start TensorBoard in a terminal pointing it to the directory where the logs are being written:

#     ```bash
#     tensorboard --logdir=runs
#     ```

# 7. **View Your Logs**:
#    Open your browser and go to `localhost:6006` (or the URL provided in the terminal when you start TensorBoard) to view the logs and visualizations.

# 8. **Close SummaryWriter**:
#    At the end of training, or when you're done logging data, ensure to close the SummaryWriter to flush any remaining outputs to disk:

#     ```python
#     writer.close()
#     ```

# By following these steps, you'll be able to integrate TensorBoard into your AlphaZero model for rich logging and visualization capabilities. This will help you monitor your training process, understand your model's behavior, and make informed decisions to improve its performance. Remember to customize the logging according to what's most relevant for your specific scenario and model.