In [1]:
import tensorflow as tf
import numpy as np
import gym
import datetime
from tqdm import tqdm
import random
from go_ai import data, metrics, mcts, models, policies
import matplotlib.pyplot as plt
import os

# Hyperparameters

In [2]:
BOARD_SIZE = 7

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 256
BATCH_SIZE = 32

In [4]:
NUM_EVAL_GAMES = 32

In [5]:
LEARNING_RATE = 2e-3

In [6]:
MC_SIMS = 0

In [7]:
WEIGHTS_DIR = 'model_weights/'
CHECKPOINT_PATH = WEIGHTS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_WEIGHTS_PATH = WEIGHTS_DIR + 'tmp.h5'
LOAD_SAVED_MODELS = True

In [8]:
NUM_WORKERS = 4

In [9]:
EPISODES_DIR = './data/'

# Go Environment
Train on a small board with heuristic reward for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [11]:
!rm -rf ./logs/

Metrics

In [12]:
tb_metrics = {}
for metric_key in ['val_loss', 'overall_loss', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), 
                                                   dtype=tf.float32)
tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

In [13]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/actor_critic/{}/main'.format(current_time)
summary_writer = tf.summary.create_file_writer(log_dir)    

# Machine Learning Models

In [14]:
actor_critic = models.make_actor_critic(BOARD_SIZE, 'val_net', 'tanh')

In [15]:
_ = tf.keras.utils.plot_model(actor_critic, to_file='logs/model.png')

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [16]:
actor_critic.summary()

Model: "actor_critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 7, 7, 6)]    0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 7, 7, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 7, 7, 64)     256         conv2d[0][0]                     
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 7, 7, 64)     0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [17]:
opponent = tf.keras.models.clone_model(actor_critic)

In [18]:
if LOAD_SAVED_MODELS:
    actor_critic.load_weights(CHECKPOINT_PATH)
    opponent.load_weights(CHECKPOINT_PATH)
    print("Loaded saved models")

Loaded saved models


# Policies

In [19]:
temp_policy_args = {
    'mode': 'actor_critic',
    'model_path': TMP_WEIGHTS_PATH,
    'mc_sims': MC_SIMS,
}

In [20]:
checkpoint_policy_args = {
    'mode': 'actor_critic',
    'model_path': CHECKPOINT_PATH,
    'mc_sims': MC_SIMS,
}

In [21]:
random_policy_args = {
    'mode': 'random'
}

In [22]:
greedy_policy_args = {
    'mode': 'greedy'
}

# Demo Trajectories

In [23]:
state = go_env.get_state()
my_policy = policies.MctPolicy(actor_critic, state, MC_SIMS)

Symmetries

In [24]:
%%time
metrics.plot_symmetries(go_env, actor_critic, 'logs/symmetries.jpg')

CPU times: user 1.26 s, sys: 129 ms, total: 1.39 s
Wall time: 810 ms


Plot a whole game trajectory

In [25]:
%%time 
_, traj, num_steps = data.self_play(go_env, policy=my_policy, get_symmetries=False)
print(f"{num_steps} steps")

83
CPU times: user 6.92 s, sys: 218 ms, total: 7.13 s
Wall time: 6.7 s


In [26]:
%%time
fig = metrics.state_responses(actor_critic, traj)
fig.savefig('logs/a_trajectory.jpg')
plt.close()

CPU times: user 7 s, sys: 241 ms, total: 7.24 s
Wall time: 6.85 s


# Train

In [27]:
actor_critic_opt = tf.keras.optimizers.Adam(LEARNING_RATE)

In [28]:
for iteration in range(ITERATIONS):
    # Make and write out the episode data
    data.make_episodes(BOARD_SIZE, checkpoint_policy_args, checkpoint_policy_args, 
                       EPISODES_PER_ITERATION, num_workers=NUM_WORKERS, outdir=EPISODES_DIR)
    # Read in the episode data
    np_data = data.episodes_from_dir(EPISODES_DIR)
    batched_np_data = [np.array_split(datum, len(np_data[0]) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))
    
    # Optimize
    models.update_win_prediction(actor_critic, batched_mem, actor_critic_opt, iteration, tb_metrics)
    
    # Save weights for evaluation
    actor_critic.save_weights(TMP_WEIGHTS_PATH)
    
    # Evaluate against checkpoint model and other baselines
    rand_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, random_policy_args, 
                                       NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
    greed_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, greedy_policy_args, 
                                        NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
    opp_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, checkpoint_policy_args, 
                                      NUM_EVAL_GAMES, num_workers=NUM_WORKERS)

    stats = "{:.1f}%O, {:.1f}%R, {:.1f}%G".format(100*opp_win_rate, 
                                                 100*rand_win_rate,
                                                 100*greed_win_rate)
    
    # If it's better than the checkpoint, update
    if opp_win_rate > 0.6:
        actor_critic.save_weights(CHECKPOINT_PATH)        
        print("{} Accepted new model".format(stats))
    else:
        print("{} Rejected new model".format(stats))
    if os.path.exists(CHECKPOINT_PATH):
        opponent.load_weights(CHECKPOINT_PATH)
        actor_critic.load_weights(CHECKPOINT_PATH)
    
    # Log results and resets the metrics
    metrics.log_to_tensorboard(summary_writer, tb_metrics, iteration, go_env, 
                               actor_critic, 'logs/a_trajectory.jpg')

Episodes: 100%|██████████| 256/256 [08:29<00:00,  1.99s/it, 49.8%]
Updating:   0%|          | 0/5183 [00:00<?, ?it/s]

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Updating: 100%|██████████| 5183/5183 [03:00<00:00, 28.64it/s, 81.9% 0.473L]
Episodes: 100%|██████████| 32/32 [00:29<00:00,  1.08it/s, 100.0%]
Episodes: 100%|██████████| 32/32 [01:08<00:00,  2.15s/it, 46.9%]
Episodes: 100%|██████████| 32/32 [01:09<00:00,  2.18s/it, 59.4%]


59.4%O, 100.0%R, 46.9%G Rejected new model


KeyboardInterrupt: 

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
data.pit(go_env, black_policy=opponent_policy, white_policy=human_policy)