Skip to content

Commit 92695d6

Browse files
feat: add mathy.example helper for generating inputs
- it uses a random problem from the PolySimplify env and returns a window observation which can be used to derive X/Y/shape and friends
1 parent 0c76609 commit 92695d6

File tree

10 files changed

+217
-72
lines changed

10 files changed

+217
-72
lines changed

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import math
23
import os
34
import queue
@@ -9,10 +10,9 @@
910
import gym
1011
import numpy as np
1112
import tensorflow as tf
13+
from memory_profiler import profile
1214
from wasabi import msg
1315

14-
from ...util import print_error
15-
1616
from ...envs.gym.mathy_gym_env import MathyGymEnv
1717
from ...state import (
1818
MathyEnvState,
@@ -22,7 +22,7 @@
2222
observations_to_window,
2323
)
2424
from ...teacher import Teacher
25-
from ...util import calculate_grouping_control_signal, discount
25+
from ...util import calculate_grouping_control_signal, discount, print_error
2626
from .. import action_selectors
2727
from ..episode_memory import EpisodeMemory
2828
from ..mcts import MCTS
@@ -31,6 +31,8 @@
3131
from .config import A3CConfig
3232
from .util import record, truncate
3333

34+
gc.set_debug(gc.DEBUG_UNCOLLECTABLE | gc.DEBUG_SAVEALL)
35+
3436

3537
class A3CWorker(threading.Thread):
3638

