Skip to content

Commit

Permalink
Merge branch 'master' into 656-using-patch-miner-without-label-header…
Browse files Browse the repository at this point in the history
…-results-in-failure
  • Loading branch information
sarthakpati authored May 22, 2023
2 parents f2787c6 + b6dc393 commit fb9dc83
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
2 changes: 2 additions & 0 deletions GANDLF/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
adadelta,
adagrad,
rmsprop,
radam,
)

global_optimizer_dict = {
Expand All @@ -22,6 +23,7 @@
"adadelta": adadelta,
"adagrad": adagrad,
"rmsprop": rmsprop,
"radam": radam,
}


Expand Down
21 changes: 21 additions & 0 deletions GANDLF/optimizers/wrap_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Adadelta,
Adagrad,
RMSprop,
RAdam
)


Expand Down Expand Up @@ -223,3 +224,23 @@ def rmsprop(parameters):
weight_decay=parameters["optimizer"].get("weight_decay", 0),
)

def radam(parameters):
"""
Creates a RAdam optimizer from the PyTorch `torch.optim` module using the input parameters.
Args:
parameters (dict): A dictionary containing the input parameters for the optimizer.
Returns:
optimizer (torch.optim.RAdam): A RAdam optimizer.
"""
# Create the optimizer using the input parameters
return RAdam(
parameters["model_parameters"],
lr=parameters.get("learning_rate"),
betas=parameters["optimizer"].get("betas", (0.9, 0.999)),
eps=parameters["optimizer"].get("eps", 1e-8),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
foreach=parameters["optimizer"].get("foreach", None),
)

0 comments on commit fb9dc83

Please sign in to comment.