Skip to content

Commit b04fbce

Browse files
feat(a3c): replace set_weights/get_weights with thinc from_bytes/to_bytes
- this ended up being a whole bunch of a pain in the ass because of thread syncronization troubles. - it turns out that `keras_model.predict` is not really threadsafe if you're passing references to models from other threads. - remove all the handshake nonsense that is now handled by thinc internally to get the subclassed models ready for training/inference.
1 parent 1a28a9d commit b04fbce

File tree

2 files changed

+16
-53
lines changed

2 files changed

+16
-53
lines changed

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import gc
21
import math
32
import os
43
import queue
@@ -30,8 +29,6 @@
3029
from .config import A3CConfig
3130
from .util import record, truncate
3231

33-
gc.set_debug(gc.DEBUG_UNCOLLECTABLE | gc.DEBUG_SAVEALL)
34-
3532

3633
class A3CWorker(threading.Thread):
3734

@@ -466,11 +463,11 @@ def update_global_network(
466463

467464
self.optimizer.apply_gradients(zipped_gradients)
468465
# Update local model with new weights
469-
# TODO: This fails with a thread local error @honnibal
470-
# self.local_model.from_bytes(self.global_model.to_bytes())
471-
self.local_model.unwrapped.set_weights(
472-
self.global_model.unwrapped.get_weights()
473-
)
466+
self.local_model.from_bytes(self.global_model.to_bytes())
467+
# self.local_model.unwrapped.set_weights(
468+
# self.global_model.unwrapped.get_weights()
469+
# )
470+
474471
episode_memory.clear()
475472

476473
def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvState):

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939

4040

4141
@keras_subclass(
42-
"TFPVModel.v0", X=eg.to_inputs(), Y=eg.mask, input_shape=eg.to_input_shapes()
42+
"TFPVModel.v0",
43+
X=eg.to_inputs(as_tf_tensor=False),
44+
Y=eg.mask,
45+
input_shape=eg.to_input_shapes(),
46+
compile_args={"loss": "binary_crossentropy", "metrics": ["accuracy"]},
4347
)
4448
class TFPVModel(tf.keras.Model):
4549
args: BaseConfig
@@ -183,23 +187,16 @@ def PolicyValueModel(
183187
tf_model = TFPVModel(args, predictions, **kwargs)
184188
return TensorFlowWrapper(
185189
tf_model,
186-
build_model=True,
187-
input_shape=eg.to_input_shapes(),
188190
model_class=ThincPolicyValueModel,
189191
model_name="agent",
192+
optimizer=tf_model.optimizer,
190193
)
191194

192195

193196
def _load_model(
194-
model: ThincPolicyValueModel,
195-
model_file: str,
196-
optimizer_file: str,
197-
build_fn: Callable[[ThincPolicyValueModel], None] = None,
197+
model: ThincPolicyValueModel, model_file: str, optimizer_file: str,
198198
) -> ThincPolicyValueModel:
199199
model.from_disk(model_file)
200-
if build_fn is not None:
201-
build_fn(model)
202-
model.unwrapped._make_train_function()
203200
with open(optimizer_file, "rb") as f:
204201
weight_values = pickle.load(f)
205202
model.unwrapped.optimizer.set_weights(weight_values)
@@ -248,28 +245,15 @@ def get_or_create_policy_model(
248245
model = PolicyValueModel(args=args, predictions=predictions, name="agent")
249246
init_inputs = initial_state.to_inputs()
250247

251-
def handshake_keras(m: ThincPolicyValueModel):
252-
253-
m.unwrapped.compile(
254-
optimizer=m.unwrapped.optimizer,
255-
loss="binary_crossentropy",
256-
metrics=["accuracy"],
257-
)
258-
m.unwrapped.build(initial_state.to_input_shapes())
259-
m.unwrapped.predict(init_inputs)
260-
m.predict_next(init_inputs)
261-
262-
handshake_keras(model)
263-
264248
opt = f"{model_path}.optimizer"
265249
mod = f"{model_path}.bytes"
266250
if os.path.exists(mod):
267251
if is_main and args.verbose:
268252
with msg.loading(f"Loading model: {mod}..."):
269-
_load_model(model, mod, opt, build_fn=handshake_keras)
253+
_load_model(model, mod, opt)
270254
msg.good(f"Loaded model: {mod}")
271255
else:
272-
_load_model(model, mod, opt, build_fn=handshake_keras)
256+
_load_model(model, mod, opt)
273257

274258
# If we're doing transfer, reset optimizer steps
275259
if is_main and args.init_model_from is not None:
@@ -303,35 +287,17 @@ def load_policy_value_model(
303287
raise ValueError(f"model not found: {model_file}")
304288
if not optimizer_file.exists():
305289
raise ValueError(f"optimizer not found: {optimizer_file}")
306-
307290
env: MathyEnv = PolySimplify()
308291
observation: MathyObservation = env.state_to_observation(
309292
env.get_initial_state()[0], rnn_size=args.lstm_units
310293
)
311294
initial_state: MathyWindowObservation = observations_to_window([observation])
312295
model = PolicyValueModel(args=args, predictions=env.action_size, name="agent")
313296
init_inputs = initial_state.to_inputs()
314-
315-
def handshake_keras(m: ThincPolicyValueModel):
316-
317-
m.unwrapped.compile(
318-
optimizer=m.unwrapped.optimizer,
319-
loss="binary_crossentropy",
320-
metrics=["accuracy"],
321-
)
322-
m.unwrapped.build(initial_state.to_input_shapes())
323-
m.unwrapped.predict(init_inputs)
324-
m.predict_next(init_inputs)
325-
326-
handshake_keras(model)
327297
if not silent:
328298
with msg.loading(f"Loading model: {model_file}..."):
329-
_load_model(
330-
model, str(model_file), str(optimizer_file), build_fn=handshake_keras
331-
)
299+
_load_model(model, str(model_file), str(optimizer_file))
332300
msg.good(f"Loaded model: {model_file}")
333301
else:
334-
_load_model(
335-
model, str(model_file), str(optimizer_file), build_fn=handshake_keras
336-
)
302+
_load_model(model, str(model_file), str(optimizer_file))
337303
return model, args

0 commit comments

Comments
 (0)