Skip to content

Commit

Permalink
Merge pull request #111 from luigibonati/lr_scheduler
Browse files Browse the repository at this point in the history
Add lr_scheduler options
  • Loading branch information
luigibonati committed Dec 22, 2023
2 parents c6381fd + 28c2e1b commit fd7bce8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 5 deletions.
77 changes: 74 additions & 3 deletions docs/notebooks/tutorials/intro_3_loss_optim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The optimizer used is returned by the function `configure_optimizers` which is called by the lightning trainer. The default optimizer is `Adam`. To change it, or to customize the optimizer's arguments, you can interact with the CV's members `optimizer_name` and `optimizer_kwargs`.\n",
"The optimizer used is returned by the function `configure_optimizers` which is called by the lightning trainer. The default optimizer is `Adam`. To change it, or to customize the optimizer's arguments, you can interact with the CV's members `optimizer_name` and `optimizer_kwargs`. \n",
"\n",
"For instance, this could be used to add an L2 regularization through the `weight_decay` argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -59,9 +59,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/etrizio@iit.local/Bin/miniconda3/envs/mlcvs_test/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -87,6 +95,69 @@
"print(f'Arguments: {cv.optimizer_kwargs}')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Options to the default `Adam` optimizer can also be passed using the `options` parameter of the CV model using the keyword `optimizer` in the dictionary. The provided options will be registered in `optimizer_kwargs`.\n",
"\n",
"For example we can set the `lr` and the `weight_decay`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"optimizer_kwargs: {'lr': 0.002, 'weight_decay': 0.0001}\n"
]
}
],
"source": [
"# define optimizer options\n",
"options = {'optimizer' : {'lr' : 2e-3, 'weight_decay' : 1e-4} }\n",
"\n",
"# define example CV\n",
"cv = RegressionCV(layers=[10,5,5,1], options=options)\n",
"\n",
"print(f'optimizer_kwargs: {cv.optimizer_kwargs}')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also associate to the optimizer a **learning rate scheduler**, which allows to modify the learning rate of the optimizer as the optimization proceeds to facilitate the training. For example, to reduce the learning rate as a function of the epochs.\n",
"\n",
"To do this we can easily use the schedulers implemented in `torch.optim.lr_scheduler`.\n",
"\n",
"This can also be passed using the `options` parameter of the CV model using the keyword `lr_scheduler` in the dictionary. \n",
"The scheduler object **must** be included under the key `scheduler`, the parameters of the chosen scheduler should be passed under the corresponding names."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"# choose the scheduler\n",
"lr_scheduler = torch.optim.lr_scheduler.ExponentialLR # requires gamma as parameter\n",
"\n",
"# define scheduler options\n",
"options = {'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999} }\n",
"\n",
"# define example CV\n",
"cv = RegressionCV(layers=[10,5,5,1], options=options)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
15 changes: 13 additions & 2 deletions mlcolvar/cvs/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
# OPTIM
self._optimizer_name = "Adam"
self.optimizer_kwargs = {}
self.lr_scheduler_kwargs = {}

# PRE/POST
self.preprocessing = preprocessing
Expand Down Expand Up @@ -85,9 +86,11 @@ def parse_options(self, options: dict = None):
if o not in self.BLOCKS:
if o == "optimizer":
self.optimizer_kwargs.update(options[o])
elif o == "lr_scheduler":
self.lr_scheduler_kwargs.update(options[o])
else:
raise ValueError(
f'The key {o} is not available in this class. The available keys are: {", ".join(self.BLOCKS)}, and optimizer.'
f'The key {o} is not available in this class. The available keys are: {", ".join(self.BLOCKS)}, optimizer and lr_scheduler.'
)

return options
Expand Down Expand Up @@ -195,10 +198,18 @@ def configure_optimizers(self):
torch.optim
Torch optimizer
"""

optimizer = getattr(torch.optim, self._optimizer_name)(
self.parameters(), **self.optimizer_kwargs
)
return optimizer

if self.lr_scheduler_kwargs:
scheduler_cls = self.lr_scheduler_kwargs['scheduler']
scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'}
lr_scheduler = scheduler_cls(**scheduler_kwargs)
return [optimizer] , [lr_scheduler]
else:
return optimizer

def __setattr__(self, key, value):
# PyTorch overrides __setattr__ to raise a TypeError when you try to assign
Expand Down

0 comments on commit fd7bce8

Please sign in to comment.