forked from Zeta36/chess-alpha-zero
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A rebase of my previous more-concurrency branch, which I messed up accidentally. @Akababa I perhaps simplified things after all, by using (the traditional) `concurrent.futures.as_completed` in `self_play.py`. The idea is that since self-play will have to be halted _anyway_ while the model is refreshed (bad things happened if I didn't do this), you might as well just schedule `nb_game_in_file` futures up front and let them run out the old way. (Instead of continuously rolling over futures.) Then flush, reload, and repeat.
- Loading branch information
1 parent
588b4d4
commit f9d30b3
Showing
16 changed files
with
455 additions
and
490 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,39 @@ | ||
from chess_zero.config import Config | ||
|
||
from threading import Thread | ||
import numpy as np | ||
import multiprocessing as mp | ||
from multiprocessing import connection | ||
import time | ||
|
||
class ChessModelAPI: | ||
def __init__(self, config: Config, model): | ||
self.config = config | ||
def __init__(self, model): | ||
self.model = model | ||
self.pipes = [] | ||
|
||
def predict(self, x): | ||
assert x.ndim in (3, 4) | ||
input_stack_height = self.config.model.input_stack_height | ||
assert x.shape == (input_stack_height, 8, 8) or x.shape[1:] == (input_stack_height, 8, 8) # should I get rid of these assertions...? they will change. | ||
orig_x = x | ||
if x.ndim == 3: | ||
x = x.reshape(1, input_stack_height, 8, 8) | ||
def start(self): | ||
prediction_worker = Thread(target=self.predict_batch_worker, name="prediction_worker") | ||
prediction_worker.daemon = True | ||
prediction_worker.start() | ||
|
||
with self.model.graph.as_default(): | ||
policy, value = self.model.model.predict_on_batch(x) | ||
def get_pipe(self): | ||
me, you = mp.Pipe() | ||
self.pipes.append(me) | ||
return you | ||
|
||
if orig_x.ndim == 3: | ||
return policy[0], value[0] | ||
else: | ||
return policy, value | ||
def predict_batch_worker(self): | ||
while True: | ||
ready = mp.connection.wait(self.pipes, timeout=0.001) | ||
if not ready: | ||
continue | ||
data, result_pipes = [], [] | ||
for pipe in ready: | ||
while pipe.poll(): | ||
data.append(pipe.recv()) | ||
result_pipes.append(pipe) | ||
if not data: | ||
continue | ||
data = np.asarray(data, dtype=np.float32) | ||
with self.model.graph.as_default(): | ||
policy_ary, value_ary = self.model.model.predict_on_batch(data) | ||
for pipe, p, v in zip(result_pipes, policy_ary, value_ary): | ||
pipe.send((p, float(v))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.