In [1]:
import tensorflow as tf
import numpy as np
import gym
from go_ai import data, metrics, policies
from go_ai.models import value_model
import matplotlib.pyplot as plt
import shutil
import multiprocessing as mp
import os

# Hyperparameters

In [2]:
BOARD_SIZE = 4

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
RANDOM_BEGINNING = 0
NUM_EVAL_GAMES = 128

In [4]:
BATCH_SIZE = 32

In [5]:
LOAD_SAVED_MODELS = False

# Data Parameters

In [6]:
NUM_WORKERS = mp.cpu_count()
print("Workers: ", NUM_WORKERS)

Workers:  8


In [7]:
EPISODES_DIR = './data/'

In [8]:
MODELS_DIR = 'models/'
CHECKPOINT_PATH = MODELS_DIR + 'checkpoint_{}x{}.h5'.format(BOARD_SIZE, BOARD_SIZE)
TMP_MODEL_PATH = MODELS_DIR + 'tmp.h5'

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Preview Model

In [11]:
if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    val_net = value_model.make_model(BOARD_SIZE)
    val_net.save(CHECKPOINT_PATH)
    print("Initialized checkpoint and temp") 
print()
    
# Sync temp with checkpoint
shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)

model = tf.keras.models.load_model(TMP_MODEL_PATH)
model.summary()

Initialized checkpoint and temp

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 96)                0         
_________________________________________________________________
dense (Dense)                (None, 256)               24832     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
re_lu (ReLU)                 (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
re_lu_1 (ReLU)         

# Policies

In [12]:
tmp_policy_args = policies.PolicyArgs('qtemp', BOARD_SIZE, TMP_MODEL_PATH, name='tmp', temperature=1/64)
checkpoint_policy_args = policies.PolicyArgs('qtemp', BOARD_SIZE, CHECKPOINT_PATH, name='checkpoint', 
                                             temperature=1/64)
random_policy_args = policies.PolicyArgs('random', BOARD_SIZE)
greedy_policy_args = policies.PolicyArgs('greedy', BOARD_SIZE)
human_policy_args = policies.PolicyArgs('human', BOARD_SIZE)

# Demo and Time Games

Symmetries

In [13]:
%%time
go_env.reset()
action = (1, 1)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 190 ms, sys: 5.5 ms, total: 195 ms
Wall time: 194 ms


With replay memory

In [14]:
%%time
data.make_episodes(tmp_policy_args, tmp_policy_args, 1, num_workers=1, 
                   outdir=EPISODES_DIR)

Episode worker: 1it [00:00, 14.57it/s]
tmp vs. tmp: 100%|██████████| 1/1 [00:00<00:00, 1016.31it/s, 100.0% WIN]

CPU times: user 255 ms, sys: 9.82 ms, total: 265 ms
Wall time: 259 ms





1.0

In [15]:
%%time
fig = metrics.gen_traj_fig(go_env, tmp_policy_args)
fig.savefig(DEMO_TRAJECTORY_PATH)
plt.close()

CPU times: user 2.33 s, sys: 103 ms, total: 2.43 s
Wall time: 1.83 s


# Train

In [None]:
for iteration in range(ITERATIONS):
    print(f"Iteration {iteration}")
    
    # Make and write out the episode data
    data.make_episodes(tmp_policy_args, tmp_policy_args, EPISODES_PER_ITERATION, 
                       num_workers=NUM_WORKERS, outdir=EPISODES_DIR, 
                       random_beginning=RANDOM_BEGINNING)
    # Read in the episode data
    replay_data = data.episodes_from_dir(EPISODES_DIR)

    # Optimize
    value_model.optimize(tmp_policy_args, replay_data, BATCH_SIZE)
    
    # Evaluate against checkpoint model and other baselines
    opp_win_rate = data.make_episodes(tmp_policy_args, checkpoint_policy_args, 
                                      NUM_EVAL_GAMES, NUM_WORKERS, 
                                      random_beginning=RANDOM_BEGINNING)

    # If it's better than the checkpoint, update
    if opp_win_rate > 0.6:
        shutil.copy(TMP_MODEL_PATH, CHECKPOINT_PATH)
        print(f"{100*opp_win_rate:.1f}% Accepted new model")
        rand_win_rate = data.make_episodes(tmp_policy_args, random_policy_args, 
                                           NUM_EVAL_GAMES, NUM_WORKERS, 
                                           random_beginning=0)
        greed_win_rate = data.make_episodes(tmp_policy_args, greedy_policy_args, 
                                            NUM_EVAL_GAMES, NUM_WORKERS, 
                                            random_beginning=0)
        print(f"{100*greed_win_rate:.1f}%G {100*rand_win_rate:.1f}%R")

        # Plot samples of states and response heatmaps
        fig = metrics.gen_traj_fig(go_env, tmp_policy_args)
        fig.savefig(DEMO_TRAJECTORY_PATH)
        plt.close()

    elif opp_win_rate >= 0.5:
        print(f"{100*opp_win_rate:.1f}% Continuing to train current weights")

    else:
        shutil.copy(CHECKPOINT_PATH, TMP_MODEL_PATH)
        print(f"{100*opp_win_rate:.1f}% Rejected new model")

Iteration 0


tmp vs. tmp: 100%|██████████| 128/128 [00:10<00:00, 12.44it/s, 46.9% WIN]


Train on 1529 samples
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


tmp vs. checkpoint: 100%|██████████| 128/128 [00:13<00:00,  9.36it/s, 94.1% WIN]


94.1% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:10<00:00, 12.17it/s, 87.5% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.71it/s, 34.4% WIN]


