In [None]:
import gym
import boardgame2
import copy
import time
import numpy as np
import tensorflow as tf

from alpha_zero import ReplayBuffer
from alpha_zero import play_game
from alpha_zero import run_mcts

In [None]:
class AlphaZeroConfig(object):

    def __init__(self):
        self.num_sampling_moves = 15
        self.max_moves = 36
        self.num_simulations = 50

        self.root_dirichlet_alpha = 0.15
        self.root_exploration_fraction = 0.25

        self.pb_c_base = 19652
        self.pb_c_init = 1.25

        self.training_steps = 2500
        self.window_size = 1000
        self.batch_size = 256

        self.learning_rate = 0.001
        self.weight_decay = 0.01

        self.base_dir = '/tmp/othello_6x6/'


In [None]:
class Game(object):
    def __init__(self, history=None):
        self.board_size = 6
        self.env = gym.make('Reversi-v0', board_shape=[self.board_size, self.board_size])
        self.env.reset()

        self.history = history or [(self.env.board, self.env.player, 0.0, False)]
        self.env.board, self.env.player, _, _ = self.history[-1]

        self.child_visits = []
        self.num_actions = self.board_size ** 2

    def terminal(self):
        _, _, _, done = self.history[-1]
        return done

    def terminal_value(self):
        _, _, reward, _ = self.history[-1]
        return reward

    def legal_actions(self):
        board, player, _, _ = self.history[-1]
        valid_moves = np.where(
            self.env.get_valid((board, player)).flatten() == 1)[0]
        return valid_moves

    def clone(self):
        return Game(copy.copy(self.history))

    def apply(self, action):
        (board, player), reward, done, _ = self.env.step([
            action // self.board_size, action % self.board_size])
        self.history.append([board, player, float(reward), done])

    def step(self, row, col):
        (board, player), reward, done, _ = self.env.step([row, col])
        self.history.append([board, player, float(reward), done])

    def store_search_statistics(self, root):
        sum_visits = sum(child.visit_count for child in root.children.values())
        visit_dist = [
            root.children[a].visit_count / sum_visits if a in root.children else 0
            for a in range(self.num_actions)
        ]
        visit_dist = np.asarray(visit_dist).reshape([self.board_size, self.board_size])
        self.child_visits.append(visit_dist)

    def make_image(self, state_index: int):
        board, player, _, _ = self.history[state_index]
        return board, player

    def make_target(self, state_index: int):
        _, player, _, _ = self.history[state_index]
        return self.terminal_value(), self.child_visits[state_index]

    def to_play(self):
        _, player, _, _ = self.history[-1]
        return player


In [None]:
class Network(tf.keras.Model):
    def __init__(self):
        super(Network, self).__init__()
        self.representation = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(128, 3, padding="same", use_bias=False, name='conv1'),
            tf.keras.layers.BatchNormalization(name='batch_norm_1'),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2D(128, 3, padding="valid", use_bias=False, name='conv2'),
            tf.keras.layers.BatchNormalization(name='batch_norm_2'),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(256, use_bias=False, name='fc1'),
            tf.keras.layers.BatchNormalization(name='batch_norm_3'),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense(128, use_bias=False, name='fc2'),
            tf.keras.layers.BatchNormalization(name='batch_norm_4'),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
        ], name='representation')
        self.value = tf.keras.layers.Dense(1, activation='tanh', name='value')
        self.policy = tf.keras.layers.Dense(36, name='policy')

    def call(self, inputs, training):
        outputs = self.representation(inputs, training=training)
        values = tf.reshape(self.value(outputs), [-1])
        policy_logits = self.policy(outputs)

        return values, policy_logits

    @tf.function(input_signature=(
        tf.TensorSpec(shape=[None, 6, 6], dtype=tf.float32),
        tf.TensorSpec(shape=[None], dtype=tf.float32),
    ))
    def inference(self, boards, players):
        inputs = tf.stack([
            boards,
            tf.ones_like(boards) * players[:, tf.newaxis, tf.newaxis]
        ], axis=3)

        values, policy_logits = self.call(inputs, training=False)
        return values, policy_logits

# Training

In [None]:
def create_train_data(dataset):
    inputs = []
    value_targets = []
    policy_targets = []

    for (board, player), (value, policy) in dataset:
        inputs.append(
            np.stack([board, np.ones_like(board) * player], axis=2))
        value_targets.append(value)
        policy_targets.append(policy.flatten())

    return np.array(inputs, dtype=np.float32), np.array(value_targets, dtype=np.float32), np.array(policy_targets, dtype=np.float32)

