In [1]:
import tensorflow as tf
import numpy as np
import gym
import datetime
from tqdm import tqdm_notebook
import random
import itertools
from go_ai import rl_utils, metrics, mcts
import matplotlib.pyplot as plt
import collections
from functools import reduce

In [2]:
from absl import logging
logging._warn_preinit_stderr = 0
logging.set_verbosity(logging.INFO)

# Hyperparameters

In [3]:
BOARD_SIZE = 5

In [4]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 32
MAX_STEPS = 2 * BOARD_SIZE**2
BATCH_SIZE = 32

In [5]:
NUM_EVAL_GAMES = 32
ITERATIONS_PER_EVAL = 4

In [6]:
LEARNING_RATE = 2e-3
BETA_1 = 0.9

In [7]:
MC_SIMS = 0
TEMP_THRESHOLD = 4

In [8]:
LOAD_SAVED_MODELS = False
WEIGHTS_DIR = 'model_weights/'

# Go Environment
Train on a small board with heuristic reward for fast training and efficient debugging

In [9]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [17]:
!rm -rf ./logs/

Metrics

In [18]:
tb_metrics = {}
for metric_key in ['val_loss', 'overall_loss', 'num_steps', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), dtype=tf.float32)
    
# for metric_key in ['explore_weight', 'explore_loss']:
#     metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), dtype=tf.float32)

tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

In [19]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/actor_critic/{}/main'.format(current_time)
summary_writer = tf.summary.create_file_writer(log_dir)    

# Machine Learning Models

In [10]:
ACTOR_CRITIC_PATH = WEIGHTS_DIR + 'tmp/actor_critic.h5'

In [11]:
actor_critic = rl_utils.make_actor_critic(BOARD_SIZE, 'val_net', 'tanh')

In [12]:
mct_forward = rl_utils.make_mcts_forward(actor_critic)

In [38]:
_ = tf.keras.utils.plot_model(actor_critic, to_file='logs/model.png')

In [14]:
actor_critic.summary()

Model: "actor_critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 5, 5, 6)]    0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 5, 5, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 5, 5, 64)     36928       conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 5, 5, 2)      130         conv2d_1[0][0]                   
_______________________________________________________________________________________

In [15]:
opponent = tf.keras.models.clone_model(actor_critic)

In [16]:
if LOAD_SAVED_MODELS:
    actor_critic.load_weights(ACTOR_CRITIC_PATH)
    opponent.load_weights(ACTOR_CRITIC_PATH)
    logging.info("Loaded saved models")

# Demo Trajectories

Symmetries

In [20]:
metrics.plot_symmetries(go_env, actor_critic, 'logs/symmetries.png')

Plot a whole game trajectory

In [41]:
%%time 
traj, _ = rl_utils.self_play(go_env, policy=actor_critic, max_steps=MAX_STEPS, mc_sims=MC_SIMS, 
                             temp_threshold=100, get_symmetries=False)

CPU times: user 1.24 s, sys: 23.9 ms, total: 1.26 s
Wall time: 1.2 s


In [None]:
fig = metrics.gen_traj_fig(go_env, actor_critic, MAX_STEPS, MC_SIMS)
fig.save('logs/a_trajectory.png')

# Optimization

In [24]:
actor_critic_opt = tf.keras.optimizers.Adam(LEARNING_RATE, BETA_1)

