In [None]:
!pip install open_spiel

: 

In [None]:
'''
Implement the AlphaZero algorithm with PyTorch for the Chess game. 
'''
from random import sample
import torch
from torch import nn


'''
Create a convolutional block made of :
    -A convolution of 256 filters of kernel size 3x3 with stride 1 
    - batch normalization
    -A rectifier non linearity 
'''
import torch.nn.functional as F
import random 
import pyspiel
import numpy as np
import os 

: 

In [None]:

class Node:
    
    def __init__(self,state,obs,player_root_node,parent,prior_probability,player_turn,alphazero_network):
        '''
        Node class that describes a node in the MCTS tree. It contains:
            -state: the state of the node its a pyspiel state
            -obs : representation of the state of the board in tensor of shape (21,8,8)
            -player_root_node: the player that is playing the root node
            -parent: the parent node of the current node
            -prior_probability: the prior probability of the node
            -player_turn: the player that is playing the current node
            -alphazero_network: the neural network used to compute the prior probability and the value of the node
            -children: a dictionary that contains the children of the node
            -cput: the exploration parameter
            -is_leaf: boolean that indicates if the node is a leaf or not
            -visit_count: the number of times the node has been visited
            -tot_act_value: the total action value of the node
            -mean_value: the mean value of the node
        '''
        
        self.state = state
        self.obs = obs
        
        self.children = {}
        self.player_root_node = player_root_node
        self.parent = parent #[] if the node is the root node
        
        self.player_turn = player_turn
        self.neural_network = alphazero_network
        
        self.prior_probability = prior_probability
        self.visit_count = 0.
        self.tot_act_value = 0.
        self.mean_value = 0.
        
        
        self.is_leaf = True
        
    def expand(self):
        '''
            Expansion of the node by creating its children
        '''
        
        if (self.state.is_terminal()):
            raise Exception('Cannot expand a terminal node!!!')
        
        #Compute the prior probability of the children nodes
        prior_probability_child = self.neural_network.forward(self.obs)[0]
        prior_probability_child = prior_probability_child.cpu().detach().numpy() #prior_probability_child.flatten().detach().numpy() #convert to numpy array
        prior_probability_child = prior_probability_child.flatten()
        #Get the legal actions of the node
        legal_actions = self.state.legal_actions()
        
        #Loop over the legal actions and create the children nodes that come from the legal actions
        
        for id_action in range(len(legal_actions)):
            
            child_state = self.state.clone() #clone the state of the node
            
            child_state.apply_action(legal_actions[id_action]) #apply the action to the state
            
            if (not child_state.is_terminal()):
                obs_child = child_state.observation_tensor() #get the observation tensor of the child state
                
            else:
                obs_child = None  
                player_turn = None 
                   
            
            player_root_node = self.player_root_node #get the player that played the root node
                
            prior_probability_child_node = prior_probability_child[id_action] #get the prior probability of the child node
                
            player_turn = child_state.current_player() #get the player that is playing the child node
                
            self.children[id_action] = Node(child_state,obs_child,player_root_node,self,prior_probability_child_node,player_turn,self.neural_network) #create the child node
            
        self.is_leaf = False
        
        return
    
    
    
    def ucb_score(self,id_child,cput):
        
        '''
            Compute the ucb score of a child node given its id_child
        '''
        
        #Get the child node
        child_node = self.children[id_child]
        
        #Compute the ucb score
        ucb_score = child_node.mean_value + cput*child_node.prior_probability*np.sqrt(self.visit_count)/(1+child_node.visit_count)
        
        return ucb_score
     
        
    def select_child(self,cput):
        '''
            Select the child node with the highest ucb score
        '''
        #First check if there are children nodes
        
        if self.children == {}:
            raise Exception("The node has no children")
        
        else:
            #Get the ucb score of each child node
            nb_children = len(self.children)
            ucb_scores = [self.ucb_score(id_child,cput) for id_child in range(nb_children)]
            
            #Select the child node with the highest ucb score
            id_best_child = np.argmax(ucb_scores)
            
            return  self.children[id_best_child]
     
     
        
    def update_node(self,value):
        ''' 
            Update information of the node 
        '''
        #Update the node information
        
        self.visit_count += 1
        self.tot_act_value += value
        self.mean_value = self.tot_act_value/self.visit_count
        return
    
    
    def summary(self):
        '''
            Print the summary of the node
        '''
        print("Node summary...........")
        print("Player turn: ",self.player_turn)
        print("Visit count: ",self.visit_count)
        print("Mean action value: ",self.mean_value)
        print("Prior probability: ",self.prior_probability)
        print("Is leaf: ",self.is_leaf)
        print("Is terminal: ",self.state.is_terminal())
        print('------------------------------')
        return
    
    