34.4%G 87.5%R
Iteration 1


tmp vs. tmp: 100%|██████████| 128/128 [00:16<00:00,  7.77it/s, 49.6% WIN]


Train on 2972 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:18<00:00,  6.90it/s, 54.3% WIN]


54.3% Continuing to train current weights
Iteration 2


tmp vs. tmp: 100%|██████████| 128/128 [00:16<00:00,  7.90it/s, 58.6% WIN]


Train on 2731 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:17<00:00,  7.33it/s, 65.6% WIN]


65.6% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:11<00:00, 11.62it/s, 94.5% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:15<00:00,  8.20it/s, 54.7% WIN]


54.7%G 94.5%R
Iteration 3


tmp vs. tmp: 100%|██████████| 128/128 [00:15<00:00,  8.17it/s, 43.4% WIN]


Train on 2515 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:12<00:00, 10.36it/s, 9.0% WIN]


9.0% Rejected new model
Iteration 4


tmp vs. tmp: 100%|██████████| 128/128 [00:15<00:00,  8.26it/s, 58.6% WIN]


Train on 2367 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:17<00:00,  7.15it/s, 60.5% WIN]


60.5% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:11<00:00, 11.43it/s, 98.8% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.78it/s, 47.3% WIN]


47.3%G 98.8%R
Iteration 5


tmp vs. tmp: 100%|██████████| 128/128 [00:16<00:00,  7.80it/s, 53.9% WIN]


Train on 2966 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:18<00:00,  6.76it/s, 70.7% WIN]


70.7% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:12<00:00, 10.59it/s, 95.7% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.80it/s, 68.0% WIN]


68.0%G 95.7%R
Iteration 6


tmp vs. tmp: 100%|██████████| 128/128 [00:16<00:00,  7.61it/s, 52.3% WIN]


Train on 3140 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.40it/s, 46.9% WIN]


46.9% Rejected new model
Iteration 7


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.15it/s, 43.8% WIN]


Train on 3066 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.57it/s, 45.7% WIN]


45.7% Rejected new model
Iteration 8


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  7.03it/s, 55.9% WIN]


Train on 3130 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.44it/s, 55.9% WIN]


55.9% Continuing to train current weights
Iteration 9


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.27it/s, 58.6% WIN]


Train on 2778 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.20it/s, 48.0% WIN]


48.0% Rejected new model
Iteration 10


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  6.77it/s, 44.1% WIN]


Train on 3262 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.32it/s, 51.6% WIN]


51.6% Continuing to train current weights
Iteration 11


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.28it/s, 46.9% WIN]


Train on 2868 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.57it/s, 52.0% WIN]


52.0% Continuing to train current weights
Iteration 12


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  6.88it/s, 52.3% WIN]


Train on 3118 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.58it/s, 55.9% WIN]


55.9% Continuing to train current weights
Iteration 13


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.26it/s, 51.6% WIN]


Train on 2805 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.22it/s, 47.3% WIN]


