In [1]:
import tensorflow as tf
import numpy as np
import gym
import datetime
from tqdm import tqdm
import random
from go_ai import data, metrics, mcts, models, policies
import matplotlib.pyplot as plt

# Hyperparameters

In [2]:
BOARD_SIZE = 7

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
MAX_STEPS = 2 * BOARD_SIZE**2
BATCH_SIZE = 32

In [4]:
NUM_EVAL_GAMES = 32
ITERATIONS_PER_EVAL = 1

In [5]:
LEARNING_RATE = 2e-3

In [6]:
MC_SIMS = 0
TEMP_FUNC = lambda step: (1/2) if (step < 16) else 0

In [7]:
WEIGHTS_DIR = 'model_weights/'
ACTOR_CRITIC_PATH = WEIGHTS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
LOAD_SAVED_MODELS = True

# Go Environment
Train on a small board with heuristic reward for fast training and efficient debugging

In [8]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Metrics and Tensorboard

In [9]:
!rm -rf ./logs/

Metrics

In [10]:
tb_metrics = {}
for metric_key in ['val_loss', 'overall_loss', 'num_steps', 'move_loss']:
    tb_metrics[metric_key] = tf.keras.metrics.Mean('{}'.format(metric_key), 
                                                   dtype=tf.float32)
tb_metrics['pred_win_acc'] = tf.keras.metrics.Accuracy()

Tensorboard

In [11]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/actor_critic/{}/main'.format(current_time)
summary_writer = tf.summary.create_file_writer(log_dir)    

# Machine Learning Models

In [12]:
actor_critic = models.make_actor_critic(BOARD_SIZE, 'val_net', 'tanh')

In [13]:
_ = tf.keras.utils.plot_model(actor_critic, to_file='logs/model.png')

In [14]:
actor_critic.summary()

Model: "actor_critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 7, 7, 6)]    0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 7, 7, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 7, 7, 64)     256         conv2d[0][0]                     
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 7, 7, 64)     0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [15]:
opponent = tf.keras.models.clone_model(actor_critic)

In [16]:
if LOAD_SAVED_MODELS:
    actor_critic.load_weights(ACTOR_CRITIC_PATH)
    opponent.load_weights(ACTOR_CRITIC_PATH)
    print("Loaded saved models")

Loaded saved models


In [17]:
state = go_env.get_state()
my_policy = policies.MctPolicy(actor_critic, state, MC_SIMS, TEMP_FUNC)
opponent_policy = policies.MctPolicy(opponent, state, MC_SIMS, TEMP_FUNC)
greedy_policy = policies.MctGreedyPolicy(state)
random_policy = policies.RandomPolicy()
human_policy = policies.HumanPolicy()

# Demo Trajectories

Symmetries

In [18]:
%%time
metrics.plot_symmetries(go_env, actor_critic, 'logs/symmetries.jpg')

CPU times: user 2.4 s, sys: 228 ms, total: 2.63 s
Wall time: 1.25 s


Plot a whole game trajectory

In [None]:
%%time 
traj, _ = data.self_play(go_env, policy=my_policy, max_steps=MAX_STEPS, 
                         get_symmetries=False)

CPU times: user 6.86 s, sys: 221 ms, total: 7.08 s
Wall time: 6.56 s


In [None]:
%%time
fig = metrics.state_responses(actor_critic, traj)
fig.savefig('logs/a_trajectory.jpg')
plt.close()

CPU times: user 17 s, sys: 1.23 s, total: 18.2 s
Wall time: 9.49 s


# Train

In [None]:
actor_critic_opt = tf.keras.optimizers.Adam(LEARNING_RATE)
replay_mem = []