def update_path(path,value):
        '''
            Update the information of each node from the current node to the root node
                Input:
                path = list of nodes from the current node to the root node
        '''
        #Convert value to a numpy array
        
        
        for node in path:
            node.update_node(value)
        return


: 

In [None]:
def compute_pi_posterior(root_node,temperature):
    '''
    Compute the posterior probability of the root node
        root_node: the root node of the tree
        temperature: the temperature parameter
    
    '''
    inv_temp = 1/temperature
    pi = torch.zeros(4672)
    pi.to('cuda')
    actions_list = root_node.state.legal_actions()
    for id_action in range(len(actions_list)):
        pi[actions_list[id_action]] = root_node.children[id_action].visit_count#**inv_temp
        pi[actions_list[id_action]]= pi[actions_list[id_action]]/(root_node.visit_count)#**inv_temp)
    #Reset and delete the children nodes
    
    del root_node.children 
    return pi
  
    
def MCTS(root,num_simulations):
    '''
    Perform a Monte Carlo Tree Search algorithm for the game of chess
    root: the root node of the tree
    num_simulations: the number of simulations to perform
    actions_tracker: the list of actions that have been taken to reach the root node
    '''
    
    
    #First Expansion
    root.expand()
    cput = 1e4 # On choisit grand pour favoriser l'exploration et les outputs du réseaux de neurones
    
    for i in range(num_simulations):
        #print('Simulation number: ',i)
        if (i<=num_simulations*(2/3)):
            cput = cput*0.99
        else:
            cput = 1
        search_path = []
        search_path.append(root)
        
        #Select the child of the root node with the best ucb score
        current = root.select_child(cput)
        search_path.append(current)
        
        #Selection
        is_leaf = current.is_leaf
        while(not is_leaf):
            #browse the tree by selecting the best child
            
            current = current.select_child(cput)
            is_leaf = current.is_leaf 
            search_path.append(current)  
            
        #Expansion
     
        if (current.visit_count==0):
            if (not current.state.is_terminal()):
                prior_prob,value = current.neural_network.forward(current.obs) #rollout
                #Convert value to numpy float
                value = value.cpu().detach().numpy()[0][0]
                if (current.player_turn== root.player_turn):    #Backpropagation
                        update_path(search_path,value)
                else:
                    update_path(search_path,-value)
            
            else:
                value = current.state.player_reward(root.player_turn)
                update_path(search_path,value)
            
            
                
        else:
            
            
            #Simulation
            if (current.state.is_terminal()):
               
                value = current.state.player_reward(root.player_turn)
                #update_path(search_path,value) Don't need to update the path because the node is terminal and has already be updated
                
            
            else:
                current.expand()
                current = current.select_child(cput)
                search_path.append(current)
                
                if (not current.state.is_terminal()):
                    prior_prob,value = current.neural_network.forward(current.obs) #rollout
                    value = value.cpu().detach().numpy()[0][0]
                    if (current.player_turn== root.player_turn):    #Backpropagation
                        update_path(search_path,value)
                    else:
                        update_path(search_path,-value)
            
                else:
                    
                    value = current.state.player_reward(root.player_turn) #Vérifier si c'est le bon player turn mieux si on a id_winner
                    update_path(search_path,value)
                    
                
        #print('*********************************')
    temperature = 1
    #print('Loop done computing pi  ')
    pi =compute_pi_posterior(root,temperature)   
    
    return pi#,root

: 

In [None]:
def choose_action(pi,state):
    '''
        Choose an action according to the probability distribution pi
        Input:
            -pi: probability distribution
        Output:
            -id_action: the id of the chosen action
    '''
    id_action = 5000
    try_ = 0
    #convert pi to numpy array
    pi_np = pi.detach().numpy()
    actions  = state.legal_actions()
    while (id_action not in actions):
    
        id_action = np.random.choice(len(pi),p=pi_np)
        try_ += 1
    #print('Find action after ',try_,' tries')
    return id_action


: 

In [None]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = 8*8*73
        self.conv1 = nn.Conv2d(20, 256, 3, stride=1, padding=1) #change 22 par 21
        self.bn1 = nn.BatchNorm2d(256)

    def forward(self, s):
        # Convert to tensor
        s = torch.tensor(s, dtype=torch.float32) 
        s = s.view(-1, 20,8,8)  # batch_size x channels x board_x x board_y
        s = s.to('cuda')
        s = F.relu(self.bn1(self.conv1(s)))
        return s
  
