In [1]:
# -*- coding: utf-8 -*-
"""
An implementation of the training pipeline of AlphaZero for Gomoku

@author: Junxiao Song
"""

from __future__ import print_function
import random
import numpy as np
from collections import defaultdict, deque
from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure
from mcts_alphaZero import MCTSPlayer
from policy_value_net_keras import PolicyValueNet # Keras

Using TensorFlow backend.


In [2]:
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 12
        self.board_height = 12
        self.n_in_row = 5
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')

In [3]:
training_pipeline = TrainPipeline()
training_pipeline.run()

batch i:1, episode_len:38
batch i:2, episode_len:73
kl:0.00503,lr_multiplier:1.500,loss:5.943051815032959,entropy:4.966987609863281,explained_var_old:-0.054,explained_var_new:0.188
batch i:3, episode_len:37
kl:0.05150,lr_multiplier:1.000,loss:5.553645133972168,entropy:4.860776901245117,explained_var_old:0.200,explained_var_new:0.464
batch i:4, episode_len:56
kl:0.02971,lr_multiplier:1.000,loss:5.410057067871094,entropy:4.787075519561768,explained_var_old:0.215,explained_var_new:0.513
batch i:5, episode_len:32
kl:0.02008,lr_multiplier:1.000,loss:5.13229513168335,entropy:4.752115249633789,explained_var_old:0.446,explained_var_new:0.663
batch i:6, episode_len:38
kl:0.07499,lr_multiplier:0.667,loss:4.870880603790283,entropy:4.762134552001953,explained_var_old:0.659,explained_var_new:0.856
batch i:7, episode_len:56
kl:0.03427,lr_multiplier:0.667,loss:4.696565628051758,entropy:4.4845380783081055,explained_var_old:0.870,explained_var_new:0.916
batch i:8, episode_len:31
kl:0.01761,lr_multiplie

batch i:55, episode_len:38
kl:0.02298,lr_multiplier:0.198,loss:4.271576881408691,entropy:3.6626110076904297,explained_var_old:0.361,explained_var_new:0.404
batch i:56, episode_len:46
kl:0.02387,lr_multiplier:0.198,loss:4.194460868835449,entropy:3.609116315841675,explained_var_old:0.390,explained_var_new:0.422
batch i:57, episode_len:49
kl:0.03046,lr_multiplier:0.198,loss:4.384311199188232,entropy:3.6817238330841064,explained_var_old:0.272,explained_var_new:0.314
batch i:58, episode_len:40
kl:0.02416,lr_multiplier:0.198,loss:4.406652450561523,entropy:3.658456325531006,explained_var_old:0.264,explained_var_new:0.291
batch i:59, episode_len:40
kl:0.03256,lr_multiplier:0.198,loss:4.370965003967285,entropy:3.6282715797424316,explained_var_old:0.223,explained_var_new:0.267
batch i:60, episode_len:36
kl:0.02907,lr_multiplier:0.198,loss:4.331374168395996,entropy:3.5723471641540527,explained_var_old:0.257,explained_var_new:0.304
batch i:61, episode_len:43
kl:0.01446,lr_multiplier:0.198,loss:4.2

kl:0.01438,lr_multiplier:0.132,loss:3.9821057319641113,entropy:3.396735668182373,explained_var_old:0.447,explained_var_new:0.483
batch i:108, episode_len:30
kl:0.01321,lr_multiplier:0.132,loss:3.7858104705810547,entropy:3.322211742401123,explained_var_old:0.507,explained_var_new:0.536
batch i:109, episode_len:28
kl:0.01663,lr_multiplier:0.132,loss:3.888829231262207,entropy:3.4206864833831787,explained_var_old:0.514,explained_var_new:0.550
batch i:110, episode_len:34
kl:0.00989,lr_multiplier:0.198,loss:3.8703882694244385,entropy:3.329573154449463,explained_var_old:0.475,explained_var_new:0.518
batch i:111, episode_len:40
kl:0.01717,lr_multiplier:0.198,loss:3.8181393146514893,entropy:3.393954038619995,explained_var_old:0.557,explained_var_new:0.587
batch i:112, episode_len:28
kl:0.01786,lr_multiplier:0.198,loss:3.8193001747131348,entropy:3.337203025817871,explained_var_old:0.528,explained_var_new:0.574
batch i:113, episode_len:48
kl:0.02338,lr_multiplier:0.198,loss:3.818274736404419,entr

