-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPSA optimizer implementation #69
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d01e8b8
Ported SPSA from older MLPack to Ensmallen
rajiv26051997 a2b7946
Merge branch 'master' of https://github.com/mlpack/ensmallen
rajiv26051997 bded377
Updated spsa algorithm files
rajiv26051997 6c1b7e9
Updated spsa test files
rajiv26051997 69f921d
Made suggested changes in variable names and code style
rajiv26051997 d0e43b1
Added SPSA documentation
rajiv26051997 250fffa
Updated SPSA documentation
rajiv26051997 8e324af
Added newlines to the spsa files and added paper link to the SPSA doc…
rajiv26051997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,141 @@ | ||||||||
/** | ||||||||
* @file spsa.hpp | ||||||||
* @author N Rajiv Vaidyanathan | ||||||||
* @author Marcus Edel | ||||||||
* | ||||||||
* SPSA (Simultaneous perturbation stochastic approximation) method for | ||||||||
* faster convergence. | ||||||||
* | ||||||||
* ensmallen is free software; you may redistribute it and/or modify it under | ||||||||
* the terms of the 3-clause BSD license. You should have received a copy of | ||||||||
* the 3-clause BSD license along with ensmallen. If not, see | ||||||||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||||||||
*/ | ||||||||
#ifndef ENSMALLEN_SPSA_SPSA_HPP | ||||||||
#define ENSMALLEN_SPSA_SPSA_HPP | ||||||||
|
||||||||
namespace ens { | ||||||||
|
||||||||
/** | ||||||||
* Implementation of the SPSA method. The SPSA algorithm approximates the | ||||||||
* gradient of the function by finite differences along stochastic directions. | ||||||||
* | ||||||||
* For more information, see the following. | ||||||||
* | ||||||||
* @code | ||||||||
* @article{Spall1998, | ||||||||
* author = {Spall, J. C.}, | ||||||||
* title = {An Overview of the Simultaneous Perturbation Method for | ||||||||
* Efficient Optimization}, | ||||||||
* journal = {Johns Hopkins APL Technical Digest}, | ||||||||
* volume = {19}, | ||||||||
* number = {4}, | ||||||||
* pages = {482--492}, | ||||||||
* year = {1998} | ||||||||
* } | ||||||||
* @endcode | ||||||||
* | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the extra line here and add something like: ensmallen/include/ensmallen_bits/sgd/sgd.hpp Lines 54 to 56 in db6917e
|
||||||||
* SPSA can optimize differentiable separable functions. For more details, | ||||||||
* see the documentation on function types included with this distribution or on | ||||||||
* the ensmallen website. | ||||||||
*/ | ||||||||
class SPSA | ||||||||
{ | ||||||||
public: | ||||||||
/** | ||||||||
* Construct the SPSA optimizer with the given function and parameters. The | ||||||||
* defaults here are not necessarily good for the given problem, so it is | ||||||||
* suggested that the values used be tailored to the task at hand. The | ||||||||
* maximum number of iterations refers to the maximum number of points that | ||||||||
* are processed (i.e., one iteration equals one point; one iteration does not | ||||||||
* equal one pass over the dataset). | ||||||||
* | ||||||||
* @param alpha Scaling exponent for the step size. | ||||||||
* @param batchSize Batch size to use for each step. | ||||||||
* @param gamma Scaling exponent for evaluation step size. | ||||||||
* @param stepSize Scaling parameter for step size (named as 'a' in the paper). | ||||||||
* @param evaluationStepSize Scaling parameter for evaluation step size (named as 'c' in the paper). | ||||||||
* @param maxIterations Maximum number of iterations allowed (0 means no | ||||||||
* limit). | ||||||||
* @param tolerance Maximum absolute tolerance to terminate algorithm. | ||||||||
* @param shuffle If true, the function order is shuffled; otherwise, each | ||||||||
* function is visited in linear order. | ||||||||
*/ | ||||||||
SPSA(const double alpha = 0.602, | ||||||||
const size_t batchSize = 32, | ||||||||
const double gamma = 0.101, | ||||||||
const double stepSize = 0.16, | ||||||||
const double evaluationStepSize = 0.3, | ||||||||
const size_t maxIterations = 100000, | ||||||||
const double tolerance = 1e-5, | ||||||||
const bool shuffle = true); | ||||||||
|
||||||||
template<typename DecomposableFunctionType> | ||||||||
double Optimize(DecomposableFunctionType& function, arma::mat& iterate); | ||||||||
|
||||||||
//! Get the scaling exponent for the step size. | ||||||||
double Alpha() const { return alpha; } | ||||||||
//! Modify the scaling exponent for the step size. | ||||||||
double& Alpha() { return alpha; } | ||||||||
|
||||||||
//! Get the batch size. | ||||||||
size_t BatchSize() const { return batchSize; } | ||||||||
//! Modify the batch size. | ||||||||
size_t& BatchSize() { return batchSize; } | ||||||||
|
||||||||
//! Get the scaling exponent for evaluation step size. | ||||||||
double Gamma() const { return gamma; } | ||||||||
//! Modify the scaling exponent for evaluation step size. | ||||||||
double& Gamma() { return gamma; } | ||||||||
|
||||||||
//! Get the scaling parameter for step size. | ||||||||
double StepSize() const { return stepSize; } | ||||||||
//! Modify the scaling parameter for step size. | ||||||||
double& StepSize() { return stepSize; } | ||||||||
|
||||||||
//! Get the scaling parameter for step size. | ||||||||
double EvaluationStepSize() const { return evaluationStepSize; } | ||||||||
//! Modify the scaling parameter for step size. | ||||||||
double& EvaluationStepSize() { return evaluationStepSize; } | ||||||||
|
||||||||
//! Get the maximum number of iterations (0 indicates no limit). | ||||||||
size_t MaxIterations() const { return maxIterations; } | ||||||||
//! Modify the maximum number of iterations (0 indicates no limit). | ||||||||
size_t& MaxIterations() { return maxIterations; } | ||||||||
|
||||||||
private: | ||||||||
//! Scaling exponent for the step size. | ||||||||
double alpha; | ||||||||
|
||||||||
//! The batch size for processing. | ||||||||
size_t batchSize; | ||||||||
|
||||||||
//! Scaling exponent for evaluation step size. | ||||||||
double gamma; | ||||||||
|
||||||||
//! Scaling parameter for step size. | ||||||||
double stepSize; | ||||||||
|
||||||||
//! Scaling parameter for step size. | ||||||||
double evaluationStepSize; | ||||||||
|
||||||||
//! The maximum number of allowed iterations. | ||||||||
size_t maxIterations; | ||||||||
|
||||||||
//! The tolerance for termination. | ||||||||
double tolerance; | ||||||||
|
||||||||
//! Controls whether or not the individual functions are shuffled when | ||||||||
//! iterating. | ||||||||
bool shuffle; | ||||||||
|
||||||||
//! Control the amount of gradient update. | ||||||||
double Ak; | ||||||||
}; | ||||||||
|
||||||||
} // namespace ens | ||||||||
|
||||||||
// Include implementation. | ||||||||
#include "spsa_impl.hpp" | ||||||||
|
||||||||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/** | ||
* @file spsa.hpp | ||
* @author N Rajiv Vaidyanathan | ||
* @author Marcus Edel | ||
* | ||
* SPSA (Simultaneous perturbation stochastic approximation) | ||
* update for faster convergence. | ||
* | ||
* ensmallen is free software; you may redistribute it and/or modify it under | ||
* the terms of the 3-clause BSD license. You should have received a copy of | ||
* the 3-clause BSD license along with ensmallen. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef ENSMALLEN_SPSA_SPSA_IMPL_HPP | ||
#define ENSMALLEN_SPSA_SPSA_IMPL_HPP | ||
|
||
// In case it hasn't been included yet. | ||
#include "spsa.hpp" | ||
|
||
#include <ensmallen_bits/function.hpp> | ||
|
||
namespace ens { | ||
|
||
inline SPSA::SPSA(const double alpha, | ||
const size_t batchSize, | ||
const double gamma, | ||
const double stepSize, | ||
const double evaluationStepSize, | ||
const size_t maxIterations, | ||
const double tolerance, | ||
const bool shuffle) : | ||
alpha(alpha), | ||
batchSize(batchSize), | ||
gamma(gamma), | ||
stepSize(stepSize), | ||
evaluationStepSize(evaluationStepSize), | ||
Ak(0.001 * maxIterations), | ||
maxIterations(maxIterations), | ||
tolerance(tolerance), | ||
shuffle(shuffle) | ||
{ /* Nothing to do. */ } | ||
|
||
template<typename DecomposableFunctionType> | ||
inline double SPSA::Optimize( | ||
DecomposableFunctionType& function, arma::mat& iterate) | ||
{ | ||
// Make sure that we have the methods that we need. | ||
traits::CheckNonDifferentiableDecomposableFunctionTypeAPI< | ||
DecomposableFunctionType>(); | ||
|
||
arma::mat gradient(iterate.n_rows, iterate.n_cols); | ||
arma::mat spVector(iterate.n_rows, iterate.n_cols); | ||
|
||
// To keep track of where we are and how things are going. | ||
double overallObjective = 0; | ||
double lastObjective = DBL_MAX; | ||
|
||
const size_t actualMaxIterations = (maxIterations == 0) ? | ||
std::numeric_limits<size_t>::max() : maxIterations; | ||
for (size_t k = 0; k < actualMaxIterations; /* incrementing done manually */) | ||
{ | ||
// Is this iteration the start of a sequence? | ||
if (k > 0) | ||
{ | ||
// Output current objective function. | ||
Info << "SPSA: iteration " << k << ", objective " << overallObjective | ||
<< "." << std::endl; | ||
|
||
if (std::isnan(overallObjective) || std::isinf(overallObjective)) | ||
{ | ||
Warn << "SPSA: converged to " << overallObjective << "; terminating" | ||
<< " with failure. Try a smaller step size?" << std::endl; | ||
return overallObjective; | ||
} | ||
|
||
if (std::abs(lastObjective - overallObjective) < tolerance) | ||
{ | ||
Info << "SPSA: minimized within tolerance " << tolerance << "; " | ||
<< "terminating optimization." << std::endl; | ||
return overallObjective; | ||
} | ||
|
||
// Reset the counter variables. | ||
lastObjective = overallObjective; | ||
|
||
if (shuffle) // Determine order of visitation. | ||
function.Shuffle(); | ||
} | ||
|
||
// Gain sequences. | ||
const double ak = stepSize / std::pow(k + 1 + Ak, alpha); | ||
const double ck = evaluationStepSize / std::pow(k + 1, gamma); | ||
|
||
gradient.zeros(); | ||
for (size_t b = 0; b < batchSize; b++) | ||
{ | ||
// Stochastic directions. | ||
spVector = arma::conv_to<arma::mat>::from( | ||
arma::randi(iterate.n_rows, iterate.n_cols, | ||
arma::distr_param(0, 1))) * 2 - 1; | ||
|
||
iterate += ck * spVector; | ||
const double fPlus = function.Evaluate(iterate, 0, iterate.n_elem); | ||
|
||
iterate -= 2 * ck * spVector; | ||
const double fMinus = function.Evaluate(iterate, 0, iterate.n_elem); | ||
iterate += ck * spVector; | ||
|
||
gradient += (fPlus - fMinus) * (1 / (2 * ck * spVector)); | ||
} | ||
|
||
gradient /= (double) batchSize; | ||
iterate -= ak * gradient; | ||
|
||
overallObjective = function.Evaluate(iterate, 0, iterate.n_elem); | ||
k += batchSize; | ||
} | ||
|
||
// Calculate final objective. | ||
return function.Evaluate(iterate, 0, iterate.n_elem); | ||
} | ||
|
||
} // namespace ens | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should link https://pdfs.semanticscholar.org/bf67/0fb6b1bd319938c6a879570fa744cf36b240.pdf as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it will be useful 👍