'''
Create a residual block made of :
    -A convolution of 2 filters of kernel size 1*1 with stride 1
    - batch normalization
    - A skip connection that adds the input to the block
    -A rectifier non linearity
'''
class ResBlock(nn.Module):
    def __init__(self, inplanes=256, planes=256, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out
    

class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(256, 1, kernel_size=1) # value head
        self.bn = nn.BatchNorm2d(1)
        self.fc1 = nn.Linear(8*8, 64)
        self.fc2 = nn.Linear(64, 1)
        
        self.conv1 = nn.Conv2d(256, 128, kernel_size=1) # policy head
        self.bn1 = nn.BatchNorm2d(128)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(8*8*128, 8*8*73)
    
    def forward(self,s):
        v = F.relu(self.bn(self.conv(s))) # value head
        v = v.view(-1, 8*8)  # batch_size X channel X height X width
        v = F.relu(self.fc1(v))
        v = torch.tanh(self.fc2(v))
        
        p = F.relu(self.bn1(self.conv1(s))) # policy head
        p = p.view(-1, 8*8*128)
        p = self.fc(p)
        p = self.logsoftmax(p).exp()
        return p, v
        
class Alphazero_net(nn.Module):
    '''
    Implement the AlphaZero algorithm with PyTorch for the Chess game. It consists in a 
    Residual Network with 19 layers and 256 filters.
    '''
    def __init__(self):
        super(Alphazero_net, self).__init__()
        self.conv = ConvBlock()
        for block in range(19):
            self.add_module('resblock'+str(block),ResBlock())
        self.outblock = OutBlock()
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=400, gamma=0.1)
        
    def forward(self, x):
        
        s = self.conv(x)
        for block in range(19):
            s = self.__getattr__('resblock'+str(block))(s)
        p,v = self.outblock(s)
        return p,v
    
    def checkpoint(self,epoch):
        torch.save(self.state_dict(), 'checkpoint.pth.tar')
        print('Checkpoint saved !')
        
    def loss_function(self,p,v,pi,z):
        '''
        Compute the loss function of the AlphaZero algorithm which is the sum of the 
        cross entropy loss and the MSE loss.
        
        '''
        return -torch.sum(pi*torch.log(p)) + torch.sum((z-v)**2)
    
     
    def fit_to_self_play(self,data_batch,cpu=1):
        '''
        Training Pipeline of the AlphaZero algorithm which consists in :
            - From data generated by self play do
            - Predict the output of the neural network
            - Compute the loss function
            - Backpropagate the loss function
            - Update the weights of the neural network
        '''
        
        
        self.optimizer.zero_grad()
        loss_batch = 0
        
        for state_board in data_batch:
            state, pi_board, v_board = state_board
            p_predicted, v_predicted = self.forward(state)
            loss_batch += self.loss_function(p_predicted, v_predicted, pi_board, v_board)

        loss_batch.backward()
        self.optimizer.step()
            
        return loss_batch
    
    
    def evaluator (self):
        pass
    
    
    def self_play(self,n_simu):
        '''
        Simulation of a single game with MCTS guided by a neural network 
        Inputs:
            -neural_network 
        Ouputs: 
            - Data_frame with for each state of the game the vector of probability pi 
        '''
        state_history = []
        pi_history = []
        
        #Initialization of state_history and pi_history
        #state_history.append(obs_init)
    
        #Initialisation of the root node
        neural_network = self
        neural_network = neural_network.to('cuda')
        game = pyspiel.load_game("chess")
        state = game.new_initial_state()
        obs_root = state.observation_tensor()
        player_turn = state.current_player()
        player_turn_root = player_turn
        parent =[]
        prob =1
        root = Node(state,obs_root,player_turn_root,parent,prob,player_turn,neural_network)
    
        play =0
        is_terminal = root.state.is_terminal()
        
        state_history.append(obs_root)
        while not is_terminal:
            if (play%50==0):
                print("play = ",play)
            #Run MCTS
            #print('Run MCTS')
            pi = MCTS(root,n_simu)
            #print('End MCTS')
            
            pi_history.append(pi.cpu().detach().numpy())
            #Select the action
            id_action = choose_action(pi,root.state)
            #Update the environnement and play the action
            root.state.apply_action(id_action)
            
            next_root_state = root.state.clone()
            if (next_root_state.is_terminal()):
                print('Game over')
                break
            player_turn = next_root_state.current_player()
            parent = []
            obs_root = next_root_state.observation_tensor()
            root = Node(next_root_state,obs_root,player_turn_root,parent,prob,player_turn,neural_network)
            
            is_terminal = root.state.is_terminal()
        
            
            play += 1
            #print('___________________________________________________________________')
            
            
        reward = root.state.player_reward(player_turn_root)
        
        #Convert state_history to tensor
        state_history = torch.tensor(state_history)
        state_history = state_history.to('cuda')
        #Convert pi_history to tensor
        pi_history = torch.tensor(pi_history)
        pi_history = pi_history.to('cuda')
        #Convert reward to tensor
        reward = torch.tensor(reward)
        reward = reward.to('cuda')
        return state_history,pi_history,reward

    
    def sample(self,data,batch_size):
        pass
    
    
    def update_parameters(self,data,batch_size=500,epochs=50):  
        for i in range(epochs):
            data_batch = sample(data,batch_size)
            loss_batch = self.fit_to_self_play(data_batch)
    
    def self_play(self):
        pass
    
    def train(self,num_iterations=1000):
        '''
        Training Pipeline of the AlphaZero algorithm which consists in :
            - Generate data from self play games monitored by an MCTS
            - Update the parameters of the neural network
        '''
        '''
        for i in range(num_iterations):
            
            data = self.self_play()
            
            self.update_parameters(data)
            
            if i%50 == 0:
                self.checkpoint()
        '''
        