@@ -242,7 +244,7 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
242244
seq_start = env.state.to_start_observation(rnn_state_h, rnn_state_c)
243245
try:
244246
window = observations_to_window([seq_start, last_observation])
245-
selector.model([window.to_inputs()], is_train=True)
247+
selector.model([window.to_inputs()], is_train=False)
246248
except BaseException as err:
247249
print_error(
248250
err, f"Episode begin thinking steps prediction failed.",
@@ -462,7 +464,11 @@ def update_global_network(
462464

463465
self.optimizer.apply_gradients(zipped_gradients)
464466
# Update local model with new weights
465-
self.local_model.unwrapped.set_weights(self.global_model.unwrapped.get_weights())
467+
# TODO: This fails with a thread local error @honnibal
468+
# self.local_model.from_bytes(self.global_model.to_bytes())
469+
self.local_model.unwrapped.set_weights(
470+
self.global_model.unwrapped.get_weights()
471+
)
466472
episode_memory.clear()
467473

468474
def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvState):
@@ -519,7 +525,7 @@ def compute_policy_value_loss(
519525
bootstrap_value = 0.0 # terminal
520526
else:
521527
# Predict the reward using the local network
522-
_, values, _ = self.local_model.unwrapped.call(
528+
_, values, _ = self.local_model.predict(
523529
observations_to_window([observation]).to_inputs()
524530
)
525531
# Select the last timestep
@@ -538,11 +544,10 @@ def compute_policy_value_loss(
538544
batch_size = len(episode_memory.actions)
539545
sequence_length = len(episode_memory.observations[0].nodes)
540546
inputs = episode_memory.to_episode_window().to_inputs()
541-
logits, values, trimmed_logits = self.local_model.unwrapped(inputs, apply_mask=False)
547+
logits, values, trimmed_logits = self.local_model.unwrapped(inputs)
542548
# TODO: don't call unwrapped here
543549

544550
logits = tf.reshape(logits, [batch_size, -1])
545-
policy_logits = tf.reshape(trimmed_logits, [batch_size, -1])
546551

547552
# Calculate entropy and policy loss
548553
h_loss = discrete_policy_entropy_loss(
@@ -570,7 +575,7 @@ def compute_policy_value_loss(
570575

571576
# Policy Loss
572577
policy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
573-
labels=episode_memory.actions, logits=policy_logits
578+
labels=episode_memory.actions, logits=logits
574579
)
575580

576581
policy_loss *= advantage

libraries/mathy_python/mathy/agents/action_selectors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
class ActionSelector:
1212
"""An episode-specific selector of actions"""
1313

14+
model: ThincPolicyValueModel
15+
worker_id: int
16+
episode: int
17+
1418
def __init__(self, *, model: ThincPolicyValueModel, episode: int, worker_id: int):
1519
self.model = model
1620
self.worker_id = worker_id
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
3+
from ..envs import PolySimplify
4+
from ..state import MathyWindowObservation, observations_to_window
5+
from .base_config import BaseConfig
6+
7+
8+
def example() -> MathyWindowObservation:
9+
"""Helper to return a random Window observation that can be
10+
passed forward through a Mathy model. """
11+
env = PolySimplify()
12+
state = env.get_initial_state()[0]
13+
observation = env.state_to_observation(state, rnn_size=BaseConfig().lstm_units)
14+
return observations_to_window([observation])

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import os
22
import pickle
33
import time
4-
from shutil import copyfile
5-
from typing import Any, Dict, Optional, Tuple, List, Callable
64
from pathlib import Path
5+
from shutil import copyfile
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
7+
from memory_profiler import profile
8+
79
import numpy as np
810
import srsly
911
import tensorflow as tf
12+
import thinc
1013
from tensorflow.keras import backend as K
14+
from thinc.api import TensorFlowWrapper, keras_subclass, tensorflow2xp, xp2tensorflow
15+
from thinc.backends import Ops, get_current_ops
16+
from thinc.layers import Linear
17+
from thinc.model import Model
18+
from thinc.optimizers import Adam
19+
from thinc.shims.tensorflow import TensorFlowShim
20+
from thinc.types import Array, Array1d, Array2d, ArrayNd
21+
from thinc.util import has_tensorflow, to_categorical
1122
from wasabi import msg
1223

24+
from mathy.agents.example import example
25+
1326
from ..env import MathyEnv
1427
from ..envs import PolySimplify
1528
from ..state import (
@@ -22,17 +35,13 @@
2235
from ..util import print_error
2336
from .base_config import BaseConfig
2437
from .embedding import MathyEmbedding
25-
import thinc
26-
from thinc.layers import Linear
27-
from thinc.api import TensorFlowWrapper, tensorflow2xp, xp2tensorflow
28-
from thinc.backends import Ops, get_current_ops
29-
from thinc.model import Model
30-
from thinc.optimizers import Adam
31-
from thinc.types import Array, Array1d, Array2d, ArrayNd
32-
from thinc.shims.tensorflow import TensorFlowShim
33-
from thinc.util import has_tensorflow, to_categorical
3438

39+
eg = example()
3540

41+
42+
@keras_subclass(
43+
"TFPVModel.v0", X=eg.to_inputs(), Y=eg.mask, input_shape=eg.to_input_shapes()
44+
)
3645
class TFPVModel(tf.keras.Model):
3746
args: BaseConfig
3847
optimizer: tf.optimizers.Optimizer
@@ -83,7 +92,13 @@ def compute_output_shape(
8392
)
8493

8594
def call(
86-
self, features_window: MathyInputsType, apply_mask=True
95+
self, features_window: MathyInputsType
96+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
97+
return self._call(features_window)
98+
99+
# @profile
100+
def _call(
101+
self, features_window: MathyInputsType
87102
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
88103
call_print = self.args.print_model_call_times
89104
nodes = features_window[ObservationFeatureIndices.nodes]
@@ -99,10 +114,13 @@ def call(
99114
values = self.normalize_v(self.value_logits(self.embedding.state_h))
100115
logits = self.normalize_pi(self.policy_logits(sequence_inputs))
101116
mask_logits = self.apply_pi_mask(logits, features_window)
102-
mask_result = logits if not apply_mask else mask_logits
103117
if call_print is True:
104-
print("call took : {0:03f}".format(time.time() - start))
105-
return logits, values, mask_result
118+
print(
119+
"call took : {0:03f} for batch {1}".format(
120+
time.time() - start, batch_size
121+
)
122+
)
123+
return logits, values, mask_logits
106124

107125
def apply_pi_mask(
108126
self, logits: tf.Tensor, features_window: MathyInputsType,
@@ -125,14 +143,18 @@ def apply_pi_mask(
125143
return negative_mask_logits
126144

127145

128-
class ThincPolicyValueModel(thinc.model.Model[ArrayNd, Tuple[Array1d, Array2d]]):
146+
class ThincPolicyValueModel(
147+
thinc.model.Model[ArrayNd, Tuple[Array2d, Array1d, Array2d]]
148+
):
129149
@property
130150
def unwrapped(self) -> TFPVModel:
131-
tf_shim: TensorFlowShim = self.shims[0]
151+
tf_shim = cast(TensorFlowShim, self.shims[0])
132152
assert isinstance(tf_shim, TensorFlowShim), "only tensorflow shim is supported"
133153
return tf_shim._model
134154

135-
def predict_next(self, inputs: MathyInputsType) -> Tuple[tf.Tensor, tf.Tensor]:
155+
def predict_next(
156+
self, inputs: MathyInputsType, is_train: bool = False
157+
) -> Tuple[tf.Tensor, tf.Tensor]:
136158
"""Predict one probability distribution and value for the
137159
given sequence of inputs """
138160
logits, values, masked = self.unwrapped.call(inputs)
@@ -148,21 +170,22 @@ def save(self) -> None:
148170
model_path = os.path.join(
149171
self.unwrapped.args.model_dir, self.unwrapped.args.model_name
150172
)
173+
save_model_file = f"{model_path}.bytes"
174+
self.to_disk(save_model_file)
151175
with open(f"{model_path}.optimizer", "wb") as f:
152176
pickle.dump(self.unwrapped.optimizer.get_weights(), f)
153-
model_path += ".h5"
154-
self.unwrapped.save_weights(model_path, save_format="keras")
155177
step = self.unwrapped.optimizer.iterations.numpy()
156-
print(f"[save] step({step}) model({model_path})")
178+
print(f"[save] step({step}) model({save_model_file})")
157179

158180

159181
def PolicyValueModel(
160-
args: BaseConfig = None, predictions=2, initial_state: Any = None, **kwargs,
182+
args: BaseConfig = None, predictions=2, **kwargs,
161183
):
162-
tf_model = TFPVModel(args, predictions, initial_state, **kwargs)
184+
tf_model = TFPVModel(args, predictions, **kwargs)
163185
return TensorFlowWrapper(
164186
tf_model,
165-
build_model=False,
187+
build_model=True,
188+
input_shape=eg.to_input_shapes(),
166189
model_class=ThincPolicyValueModel,
167190
model_name="agent",
168191
)
@@ -174,7 +197,7 @@ def _load_model(
174197
optimizer_file: str,
175198
build_fn: Callable[[ThincPolicyValueModel], None] = None,
176199
) -> ThincPolicyValueModel:
177-
model.unwrapped.load_weights(model_file)
200+
model.from_disk(model_file)
178201
if build_fn is not None:
179202
build_fn(model)
180203
model.unwrapped._make_train_function()
@@ -205,8 +228,8 @@ def get_or_create_policy_model(
205228
if is_main and args.init_model_from is not None:
206229
init_model_path = os.path.join(args.init_model_from, args.model_name)
207230
opt = f"{init_model_path}.optimizer"
208-
mod = f"{init_model_path}.h5"
209-
if os.path.exists(f"{model_path}.h5"):
231+
mod = f"{init_model_path}.bytes"
232+
if os.path.exists(f"{model_path}.bytes"):
210233
print_error(
211234
ValueError("Model Exists"),
212235
f"Cannot initialize on top of model: {model_path}",
@@ -215,7 +238,7 @@ def get_or_create_policy_model(
215238
if os.path.exists(opt) and os.path.exists(mod):
216239
print(f"initialize model from: {init_model_path}")
217240
copyfile(opt, f"{model_path}.optimizer")
218-
copyfile(mod, f"{model_path}.h5")
241+
copyfile(mod, f"{model_path}.bytes")
219242
else:
220243
print_error(
221244
ValueError("Model Exists"),
@@ -240,7 +263,7 @@ def handshake_keras(m: ThincPolicyValueModel):
240263
handshake_keras(model)
241264

242265
opt = f"{model_path}.optimizer"
243-
mod = f"{model_path}.h5"
266+
mod = f"{model_path}.bytes"
244267
if os.path.exists(mod):
245268
if is_main and args.verbose:
246269
with msg.loading(f"Loading model: {mod}..."):
@@ -275,7 +298,7 @@ def load_policy_value_model(
275298
if not meta_file.exists():
276299
raise ValueError(f"model meta not found: {meta_file}")
277300
args = BaseConfig(**srsly.read_json(str(meta_file)))
278-
model_file = Path(model_data_folder) / "model.h5"
301+
model_file = Path(model_data_folder) / "model.bytes"
279302
optimizer_file = Path(model_data_folder) / "model.optimizer"
280303
if not model_file.exists():
281304
raise ValueError(f"model not found: {model_file}")

libraries/mathy_python/mathy/agents/zero/trainer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...util import discount
2222
from .. import action_selectors
2323
from ..episode_memory import EpisodeMemory
24-
from ..policy_value_model import PolicyValueModel
24+
from ..policy_value_model import ThincPolicyValueModel
2525
from ..trfl import discrete_policy_entropy_loss, td_lambda
2626
from .config import SelfPlayConfig
2727
from .lib.average_meter import AverageMeter
@@ -33,7 +33,7 @@ class SelfPlayTrainer:
3333
args: SelfPlayConfig
3434

3535
def __init__(
36-
self, args: SelfPlayConfig, model: PolicyValueModel, action_size: int,
36+
self, args: SelfPlayConfig, model: ThincPolicyValueModel, action_size: int,
3737
):
3838
super(SelfPlayTrainer, self).__init__()
3939
import tensorflow as tf
@@ -141,13 +141,11 @@ def compute_policy_value_loss(
141141

142142
batch_size = len(inputs.nodes)
143143
step = self.model.optimizer.iterations
144-
logits, values, trimmed_logits = self.model(
145-
inputs.to_inputs(), apply_mask=False
146-
)
144+
logits, values, _ = self.model(inputs.to_inputs())
147145
value_loss = tf.losses.mean_squared_error(
148146
target_v, tf.reshape(values, shape=[-1])
149147
)
150-
policy_logits = tf.reshape(trimmed_logits, [batch_size, -1])
148+
policy_logits = tf.reshape(logits, [batch_size, -1])
151149
policy_logits = policy_logits[:, : target_pi.shape[1]]
152150
policy_loss = tf.nn.softmax_cross_entropy_with_logits(
153151
labels=target_pi, logits=policy_logits

libraries/mathy_python/mathy/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050

5151
REQUIRED_META_KEYS = ["units", "embedding_units", "lstm_units", "version"]
52-
REQUIRED_MODEL_FILES = ["model.h5", "model.optimizer", "model.config.json"]
52+
REQUIRED_MODEL_FILES = ["model.bytes", "model.optimizer", "model.config.json"]
5353

5454

5555
def load_model(name: str, **overrides) -> Mathy:

0 commit comments

Comments
 (0)