Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitoring convergence in gradient-based optimization #293

Open
greaa-aws opened this issue Oct 14, 2022 · 8 comments
Open

Monitoring convergence in gradient-based optimization #293

greaa-aws opened this issue Oct 14, 2022 · 8 comments
Labels
enhancement New feature or request

Comments

@greaa-aws
Copy link

Is your feature request related to a problem? Please describe.

Currently, model training via gradient-based optimization in Tribuo terminates after a fixed number of epochs. The main problem with maximum iteration number as a stopping criterion is that there is no relation between the stopping criterion and the optimality of the current iterate. It is difficult to know a priori how many epochs will be sufficient for a given training problem, and there are costs to over- or under-estimating this number (but especially underestimation).

Describe the solution you'd like

Ideally, for iterative gradient-based optimization we would be able to use problem-specific stopping criteria such as a threshold on relative reduction of the loss or the norm of the gradient. Typically these are accompanied by a (large) max-epoch cutoff to bound computation time and catch cases where the loss diverges. For stochastic algorithms we could also consider early stopping rules, for example based on the loss on a held-out validation set.

Are there any plans to implement zero- or first-order stopping criteria for optimizers extending AbstractSGDTrainer? Are there other workarounds for checking convergence of the optimizer in the case of linear and logistic regression?

Describe alternatives you've considered

An alternative to implementing new stopping criteria could be to (optionally) report some metric(s) relevant to the specific training problem after training is "completed" according to the max-epoch rule. These could include the norm of the gradient or a sequence of loss values at each epoch.

One alternative that does not work in general is to change the optimization algorithm from the standard SGD. All optimizers implement some form of iterative, gradient-based optimization, so they all face the same problem of enforcing an appropriate stopping criterion.

@greaa-aws greaa-aws added the enhancement New feature or request label Oct 14, 2022
@Craigacp
Copy link
Member

Craigacp commented Oct 15, 2022

We're in the process of figuring out what features to put in v5 from the roadmap, and we can certainly look at adding this. It's most likely going to be a breaking change, but that's ok for v5, and there are a few others under consideration (they'll be marked out on the projects page under v5 when we've made the decisions). The other breaking changes are required so we can work on adding online learning and expanding SequenceExample into something more like a structured example we could use for (e.g.) learning to rank.

As a rough idea, I think we'd need two interfaces,

interface StoppingCriterionFactory<T extends Output<T>> extends Configurable, 
ProtoSerializable, Provenanceable<ConfiguredObjectProvenance> {
    StoppingCriterion configure(FeatureIDMap fmap, OutputIDInfo<T> outputInfo);
}

and

interface StoppingCriterion {
    boolean step(int stepCount, double lossValue, Parameters model);
    boolean epochCompleted(int epochCount, double lossValue, Parameters model);
}

Unfortunately I don't think we can get away with something simpler as the stopping criterion needs to be stateful, and Trainers are not, so we need to make a fresh one for each call to trainer.train(). I'm not sure where to put the gradient norm in, passing in a single double seems like it might be very restrictive, especially if we expand the type of models that AbstractSGDTrainer can train. Do you have suggestions for how you'd like to monitor gradient norms?

@greaa-aws
Copy link
Author

Thanks for the update and context here. It would be great to see these features in v5!

I agree that passing in just a double for the gradient norm may be restrictive. While the gradient norm is technically sufficient to monitor convergence in smooth convex problems like linear and logistic regression, in practice it needs to be implemented with a hard stop (e.g. max epochs), and in other problem settings of possible relevance for AbstractSGDTrainer it may not be preferred at all.

As a rough breakdown I think problem-specific stopping criteria can be a function of:

  • The gradient at the current iterate.
  • The training loss at the current and past iterate(s).
  • The validation loss at the current and past iterate(s).