In [25]:
def update_actor_critic(batched_mem, iteration, metrics):
    """
    Optimizes the actor over the whole replay memory
    """ 
    binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
    for states, actions, next_states, rewards, terminals, wins, mcts_action_probs in tqdm_notebook(batched_mem, 
                                                                                desc='Updating actor_critic', 
                                                                                leave=False):
        batch_size = states.shape[0]
        wins = wins[:,np.newaxis]
        with tf.GradientTape() as tape:    
            move_prob_distrs, state_vals = rl_utils.forward_pass(states, actor_critic, training=True)
            
            # Actor
            move_loss = binary_cross_entropy(mcts_action_probs, move_prob_distrs)
            
            # Critic
            assert state_vals.shape == wins.shape
            val_loss = tf.reduce_mean((wins - state_vals)**2)
            
            overall_loss = val_loss + move_loss
        
        tb_metrics['move_loss'].update_state(move_loss)
        tb_metrics['val_loss'].update_state(val_loss)
        
        tb_metrics['overall_loss'].update_state(overall_loss)
        
        wins_01 = np.copy(wins)
        wins_01[wins_01 < 0] = 0
        tb_metrics['pred_win_acc'].update_state(wins_01, state_vals > 0)
        
        # compute and apply gradients
        gradients = tape.gradient(overall_loss, actor_critic.trainable_variables)
        actor_critic_opt.apply_gradients(zip(gradients, actor_critic.trainable_variables))

# Train

In [26]:
replay_mem = []

In [None]:
for iteration in tqdm_notebook(range(ITERATIONS), desc='Iteration'):
    # Train
    logging.debug("Playing games")
    episode_pbar = tqdm_notebook(range(EPISODES_PER_ITERATION), desc='Episode', leave=False)
    for episode in episode_pbar:
        trajectory, num_steps = rl_utils.self_play(go_env, policy=actor_critic, max_steps=MAX_STEPS, 
                                                       mc_sims=MC_SIMS, temp_threshold=TEMP_THRESHOLD)

        replay_mem.extend(trajectory)
        tb_metrics['num_steps'].update_state(num_steps)
        
    # Update the models (also shuffles memory)
    logging.debug("Updating model...")
    random.shuffle(replay_mem)
    np_data = rl_utils.replay_mem_to_numpy(replay_mem)
    batched_np_data = [np.array_split(datum, len(replay_mem) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))
    update_actor_critic(batched_mem, iteration, metrics)    
    
    # Save the first and last events for logging to tensorboard
    first_event, last_event = replay_mem[0], replay_mem[-1]
    
    # Evaluate against previous model
    if (iteration+1) % ITERATIONS_PER_EVAL == 0:
        win_rate = metrics.evaluate(go_env, actor_critic, opponent, max_steps=MAX_STEPS, 
                                    num_games=NUM_EVAL_GAMES, mc_sims=MC_SIMS, temp_threshold=TEMP_THRESHOLD)
        if win_rate > 0.6:
            actor_critic.save_weights(ACTOR_CRITIC_PATH)
            opponent.load_weights(ACTOR_CRITIC_PATH)
            logging.info("{:.1f}% Accepted new model".format(100*win_rate))
        else:
            logging.info("{:.1f}% Rejected new model".format(100*win_rate))
    
    # Log results and resets the metrics
    logging.debug("Logging metrics to tensorboard...")
    metrics.log_to_tensorboard(summary_writer, tb_metrics, iteration, [first_event] + replay_mem[:6] + [last_event], 
                               actor_critic)

    # Reset memory
    replay_mem.clear()

HBox(children=(IntProgress(value=0, description='Iteration', max=256, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=714, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=149, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=364, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=291, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0916 21:19:00.691926 4496848320 <ipython-input-43-e5946eafe30a>:30] 100.0% Accepted new model


HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=390, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=196, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=408, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=368, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0916 21:23:31.183088 4496848320 <ipython-input-43-e5946eafe30a>:32] 0.0% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=239, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=408, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=349, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=408, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0916 21:28:15.362138 4496848320 <ipython-input-43-e5946eafe30a>:32] 0.0% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=299, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=326, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=391, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=402, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0916 21:32:45.219352 4496848320 <ipython-input-43-e5946eafe30a>:32] 14.1% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=268, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=332, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=368, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=126, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Evaluating against former self', max=32, style=ProgressStyle(…

I0916 21:36:40.088299 4496848320 <ipython-input-43-e5946eafe30a>:32] 6.2% Rejected new model


HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Updating actor_critic', max=351, style=ProgressStyle(descript…

HBox(children=(IntProgress(value=0, description='Episode', max=32, style=ProgressStyle(description_width='init…

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
rl_utils.play_against(opponent, go_env, MC_SIMS, 0)