###############################################


: 

In [None]:
'''
    In this script we will implement the self play method based on MCTS algorithm guided by a neural network
    
'''
directory = os.getcwd()


def play_mcts_guided_game(neural_network,n_simu,num_game):
    '''
        Simulation of a single game with MCTS guided by a neural network 
        Inputs:
            -neural_network 
        Ouputs: 
            - Data_frame with for each state of the game the vector of probability pi 
    '''
    state_history = []
    pi_history = []
    
    #Initialization of state_history and pi_history
    #state_history.append(obs_init)
   
    #Initialisation of the root node
    neural_network = Alphazero_net()
    neural_network = neural_network.to('cuda')
    game = pyspiel.load_game("chess")
    state = game.new_initial_state()
    obs_root = state.observation_tensor()
    player_turn = state.current_player()
    player_turn_root = player_turn
    parent =[]
    prob =1
    root = Node(state,obs_root,player_turn_root,parent,prob,player_turn,neural_network)
   
    play =0
    is_terminal = root.state.is_terminal()
    
    state_history.append(obs_root)
    while not is_terminal:
        if (play%50==0):
          print("play = ",play)
        #Run MCTS
        #print('Run MCTS')
        pi = MCTS(root,n_simu)
        #print('End MCTS')
        
        pi_history.append(pi.cpu().detach().numpy())
        #Select the action
        '''
        if  (pi.sum() != 1):
            print("pi.sum() = ",pi.sum())
            print('Problem with pi')
            #Print pi 
            break
        '''
        id_action = choose_action(pi,root.state)
        #Update the environnement and play the action
        root.state.apply_action(id_action)
        
        next_root_state = root.state.clone()
        if (next_root_state.is_terminal()):
            print('Game over')
            break
        player_turn = next_root_state.current_player()
        parent = []
        obs_root = next_root_state.observation_tensor()
        root = Node(next_root_state,obs_root,player_turn_root,parent,prob,player_turn,neural_network)
        
        is_terminal = root.state.is_terminal()
       
        
        play += 1
        #print('___________________________________________________________________')
        
         
    reward = root.state.player_reward(player_turn_root)
    try:
      os.chdir(directory+'/data_self_play')
    except:
      os.mkdir('data_self_play')
      os.chdir(directory+'/data_self_play')
        #Create a folder for the game
    os.mkdir('game_'+str(num_game))
        #Go to the folder game_num_game
    os.chdir(directory+'/data_self_play/game_'+str(num_game))
    #Save the neural network weights
    torch.save(neural_network.state_dict(), 'neural_network_weights'+str(num_game)+'.pt')
    #Save the state history and the pi history and the reward
     #Convert the state history to numpy array
    state_history_np = np.array(state_history)
     #Convert the pi history to numpy array
    
    pi_history_np = np.array(pi_history)
    #Save the state history and the pi history
    np.save('state_history'+str(num_game)+'.npy',state_history_np)
    #print(pi_history_np.shape)
    np.save('pi_history'+str(num_game)+'.npy',pi_history_np)
    np.save('reward'+str(num_game)+'.npy',reward)
    
    os.chdir(directory)

    print('Game number ',num_game,' is over')
    print('Reward = ',reward)
    return #state_history,pi_history,reward

#neural_network = Alphazero_net()
#for game in range (1):
#    print('Game number: ',game)
#    play_mcts_guided_game(neural_network,150,game)
#    print('-------------------------------------------------------------------')
#print('Reward: ',reward)
#print('End of the games')

: 

In [None]:

neural_network = Alphazero_net()
neural_network.self_play(10)

: 