# 5 layers

In [0]:
import os
# Download libraries
libraries = ['game.py', 'human_play.py', 'mcts_alphaZero.py', 'mcts_pure.py', 'policy_value_net.py', 
             'policy_value_net_keras.py', 'policy_value_net_numpy.py', 'policy_value_net_pytorch.py',
             'policy_value_net_tensorflow.py']
library_url = 'https://raw.githubusercontent.com/abx67/AlphaZero_Gomoku_my/master/morelayer_version/'

for lib in libraries:
  lib_url = library_url + lib
  if not os.path.exists(lib):
    !curl -O $lib_url

In [0]:
!pip install lasagne
# !pip install --upgrade https://github.com/Lasagne/Lasagne/archive/master.zip
# !pip install pydot==1.0.2 --upgrade

In [0]:
################# ZIP AND UPLOAD FOLDER TO GOOGLE DRIVE ########################

!pip install -U -q PyDrive

from google.colab import files
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import zipfile
import os
import sys

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [4]:
# -*- coding: utf-8 -*-
"""
An implementation of the training pipeline of AlphaZero for Gomoku

@author: Junxiao Song
"""

from __future__ import print_function
import random
import numpy as np
from collections import defaultdict, deque
from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure
from mcts_alphaZero import MCTSPlayer
# from policy_value_net import PolicyValueNet  # Theano and Lasagne
# from policy_value_net_pytorch import PolicyValueNet  # Pytorch
# from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
from policy_value_net_keras import PolicyValueNet # Keras

import time

Using TensorFlow backend.


In [0]:
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.bash_output = ''
        self.f = open("output8by8_5layerslayer.txt","w+")
        self.time_now = time.time()
        
        self.board_width = 8 
        self.board_height = 8
        self.n_in_row = 5
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 200
        self.game_batch_num = 2000
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            print(play_data)  # fanerror
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        
        self.bash_output = ("kl:{:.5f},"
                           "lr_multiplier:{:.3f},"
                           "loss:{},"
                           "entropy:{},"
                           "explained_var_old:{:.3f},"
                           "explained_var_new:{:.3f}"
                           ).format(kl,
                            self.lr_multiplier,
                            loss,
                            entropy,
                            explained_var_old,
                            explained_var_new)
        self.f.write(self.bash_output)
        self.f.write('\n')
        self.bash_output = ''
        
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        self.time_now = time.time()
        start_time = time.time()
        try:
            for i in range(self.game_batch_num):
              
                print('Time elapsed: {} seconds'.format(round(time.time() - self.time_now)) + 
                      '\t Total time elapsed: {} seconds'.format(round(time.time() - start_time)))
                self.f.write('Time elapsed: {} seconds'.format(round(time.time() - self.time_now)) + 
                      '\t Total time elapsed: {} seconds'.format(round(time.time() - start_time)))
                self.f.write("\n")
                
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                self.f.write("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                self.f.write("\n")
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                  
                    self.f.close()
                    # save the output figures in google drive
                    auth.authenticate_user()
                    gauth = GoogleAuth()
                    gauth.credentials = GoogleCredentials.get_application_default()
                    drive = GoogleDrive(gauth)

                    file = drive.CreateFile()
                    file.SetContentFile('output8by8_5layers.txt')
                    file.Upload()
                    self.f = open("output8by8_5layers.txt","a")
                  
                  
                  
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    
                    # save the output figures in google drive

                    file = drive.CreateFile()
                    file.SetContentFile('current_policy.model')
                    file.Upload()
                    
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        
                        # save the output figures in google drive
                        auth.authenticate_user()
                        gauth = GoogleAuth()
                        gauth.credentials = GoogleCredentials.get_application_default()
                        drive = GoogleDrive(gauth)

                        file = drive.CreateFile()
                        file.SetContentFile('best_policy.model')
                        file.Upload()
                  
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
                  self.time_now = time.time()
            self.f.close()
        except KeyboardInterrupt:
            print('\n\rquit')
            self.f.close()

In [0]:
if __name__ == '__main__':
    training_pipeline = TrainPipeline()
    training_pipeline.run()


Time elapsed: 0 seconds	 Total time elapsed: 0 seconds
<zip object at 0x7f88475663c8>
batch i:1, episode_len:27
Time elapsed: 38 seconds	 Total time elapsed: 39 seconds
<zip object at 0x7f88475d0a08>
batch i:2, episode_len:36
Time elapsed: 49 seconds	 Total time elapsed: 88 seconds
<zip object at 0x7f884b046908>
batch i:3, episode_len:19
kl:0.00108,lr_multiplier:1.500,loss:5.137330532073975,entropy:4.1588544845581055,explained_var_old:-0.004,explained_var_new:0.082
Time elapsed: 50 seconds	 Total time elapsed: 139 seconds
<zip object at 0x7f8846734548>
batch i:4, episode_len:38
kl:0.00089,lr_multiplier:2.250,loss:4.894233226776123,entropy:4.158131122589111,explained_var_old:0.066,explained_var_new:0.412
Time elapsed: 75 seconds	 Total time elapsed: 214 seconds
<zip object at 0x7f8845eb37c8>
batch i:5, episode_len:26
kl:0.01138,lr_multiplier:2.250,loss:4.60254430770874,entropy:4.138716697692871,explained_var_old:0.307,explained_var_new:0.714
Time elapsed: 59 seconds	 Total time elapsed:

In [0]:
1373/30