Skip to content

Commit

Permalink
Concurrency (#19)
Browse files Browse the repository at this point in the history
A rebase of my previous more-concurrency branch, which I messed up accidentally.

@Akababa I perhaps simplified things after all, by using (the traditional) `concurrent.futures.as_completed` in `self_play.py`. The idea is that since self-play will have to be halted _anyway_ while the model is refreshed (bad things happened if I didn't do this), you might as well just schedule `nb_game_in_file` futures up front and let them run out the old way. (Instead of continuously rolling over futures.) Then flush, reload, and repeat.
  • Loading branch information
benediamond committed Dec 26, 2017
1 parent 588b4d4 commit f9d30b3
Show file tree
Hide file tree
Showing 16 changed files with 455 additions and 490 deletions.
48 changes: 32 additions & 16 deletions src/chess_zero/agent/api_chess.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
from chess_zero.config import Config

from threading import Thread
import numpy as np
import multiprocessing as mp
from multiprocessing import connection
import time

class ChessModelAPI:
def __init__(self, config: Config, model):
self.config = config
def __init__(self, model):
self.model = model
self.pipes = []

def predict(self, x):
assert x.ndim in (3, 4)
input_stack_height = self.config.model.input_stack_height
assert x.shape == (input_stack_height, 8, 8) or x.shape[1:] == (input_stack_height, 8, 8) # should I get rid of these assertions...? they will change.
orig_x = x
if x.ndim == 3:
x = x.reshape(1, input_stack_height, 8, 8)
def start(self):
prediction_worker = Thread(target=self.predict_batch_worker, name="prediction_worker")
prediction_worker.daemon = True
prediction_worker.start()

with self.model.graph.as_default():
policy, value = self.model.model.predict_on_batch(x)
def get_pipe(self):
me, you = mp.Pipe()
self.pipes.append(me)
return you

if orig_x.ndim == 3:
return policy[0], value[0]
else:
return policy, value
def predict_batch_worker(self):
while True:
ready = mp.connection.wait(self.pipes, timeout=0.001)
if not ready:
continue
data, result_pipes = [], []
for pipe in ready:
while pipe.poll():
data.append(pipe.recv())
result_pipes.append(pipe)
if not data:
continue
data = np.asarray(data, dtype=np.float32)
with self.model.graph.as_default():
policy_ary, value_ary = self.model.model.predict_on_batch(data)
for pipe, p, v in zip(result_pipes, policy_ary, value_ary):
pipe.send((p, float(v)))
34 changes: 18 additions & 16 deletions src/chess_zero/agent/model_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from logging import getLogger
# noinspection PyPep8Naming

import tensorflow as tf
from tensorflow import get_default_graph
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.merge import Add
from keras.layers.normalization import BatchNormalization
from keras.losses import categorical_crossentropy, mean_squared_error
from keras.regularizers import l2
from chess_zero.agent.api_chess import ChessModelAPI

from chess_zero.config import Config

Expand All @@ -23,14 +23,21 @@ class ChessModel:
def __init__(self, config: Config):
self.config = config
self.model = None # type: Model
self.digest = None
self.graph = None
self.digest = None
self.api = None

def get_pipes(self, num=1):
if self.api is None:
self.api = ChessModelAPI(self)
self.api.start()
return [self.api.get_pipe() for _ in range(num)]

def build(self):
mc = self.config.model
in_x = x = Input((mc.input_stack_height, 8, 8))
# (batch, channels, height, width)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg), input_shape=(mc.input_stack_height, 8, 8))(x)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", use_bias=False, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg), input_shape=(mc.input_stack_height, 8, 8))(x)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)

Expand All @@ -39,30 +46,31 @@ def build(self):

res_out = x
# for policy output
x = Conv2D(filters=2, kernel_size=1, data_format="channels_first", kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(res_out)
x = Conv2D(filters=2, kernel_size=1, data_format="channels_first", use_bias=False, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(res_out)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Flatten()(x)
# no output for 'pass'
policy_out = Dense(self.config.n_labels, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg), activation="softmax", name="policy_out")(x)

# for value output
x = Conv2D(filters=1, kernel_size=1, data_format="channels_first", kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(res_out)
x = Conv2D(filters=1, kernel_size=1, data_format="channels_first", use_bias=False, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(res_out)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Flatten()(x)
x = Dense(mc.value_fc_size, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg), activation="relu")(x)
value_out = Dense(1, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg), activation="tanh", name="value_out")(x)

self.model = Model(in_x, [policy_out, value_out], name="chess_model")
self.graph = get_default_graph()

def _build_residual_block(self, x):
mc = self.config.model
in_x = x
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(x)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", use_bias=False, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(x)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(x)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", use_bias=False, kernel_initializer='glorot_normal', bias_initializer='zeros', kernel_regularizer=l2(mc.l2_reg))(x)
x = BatchNormalization(axis=1)(x)
x = Add()([in_x, x])
x = Activation("relu")(x)
Expand All @@ -82,6 +90,8 @@ def load(self, config_path, weight_path):
with open(config_path, "rt") as f:
self.model = Model.from_config(json.load(f))
self.model.load_weights(weight_path)
self.graph = get_default_graph()
# self.model._make_predict_function()
self.digest = self.fetch_digest(weight_path)
logger.debug(f"loaded model digest = {self.digest}")
return True
Expand All @@ -96,11 +106,3 @@ def save(self, config_path, weight_path):
self.model.save_weights(weight_path)
self.digest = self.fetch_digest(weight_path)
logger.debug(f"saved model digest {self.digest}")


def loss_function_for_policy(y_true, y_pred):
return categorical_crossentropy(y_true, y_pred)


def loss_function_for_value(y_true, y_pred):
return mean_squared_error(y_true, y_pred)
Loading

0 comments on commit f9d30b3

Please sign in to comment.