Skip to content

Commit

Permalink
Support "MultiStepLR" and "StepLR" learning rate schedulers.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Dec 8, 2019
1 parent fd5f55b commit 0833137
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
26 changes: 26 additions & 0 deletions ml4chem/optim/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,36 @@ def get_lr_scheduler(optimizer, lr_scheduler):
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
name = "ReduceLROnPlateau"

elif scheduler_name == "multisteplr":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **kwargs)
name = "MultiStepLR"

elif scheduler_name == "steplr":
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **kwargs)
name = "StepLR"


logger.info("Learning Rate Scheduler")
logger.info("-----------------------")
logger.info(" - Name: {}.".format(name))
logger.info(" - Args: {}.".format(kwargs))
logger.info("")

return scheduler


def get_lr(optimizer):
"""Get current learning rate
Parameters
----------
optimizer : obj
An optimizer object.
Returns
-------
lr
Current learning rate.
"""
for param_group in optimizer.param_groups:
return param_group['lr']
2 changes: 2 additions & 0 deletions ml4chem/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def train(
lossfxn=None,
regularization=0.0,
batch_size=None,
**kwargs
):
"""Method to train models
Expand Down Expand Up @@ -315,6 +316,7 @@ def train(
lossfxn=lossfxn,
device=device,
batch_size=batch_size,
**kwargs
)

self.save(self.model, features=self.features, path=self.path, label=self.label)
Expand Down

0 comments on commit 0833137

Please sign in to comment.