@tf.function
def update_weights(batch_inputs, batch_value_targets, batch_policy_targets):
    with tf.GradientTape() as tape:
        values, policy_logits = network(batch_inputs, training=True)
        value_loss = tf.losses.mean_squared_error(values, batch_value_targets)
        policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=policy_logits, labels=batch_policy_targets))

        loss = value_loss + policy_loss
        for v in network.trainable_variables:
            if 'bias' not in v.name and 'batch_norm' not in v.name:
                loss += config.weight_decay * tf.nn.l2_loss(v)

    variables = network.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss, value_loss, policy_loss

In [None]:
config = AlphaZeroConfig()
replay_buffer = ReplayBuffer(config)
network = Network()
optimizer = tf.keras.optimizers.Adam(config.learning_rate)

In [None]:
for i in range(config.batch_size):
    game = play_game(config, Game(), network)
    replay_buffer.save_game(game)

In [None]:
loss_metric = tf.keras.metrics.Mean(name='loss')
value_metric = tf.keras.metrics.Mean(name='value_loss')
policy_metric = tf.keras.metrics.Mean(name='policy_loss')
start = time.time()

for i in range(config.training_steps):
    game = play_game(config, Game(), network)
    replay_buffer.save_game(game)

    batch_dateset = replay_buffer.sample_batch()
    batch_inputs, batch_value_targets, batch_policy_targets = create_train_data(batch_dateset)
    loss, value_loss, policy_loss = update_weights(
        batch_inputs, batch_value_targets, batch_policy_targets)

    loss_metric(loss)
    value_metric(value_loss)
    policy_metric(policy_loss)

    if (i + 1) % 10 == 0:
        print('{}/{}, Loss: {:.4f}, Valus Loss: {:.4f}, Policy Loss: {:.4f}, Elapsed Time: {:.2f}'.format(
            i + 1,
            config.training_steps,
            loss_metric.result().numpy(),
            value_metric.result().numpy(),
            policy_metric.result().numpy(),
            time.time() - start
        ))

        loss_metric.reset_states()
        value_metric.reset_states()
        policy_metric.reset_states()
        start = time.time()

network.save_weights(config.base_dir + 'weights')

# Play Game

In [None]:
import matplotlib.pyplot as plt

def show_board(board, player, legal_actions, ai_moves):
    board = copy.copy(board)
    num_row, num_col = board.shape

    legal_actions = [(a // num_col, a % num_col) for a in legal_actions] # show the move candidates
    ai_moves = [(a // num_col, a % num_col) for a in ai_moves] # show the AI moves

    plt.rcParams['axes.facecolor'] = 'g'
    plt.rcParams['text.color'] = 'k'
    plt.rcParams['xtick.color'] = 'k'
    plt.rcParams['ytick.color'] = 'k'
    plt.figure(figsize=(4, 4), facecolor='w')
    plt.subplot(111)
    plt.title('●:○ = {}:{}'.format(
        np.sum(board == 1),
        np.sum(board == -1)
    ), y=-0.14, fontsize=12)

    for y_pos in range(num_row):
        plt.axhline(y_pos-.5, color='k', lw=2)
        for x_pos in range(num_col):
            plt.axvline(x_pos-.5, color='k', lw=2)
            if board[y_pos, x_pos] == 1:
                plt.plot(x_pos, y_pos, 'o', color='k', ms=30)
            elif board[y_pos, x_pos] == -1:
                plt.plot(x_pos, y_pos, 'o', color='w', ms=30)

            if (y_pos, x_pos) in legal_actions:
                plt.plot(
                    x_pos,
                    y_pos,
                    'o',
                    color='k' if player == 1 else 'w',
                    ms=30,
                    markerfacecolor='none'
                )

            if (y_pos, x_pos) in ai_moves:
                plt.plot(x_pos, y_pos, '^', color='r', ms=10)

    plt.xlim([-.5, num_col-.5])
    plt.ylim([-.5, num_row-.5])
    plt.gca().set_aspect('equal', adjustable='box')
    plt.gca().set_yticks(range(num_row))
    plt.gca().set_yticks(range(num_col))
    plt.gca().invert_yaxis()
    plt.gca().xaxis.set_ticks_position('top')
    plt.tick_params(length=0)
    plt.tight_layout()


In [None]:
config = AlphaZeroConfig()
network = Network()
network.load_weights(config.base_dir + 'weights')

In [None]:
game = Game()
show_board(game.env.board, game.env.player, game.legal_actions(), [])

In [None]:
game.step(4, 5)

_, node = run_mcts(config, game, network)
best_action = max(
    node.children.keys(),
    key=lambda k: node.children[k].visit_count
)
game.apply(best_action)

show_board(game.env.board, game.env.player, game.legal_actions(), [best_action])