Skip to content

Commit 427352f

Browse files
feat(cli): add --lr for setting adam learning rate
1 parent ed662e3 commit 427352f

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Config:
3636
init_model_from: Optional[str] = None
3737
train: bool = False
3838
verbose: bool = False
39-
lr: float = 3e-4
39+
lr: float = 1e-3
4040
max_eps: int = 15000
4141
# How often to write histograms to tensorboard (in training steps)
4242
summary_interval: int = 100

libraries/mathy_python/mathy/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def cli_print_problems(environment: str, difficulty: str, number: int):
139139
type=int,
140140
help="Number of dimensions to use for math vectors and model dimensions",
141141
)
142+
@click.option(
143+
"lr",
144+
"--lr",
145+
default=3e-4,
146+
type=float,
147+
help="The learning rate to use with adam optimizer",
148+
)
142149
@click.option(
143150
"embeddings",
144151
"--embeddings",
@@ -220,6 +227,7 @@ def cli_train(
220227
profile: bool,
221228
episodes: int,
222229
mcts_sims: int,
230+
lr: float,
223231
show: bool,
224232
verbose: bool,
225233
training_iterations: int,
@@ -253,6 +261,7 @@ def cli_train(
253261
lstm_units=rnn,
254262
units=units,
255263
embedding_units=embeddings,
264+
lr=lr,
256265
mcts_sims=mcts_sims,
257266
model_dir=folder,
258267
init_model_from=transfer,
@@ -274,6 +283,7 @@ def cli_train(
274283
verbose=verbose,
275284
difficulty=difficulty,
276285
topics=topics_list,
286+
lr=lr,
277287
lstm_units=rnn,
278288
units=units,
279289
embedding_units=embeddings,

0 commit comments

Comments
 (0)