In [1]:
import tensorflow as tf
import numpy as np
import gym
from go_ai import data, metrics, models, policies
import matplotlib.pyplot as plt
import shutil
import multiprocessing as mp

# Hyperparameters

In [2]:
BOARD_SIZE = 7

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 32
NUM_EVAL_GAMES = 32

In [4]:
EPISODES_DIR = './data/'

In [5]:
BATCH_SIZE = 32
LEARNING_RATE = 1e-3

In [6]:
WEIGHTS_DIR = 'model_weights/'
CHECKPOINT_PATH = WEIGHTS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_WEIGHTS_PATH = WEIGHTS_DIR + 'tmp.h5'
LOAD_SAVED_MODELS = True

In [7]:
NUM_WORKERS = mp.cpu_count()

# Go Environment
Train on a small board for fast training and efficient debugging

In [8]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

In [10]:
!rm -rf ./logs/
!mkdir ./logs/

Metrics

In [11]:
tb_metrics = {}
for metric_key in ['val_loss', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), 
                                                   dtype=tf.float32)
tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

# Machine Learning Models

In [12]:
actor_critic = models.make_actor_critic(BOARD_SIZE)

In [13]:
_ = tf.keras.utils.plot_model(actor_critic, to_file='logs/model.png')

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [14]:
actor_critic.summary()

Model: "actor_critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 7, 7, 6)]    0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 7, 7, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 7, 7, 64)     256         conv2d[0][0]                     
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 7, 7, 64)     0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [15]:
if LOAD_SAVED_MODELS:
    actor_critic.load_weights(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    actor_critic.save_weights(CHECKPOINT_PATH)
    print("Initialized checkpoint and temp")
    
# Sync temp with checkpoint
actor_critic.save_weights(TMP_WEIGHTS_PATH)

Loaded saved models


# Policies

In [16]:
temp_policy_args = {
    'mode': 'actor_critic',
    'model_path': TMP_WEIGHTS_PATH,
}

In [17]:
checkpoint_policy_args = {
    'mode': 'actor_critic',
    'model_path': CHECKPOINT_PATH,
}

In [18]:
random_policy_args = {
    'mode': 'random'
}

In [19]:
greedy_policy_args = {
    'mode': 'greedy'
}

# Demo and Time Games

Symmetries

In [20]:
%%time
go_env.reset()
action = (1, 2)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 134 ms, sys: 9.07 ms, total: 143 ms
Wall time: 147 ms


Without replay memory

In [21]:
%%time
data.make_episodes(BOARD_SIZE, temp_policy_args, temp_policy_args, 1, num_workers=1)

Episode worker: 1it [00:00,  1.81it/s]
Episodes: 100%|██████████| 1/1 [00:00<00:00, 1041.80it/s, 0.0% WIN]

CPU times: user 775 ms, sys: 20.1 ms, total: 795 ms
Wall time: 769 ms





0.0

With replay memory

In [22]:
%%time
data.make_episodes(BOARD_SIZE, temp_policy_args, temp_policy_args, 1, num_workers=1, 
                   outdir=EPISODES_DIR)

Episode worker: 1it [00:00,  2.10it/s]
Episodes: 100%|██████████| 1/1 [00:00<00:00, 1127.80it/s, 100.0% WIN]

CPU times: user 782 ms, sys: 22.7 ms, total: 805 ms
Wall time: 780 ms





1.0

In [23]:
%%time
fig = metrics.gen_traj_fig(go_env, TMP_WEIGHTS_PATH)
fig.savefig(DEMO_TRAJECTORY_PATH)
plt.close()

CPU times: user 12.9 s, sys: 459 ms, total: 13.4 s
Wall time: 12.4 s


# Train

In [24]:
actor_critic_opt = tf.keras.optimizers.Adam(LEARNING_RATE)

In [None]:
for iteration in range(ITERATIONS):
    # Optimization
    
    # Make and write out the episode data
    data.make_episodes(BOARD_SIZE, temp_policy_args, temp_policy_args, EPISODES_PER_ITERATION, 
                       num_workers=NUM_WORKERS, outdir=EPISODES_DIR)
    # Read in the episode data
    np_data = data.episodes_from_dir(EPISODES_DIR)
    batched_np_data = [np.array_split(datum, len(np_data[0]) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))

    # Optimize
    models.optimize_actor_critic(TMP_WEIGHTS_PATH, BOARD_SIZE, batched_mem, actor_critic_opt, tb_metrics)
    # Resets the metrics
    metrics.reset_metrics(tb_metrics)
    
    if (iteration + 1) % 2 == 0:
        # Plot samples of states and response heatmaps
        fig = metrics.gen_traj_fig(go_env, TMP_WEIGHTS_PATH)
        fig.savefig(DEMO_TRAJECTORY_PATH)
        plt.close()

        # Evaluate against checkpoint model and other baselines
        rand_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, random_policy_args, 
                                           NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
        greed_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, greedy_policy_args, 
                                            NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
        opp_win_rate = data.make_episodes(BOARD_SIZE, temp_policy_args, checkpoint_policy_args, 
                                          NUM_EVAL_GAMES, num_workers=NUM_WORKERS)

        stats = f"{100*opp_win_rate:.1f}%O, {100*greed_win_rate:.1f}%G, {100*rand_win_rate:.1f}%R"

        # If it's better than the checkpoint, update
        if opp_win_rate > 0.6:
            shutil.copy(TMP_WEIGHTS_PATH, CHECKPOINT_PATH)
            print(f"{stats} Accepted new model")
        elif opp_win_rate > 0.5:
            print(f"{stats} Continuing to train current weights")
        else:
            shutil.copy(CHECKPOINT_PATH, TMP_WEIGHTS_PATH)
            print(f"{stats} Rejected new model")

Episodes: 100%|██████████| 32/32 [00:08<00:00,  3.96it/s, 67.2% WIN]
Updating:   0%|          | 0/56 [00:00<?, ?it/s]

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Updating:  34%|███▍      | 19/56 [01:01<02:00,  3.26s/it, 60.4% ACC 0.949VL 2.671ML]

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
data.pit(go_env, black_policy=opponent_policy, white_policy=human_policy)