Skip to content

Commit 47ad597

Browse files
feat(config): add print_model_call_times option
1 parent 91e2240 commit 47ad597

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class Config:
4646
# Verbose setting to print out worker_0 training steps. Useful for trying
4747
# to find problems.
4848
print_training: bool = False
49+
# This is very verbose and prints every policy_value_model.call time
50+
print_model_call_times: bool = False
4951
# Print mode for output. "terminal" is the default, also supports "attention"
5052
# NOTE: attention is gone (like... the layer)
5153
print_mode: str = "terminal"

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def compute_output_shape(
7676
def call(
7777
self, features_window: MathyInputsType, apply_mask=True
7878
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
79-
call_print = False
79+
call_print = self.args.print_model_call_times
8080
nodes = features_window[ObservationFeatureIndices.nodes]
8181
batch_size = (
8282
len(nodes)
@@ -123,7 +123,7 @@ def call_graph(
123123
return self.call(inputs)
124124

125125
def predict_next(
126-
self, inputs: MathyInputsType, use_graph=False
126+
self, inputs: MathyInputsType, use_graph: bool = False
127127
) -> Tuple[tf.Tensor, tf.Tensor]:
128128
"""Predict one probability distribution and value for the
129129
given sequence of inputs """

libraries/mathy_python/mathy/cli.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def cli_print_problems(environment: str, difficulty: str, number: int):
149149
@click.option(
150150
"use_lstm",
151151
"--use-lstm",
152-
default=True,
153152
type=bool,
154153
help="Whether to use the recurrent architecture or not",
155154
)
@@ -260,10 +259,11 @@ def cli_train(
260259
num_workers=workers,
261260
profile=profile,
262261
print_training=show,
263-
use_lstm=use_lstm,
264262
)
265263
if episodes is not None:
266264
args.max_eps = episodes
265+
if use_lstm is not None:
266+
args.use_lstm = use_lstm
267267
instance = A3CAgent(args)
268268
instance.train()
269269
elif agent == "zero":
@@ -285,8 +285,9 @@ def cli_train(
285285
self_play_problems=self_play_problems,
286286
print_training=show,
287287
profile=profile,
288-
use_lstm=use_lstm,
289288
)
289+
if use_lstm is not None:
290+
self_play_cfg.use_lstm = use_lstm
290291
if episodes is not None:
291292
self_play_cfg.max_eps = episodes
292293

0 commit comments

Comments
 (0)