## **Neural Network**

#### Imports

In [2]:
import numpy as np

import math

import torch

import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0) # set the seed as 0

from tqdm.notebook import trange

import random

## ResNet Class

### Description

The `ResNet` class is a powerful neural network architecture designed for game playing. Inspired by the residual network (ResNet) design principles, this class leverages the idea of residual learning to enable effective training of deep networks. Residual blocks within the architecture facilitate the flow of information across layers, mitigating the vanishing gradient problem and fostering the learning of complex features.

### Key Components

#### Initial Convolutional Block
The model starts with a convolutional block that processes the input game state. This block extracts essential features from the raw pixel values using convolutional layers, batch normalization for stable training, and rectified linear unit (ReLU) activation functions.

#### Residual Blocks
The backbone of the network consists of a stack of residual blocks. Each block contains two convolutional layers with batch normalization and ReLU activation, ensuring that the model can learn intricate patterns and features from the game environment. The use of residual connections allows for the smooth propagation of gradients during training, facilitating the training of deeper networks.

#### Policy Head
The policy head is responsible for predicting the probabilities of different actions that the agent can take. It further refines the features extracted by the preceding layers, employing additional convolutional layers, batch normalization, ReLU activation, and a fully connected layer for the final output.

#### Value Head
The value head focuses on predicting the state value, providing an estimate of the potential reward associated with the current game state. Similar to the policy head, it utilizes convolutional layers, batch normalization, ReLU activation, and a fully connected layer with a hyperbolic tangent (tanh) activation function.

In [3]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        self.device = device

         # Initial convolutional block
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size = 3, padding = 1), # 3 because "of the three planes ate the beggining" -1, 0, 1, acho que se matem igual
            nn.BatchNorm2d(num_hidden), # Batch normalization for stabilizing training
            nn.ReLU() # Rectified Linear Unit activation function
        )

        # Residual blocks
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)] 
        ) 
        
        # Policy head for predicting actions 
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(32), # Batch normalization for stabilizing training
            nn.ReLU(), # Rectified Linear Unit activation function
            nn.Flatten(), # Flatten the output of the previous layer
            nn.Linear(32 * game.GRID_SIZE * game.GRID_SIZE, game.action_size + 1) 
        )

        # Value head for predicting the state value
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(3), # Batch normalization for stabilizing training
            nn.ReLU(), # Rectified Linear Unit activation function
            nn.Flatten(),  # Flatten the output of the previous layer
            nn.Linear(3 * game.GRID_SIZE * game.GRID_SIZE, 1), # Linear layer with 1 output neuron 
            nn.Tanh() # Hyperbolic tangent activation function
        )

        # Send the model to the device (CPU or GPU)
        self.to(device)
        
    # The `forward` method of the `ResNet` class is responsible 
    # for performing a forward pass through the neural network, transforming the input game state 
    # into policy and value predictions

    def forward(self, x):
        x = self.startBlock(x) # Initial convolutional block
        for resBlock in self.backBone: 
            x = resBlock(x) 
        policy_logits = self.policyHead(x) # Raw policy logits
        value = self.valueHead(x) # Value head for predicting the state value

        # Normalize policy logits to probabilities
        policy = F.softmax(policy_logits, dim=1)

        # Check for NaN values in policy and value outputs
        if torch.isnan(policy).any() or torch.isnan(value).any():
            raise ValueError("NaN value detected in network output")

        return policy, value


## ResBlock Class

### Description

The `ResBlock` class represents a residual block, a fundamental building block in deep neural network architectures, specifically employed in the `ResNet` model for game playing. Residual blocks enable the training of very deep networks by mitigating the vanishing gradient problem and facilitating the flow of information through the network.


In [4]:
# Residual block class 
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        # First convolutional layer of the residual block
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size = 3, padding = 1) 
        self.bn1 = nn.BatchNorm2d(num_hidden) # Batch normalization for stabilizing training

        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size = 3, padding = 1)
        self.bn2 = nn.BatchNorm2d(num_hidden) # Batch normalization for stabilizing training

    # Residual connection
    def forward(self, x):
        residual = x # Save the input value
        x = F.relu(self.bn1(self.conv1(x))) # First convolutional layer of the residual block with ReLU activation function and batch normalization
        x = self.bn2(self.conv2(x)) # Second convolutional layer of the residual block with batch normalization
        x += residual # Add the input value to the output of the second convolutional layer
        x = F.relu(x) # ReLU activation function
        return x

## GoRNN Class Documentation

### Description

The `GoRNN` class is a recurrent neural network (LSTM) designed for processing sequential input data, specifically tailored for the game of Go. It utilizes an LSTM layer to capture temporal dependencies in the input sequence and includes fully connected layers for policy and value predictions. This architecture makes the `GoRNN` model well-suited for learning and predicting patterns in the evolving states of the Go game.

In [5]:
# RNN model for Go 
class GoRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, policy_size, value_size):
        super(GoRNN, self).__init__()
        self.hidden_size = hidden_size # Hidden size of the LSTM
        self.num_layers = num_layers # Number of layers of the LSTM

        # Define an LSTM layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # Define fully connected layers for policy and value outputs
        self.policy_fc = nn.Linear(hidden_size, policy_size) # Linear layer for policy output
        self.value_fc = nn.Linear(hidden_size, value_size) # Linear layer for value output

    # The `forward` method of the `GoRNN` class is responsible
    # for performing a forward pass through the neural network, transforming the input game state
    # into policy and value predictions.
    def forward(self, x):
        # Check if x needs to be flattened (for 4D input)
        if x.dim() == 4:
            batch_size, seq_length = x.size(0), x.size(1) # Get batch size and sequence length
            x = x.view(batch_size, seq_length, -1)  # Flatten the spatial dimensions

        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) # Initialize hidden state
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) # Initialize cell state

        # Forward propagate through the LSTM
        out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step for policy and value
        policy = self.policy_fc(out[:, -1, :]) # Get policy output
        value = self.value_fc(out[:, -1, :]) # Get value output
        return policy, value