47.3% Rejected new model
Iteration 14


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  7.06it/s, 44.5% WIN]


Train on 3125 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.54it/s, 52.7% WIN]


52.7% Continuing to train current weights
Iteration 15


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.14it/s, 50.8% WIN]


Train on 2989 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.45it/s, 59.4% WIN]


59.4% Continuing to train current weights
Iteration 16


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  6.90it/s, 54.7% WIN]


Train on 3220 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.52it/s, 65.2% WIN]


65.2% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:11<00:00, 10.80it/s, 98.0% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.81it/s, 69.5% WIN]


69.5%G 98.0%R
Iteration 17


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.24it/s, 44.5% WIN]


Train on 3172 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.23it/s, 39.8% WIN]


39.8% Rejected new model
Iteration 18


tmp vs. tmp: 100%|██████████| 128/128 [00:17<00:00,  7.37it/s, 49.2% WIN]


Train on 3051 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.62it/s, 54.3% WIN]


54.3% Continuing to train current weights
Iteration 19


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  7.06it/s, 57.8% WIN]


Train on 3141 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.48it/s, 40.6% WIN]


40.6% Rejected new model
Iteration 20


tmp vs. tmp: 100%|██████████| 128/128 [00:19<00:00,  6.63it/s, 47.7% WIN]


Train on 3195 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:20<00:00,  6.39it/s, 58.2% WIN]


58.2% Continuing to train current weights
Iteration 21


tmp vs. tmp: 100%|██████████| 128/128 [00:20<00:00,  6.36it/s, 52.7% WIN]


Train on 3179 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:17<00:00,  7.46it/s, 28.5% WIN]


28.5% Rejected new model
Iteration 22


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  6.96it/s, 48.4% WIN]


Train on 3027 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:19<00:00,  6.43it/s, 62.5% WIN]


62.5% Accepted new model


tmp vs. random: 100%|██████████| 128/128 [00:11<00:00, 10.74it/s, 97.7% WIN]
tmp vs. greedy: 100%|██████████| 128/128 [00:16<00:00,  7.62it/s, 79.3% WIN]


79.3%G 97.7%R
Iteration 23


tmp vs. tmp: 100%|██████████| 128/128 [00:20<00:00,  6.15it/s, 51.6% WIN]


Train on 2913 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:22<00:00,  5.63it/s, 41.0% WIN]


41.0% Rejected new model
Iteration 24


tmp vs. tmp: 100%|██████████| 128/128 [00:23<00:00,  5.43it/s, 51.6% WIN]


Train on 2981 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:25<00:00,  4.94it/s, 37.1% WIN]


37.1% Rejected new model
Iteration 25


tmp vs. tmp: 100%|██████████| 128/128 [00:20<00:00,  6.26it/s, 60.5% WIN]


Train on 2889 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:23<00:00,  5.53it/s, 36.7% WIN]


36.7% Rejected new model
Iteration 26


tmp vs. tmp: 100%|██████████| 128/128 [00:19<00:00,  6.51it/s, 47.3% WIN]


Train on 2928 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:21<00:00,  5.97it/s, 19.5% WIN]


19.5% Rejected new model
Iteration 27


tmp vs. tmp: 100%|██████████| 128/128 [00:22<00:00,  5.68it/s, 45.7% WIN]


Train on 2923 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.33it/s, 43.0% WIN]


43.0% Rejected new model
Iteration 28


tmp vs. tmp: 100%|██████████| 128/128 [00:18<00:00,  7.08it/s, 45.7% WIN]


Train on 2824 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.16it/s, 39.5% WIN]


39.5% Rejected new model
Iteration 29


tmp vs. tmp: 100%|██████████| 128/128 [00:22<00:00,  5.71it/s, 50.0% WIN]


Train on 2923 samples


tmp vs. checkpoint: 100%|██████████| 128/128 [00:22<00:00,  5.63it/s, 45.7% WIN]


45.7% Rejected new model
Iteration 30


tmp vs. tmp: 100%|██████████| 128/128 [00:20<00:00,  6.14it/s, 53.9% WIN]


Train on 2919 samples


tmp vs. checkpoint:  48%|████▊     | 61/128 [00:18<00:08,  8.05it/s, 36.1% WIN]

# Evaluate

Play against our AI

In [None]:
data.make_episodes(checkpoint_policy_args, human_policy_args, 1)