batch i:159, episode_len:30
kl:0.01800,lr_multiplier:0.132,loss:4.1100993156433105,entropy:3.342484951019287,explained_var_old:0.194,explained_var_new:0.265
batch i:160, episode_len:19
kl:0.01805,lr_multiplier:0.132,loss:4.205097675323486,entropy:3.4262475967407227,explained_var_old:0.148,explained_var_new:0.234
batch i:161, episode_len:34
kl:0.01955,lr_multiplier:0.132,loss:4.132944107055664,entropy:3.3625197410583496,explained_var_old:0.183,explained_var_new:0.246
batch i:162, episode_len:48
kl:0.02411,lr_multiplier:0.132,loss:4.091619491577148,entropy:3.3374452590942383,explained_var_old:0.240,explained_var_new:0.301
batch i:163, episode_len:30
kl:0.01443,lr_multiplier:0.132,loss:3.9918808937072754,entropy:3.3414306640625,explained_var_old:0.283,explained_var_new:0.351
batch i:164, episode_len:24
kl:0.01873,lr_multiplier:0.132,loss:4.016975402832031,entropy:3.2900748252868652,explained_var_old:0.267,explained_var_new:0.314
batch i:165, episode_len:40
kl:0.02146,lr_multiplier:0.132,l

batch i:211, episode_len:34
kl:0.02299,lr_multiplier:0.088,loss:4.136329650878906,entropy:3.471571922302246,explained_var_old:0.306,explained_var_new:0.335
batch i:212, episode_len:40
kl:0.01320,lr_multiplier:0.088,loss:4.036716461181641,entropy:3.3616538047790527,explained_var_old:0.272,explained_var_new:0.322
batch i:213, episode_len:50
kl:0.01808,lr_multiplier:0.088,loss:4.167263031005859,entropy:3.2963786125183105,explained_var_old:0.143,explained_var_new:0.203
batch i:214, episode_len:24
kl:0.02112,lr_multiplier:0.088,loss:4.124012470245361,entropy:3.3429696559906006,explained_var_old:0.228,explained_var_new:0.267
batch i:215, episode_len:54
kl:0.03273,lr_multiplier:0.088,loss:4.128477573394775,entropy:3.386174201965332,explained_var_old:0.251,explained_var_new:0.293
batch i:216, episode_len:35
kl:0.03087,lr_multiplier:0.088,loss:4.2256550788879395,entropy:3.483306884765625,explained_var_old:0.203,explained_var_new:0.260
batch i:217, episode_len:28
kl:0.02229,lr_multiplier:0.088,l

batch i:263, episode_len:56
kl:0.02632,lr_multiplier:0.088,loss:3.9552834033966064,entropy:3.3965532779693604,explained_var_old:0.358,explained_var_new:0.431
batch i:264, episode_len:25
kl:0.02528,lr_multiplier:0.088,loss:3.8760783672332764,entropy:3.357593059539795,explained_var_old:0.381,explained_var_new:0.466
batch i:265, episode_len:11
kl:0.01312,lr_multiplier:0.088,loss:3.8707940578460693,entropy:3.327479839324951,explained_var_old:0.448,explained_var_new:0.496
batch i:266, episode_len:15
kl:0.01923,lr_multiplier:0.088,loss:3.895585536956787,entropy:3.3546533584594727,explained_var_old:0.429,explained_var_new:0.465
batch i:267, episode_len:33
kl:0.02024,lr_multiplier:0.088,loss:3.8919146060943604,entropy:3.336574077606201,explained_var_old:0.424,explained_var_new:0.492
batch i:268, episode_len:21
kl:0.03254,lr_multiplier:0.088,loss:3.8193554878234863,entropy:3.3995370864868164,explained_var_old:0.518,explained_var_new:0.565
batch i:269, episode_len:46
kl:0.01922,lr_multiplier:0.0

batch i:315, episode_len:13
kl:0.02606,lr_multiplier:0.088,loss:4.004856109619141,entropy:3.40632963180542,explained_var_old:0.335,explained_var_new:0.389
batch i:316, episode_len:17
kl:0.02494,lr_multiplier:0.088,loss:3.955728530883789,entropy:3.344294548034668,explained_var_old:0.356,explained_var_new:0.408
batch i:317, episode_len:17
kl:0.02022,lr_multiplier:0.088,loss:4.021249771118164,entropy:3.35247802734375,explained_var_old:0.321,explained_var_new:0.357
batch i:318, episode_len:33
kl:0.02699,lr_multiplier:0.088,loss:3.966035842895508,entropy:3.386528968811035,explained_var_old:0.319,explained_var_new:0.423
batch i:319, episode_len:19
kl:0.02386,lr_multiplier:0.088,loss:3.9249227046966553,entropy:3.3512024879455566,explained_var_old:0.373,explained_var_new:0.405
batch i:320, episode_len:29
kl:0.01561,lr_multiplier:0.088,loss:4.010831832885742,entropy:3.380427360534668,explained_var_old:0.341,explained_var_new:0.400
batch i:321, episode_len:23
kl:0.01593,lr_multiplier:0.088,loss:

batch i:367, episode_len:15
kl:0.01678,lr_multiplier:0.088,loss:3.8127384185791016,entropy:3.233351707458496,explained_var_old:0.374,explained_var_new:0.403
batch i:368, episode_len:9
kl:0.01013,lr_multiplier:0.088,loss:3.779733180999756,entropy:3.235827922821045,explained_var_old:0.487,explained_var_new:0.518
batch i:369, episode_len:12
kl:0.01615,lr_multiplier:0.088,loss:3.7392168045043945,entropy:3.2101919651031494,explained_var_old:0.456,explained_var_new:0.482
batch i:370, episode_len:19
kl:0.01215,lr_multiplier:0.088,loss:3.8043034076690674,entropy:3.1846394538879395,explained_var_old:0.408,explained_var_new:0.442
batch i:371, episode_len:9
kl:0.00922,lr_multiplier:0.132,loss:3.654968738555908,entropy:3.216559648513794,explained_var_old:0.516,explained_var_new:0.560
batch i:372, episode_len:43
kl:0.02236,lr_multiplier:0.132,loss:3.664642810821533,entropy:3.116549015045166,explained_var_old:0.437,explained_var_new:0.482
batch i:373, episode_len:16
kl:0.03074,lr_multiplier:0.132,lo

batch i:419, episode_len:16
kl:0.02555,lr_multiplier:0.132,loss:3.65087628364563,entropy:3.1306490898132324,explained_var_old:0.465,explained_var_new:0.526
batch i:420, episode_len:17
kl:0.01890,lr_multiplier:0.132,loss:3.6412088871002197,entropy:3.136336326599121,explained_var_old:0.427,explained_var_new:0.502
batch i:421, episode_len:23
kl:0.02679,lr_multiplier:0.132,loss:3.555161237716675,entropy:3.049464702606201,explained_var_old:0.483,explained_var_new:0.559
batch i:422, episode_len:22
kl:0.02511,lr_multiplier:0.132,loss:3.574303388595581,entropy:3.1153316497802734,explained_var_old:0.497,explained_var_new:0.547
batch i:423, episode_len:14
kl:0.02223,lr_multiplier:0.132,loss:3.6049273014068604,entropy:3.049152374267578,explained_var_old:0.460,explained_var_new:0.531
batch i:424, episode_len:24
kl:0.02722,lr_multiplier:0.132,loss:3.5092363357543945,entropy:3.0781874656677246,explained_var_old:0.472,explained_var_new:0.548
batch i:425, episode_len:22
kl:0.02723,lr_multiplier:0.132,

batch i:471, episode_len:9
kl:0.04800,lr_multiplier:0.088,loss:3.403177499771118,entropy:3.034151077270508,explained_var_old:0.472,explained_var_new:0.553
batch i:472, episode_len:26
kl:0.04496,lr_multiplier:0.088,loss:3.4826412200927734,entropy:2.921539545059204,explained_var_old:0.449,explained_var_new:0.515
batch i:473, episode_len:13
kl:0.02705,lr_multiplier:0.088,loss:3.5242884159088135,entropy:2.9411964416503906,explained_var_old:0.397,explained_var_new:0.442
batch i:474, episode_len:19
kl:0.01290,lr_multiplier:0.088,loss:3.4383435249328613,entropy:2.997856378555298,explained_var_old:0.448,explained_var_new:0.494
batch i:475, episode_len:16
kl:0.01242,lr_multiplier:0.088,loss:3.5599730014801025,entropy:3.011693000793457,explained_var_old:0.424,explained_var_new:0.471
batch i:476, episode_len:14
kl:0.01135,lr_multiplier:0.088,loss:3.4709465503692627,entropy:2.980170965194702,explained_var_old:0.453,explained_var_new:0.502
batch i:477, episode_len:32
kl:0.01320,lr_multiplier:0.088,

