Skip to content

Commit

Permalink
pt: add explicit decay_rate for lr (#3445)
Browse files Browse the repository at this point in the history
This is for multitask training, when explicitly setting decay_rate is
much more convenient for long training.
  • Loading branch information
iProzd committed Mar 20, 2024
1 parent be95d09 commit 47366f6
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
44 changes: 31 additions & 13 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@


class LearningRateExp:
def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
"""Construct an exponential-decayed learning rate.
def __init__(
self,
start_lr,
stop_lr,
decay_steps,
stop_steps,
decay_rate=None,
**kwargs,
):
"""
Construct an exponential-decayed learning rate.
Args:
- start_lr: Initial learning rate.
- stop_lr: Learning rate at the last step.
- decay_steps: Decay learning rate every N steps.
- stop_steps: When is the last step.
Parameters
----------
start_lr
The learning rate at the start of the training.
stop_lr
The desired learning rate at the end of the training.
When decay_rate is explicitly set, this value will serve as
the minimum learning rate during training. In other words,
if the learning rate decays below stop_lr, stop_lr will be applied instead.
decay_steps
The learning rate is decaying every this number of training steps.
stop_steps
The total training steps for learning rate scheduler.
decay_rate
The decay rate for the learning rate.
If provided, the decay rate will be set instead of
calculating it through interpolation between start_lr and stop_lr.
"""
self.start_lr = start_lr
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
Expand All @@ -20,12 +41,9 @@ def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
if "decay_rate" in kwargs:
self.decay_rate = kwargs["decay_rate"]
if "min_lr" in kwargs:
self.min_lr = kwargs["min_lr"]
else:
self.min_lr = 3e-10
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = stop_lr

def value(self, step):
"""Get the learning rate at the given step."""
Expand Down
19 changes: 18 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,15 +1517,32 @@ def linear_ener_model_args() -> Argument:
# --- Learning rate configurations: --- #
def learning_rate_exp():
doc_start_lr = "The learning rate at the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_stop_lr = (
"The desired learning rate at the end of the training. "
f"When decay_rate {doc_only_pt_supported}is explicitly set, "
"this value will serve as the minimum learning rate during training. "
"In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead."
)
doc_decay_steps = (
"The learning rate is decaying every this number of training steps."
)
doc_decay_rate = (
"The decay rate for the learning rate. "
"If this is provided, it will be used directly as the decay rate for learning rate "
"instead of calculating it through interpolation between start_lr and stop_lr."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps),
Argument(
"decay_rate",
float,
optional=True,
default=None,
doc=doc_only_pt_supported + doc_decay_rate,
),
]
return args

Expand Down
47 changes: 47 additions & 0 deletions source/tests/pt/test_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_consistency(self):
self.decay_step = decay_step
self.stop_step = stop_step
self.judge_it()
self.decay_rate_pt()

def judge_it(self):
base_lr = learning_rate.LearningRateExp(
Expand Down Expand Up @@ -54,6 +55,52 @@ def judge_it(self):
self.assertTrue(np.allclose(base_vals, my_vals))
tf.reset_default_graph()

def decay_rate_pt(self):
my_lr = LearningRateExp(
self.start_lr, self.stop_lr, self.decay_step, self.stop_step
)

default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
if self.decay_step >= self.stop_step:
self.decay_step = default_ds
decay_rate = np.exp(
np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step)
)
my_lr_decay = LearningRateExp(
self.start_lr,
1e-10,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
)
min_lr = 1e-5
my_lr_decay_trunc = LearningRateExp(
self.start_lr,
min_lr,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
)
my_vals = [
my_lr.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay = [
my_lr_decay.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay_trunc = [
my_lr_decay_trunc.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
self.assertTrue(np.allclose(my_vals_decay, my_vals))
self.assertTrue(
np.allclose(my_vals_decay_trunc, np.clip(my_vals, a_min=min_lr, a_max=None))
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 47366f6

Please sign in to comment.