In [1]:
import tensorflow as tf
import numpy as np
import gym
from go_ai import data, metrics, policies
from go_ai.models import value_model
import matplotlib.pyplot as plt
import shutil
import multiprocessing as mp

# Hyperparameters

In [2]:
BOARD_SIZE = 7

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 32
NUM_EVAL_GAMES = 32

In [4]:
EPISODES_DIR = './data/'

In [5]:
BATCH_SIZE = 32
LEARNING_RATE = 2e-3

In [6]:
WEIGHTS_DIR = 'model_weights/'
CHECKPOINT_PATH = WEIGHTS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_WEIGHTS_PATH = WEIGHTS_DIR + 'tmp.h5'
LOAD_SAVED_MODELS = False

In [7]:
NUM_WORKERS = mp.cpu_count()

# Go Environment
Train on a small board for fast training and efficient debugging

In [8]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

In [10]:
!rm -rf ./logs/
!mkdir ./logs/

Metrics

In [11]:
tb_metrics = {}
for metric_key in ['val_loss', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), 
                                                   dtype=tf.float32)
tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

# Machine Learning Models

In [12]:
val_net = value_model.make_val_net(BOARD_SIZE)

In [13]:
_ = tf.keras.utils.plot_model(val_net, to_file='logs/model.png')

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [14]:
val_net.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 7, 7, 128)         7040      
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 128)         512       
_________________________________________________________________
re_lu (ReLU)                 (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 128)         147584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
re_lu_1 (ReLU)               (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 64)          7

In [15]:
if LOAD_SAVED_MODELS:
    val_net.load_weights(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    val_net.save_weights(CHECKPOINT_PATH)
    print("Initialized checkpoint and temp")
    
# Sync temp with checkpoint
val_net.save_weights(TMP_WEIGHTS_PATH)

Initialized checkpoint and temp


# Policies

In [16]:
temp_policy_args = {
    'mode': 'values',
    'board_size': BOARD_SIZE,
    'model_path': TMP_WEIGHTS_PATH,
}

In [17]:
checkpoint_policy_args = {
    'mode': 'values',
    'board_size': BOARD_SIZE,
    'model_path': CHECKPOINT_PATH,
}

In [18]:
random_policy_args = {
    'mode': 'random',
    'board_size': BOARD_SIZE,
}

In [19]:
greedy_policy_args = {
    'mode': 'greedy',
    'board_size': BOARD_SIZE,
}

# Demo and Time Games

Symmetries

In [20]:
%%time
go_env.reset()
action = (1, 2)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 121 ms, sys: 10.5 ms, total: 131 ms
Wall time: 128 ms


Without replay memory

In [21]:
%%time
data.make_episodes(temp_policy_args, temp_policy_args, 1, num_workers=1)

Episode worker: 0it [00:00, ?it/s]



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



Episode worker: 1it [00:00,  1.81it/s]
values vs. values: 100%|██████████| 1/1 [00:00<00:00, 1026.51it/s, 0.0% WIN]

CPU times: user 878 ms, sys: 102 ms, total: 980 ms
Wall time: 750 ms





0.0

With replay memory

In [22]:
%%time
data.make_episodes(temp_policy_args, temp_policy_args, 1, num_workers=1, 
                   outdir=EPISODES_DIR)

Episode worker: 0it [00:00, ?it/s]



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



Episode worker: 1it [00:00,  1.67it/s]
values vs. values: 100%|██████████| 1/1 [00:00<00:00, 1154.50it/s, 0.0% WIN]

CPU times: user 1.07 s, sys: 110 ms, total: 1.18 s
Wall time: 951 ms





0.0

In [23]:
%%time
fig = metrics.gen_traj_fig(go_env, temp_policy_args)
fig.savefig(DEMO_TRAJECTORY_PATH)
plt.close()



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

CPU times: user 3.86 s, sys: 382 ms, total: 4.24 s
Wall time: 3.09 s


# Train

In [None]:
for iteration in range(ITERATIONS):
    # Optimization
    
    # Make and write out the episode data
    data.make_episodes(temp_policy_args, temp_policy_args, EPISODES_PER_ITERATION, 
                       num_workers=NUM_WORKERS, outdir=EPISODES_DIR)
    # Read in the episode data
    np_data = data.episodes_from_dir(EPISODES_DIR)
    batched_np_data = [np.array_split(datum, len(np_data[0]) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))

    # Optimize
    value_model.optimize_val_net(temp_policy_args, batched_mem, LEARNING_RATE, tb_metrics)
    # Resets the metrics
    metrics.reset_metrics(tb_metrics)
    
    # Evaluate against checkpoint model and other baselines
    rand_win_rate = data.make_episodes(temp_policy_args, random_policy_args, 
                                       NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
    greed_win_rate = data.make_episodes(temp_policy_args, greedy_policy_args, 
                                        NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
    opp_win_rate = data.make_episodes(temp_policy_args, checkpoint_policy_args, 
                                      NUM_EVAL_GAMES, num_workers=NUM_WORKERS)

    stats = f"{100*opp_win_rate:.1f}%O, {100*greed_win_rate:.1f}%G, {100*rand_win_rate:.1f}%R"

    # If it's better than the checkpoint, update
    if opp_win_rate > 0.6:
        shutil.copy(TMP_WEIGHTS_PATH, CHECKPOINT_PATH)
        print(f"{stats} Accepted new model")

        # Plot samples of states and response heatmaps
        fig = metrics.gen_traj_fig(go_env, temp_policy_args)
        fig.savefig(DEMO_TRAJECTORY_PATH)
        plt.close()

    elif opp_win_rate >= 0.5:
        print(f"{stats} Continuing to train current weights")

    else:
        shutil.copy(CHECKPOINT_PATH, TMP_WEIGHTS_PATH)
        print(f"{stats} Rejected new model")

values vs. values: 100%|██████████| 32/32 [00:10<00:00,  3.19it/s, 50.0% WIN]
Updating: 100%|██████████| 4/4 [00:00<00:00, 13.37it/s, 77.3% ACC, 0.378VL]
values vs. random: 100%|██████████| 32/32 [00:27<00:00,  1.14it/s, 96.9% WIN] 
values vs. greedy: 100%|██████████| 32/32 [01:11<00:00,  2.24s/it, 3.1% WIN]
values vs. values: 100%|██████████| 32/32 [00:09<00:00,  3.39it/s, 100.0% WIN]


100.0%O, 3.1%G, 96.9%R Accepted new model


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



values vs. values: 100%|██████████| 32/32 [01:02<00:00,  1.95s/it, 50.0% WIN]
Updating: 100%|██████████| 68/68 [00:03<00:00, 17.00it/s, 96.2% ACC, 0.127VL]
values vs. random:   0%|          | 0/32 [00:00<?, ?it/s]

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
data.pit(go_env, black_policy=opponent_policy, white_policy=human_policy)