In [1]:
import tensorflow as tf
import numpy as np
import gym
from go_ai import data, metrics, policies
from go_ai.models import value_model
import matplotlib.pyplot as plt
import shutil
import multiprocessing as mp
import os

# Hyperparameters

In [2]:
BOARD_SIZE = 4

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
RANDOM_BEGINNING = 0
NUM_EVAL_GAMES = 128

In [4]:
BATCH_SIZE = 32

In [5]:
LOAD_SAVED_MODELS = False

# Data Parameters

In [6]:
NUM_WORKERS = mp.cpu_count()
print("Workers: ", NUM_WORKERS)

Workers:  8


In [7]:
EPISODES_DIR = './data/'

In [8]:
MODELS_DIR = 'models/'
CHECKPOINT_PATH = MODELS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_MODEL_PATH = MODELS_DIR + 'tmp.h5'

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Preview Model

In [11]:
if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    val_net = value_model.make_model(BOARD_SIZE)
    val_net.save(CHECKPOINT_PATH)
    print("Initialized checkpoint and temp") 
print()
    
# Sync temp with checkpoint
shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)

model = tf.keras.models.load_model(TMP_MODEL_PATH)
model.summary()

Initialized checkpoint and temp

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 96)                0         
_________________________________________________________________
dense (Dense)                (None, 256)               24832     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
re_lu (ReLU)                 (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
re_lu_1 (ReLU)         

# Policies

In [12]:
tmp_policy_args = policies.PolicyArgs('qtemp', BOARD_SIZE, TMP_MODEL_PATH, name='tmp', temperature=1/32)
checkpoint_policy_args = policies.PolicyArgs('qtemp', BOARD_SIZE, CHECKPOINT_PATH, name='checkpoint', 
                                             temperature=1/32)
random_policy_args = policies.PolicyArgs('random', BOARD_SIZE)
greedy_policy_args = policies.PolicyArgs('greedy', BOARD_SIZE)
human_policy_args = policies.PolicyArgs('human', BOARD_SIZE)

# Demo and Time Games

Symmetries

In [13]:
%%time
go_env.reset()
action = (1, 1)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 219 ms, sys: 13.1 ms, total: 232 ms
Wall time: 235 ms


With replay memory

In [14]:
%%time
data.make_episodes(tmp_policy_args, tmp_policy_args, 1, num_workers=1, 
                   outdir=EPISODES_DIR)

Episode worker: 1it [00:00,  3.10it/s]
tmp vs. tmp: 100%|██████████| 1/1 [00:00<00:00, 1196.66it/s, 0.0% WIN]

CPU times: user 513 ms, sys: 12.5 ms, total: 525 ms
Wall time: 514 ms





0.0

In [15]:
%%time
fig = metrics.gen_traj_fig(go_env, tmp_policy_args)
fig.savefig(DEMO_TRAJECTORY_PATH)
plt.close()

CPU times: user 2.12 s, sys: 101 ms, total: 2.22 s
Wall time: 1.62 s


# Train

In [None]:
for iteration in range(ITERATIONS):
    print(f"Iteration {iteration}")
    
    # Make and write out the episode data
    data.make_episodes(tmp_policy_args, tmp_policy_args, EPISODES_PER_ITERATION, 
                       num_workers=NUM_WORKERS, outdir=EPISODES_DIR, 
                       random_beginning=RANDOM_BEGINNING)
    # Read in the episode data
    replay_data = data.episodes_from_dir(EPISODES_DIR)

    # Optimize
    value_model.optimize(tmp_policy_args, replay_data, BATCH_SIZE)
    
    # Evaluate against checkpoint model and other baselines
    opp_win_rate = data.make_episodes(tmp_policy_args, checkpoint_policy_args, 
                                      NUM_EVAL_GAMES, NUM_WORKERS, 
                                      random_beginning=RANDOM_BEGINNING)

    # If it's better than the checkpoint, update
    if opp_win_rate > 0.6:
        shutil.copy(TMP_MODEL_PATH, CHECKPOINT_PATH)
        print(f"{100*opp_win_rate:.1f}% Accepted new model")
        rand_win_rate = data.make_episodes(tmp_policy_args, random_policy_args, 
                                           NUM_EVAL_GAMES, NUM_WORKERS, 
                                           random_beginning=0)
        greed_win_rate = data.make_episodes(tmp_policy_args, greedy_policy_args, 
                                            NUM_EVAL_GAMES, NUM_WORKERS, 
                                            random_beginning=0)
        print(f"{100*greed_win_rate:.1f}%G {100*rand_win_rate:.1f}%R")

        # Plot samples of states and response heatmaps
        fig = metrics.gen_traj_fig(go_env, tmp_policy_args)
        fig.savefig(DEMO_TRAJECTORY_PATH)
        plt.close()

    elif opp_win_rate >= 0.5:
        print(f"{100*opp_win_rate:.1f}% Continuing to train current weights")

    else:
        shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)
        print(f"{100*opp_win_rate:.1f}% Rejected new model")

Iteration 0


tmp vs. tmp: 100%|██████████| 128/128 [00:11<00:00, 10.91it/s, 50.8% WIN]


Train on 2015 samples
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


tmp vs. checkpoint: 100%|██████████| 128/128 [00:14<00:00,  8.54it/s, 62.5% WIN]


62.5% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:10<00:00, 11.84it/s, 86.3% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:14<00:00,  8.70it/s, 18.0% WIN]


18.0%G 86.3%R
Iteration 1


tmp vs. tmp: 100%|██████████| 128/128 [00:15<00:00,  8.16it/s, 46.5% WIN]


Train on 2478 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.61it/s, 45.7% WIN]


45.7% Rejected new model
Iteration 2


tmp vs. tmp: 100%|██████████| 128/128 [00:16<00:00,  7.90it/s, 49.6% WIN]


Train on 2261 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.71it/s, 73.0% WIN]


73.0% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:12<00:00, 10.34it/s, 89.5% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:19<00:00,  6.72it/s, 26.2% WIN]


26.2%G 89.5%R
Iteration 3


tmp vs. tmp: 100%|██████████| 128/128 [00:20<00:00,  6.30it/s, 47.3% WIN]


Train on 3091 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.27it/s, 50.8% WIN]


50.8% Continuing to train current weights
Iteration 4


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  6.91it/s, 48.8% WIN]


Train on 2917 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.36it/s, 45.7% WIN]


45.7% Rejected new model
Iteration 5


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.25it/s, 49.2% WIN]


Train on 2826 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.19it/s, 59.4% WIN]


59.4% Continuing to train current weights
Iteration 6


tmp vs. tmp: 100%|██████████| 128/128 [00:19<00:00,  6.57it/s, 52.0% WIN]


Train on 3187 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.17it/s, 64.1% WIN]


64.1% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:12<00:00, 10.62it/s, 96.1% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.56it/s, 46.5% WIN]


46.5%G 96.1%R
Iteration 7


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.25it/s, 51.2% WIN]


Train on 3013 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.27it/s, 56.6% WIN]


56.6% Continuing to train current weights
Iteration 8


tmp vs. tmp: 100%|██████████| 128/128 [00:19<00:00,  6.62it/s, 50.8% WIN]


Train on 2992 samples


tmp vs. checkpoint:   0%|          | 0/128 [00:00<?, ?it/s]

# Evaluate

Play against our AI

In [None]:
data.make_episodes(checkpoint_policy_args, human_policy_args, 1)