New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classic FrankWolfe and OMP #1041
Changes from 7 commits
91e0db2
6cb9f3a
b959694
20aee0c
b2ba0bc
3e1ebac
bc056f1
41fb641
4d43b4e
49b23c0
d395596
19f14d8
7a4b8c4
918cfb3
4baa16e
61a5864
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ set(DIRS | |
ada_grad | ||
adam | ||
aug_lagrangian | ||
fw | ||
gradient_descent | ||
lbfgs | ||
minibatch_sgd | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
set(SOURCES | ||
frank_wolfe.hpp | ||
frank_wolfe_impl.hpp | ||
constr_lpball.hpp | ||
update_classic.hpp | ||
update_span.hpp | ||
func_sq.hpp | ||
test_func_fw.hpp | ||
) | ||
|
||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/** | ||
* @file constr_lpball.hpp | ||
* @author Chenzhe Diao | ||
* | ||
* Lp ball constrained for FrankWolfe algorithm. Used as LinearConstrSolverType. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_FW_CONSTR_LPBALL_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_FW_CONSTR_LPBALL_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
/** | ||
* LinearConstrSolver for FrankWolfe algorithm. Constraint domain given in the | ||
* form of lp ball. That is, given \f$ v \f$, solve | ||
* \f[ | ||
* s:=arg\min_{s\in D} <s, v> | ||
* \f] | ||
* when \f$ D \f$ is an lp ball. | ||
* | ||
* For \f$ p=1 \f$: take (one) \f$ k = arg\max_j |v_j|\f$, then the solution is: | ||
* \f[ | ||
* s_k = -sign(v_k), \qquad s_j = 0, j\neq k. | ||
* \f] | ||
* | ||
* For \f$ 1<p<\infty \f$: the solution is | ||
* \f[ | ||
* s_j = -sign(v_j) |v_j|^{p-1} | ||
* \f] | ||
* | ||
* For \f$ p=\infty \f$: the solution is | ||
* \f[ | ||
* s_j = -sign(v_j) | ||
* \f] | ||
* | ||
* where \f$ \alpha \f$ is a parameter which specifies the step size. \f$ i \f$ | ||
* is chosen according to \f$ j \f$ (the iteration number). | ||
*/ | ||
class ConstrLpBallSolver | ||
{ | ||
public: | ||
/** | ||
* Construct the solver of constrained problem. The constrained domain should | ||
* be unit lp ball for this class. | ||
* | ||
* @param p The constraint is unit lp ball. | ||
*/ | ||
ConstrLpBallSolver(const double p) : p(p) | ||
{ /* Do nothing. */ } | ||
|
||
/** | ||
* Optimizer of Linear Constrained Problem for FrankWolfe. | ||
* | ||
* @param v Input local gradient. | ||
* @param s Output optimal solution in the constrained domain (lp ball). | ||
*/ | ||
void Optimize(const arma::mat& v, | ||
arma::mat& s) | ||
{ | ||
if (p == std::numeric_limits<double>::infinity()) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more straightforward to use |
||
// l-inf ball | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When wrapping a line, the next line should be tabbed twice from where the previous line began, for more information take a look at: https://github.com/mlpack/mlpack/wiki/DesignGuidelines#line-length-and-wrapping. |
||
s = -sign(v); | ||
return; | ||
} | ||
else if (p > 1.0) | ||
{ | ||
// lp ball with 1<p<inf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use complete sentences with proper grammar and punctuation: https://github.com/mlpack/mlpack/wiki/DesignGuidelines#comments in most cases you can just add |
||
s = -sign(v) % pow(abs(v), p-1); | ||
return; | ||
} | ||
else if (p == 1.0) | ||
{ | ||
// l1 ball, used in OMP | ||
arma::mat tmp = arma::abs(v); | ||
arma::uword k = tmp.index_max(); // linear index of matrix | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's time to bump up the required Armadillo version somewhat we can do that, although I'd like to make sure that we don't go past 6.500.5 since that's all that's available in Ubuntu 16.04. The RHEL6 and RHEL7 versions are really old, but... I think that I am the maintainer for those, so technically that is my responsibility that I am failing at... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it doesn't matter to me~ Or I could even write a function that go through all the elements in the matrix, and put it somewhere you like ^_^ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rcurtin I think it would be really helpful if we can bump up the required Armadillo version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ShangtongZhang: sure---can you propose a version we should jump to, and the functionality that it would get us? Like I said before, unfortunately I don't think we should go past 6.500.5. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, how about this---we can backport
My thought is, we can go with 6.500.5 in this case, if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, let's go with 6.500 then. I'll open an issue about it. |
||
s.zeros(v.n_rows, v.n_cols); | ||
s(k) = -sign_double(v(k)); | ||
return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you just do this:
I think this could save some computation. |
||
} | ||
else | ||
{ | ||
Log::Fatal << "Wrong norm p!" << std::endl; | ||
return; | ||
} | ||
} | ||
|
||
private: | ||
//! lp norm, take 1<p<inf, | ||
// use std::numeric_limits<double>::infinity() for inf norm. | ||
double p; | ||
|
||
//! Signum function for double. | ||
double sign_double(const double x) const {return (x > 0) - (x < 0);} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use camel casing for all names. Also, maybe it makes sense to put the function in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like std library doesn't have sign() function. I could put it inside Also, thanks for pointing out the coding style problems. I will read the DesignGuideline page gain. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Putting the |
||
}; | ||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
/** | ||
* @file frank_wolfe.hpp | ||
* @author Chenzhe Diao | ||
* | ||
* Frank-Wolfe Algorithm. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_FW_FRANK_WOLFE_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_FW_FRANK_WOLFE_HPP | ||
|
||
|
||
#include <mlpack/prereqs.hpp> | ||
#include "update_classic.hpp" | ||
#include "update_span.hpp" | ||
#include "constr_lpball.hpp" | ||
#include "func_sq.hpp" | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
/** | ||
* Frank-Wolfe is a technique to minimize a continuously differentiable convex | ||
* function \f$ f \f$ over a compact convex subset \f$ D \f$ of a vector space. | ||
* It is also known as conditional gradient method. | ||
* | ||
* To find minimum of a function using Frank-Wolfe in each iteration \f$ k \f$: | ||
* 1. One optimize the linearized constrained problem, using LinearConstrSolver: | ||
* \f[ | ||
* s_k:= arg\min_{s\in D} <s_k, \nabla f(x_k)> | ||
* \f] | ||
* | ||
* 2. Update \f$ x \f$ using UpdateRule: | ||
* \f[ | ||
* x_{k+1} := (1-\gamma) x_k + \gamma s_k | ||
* \f] | ||
* for some \f$ \gamma \in (0, 1) \f$, or use Fully-Corrective Variant: | ||
* \f[ | ||
* x_{k+1}:= arg\min_{x\in conv(s_0, \cdots, s_k)} f(x) | ||
* \f] | ||
* | ||
* | ||
* The algorithm continues until \f$ k \f$ reaches the maximum number of iterations, | ||
* or when the duality gap is bounded by a certain tolerance \f$ \epsilon \f$. | ||
* That is, | ||
* | ||
* \f[ | ||
* g(x):= \max_{s\in D} <x-s, \nabla f(x)> \quad \leq \epsilon, | ||
* \f] | ||
* | ||
* we also know that \f$ g(x) \geq f(x) - f(x^*) \f$, where \f$ x^* \f$ is the optimal | ||
* solution. | ||
* | ||
* The parameter \f$ \epsilon \f$ is specified by the tolerance parameter to the | ||
* constructor. | ||
* | ||
* For FrankWolfe to work, FunctionType, LinearConstrSolverType and UpdateRuleType | ||
* template parameters are required. | ||
* These classes must implement the following functions: | ||
* | ||
* FunctionType: | ||
* | ||
* double Evaluate(const arma::mat& coordinates); | ||
* void Gradient(const arma::mat& coordinates, | ||
* arma::mat& gradient); | ||
* | ||
* LinearConstrSolverType: | ||
* | ||
* void Optimize(const arma::mat& gradient, | ||
* arma::mat& s); | ||
* | ||
* UpdateRuleType: | ||
* | ||
* void Update(const arma::mat& old_coords, | ||
* const arma::mat& s, | ||
* arma::mat& new_coords, | ||
* const size_t num_iter); | ||
* | ||
* @tparam FunctionType Objective function type to be | ||
* minimized. | ||
* @tparam LinearConstrSolverType Solver for the linear constrained problem. | ||
* @tparam UpdateRuleType Rule to update the solution in each iteration. | ||
* | ||
*/ | ||
template< | ||
typename LinearConstrSolverType, | ||
typename UpdateRuleType | ||
> | ||
class FrankWolfe | ||
{ | ||
public: | ||
/** | ||
* Construct the Frank-Wolfe optimizer with the given function and | ||
* parameters. Notice that the constraint domain \f$ D \f$ is input | ||
* at the initialization of linear_constr_solver, the function to be | ||
* optimized is stored in update_rule. | ||
* | ||
* @param linear_constr_solver Solver for linear constrained problem. | ||
* @param update_rule Rule for updating solution in each iteration. | ||
* @param maxIterations Maximum number of iterations allowed (0 means no | ||
* limit). | ||
* @param tolerance Maximum absolute tolerance to terminate algorithm. | ||
*/ | ||
FrankWolfe(const LinearConstrSolverType linear_constr_solver, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use camel casing for all names. |
||
const UpdateRuleType update_rule, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you align the upcoming parameter with the first parameter? |
||
const size_t maxIterations = 100000, | ||
const double tolerance = 1e-10); | ||
|
||
/** | ||
* Optimize the given function using FrankWolfe. The given starting | ||
* point will be modified to store the finishing point of the algorithm, and | ||
* the final objective value is returned. | ||
* | ||
* @tparam function Function to be optimized. | ||
* @param iterate Starting point (will be modified). | ||
* @return Objective value of the final point. | ||
*/ | ||
template<typename FunctionType> | ||
double Optimize(FunctionType& function, arma::mat& iterate); | ||
|
||
//! Get the linear constrained solver. | ||
LinearConstrSolverType LinearConstrSolver() | ||
const { return linear_constr_solver; } | ||
//! Modify the linear constrained solver. | ||
LinearConstrSolverType& LinearConstrSolver() { return linear_constr_solver; } | ||
|
||
//! Get the update rule. | ||
UpdateRuleType UpdateRule() const { return update_rule; } | ||
//! Modify the update rule. | ||
UpdateRuleType& UpdateRule() { return update_rule; } | ||
|
||
//! Get the maximum number of iterations (0 indicates no limit). | ||
size_t MaxIterations() const { return maxIterations; } | ||
//! Modify the maximum number of iterations (0 indicates no limit). | ||
size_t& MaxIterations() { return maxIterations; } | ||
|
||
//! Get the tolerance for termination. | ||
double Tolerance() const { return tolerance; } | ||
//! Modify the tolerance for termination. | ||
double& Tolerance() { return tolerance; } | ||
|
||
private: | ||
//! The solver for constrained linear problem in first step. | ||
LinearConstrSolverType linear_constr_solver; | ||
|
||
//! The rule to update, used in the second step. | ||
UpdateRuleType update_rule; | ||
|
||
//! The maximum number of allowed iterations. | ||
size_t maxIterations; | ||
|
||
//! The tolerance for termination. | ||
double tolerance; | ||
}; | ||
|
||
//! Orthogonal Matching Pursuit | ||
using OMP = FrankWolfe<ConstrLpBallSolver, UpdateSpan>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a little extra documentation here please on what OMP is, what it's for, etc.? Nothing too in-depth, just something simple that can point users in the right direction. Users may come across this in the Doxygen documentation, kind of like this: so if there's not much information, it will only say There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a brief document in the new commit. However, if users want to know how to use it, I guess they need to consult the test file. I can make it a method in |
||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "frank_wolfe_impl.hpp" | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/** | ||
* @file frank_wolfe_impl.hpp | ||
* @author Chenzhe Diao | ||
* | ||
* Frank-Wolfe Algorithm. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_FW_FRANK_WOLFE_IMPL_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_FW_FRANK_WOLFE_IMPL_HPP | ||
|
||
// In case it hasn't been included yet. | ||
#include "frank_wolfe.hpp" | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
template< | ||
typename LinearConstrSolverType, | ||
typename UpdateRuleType | ||
> | ||
FrankWolfe<LinearConstrSolverType, UpdateRuleType>::FrankWolfe( | ||
const LinearConstrSolverType linear_constr_solver, | ||
const UpdateRuleType update_rule, | ||
const size_t maxIterations, | ||
const double tolerance) : | ||
linear_constr_solver(linear_constr_solver), | ||
update_rule(update_rule), | ||
maxIterations(maxIterations), | ||
tolerance(tolerance) | ||
{ /* Nothing to do*/ } | ||
|
||
|
||
//! Optimize the function (minimize). | ||
template< | ||
typename LinearConstrSolverType, | ||
typename UpdateRuleType | ||
> | ||
template<typename FunctionType> | ||
double FrankWolfe<LinearConstrSolverType, UpdateRuleType> | ||
::Optimize(FunctionType& function, arma::mat& iterate) | ||
{ | ||
// To keep track of the function value | ||
double CurrentObjective = function.Evaluate(iterate); | ||
double PreviousObjective = DBL_MAX; | ||
|
||
arma::mat gradient(iterate.n_rows, iterate.n_cols); | ||
arma::mat s(iterate.n_rows, iterate.n_cols); | ||
arma::mat iterate_new(iterate.n_rows, iterate.n_cols); | ||
double gap = 0; | ||
|
||
for (size_t i=1; i != maxIterations; ++i) | ||
{ | ||
// Output current objective function | ||
Log::Info << "Iteration " << i << ", objective " | ||
<< CurrentObjective << "." << std::endl; | ||
|
||
// Reset counter variables. | ||
PreviousObjective = CurrentObjective; | ||
|
||
// Calculate the gradient | ||
function.Gradient(iterate, gradient); | ||
|
||
// Solve linear constrained problem, solution saved in s. | ||
linear_constr_solver.Optimize(gradient, s); | ||
|
||
// Check duality gap for return condition | ||
gap = std::fabs(dot(iterate-s, gradient)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another pedantic style issue, |
||
if (gap < tolerance) | ||
{ | ||
Log::Info << "FrankWolfe: minimized within tolerance " | ||
<< tolerance << "; " << "terminating optimization." << std::endl; | ||
return CurrentObjective; | ||
} | ||
|
||
|
||
// Update solution, save in iterate_new | ||
update_rule.Update(function, iterate, s, iterate_new, i); | ||
|
||
iterate = std::move(iterate_new); | ||
CurrentObjective = function.Evaluate(iterate); | ||
} | ||
Log::Info << "Frank Wolfe: maximum iterations (" << maxIterations | ||
<< ") reached; " << "terminating optimization." << std::endl; | ||
return CurrentObjective; | ||
} | ||
|
||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tabs should be two spaces wide. for more information take a look at: https://github.com/mlpack/mlpack/wiki/DesignGuidelines#tabbing.