Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix callback return values and clarify docs #383

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions COPYRIGHT.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ Source:

Files: *
Copyright:
Copyright 2008-2018, Ryan Curtin <ryan@ratml.org>
Copyright 2008-2023, Ryan Curtin <ryan@ratml.org>
Copyright 2008-2012, Dongryeol Lee <dongryel@cc.gatech.edu>
Copyright 2010-2012, James Cline <james.cline@gatech.edu>
Copyright 2013-2018, Marcus Edel <marcus.edel@fu-berlin.de>
Copyright 2013-2023, Marcus Edel <marcus.edel@fu-berlin.de>
Copyright 2013-2018, Sumedh Ghaisas <sumedhghaisas@gmail.com>
Copyright 2013, Mudit Raj Gupta <mudit.raaj.gupta@gmail.com>
Copyright 2014, Ryan Birmingham <birm@gatech.edu>
Expand All @@ -33,7 +33,7 @@ Copyright:
Copyright 2018, B Kartheek Reddy <bkartheekreddy@gmail.com>
Copyright 2018, Moksh Jain <mokshjn00@gmail.com>
Copyright 2018, Shikhar Jaiswal <jaiswalshikhar87@gmail.com>
Copyright 2018, Conrad Sanderson
Copyright 2018-2023, Conrad Sanderson
Copyright 2018, Dan Timson
Copyright 2019, Rahul Ganesh Prabhu
Copyright 2019, Roberto Hueso <robertohueso96@gmail.com>
Expand Down
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
### ensmallen ?.??.?: "???"
###### ????-??-??
* Clarify return values for different callback types
([#383](https://github.com/mlpack/ensmallen/pull/383)).

* Fix return types of callbacks
([#382](https://github.com/mlpack/ensmallen/pull/382)).

Expand Down
51 changes: 34 additions & 17 deletions doc/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,18 +424,18 @@ std::cout << "The optimized model found by AdaDelta has the "
Callbacks are called at several states during the optimization process:

* At the beginning and end of the optimization process.
* After any call to `Evaluate()` and `EvaluateConstraint`.
* After any call to `Gradient()` and `GradientConstraint`.
* After any call to `Evaluate()` and `EvaluateConstraint()`.
* After any call to `Gradient()` and `GradientConstraint()`.
* At the start and end of an epoch.

Each callback provides optimization relevant information that can be accessed or
Each callback provides optimization-relevant information that can be accessed or
modified.

### BeginOptimization

Called at the beginning of the optimization process.

* `BeginOptimization(`_`optimizer, function, coordinates`_`)`
* `void BeginOptimization(`_`optimizer, function, coordinates`_`)`

#### Attributes

Expand All @@ -449,7 +449,7 @@ Called at the beginning of the optimization process.

Called at the end of the optimization process.

* `EndOptimization(`_`optimizer, function, coordinates`_`)`
* `void EndOptimization(`_`optimizer, function, coordinates`_`)`

#### Attributes

Expand All @@ -463,7 +463,9 @@ Called at the end of the optimization process.

Called after any call to `Evaluate()`.

* `Evaluate(`_`optimizer, function, coordinates, objective`_`)`
* `bool Evaluate(`_`optimizer, function, coordinates, objective`_`)`

If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -476,9 +478,11 @@ Called after any call to `Evaluate()`.

### EvaluateConstraint

Called after any call to `EvaluateConstraint()`.
Called after any call to `EvaluateConstraint()`.

* `bool EvaluateConstraint(`_`optimizer, function, coordinates, constraint, constraintValue`_`)`

* `EvaluateConstraint(`_`optimizer, function, coordinates, constraint, constraintValue`_`)`
If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -492,9 +496,11 @@ Called after any call to `Evaluate()`.

### Gradient

Called after any call to `Gradient()`.
Called after any call to `Gradient()`.

* `bool Gradient(`_`optimizer, function, coordinates, gradient`_`)`

* `Gradient(`_`optimizer, function, coordinates, gradient`_`)`
If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -507,9 +513,11 @@ Called after any call to `Evaluate()`.

### GradientConstraint

Called after any call to `GradientConstraint()`.
Called after any call to `GradientConstraint()`.

* `GradientConstraint(`_`optimizer, function, coordinates, constraint, gradient`_`)`
* `bool GradientConstraint(`_`optimizer, function, coordinates, constraint, gradient`_`)`

If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -526,7 +534,9 @@ Called after any call to `Evaluate()`.
Called at the beginning of a pass over the data. The objective may be exact or
an estimate depending on `exactObjective` value.

* `BeginEpoch(`_`optimizer, function, coordinates, epoch, objective`_`)`
* `bool BeginEpoch(`_`optimizer, function, coordinates, epoch, objective`_`)`

If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -543,7 +553,9 @@ an estimate depending on `exactObjective` value.
Called at the end of a pass over the data. The objective may be exact or
an estimate depending on `exactObjective` value.

* `EndEpoch(`_`optimizer, function, coordinates, epoch, objective`_`)`
* `bool EndEpoch(`_`optimizer, function, coordinates, epoch, objective`_`)`

If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand All @@ -560,7 +572,9 @@ an estimate depending on `exactObjective` value.
Called after the evolution of a single generation. Intended specifically for
MultiObjective Optimizers.

* `GenerationalStepTaken(`_`optimizer, function, coordinates, objectives, frontIndices`_`)`
* `bool GenerationalStepTaken(`_`optimizer, function, coordinates, objectives, frontIndices`_`)`

If the callback returns `true`, the optimization will be terminated.

#### Attributes

Expand Down Expand Up @@ -615,7 +629,7 @@ class ExponentialDecay
// Callback function called at the end of a pass over the data. We are only
// interested in the current epoch and the optimizer, we ignore the rest.
template<typename OptimizerType, typename FunctionType, typename MatType>
void EndEpoch(OptimizerType& optimizer,
bool EndEpoch(OptimizerType& optimizer,
FunctionType& /* function */,
const MatType& /* coordinates */,
const size_t epoch,
Expand All @@ -624,6 +638,9 @@ class ExponentialDecay
// Update the learning rate.
optimizer.StepSize() = learningRate * (1.0 - std::pow(decay,
(double) epoch));

// Do not terminate the optimization.
return false;
}

double learningRate;
Expand Down Expand Up @@ -695,7 +712,7 @@ class EarlyStop
// the current objective. We are only interested in the objective and ignore
// the rest.
template<typename OptimizerType, typename FunctionType, typename MatType>
void EndEpoch(OptimizerType& /* optimizer */,
bool EndEpoch(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
const size_t /* epoch */,
Expand Down
24 changes: 13 additions & 11 deletions include/ensmallen_bits/aug_lagrangian/aug_lagrangian_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ AugLagrangian::Optimize(
for (size_t i = 0; i < function.NumConstraints(); i++)
{
const ElemType p = std::pow(function.EvaluateConstraint(i, coordinates), 2);
Callback::EvaluateConstraint(*this, function, coordinates, i, p,
callbacks...);
terminate |= Callback::EvaluateConstraint(*this, function, coordinates, i,
p, callbacks...);

penalty += p;
}
Expand All @@ -129,8 +129,7 @@ AugLagrangian::Optimize(
// The odd comparison allows user to pass maxIterations = 0 (i.e. no limit on
// number of iterations).
size_t it;
terminate |= Callback::BeginOptimization(*this, function, coordinates,
callbacks...);
Callback::BeginOptimization(*this, function, coordinates, callbacks...);
for (it = 0; it != (maxIterations - 1) && !terminate; it++)
{
Info << "AugLagrangian on iteration " << it
Expand All @@ -143,7 +142,7 @@ AugLagrangian::Optimize(

const ElemType objective = function.Evaluate(coordinates);

Callback::Evaluate(*this, function, coordinates, objective,
terminate |= Callback::Evaluate(*this, function, coordinates, objective,
callbacks...);

// Check if we are done with the entire optimization (the threshold we are
Expand All @@ -170,14 +169,17 @@ AugLagrangian::Optimize(
{
const ElemType p = std::pow(function.EvaluateConstraint(i, coordinates),
2);
Callback::EvaluateConstraint(*this, function, coordinates, i, p,
callbacks...);
terminate |= Callback::EvaluateConstraint(*this, function, coordinates, i,
p, callbacks...);

penalty += p;
}

Info << "Penalty is " << penalty << " (threshold "
<< penaltyThreshold << ")." << std::endl;
Info << "Penalty is " << penalty << " (threshold " << penaltyThreshold
<< ")." << std::endl;

if (terminate)
break;

if (penalty < penaltyThreshold) // We update lambda.
{
Expand All @@ -186,8 +188,8 @@ AugLagrangian::Optimize(
for (size_t i = 0; i < function.NumConstraints(); i++)
{
const ElemType p = function.EvaluateConstraint(i, coordinates);
Callback::EvaluateConstraint(*this, function, coordinates, i, p,
callbacks...);
terminate |= Callback::EvaluateConstraint(*this, function, coordinates,
i, p, callbacks...);

augfunc.Lambda()[i] -= augfunc.Sigma() * p;
}
Expand Down
23 changes: 16 additions & 7 deletions include/ensmallen_bits/bigbatch_sgd/bigbatch_sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
BaseGradType functionGradient(iterate.n_rows, iterate.n_cols);
const size_t actualMaxIterations = (maxIterations == 0) ?
std::numeric_limits<size_t>::max() : maxIterations;
terminate |= Callback::BeginOptimization(*this, f, iterate, callbacks...);
Callback::BeginOptimization(*this, f, iterate, callbacks...);
for (size_t i = 0; i < actualMaxIterations && !terminate;
/* incrementing done manually */)
{
Expand Down Expand Up @@ -191,6 +191,9 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
}
}

if (terminate)
break;

instUpdatePolicy.As<InstUpdatePolicyType>().Update(f, stepSize, iterate,
gradient, gB, vB, currentFunction, batchSize, effectiveBatchSize,
reset);
Expand Down Expand Up @@ -229,9 +232,7 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
return overallObjective;
}

if (std::abs(lastObjective - overallObjective) < tolerance ||
Callback::BeginEpoch(*this, f, iterate, epoch, overallObjective,
callbacks...))
if (std::abs(lastObjective - overallObjective) < tolerance)
{
Info << "Big-batch SGD: minimized within tolerance " << tolerance
<< "; terminating optimization." << std::endl;
Expand All @@ -240,6 +241,9 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
return overallObjective;
}

terminate |= Callback::BeginEpoch(*this, f, iterate, epoch,
overallObjective, callbacks...);

// Reset the counter variables.
lastObjective = overallObjective;
overallObjective = 0;
Expand All @@ -250,8 +254,11 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
}
}

Info << "Big-batch SGD: maximum iterations (" << maxIterations << ") "
<< "reached; terminating optimization." << std::endl;
if (!terminate)
{
Info << "Big-batch SGD: maximum iterations (" << maxIterations << ") "
<< "reached; terminating optimization." << std::endl;
}

// Calculate final objective if exactObjective is set to true.
if (exactObjective)
Expand All @@ -263,7 +270,9 @@ BigBatchSGD<UpdatePolicyType>::Optimize(
const ElemType objective = f.Evaluate(iterate, i, effectiveBatchSize);
overallObjective += objective;

Callback::Evaluate(*this, f, iterate, objective, callbacks...);
// The optimization is finished, so we don't need to care what the
// callback returns.
(void) Callback::Evaluate(*this, f, iterate, objective, callbacks...);
}
}

Expand Down
Loading