Skip to content

Commit 684191d

Browse files
feat(a3c): add exponential decay to learning rate
- the model learns well but never seems to get beyond a certain perf. I suspect the higher learning rate (0.01) is good up front but needs to decrease over time.
1 parent 16f86ff commit 684191d

File tree

4 files changed

+13
-21
lines changed

4 files changed

+13
-21
lines changed

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,13 @@ def compute_policy_value_loss(
490490
# Scale the policy loss down by the seq_len to make invariant to length
491491
total_loss = value_loss + policy_loss + entropy_loss + rp_loss
492492
prefix = self.tb_prefix
493-
tf.summary.scalar(f"{prefix}/policy_loss", data=policy_loss, step=step)
494-
tf.summary.scalar(f"{prefix}/value_loss", data=value_loss, step=step)
495-
tf.summary.scalar(f"{prefix}/entropy_loss", data=entropy_loss, step=step)
496-
tf.summary.scalar(f"{prefix}/rp_loss", data=rp_loss, step=step)
493+
tf.summary.scalar(f"losses/{prefix}/policy_loss", data=policy_loss, step=step)
494+
tf.summary.scalar(f"losses/{prefix}/value_loss", data=value_loss, step=step)
495+
tf.summary.scalar(f"losses/{prefix}/entropy_loss", data=entropy_loss, step=step)
496+
tf.summary.scalar(f"losses/{prefix}/rp_loss", data=rp_loss, step=step)
497497
tf.summary.scalar(
498-
f"{prefix}/advantage", data=tf.reduce_mean(advantage), step=step
498+
f"settings/learning_rate", data=self.optimizer.lr(step), step=step
499499
)
500-
tf.summary.scalar(
501-
f"{prefix}/entropy", data=tf.reduce_mean(h_loss.extra.entropy), step=step
502-
)
503-
504500
return (
505501
(
506502
policy_loss,

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ class Config:
3737
train: bool = False
3838
verbose: bool = False
3939
# Initial learning rate that decays over time.
40-
lr: float = 0.01
40+
lr_initial: float = 0.01
41+
lr_decay_steps: int = 1000
42+
lr_decay_rate: float = 0.96
43+
lr_decay_staircase: bool = True
4144
max_eps: int = 15000
4245
# How often to write histograms to tensorboard (in training steps)
4346
summary_interval: int = 100

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def __init__(
4242
if args is None:
4343
args = BaseConfig()
4444
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
45-
args.lr, decay_steps=100000, decay_rate=0.96, staircase=True
45+
args.lr_initial,
46+
decay_steps=args.lr_decay_steps,
47+
decay_rate=args.lr_decay_rate,
48+
staircase=args.lr_decay_staircase,
4649
)
4750
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
4851
self.args = args

libraries/mathy_python/mathy/cli.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,6 @@ def cli_print_problems(environment: str, difficulty: str, number: int):
131131
type=int,
132132
help="Number of dimensions to use for math vectors and model dimensions",
133133
)
134-
@click.option(
135-
"lr",
136-
"--lr",
137-
default=0.01,
138-
type=float,
139-
help="The learning rate to use with adam optimizer",
140-
)
141134
@click.option(
142135
"embeddings",
143136
"--embeddings",
@@ -205,7 +198,6 @@ def cli_train(
205198
profile: bool,
206199
episodes: int,
207200
mcts_sims: int,
208-
lr: float,
209201
show: bool,
210202
verbose: bool,
211203
training_iterations: int,
@@ -237,7 +229,6 @@ def cli_train(
237229
topics=topics_list,
238230
units=units,
239231
embedding_units=embeddings,
240-
lr=lr,
241232
mcts_sims=mcts_sims,
242233
model_dir=folder,
243234
init_model_from=transfer,
@@ -257,7 +248,6 @@ def cli_train(
257248
verbose=verbose,
258249
difficulty=difficulty,
259250
topics=topics_list,
260-
lr=lr,
261251
units=units,
262252
embedding_units=embeddings,
263253
mcts_sims=mcts_sims,

0 commit comments

Comments
 (0)