Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow variable stepping APIs for different optimizers #39

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions hippynn/experiment/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .device import set_devices
from .. import tools
from .assembly import TrainingModules
from .step_functions import get_step_function

from .. import custom_kernels

Expand Down Expand Up @@ -419,12 +420,13 @@ def training_loop(

epoch = metric_tracker.current_epoch
device = evaluator.model_device
step_function = get_step_function(controller.optimizer)
optimizer = controller.optimizer

continue_training = True # Assume that nobody ran this function without wanting at least 1 epoch.

while continue_training:

optimizer = controller.optimizer
qprint("_" * 50)
qprint("Epoch {}:".format(epoch))
tools.print_lr(optimizer)
Expand All @@ -442,20 +444,13 @@ def training_loop(
batch_targets = batch[-n_targets:]
batch_targets = [x.requires_grad_(False) for x in batch_targets]

optimizer.zero_grad(set_to_none=True)
batch_model_outputs = model(*batch_inputs)

# The extra .mean call here deals with an edge case for multi-GPU DataParallel with scalar outputs
batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean()

batch_train_loss.backward()
optimizer.step()
batch_model_outputs = step_function(optimizer, model, loss, batch_inputs, batch_targets)

if batch_callbacks:
for cb in batch_callbacks:
cb(batch_inputs, batch_model_outputs, batch_targets)
# Allow garbage collection of computed values.
del batch_model_outputs, batch_train_loss
del batch_model_outputs

elapsed_epoch_run_time = timeit.default_timer() - epoch_run_time
qprint("Training time: ", round(elapsed_epoch_run_time, 2), "s")
Expand Down
90 changes: 90 additions & 0 deletions hippynn/experiment/step_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
This file implements various stepping protocols used by different optimizer APIs.

In particular:
- The "standard" step function which only requires that backwards has been called.
- The "closure" step function for when line search is required (currently only active on LBFGS)
- The "two step" style of Sharpness Aware Minimization algorithms

The main output function here is `get_step_function(optimizer)-> callable`.

The various step functions are provided as classes that act with staticmethods.
This is to provide for the possibility of extension, for example, to schemes with
stepping schemes that require additional state, or for the possibility to specifiy
the step function explicitly within the controller.
"""
from torch.optim import Optimizer, LBFGS


def standard_step_fn(optimizer, model, loss, batch_inputs, batch_targets):
optimizer.zero_grad(set_to_none=True)
batch_model_outputs = model(*batch_inputs)

# The extra .mean call here deals with an edge case for multi-GPU DataParallel with scalar outputs
batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean()

batch_train_loss.backward()
optimizer.step()
return batch_model_outputs


def twostep_step_fn(optimizer, model, loss, batch_inputs, batch_targets):
# Step function for SAM algorithm.
optimizer.zero_grad(set_to_none=True)

batch_model_outputs = model(*batch_inputs)
batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean()
batch_train_loss.backward()
optimizer.first_step(zero_grad=True)

batch_model_outputs_2 = model(*batch_inputs)
loss(*batch_model_outputs_2, *batch_targets)[0].mean().backward()
optimizer.second_step(zero_grad=True)
return batch_model_outputs


def closure_step_fn(optimizer, model, loss, batch_inputs, batch_targets):
return_outputs = None

def closure():
nonlocal return_outputs
optimizer.zero_grad(set_to_none=True)
batch_model_outputs = model(*batch_inputs)
if return_outputs is None:
return_outputs = batch_model_outputs
batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean()
batch_train_loss.backward()
return batch_train_loss

optimizer.step(closure)
return return_outputs


# Note: The staticmethod version here can be re-written using class parameters
# and __init_subclass, but will they always be staticmethods?
class StepFn:
step = NotImplemented

def __call__(self, *args, **kwargs):
return self.step(*args, **kwargs)


class StandardStep(StepFn):
step = staticmethod(standard_step_fn)


class TwoStep(StepFn):
step = staticmethod(twostep_step_fn)


class ClosureStep(StepFn):
step = staticmethod(closure_step_fn)


def get_step_function(optimizer: Optimizer) -> callable:
if type(optimizer).__name__ == "SAM":
return TwoStep()
if isinstance(optimizer, (LBFGS,)):
return ClosureStep()
else:
return StandardStep()
Loading