In [None]:
for iteration in range(ITERATIONS):
    # Train
    episode_pbar = tqdm(range(EPISODES_PER_ITERATION), 
                        desc='Iteration {} - Self Play'.format(iteration), 
                        leave=True, position=0)
    for episode in episode_pbar:
        trajectory, num_steps = data.self_play(go_env, policy=my_policy, 
                                               max_steps=MAX_STEPS)
        replay_mem.extend(trajectory)
        tb_metrics['num_steps'].update_state(num_steps)
        
    # Update the models (also shuffles memory)
    random.shuffle(replay_mem)
    np_data = data.replay_mem_to_numpy(replay_mem)
    batched_np_data = [np.array_split(datum, len(replay_mem) // BATCH_SIZE) 
                       for datum in np_data]
    batched_mem = list(zip(*batched_np_data))
    models.update_win_prediction(actor_critic, batched_mem, actor_critic_opt, 
                                 iteration, tb_metrics)
    
    # Evaluate against previous model
    if (iteration+1) % ITERATIONS_PER_EVAL == 0:
        rand_win_rate = metrics.evaluate(go_env, my_policy, random_policy,
                                                 max_steps=MAX_STEPS, 
                                                 num_games=8)
        greed_win_rate = metrics.evaluate(go_env, my_policy, greedy_policy,
                                                 max_steps=MAX_STEPS, 
                                                 num_games=8)
        opp_win_rate = metrics.evaluate(go_env, my_policy, opponent_policy, 
                                    max_steps=MAX_STEPS, 
                                    num_games=NUM_EVAL_GAMES)
        
        stats = "{:.1f}%O, {:.1f}%R, {:.1f}%G".format(100*opp_win_rate, 
                                                     100*rand_win_rate,
                                                     100*greed_win_rate)
        if opp_win_rate > 0.6:
            actor_critic.save_weights(ACTOR_CRITIC_PATH)
            opponent.load_weights(ACTOR_CRITIC_PATH)
            
            print("{} Accepted new model".format(stats))
        else:
            print("{} Rejected new model".format(stats))
            actor_critic.load_weights(ACTOR_CRITIC_PATH)
    
    # Log results and resets the metrics
    
    metrics.log_to_tensorboard(summary_writer, tb_metrics, iteration, go_env, 
                               actor_critic, TEMP_FUNC, 'logs/a_trajectory.jpg')
    # Reset memory
    replay_mem.clear()

Iteration 0 - Self Play: 100%|██████████| 128/128 [30:09<00:00,  7.67s/it]  
W0919 10:46:16.135141 4674909632 deprecation.py:323] From /usr/local/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1394: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Updating: 100%|██████████| 2374/2374 [02:01<00:00, 19.07it/s, 79.7% 0.529L]
Evaluation: 100%|██████████| 8/8 [00:27<00:00,  3.48s/it, 1.0 100.0%]
Evaluation: 100%|██████████| 8/8 [01:14<00:00,  9.64s/it, 1.0 87.5%]
Evaluation: 100%|██████████| 32/32 [04:37<00:00,  9.07s/it, 0.0 31.2%]


31.2%O, 100.0%R, 87.5%G Rejected new model


Iteration 1 - Self Play: 100%|██████████| 128/128 [15:39<00:00,  7.05s/it]
Updating: 100%|██████████| 2353/2353 [01:56<00:00, 20.26it/s, 81.6% 0.476L]
Evaluation: 100%|██████████| 8/8 [00:26<00:00,  3.33s/it, 1.0 100.0%]
Evaluation: 100%|██████████| 8/8 [01:01<00:00,  7.34s/it, 1.0 100.0%]
Evaluation: 100%|██████████| 32/32 [04:12<00:00,  8.11s/it, 1.0 43.8%]


43.8%O, 100.0%R, 100.0%G Rejected new model


Iteration 2 - Self Play: 100%|██████████| 128/128 [2:14:36<00:00,  7.15s/it]     
Updating: 100%|██████████| 2356/2356 [02:09<00:00, 18.15it/s, 82.0% 0.469L]
Evaluation: 100%|██████████| 8/8 [00:24<00:00,  2.92s/it, 1.0 100.0%]
Evaluation: 100%|██████████| 8/8 [01:16<00:00,  9.48s/it, 0.0 50.0%]
Evaluation: 100%|██████████| 32/32 [04:51<00:00, 10.04s/it, 0.0 37.5%]


37.5%O, 100.0%R, 50.0%G Rejected new model


Iteration 3 - Self Play:  53%|█████▎    | 68/128 [08:33<07:35,  7.58s/it]

# Evaluate

Play against our AI

In [None]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)
data.pit(go_env, black_policy=opponent_policy, white_policy=human_policy, max_steps=MAX_STEPS)