batch i:523, episode_len:27
kl:0.01501,lr_multiplier:0.088,loss:3.108696222305298,entropy:2.7158970832824707,explained_var_old:0.605,explained_var_new:0.656
batch i:524, episode_len:13
kl:0.01481,lr_multiplier:0.088,loss:3.1767282485961914,entropy:2.6911487579345703,explained_var_old:0.536,explained_var_new:0.600
batch i:525, episode_len:11
kl:0.01858,lr_multiplier:0.088,loss:3.1149652004241943,entropy:2.72245192527771,explained_var_old:0.601,explained_var_new:0.649
batch i:526, episode_len:27
kl:0.01932,lr_multiplier:0.088,loss:3.0984697341918945,entropy:2.7442569732666016,explained_var_old:0.626,explained_var_new:0.661
batch i:527, episode_len:19
kl:0.01874,lr_multiplier:0.088,loss:3.0148208141326904,entropy:2.606722831726074,explained_var_old:0.579,explained_var_new:0.653
batch i:528, episode_len:11
kl:0.02031,lr_multiplier:0.088,loss:3.14241361618042,entropy:2.759479522705078,explained_var_old:0.585,explained_var_new:0.642
batch i:529, episode_len:42
kl:0.01959,lr_multiplier:0.088,

batch i:575, episode_len:25
kl:0.03082,lr_multiplier:0.132,loss:3.27242112159729,entropy:2.5544981956481934,explained_var_old:0.268,explained_var_new:0.359
batch i:576, episode_len:46
kl:0.02017,lr_multiplier:0.132,loss:3.483335018157959,entropy:2.739053249359131,explained_var_old:0.202,explained_var_new:0.281
batch i:577, episode_len:13
kl:0.03526,lr_multiplier:0.132,loss:3.3245275020599365,entropy:2.6541075706481934,explained_var_old:0.281,explained_var_new:0.392
batch i:578, episode_len:25
kl:0.03665,lr_multiplier:0.132,loss:3.3530068397521973,entropy:2.6904191970825195,explained_var_old:0.242,explained_var_new:0.346
batch i:579, episode_len:23
kl:0.04399,lr_multiplier:0.088,loss:3.3799638748168945,entropy:2.7583487033843994,explained_var_old:0.285,explained_var_new:0.354
batch i:580, episode_len:19
kl:0.01836,lr_multiplier:0.088,loss:3.303767442703247,entropy:2.6793441772460938,explained_var_old:0.338,explained_var_new:0.390
batch i:581, episode_len:30
kl:0.01271,lr_multiplier:0.08

batch i:627, episode_len:62
kl:0.02245,lr_multiplier:0.088,loss:3.3555634021759033,entropy:2.865142345428467,explained_var_old:0.466,explained_var_new:0.506
batch i:628, episode_len:39
kl:0.01404,lr_multiplier:0.088,loss:3.442932367324829,entropy:2.8992342948913574,explained_var_old:0.414,explained_var_new:0.473
batch i:629, episode_len:11
kl:0.01562,lr_multiplier:0.088,loss:3.4228150844573975,entropy:2.866746664047241,explained_var_old:0.352,explained_var_new:0.444
batch i:630, episode_len:13
kl:0.01288,lr_multiplier:0.088,loss:3.401580333709717,entropy:2.867344379425049,explained_var_old:0.423,explained_var_new:0.466
batch i:631, episode_len:14
kl:0.01305,lr_multiplier:0.088,loss:3.4882454872131348,entropy:2.840941905975342,explained_var_old:0.370,explained_var_new:0.445
batch i:632, episode_len:23
kl:0.01300,lr_multiplier:0.088,loss:3.420516014099121,entropy:2.8508026599884033,explained_var_old:0.419,explained_var_new:0.472
batch i:633, episode_len:14
kl:0.01118,lr_multiplier:0.088,

batch i:679, episode_len:12
kl:0.02097,lr_multiplier:0.088,loss:3.9109692573547363,entropy:3.1154747009277344,explained_var_old:0.213,explained_var_new:0.274
batch i:680, episode_len:11
kl:0.03005,lr_multiplier:0.088,loss:3.8493239879608154,entropy:3.187422037124634,explained_var_old:0.271,explained_var_new:0.317
batch i:681, episode_len:11
kl:0.02079,lr_multiplier:0.088,loss:3.745438814163208,entropy:3.0564820766448975,explained_var_old:0.312,explained_var_new:0.354
batch i:682, episode_len:31
kl:0.01989,lr_multiplier:0.088,loss:3.93438458442688,entropy:3.129962921142578,explained_var_old:0.189,explained_var_new:0.294
batch i:683, episode_len:31
kl:0.03116,lr_multiplier:0.088,loss:3.858407974243164,entropy:3.1640520095825195,explained_var_old:0.247,explained_var_new:0.300
batch i:684, episode_len:9
kl:0.01266,lr_multiplier:0.088,loss:3.863063335418701,entropy:3.1824045181274414,explained_var_old:0.269,explained_var_new:0.327
batch i:685, episode_len:15
kl:0.01557,lr_multiplier:0.088,l