In [1]:
import tensorflow as tf
import numpy as np
import gym
from go_ai import data, metrics, policies
from go_ai.models import value_model
import matplotlib.pyplot as plt
import shutil
import multiprocessing as mp
import os

# Hyperparameters

In [2]:
BOARD_SIZE = 5

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
NUM_EVAL_GAMES = 64

In [4]:
BATCH_SIZE = 32
LEARNING_RATE = 2e-3

In [5]:
LOAD_SAVED_MODELS = False

# Data Parameters

In [6]:
NUM_WORKERS = mp.cpu_count()

In [7]:
EPISODES_DIR = './data/'

In [12]:
MODELS_DIR = 'models/'
CHECKPOINT_PATH = MODELS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_MODEL_PATH = MODELS_DIR + 'tmp.h5'

In [13]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [14]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Preview Model

In [15]:
if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    val_net = value_model.make_val_net(BOARD_SIZE)
    if not os.path.exists(CHECKPOINT_PATH):
        val_net.save(CHECKPOINT_PATH)
    print("Initialized checkpoint and temp")
    
# Sync temp with checkpoint
shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)

model = tf.keras.models.load_model(TMP_MODEL_PATH)
model.summary()

Initialized checkpoint and temp
Model: "dense val net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
board (InputLayer)              [(None, 5, 5, 6)]    0                                            
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 5, 5, 64)     3520        board[0][0]                      
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 5, 5, 64)     256         conv2d_18[0][0]                  
__________________________________________________________________________________________________
re_lu_18 (ReLU)                 (None, 5, 5, 64)     0           batch_normalization_17[0][0]     
______________________________________________________

# Policies

In [16]:
temp_policy_args = policies.PolicyArgs('values', BOARD_SIZE, TMP_MODEL_PATH, name='temp', )
checkpoint_policy_args = policies.PolicyArgs('values', BOARD_SIZE, TMP_MODEL_PATH, name='checkpoint')
random_policy_args = policies.PolicyArgs('random', BOARD_SIZE)
greedy_policy_args = policies.PolicyArgs('greedy', BOARD_SIZE)

# Demo and Time Games

Symmetries

In [17]:
%%time
go_env.reset()
action = (1, 2)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 137 ms, sys: 10.4 ms, total: 148 ms
Wall time: 151 ms


With replay memory

In [18]:
%%time
data.make_episodes(temp_policy_args, temp_policy_args, 1, num_workers=1, 
                   outdir=EPISODES_DIR)



Episode worker: 1it [00:00,  1.21it/s]
temp vs. temp: 100%|██████████| 1/1 [00:00<00:00, 1309.08it/s, 0.0% WIN]

CPU times: user 1.74 s, sys: 146 ms, total: 1.89 s
Wall time: 1.69 s





0.0

# Train

In [None]:
for iteration in range(ITERATIONS):
    # Optimization
    
    # Make and write out the episode data
    data.make_episodes(temp_policy_args, temp_policy_args, EPISODES_PER_ITERATION, 
                       num_workers=NUM_WORKERS, outdir=EPISODES_DIR)
    # Read in the episode data
    np_data = data.episodes_from_dir(EPISODES_DIR)
    batched_np_data = [np.array_split(datum, len(np_data[0]) // BATCH_SIZE) for datum in np_data]
    batched_mem = list(zip(*batched_np_data))

    # Optimize
    value_model.optimize_val_net(temp_policy_args, batched_mem, LEARNING_RATE)
    
    # Evaluate against checkpoint model and other baselines
    opp_win_rate = data.make_episodes(temp_policy_args, checkpoint_policy_args, 
                                      NUM_EVAL_GAMES, num_workers=NUM_WORKERS)

    # If it's better than the checkpoint, update
    if opp_win_rate > 0.6:
        shutil.copy(TMP_MODEL_PATH, CHECKPOINT_PATH)
        print(f"{100*opp_win_rate:.1f}% Accepted new model")
        rand_win_rate = data.make_episodes(temp_policy_args, random_policy_args, 
                                       NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
        greed_win_rate = data.make_episodes(temp_policy_args, greedy_policy_args, 
                                        NUM_EVAL_GAMES, num_workers=NUM_WORKERS)
        print(f"{100*greed_win_rate:.1f}%G {100*rand_win_rate:.1f}%R")

        # Plot samples of states and response heatmaps
        fig = metrics.gen_traj_fig(go_env, temp_policy_args)
        fig.savefig(DEMO_TRAJECTORY_PATH)
        plt.close()

    elif opp_win_rate >= 0.5:
        print(f"{100*opp_win_rate:.1f}% Continuing to train current weights")

    else:
        shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)
        print(f"{100*opp_win_rate:.1f}% Rejected new model")

temp vs. temp: 100%|██████████| 128/128 [00:41<00:00,  3.05it/s, 54.7% WIN]




Updating: 100%|██████████| 63/63 [00:06<00:00,  9.60it/s, 54.8% ACC, 0.827VL]
temp vs. checkpoint:   0%|          | 0/64 [00:00<?, ?it/s]

# Evaluate

Play against our AI

In [None]:
human_policy_args = {
    'mode': 'human',
    'board_size': BOARD_SIZE,
}

In [None]:
data.make_episodes(checkpoint_policy_args, human_policy_args, 1, num_workers=1)