Skip to content

Commit

Permalink
【Hackathon 5th No.22】Add CosineAnnealingWarmRestarts API to Paddle (P…
Browse files Browse the repository at this point in the history
…addlePaddle#57744)

* add CosineAnnealingWarmRestarts API

* modify test design

* modify api design

* formate code

* add test example

* fix documentation
  • Loading branch information
Patrick-Star125 committed Nov 3, 2023
1 parent 1a4fd36 commit dfdab28
Show file tree
Hide file tree
Showing 3 changed files with 378 additions and 0 deletions.
163 changes: 163 additions & 0 deletions python/paddle/optimizer/lr.py
Expand Up @@ -46,6 +46,7 @@
'OneCycleLR',
'CyclicLR',
'LinearLR',
'CosineAnnealingWarmRestarts',
]


Expand Down Expand Up @@ -2347,6 +2348,168 @@ def get_lr(self):
return self.last_lr * factor


class CosineAnnealingWarmRestarts(LRScheduler):
r"""
Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
Args:
learning_rate (float): Initial learning rate.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Returns:
``CosineAnnealingWarmRestarts`` instance to schedule learning rate.
Examples:
.. code-block:: python
:name: code-example1
>>> import paddle
>>> import numpy as np
>>> # train on default dynamic graph mode
>>> linear = paddle.nn.Linear(10, 10)
>>> scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(learning_rate=0.5, T_0=1, T_mult=2, verbose=True)
>>> adam = paddle.optimizer.Adam(learning_rate=scheduler, parameters=linear.parameters())
>>> for epoch in range(10):
... for batch_id in range(10):
... x = paddle.uniform([10, 10])
... out = linear(x)
... loss = paddle.mean(out)
... loss.backward()
... adam.step()
... adam.clear_grad()
... scheduler.step(epoch) # You should update learning rate each step
.. code-block:: python
:name: code-example2
>>> import paddle
>>> import numpy as np
>>> paddle.enable_static()
>>> main_prog = paddle.static.Program()
>>> start_prog = paddle.static.Program()
>>> with paddle.static.program_guard(main_prog, start_prog):
... x = paddle.static.data(name='x', shape=[None, 4, 5])
... y = paddle.static.data(name='y', shape=[None, 4, 5])
... z = paddle.static.nn.fc(x, 100)
... loss = paddle.mean(z)
... scheduler = paddle.optimizer.lr.CosineAnnealingWarmRestarts(learning_rate=0.5, T_0=1, T_mult=2,verbose=True)
... sgd = paddle.optimizer.SGD(learning_rate=scheduler)
... sgd.minimize(loss)
>>> exe = paddle.static.Executor()
>>> exe.run(start_prog)
>>> for epoch in range(10):
... for batch_id in range(10):
... out = exe.run(
... main_prog,
... feed={
... 'x': np.random.randn(3, 4, 5).astype('float32'),
... 'y': np.random.randn(3, 4, 5).astype('float32')
... },
... fetch_list=loss.name)
... scheduler.step(epoch) # You should update learning rate each step
"""

def __init__(
self,
learning_rate,
T_0,
T_mult=1,
eta_min=0,
last_epoch=-1,
verbose=False,
):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = last_epoch
super().__init__(learning_rate, last_epoch, verbose)

def get_lr(self):
return (
self.eta_min
+ (self.base_lr - self.eta_min)
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
/ 2
)

def step(self, epoch=None):
"""
step should be called after `optimizer.step()` . It will update the learning rate in optimizer.
The new learning rate will take effect on next epoch.
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current LRScheduler.
"""

if epoch is None and self.last_epoch < 0:
epoch = 0

if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError(
f"Expected non-negative epoch, but got {epoch}"
)
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(
math.log(
(epoch / self.T_0 * (self.T_mult - 1) + 1),
self.T_mult,
)
)
self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
self.T_mult - 1
)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)
self.last_lr = self.get_lr()
if self.verbose:
print(
'Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr
)
)


def autoincreased_step_counter(counter_name=None, begin=1, step=1):
"""
:api_attr: Static Graph
Expand Down
17 changes: 17 additions & 0 deletions test/dygraph_to_static/test_train_step.py
Expand Up @@ -436,5 +436,22 @@ def setUp(self):
self.rtol = 1e-4


class TestTrainStepTinyModelCosineAnnealingWarmRestarts(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.CosineAnnealingWarmRestarts,
learning_rate=0.5,
T_0=1,
T_mult=1,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4


if __name__ == "__main__":
unittest.main()

0 comments on commit dfdab28

Please sign in to comment.