The general input to a problem-specific stopping criterion is then a loss, a model, and a dataset. In addition, for stopping criteria that depend on past (training or validation) losses, we would need to store and recall these values. I am still familiarizing myself with Tribuo’s architecture and defer to your judgment as to whether and how this might be accomplished. I'll continue to dig in here - meanwhile, thanks for the perspective and happy to discuss further.

@Craigacp
Copy link
Member

Yeah, so the training loss I think can go through the method call and be computed by AbstractSGDTrainer. Validation loss I think will probably best come by constructing the stopping criterion factory with a dataset, then at configure time it checks that the dataset is appropriate for the training data (e.g. that it's the right prediction type and the feature & output dimensions overlap). Those are both stateful to monitor changes to the loss over time and we can do epoch or step based versions (not sure if those should be different classes or not yet, that's something to think about). The complexity for validation is figuring out how to make it cleanly interact with the config system, which doesn't currently know about Dataset, only DataSource. That might not be a problem, a DataSource might be sufficient, but it needs thinking through (particularly around the area of transformations, as we're also considering adding things like PCA, and DataSource isn't transformed).

Max iteration caps & epoch caps are straightforward and stateless, so those are fine, the main difficulty there is making sure that the provenance/configuration knows how to evolve v4 style information into v5 style as I don't want to break users who want to migrate upwards.

Losses come from SGDObjectives at the moment, but in the case of validation we don't want to compute the gradient, so that needs splitting apart by adding an extra method. I need to have a think about if we want the validation loss function to be different to the training loss function or if the configure method should accept the training loss function. Though if the stopping criterion owns the computation of the training loss as well then it can compute whatever function of the gradient norm it wants as it has the parameters, gradient and loss at that point.

Anyway, I think we can do this, but there's a bunch of design space to think about.

@greaa-aws
Copy link
Author

It seems like early stopping (validation loss monitoring) requires more extensive design consideration, and understandably so given the overall data flow. The validation data just has to have the same features as the training data; it does seem like this interacts somewhat with planned work on transformations, especially if DataSource isn’t aware of transformations made prior to training. Perhaps early stopping could be split out from the more straightforward gradient- or training loss-based stopping criteria?

In terms of epoch vs. step-based, in the implementations I’ve seen the convergence criterion is usually only checked once per epoch, if not even less frequently (i.e. once every check_cvg epochs, where a default value for check_cvg might be 5 or 10). Multiple checks per epoch are usually considered to add too much computational overhead, especially as they require a computation over the entire dataset and not just the current datapoint or minibatch.

Regarding the loss function for validation vs training, I would think it’s fine to assume these are the same. The idea of the validation loss is to keep an unbiased estimate of the population risk, but this only makes sense if the loss is the same as what is being minimized on the training data.

Thanks again for the follow-up here and excited to hear that you think this sounds doable.

@handshape
Copy link

Adding my voice to the desire for better observability, but for a decidedly more mundane reason: I need to be able to graph convergence for my human users that want to eyeball over/underfit.

@Craigacp
Copy link
Member

Adding my voice to the desire for better observability, but for a decidedly more mundane reason: I need to be able to graph convergence for my human users that want to eyeball over/underfit.

For that use case do you want to save out intermediate checkpoints?

@handshape
Copy link

@Craigacp - Intermediate checkpoints would be complementary (but not core to my specific use case), and having control over which conditions would trigger such a save would be a big boon. The intermediate model with the best performance on the evaluation set is important to preserve - especially in overfit cases.

@Craigacp
Copy link
Member

@Craigacp - Intermediate checkpoints would be complementary (but not core to my specific use case), and having control over which conditions would trigger such a save would be a big boon. The intermediate model with the best performance on the evaluation set is important to preserve - especially in overfit cases.

Ok. Intermediate checkpointing is something that we could add as part of adding online learning to the gradient based models, because the checkpoint is equivalent to a model which can continue to be trained. That's also planned for v5, and in the checkpointing case is easier to implement than the full online learning problem we're considering.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Development

No branches or pull requests

3 participants