In [1]:
import tensorflow as tf
import numpy as np
import gym
import datetime
from tqdm import tqdm_notebook
import random
import itertools
from go_ai import data, metrics, mcts, models
import matplotlib.pyplot as plt
import collections
from functools import reduce

In [2]:
from absl import logging
logging._warn_preinit_stderr = 0
logging.set_verbosity(logging.INFO)

# Hyperparameters

In [3]:
BOARD_SIZE = 5

In [4]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
MAX_STEPS = 2 * BOARD_SIZE**2
BATCH_SIZE = 32

In [5]:
NUM_EVAL_GAMES = 32
ITERATIONS_PER_EVAL = 1

In [6]:
LEARNING_RATE = 1e-3

In [7]:
MC_SIMS = 0
TEMP_FUNC = lambda x: 1 if x < 4 else 0

In [8]:
WEIGHTS_DIR = 'model_weights/'
ACTOR_CRITIC_PATH = WEIGHTS_DIR + 'tmp/actor_critic.h5'
LOAD_SAVED_MODELS = False

# Go Environment
Train on a small board with heuristic reward for fast training and efficient debugging

In [9]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [10]:
!rm -rf ./logs/

Metrics

In [11]:
tb_metrics = {}
for metric_key in ['val_loss', 'overall_loss', 'num_steps', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), dtype=tf.float32)
    
# for metric_key in ['explore_weight', 'explore_loss']:
#     metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), dtype=tf.float32)

tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

In [12]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/actor_critic/{}/main'.format(current_time)
summary_writer = tf.summary.create_file_writer(log_dir)    

# Machine Learning Models

In [13]:
actor_critic = models.make_actor_critic(BOARD_SIZE, 'val_net', 'tanh')

In [14]:
mct_forward = models.make_mcts_forward(actor_critic)

In [15]:
_ = tf.keras.utils.plot_model(actor_critic, to_file='logs/model.png')

In [16]:
actor_critic.summary()

Model: "actor_critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 5, 5, 6)]    0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 5, 5, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 5, 5, 64)     256         conv2d[0][0]                     
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 5, 5, 64)     0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [17]:
opponent = tf.keras.models.clone_model(actor_critic)

In [18]:
if LOAD_SAVED_MODELS:
    actor_critic.load_weights(ACTOR_CRITIC_PATH)
    opponent.load_weights(ACTOR_CRITIC_PATH)
    logging.info("Loaded saved models")

# Demo Trajectories

Symmetries

In [19]:
metrics.plot_symmetries(go_env, actor_critic, 'logs/symmetries.png')

Plot a whole game trajectory

In [20]:
%%time 
traj, _ = data.self_play(go_env, policy=actor_critic, max_steps=MAX_STEPS, mc_sims=MC_SIMS, 
                             temp_func=lambda x: 1, get_symmetries=False)

CPU times: user 1.03 s, sys: 24.4 ms, total: 1.05 s
Wall time: 989 ms


In [21]:
fig = metrics.gen_traj_fig(go_env, actor_critic, lambda x: 1, MAX_STEPS, MC_SIMS)
fig.savefig('logs/a_trajectory.png')
plt.close()

# Train

In [22]:
actor_critic_opt = tf.keras.optimizers.Adam(LEARNING_RATE)
replay_mem = []

In [23]:
for iteration in tqdm_notebook(range(ITERATIONS), desc='Iteration'):
    # Train
    logging.debug("Playing games")
    episode_pbar = tqdm_notebook(range(EPISODES_PER_ITERATION), desc='Episode', leave=False)
    for episode in episode_pbar:
        trajectory, num_steps = data.self_play(go_env, policy=actor_critic, max_steps=MAX_STEPS, 
                                                       mc_sims=MC_SIMS, temp_func=TEMP_FUNC)

        replay_mem.extend(trajectory)
        tb_metrics['num_steps'].update_state(num_steps)
        
    # Update the models (also shuffles memory)
    logging.debug("Updating model...")
    random.shuffle(replay_mem)
    np_data = data.replay_mem_to_numpy(replay_mem)
    batched_np_data = [np.array_split(datum, len(replay_mem) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))
    models.update_win_prediction(actor_critic, batched_mem, actor_critic_opt, iteration, tb_metrics)
    
    # Evaluate against previous model
    if (iteration+1) % ITERATIONS_PER_EVAL == 0:
        win_rate = metrics.evaluate(go_env, actor_critic, opponent, max_steps=MAX_STEPS, 
                                    num_games=NUM_EVAL_GAMES, mc_sims=MC_SIMS, temp_func=TEMP_FUNC)
        if win_rate > 0.6:
            actor_critic.save_weights(ACTOR_CRITIC_PATH)
            opponent.load_weights(ACTOR_CRITIC_PATH)
            logging.info("{:.1f}% Accepted new model".format(100*win_rate))
        else:
            logging.info("{:.1f}% Rejected new model".format(100*win_rate))
    
    # Log results and resets the metrics
    logging.debug("Logging metrics to tensorboard...")
    metrics.log_to_tensorboard(summary_writer, tb_metrics, iteration, go_env, actor_critic)

    # Reset memory
    replay_mem.clear()

HBox(children=(IntProgress(value=0, description='Iteration', max=256, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Episode', max=128, style=ProgressStyle(description_width='ini…

W0917 12:11:56.989238 4451120576 deprecation.py:323] From /usr/local/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1394: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0917 12:13:42.537729 4451120576 <ipython-input-23-345092bc7a6c>:29] 42.2% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=128, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0917 12:15:16.508399 4451120576 <ipython-input-23-345092bc7a6c>:27] 67.2% Accepted new model


HBox(children=(IntProgress(value=0, description='Episode', max=128, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0917 12:17:33.847632 4451120576 <ipython-input-23-345092bc7a6c>:29] 29.7% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=128, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0917 12:22:48.212754 4451120576 <ipython-input-23-345092bc7a6c>:29] 54.7% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=128, style=ProgressStyle(description_width='ini…

KeyboardInterrupt: 

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
data.play_against(opponent, go_env, MC_SIMS, 0)