Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atom domain FrankWolfe Algorithm for vector problems #1087

Merged
merged 27 commits into from Sep 5, 2017
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d6f685
Added line search update method for classical FW algorithm, and
czdiao Jul 26, 2017
f3c7f5f
Fixed some bugs in update_linesearch.hpp
czdiao Jul 27, 2017
de0ec2a
Tested regularization.
czdiao Aug 8, 2017
0f29403
Changed the new atom interface. All the tests still pass.
czdiao Aug 8, 2017
f74d22d
Support Prune for UpdateSpan tested.
czdiao Aug 8, 2017
3e716ff
Added atom norm constraint in update steps. Using Projected Gradient
czdiao Aug 9, 2017
39a418a
Added documentations.
czdiao Aug 9, 2017
665e222
Added structure group constraint.
czdiao Aug 9, 2017
0f90149
Changed the code for structured group constraint solver.
czdiao Aug 9, 2017
da900ed
Fixed regularized OMP test.
czdiao Aug 9, 2017
b6425f7
Added documentation.
czdiao Aug 9, 2017
d1f793a
Merge branch 'master' into atomVector
czdiao Aug 9, 2017
e186c50
Fixed merge problems.
czdiao Aug 9, 2017
c281c17
Fixed style problems.
czdiao Aug 10, 2017
1402de5
Fixed variants of OMP tests.
czdiao Aug 10, 2017
674ff50
Fixed armadillo mat & vec initialization problems.
czdiao Aug 11, 2017
1b9bfb5
Added the line search optimization solver.
czdiao Aug 20, 2017
a8690ec
In FrankWolfe type optimizer, changed the line search update rule to use
czdiao Aug 20, 2017
fd24ddc
Fixed style issues. Fixed the atoms to be arma::vec.
czdiao Aug 20, 2017
a967720
Added Proximal class, and unit test for it.
czdiao Aug 27, 2017
d8d97d7
Changed the projection to l1 ball function to use the new class.
czdiao Aug 28, 2017
f59a008
Fixed style issues.
czdiao Aug 28, 2017
dc8f3b7
Fixed regspace and static proximal methods.
czdiao Aug 28, 2017
6af5eca
Fixed static cpp file.
czdiao Aug 28, 2017
0847231
Fixed the unit test.
czdiao Aug 28, 2017
bb242c4
More efficient implementation of atoms class.
czdiao Aug 31, 2017
0b3307e
Fixed add atom bug.
czdiao Sep 3, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mlpack/core/optimizers/CMakeLists.txt
Expand Up @@ -6,7 +6,9 @@ set(DIRS
fw
gradient_descent
lbfgs
line_search
minibatch_sgd
proximal
rmsprop
sa
sdp
Expand Down
4 changes: 4 additions & 0 deletions src/mlpack/core/optimizers/fw/CMakeLists.txt
@@ -1,9 +1,13 @@
set(SOURCES
atoms.hpp
frank_wolfe.hpp
frank_wolfe_impl.hpp
constr_lpball.hpp
constr_structure_group.hpp
update_classic.hpp
update_span.hpp
update_linesearch.hpp
update_full_correction.hpp
func_sq.hpp
test_func_fw.hpp
)
Expand Down
203 changes: 203 additions & 0 deletions src/mlpack/core/optimizers/fw/atoms.hpp
@@ -0,0 +1,203 @@
/**
* @file atoms.hpp
* @author Chenzhe Diao
*
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_CORE_OPTIMIZERS_FW_ATOMS_HPP
#define MLPACK_CORE_OPTIMIZERS_FW_ATOMS_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/core/optimizers/proximal/proximal.hpp>
#include "func_sq.hpp"

namespace mlpack {
namespace optimization {

/**
* Class to hold the information and operations of current atoms in the
* soluton space.
*/
class Atoms
{
public:
Atoms(){ /* Nothing to do. */ }

/**
* Add atom into the solution space.
*
* @param v new atom to be added.
* @param c coefficient of the new atom.
*/
void AddAtom(const arma::vec& v, const double c = 0)
{
if (currentAtoms.is_empty())
{
CurrentAtoms() = v;
CurrentCoeffs().set_size(1);
CurrentCoeffs().fill(c);
}
else
{
currentAtoms.insert_cols(0, v);
arma::vec cVec(1);
cVec(0) = c;
currentCoeffs.insert_rows(0, cVec);
}
}


//! Recover the solution coordinate from the coefficients of current atoms.
void RecoverVector(arma::mat& x)
{
x = currentAtoms * currentCoeffs;
}

/**
* Prune the support, delete previous atoms if they don't contribute much.
* See Algorithm 2 of paper:
* @code
* @article{RaoShaWri:2015Forward--backward,
* Author = {Rao, Nikhil and Shah, Parikshit and Wright, Stephen},
* Journal = {IEEE Transactions on Signal Processing},
* Number = {21},
* Pages = {5798--5811},
* Publisher = {IEEE},
* Title = {Forward--backward greedy algorithms for atomic norm regularization},
* Volume = {63},
* Year = {2015}
* }
* @endcode
*
* @param F thresholding number.
* @param function function to be optimized.
*/
void PruneSupport(const double F, FuncSq& function)
{
arma::mat atomSqTerm = function.MatrixA() * currentAtoms;
atomSqTerm = sum(square(atomSqTerm), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be helpful and slightly more efficient to inline this into one line:

arma::mat atomSqTerm = arma::sum(arma::square(function.MatrixA() * currentAtoms), 0);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This atomSqTerm could be calculated when each atom is added, so we don't need to recalculate all of them each time we prune the support. I just modified the implementation. Thanks for pointing out the problem.

atomSqTerm = 0.5 * atomSqTerm.t() % square(currentCoeffs);

while (true)
{
// Solve for current gradient.
arma::mat x;
RecoverVector(x);
arma::mat gradient(size(x));
function.Gradient(x, gradient);

// Find possible atom to be deleted.
arma::vec gap = atomSqTerm -
currentCoeffs % trans(gradient.t() * currentAtoms);
arma::uword ind;
gap.min(ind);

// Try deleting the atom.
arma::mat newAtoms = currentAtoms;
newAtoms.shed_col(ind);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, so these two lines will first copy the full currentAtoms matrix, then another copy will be incurred by the shed_col() function (where everything except the column ind will be copied). I think you can avoid the extra copy with code like this:

arma::mat newAtoms(currentAtoms.n_rows, currentAtoms.n_cols - 1);
newAtoms.cols(0, ind - 1) = currentAtoms.cols(0, ind - 1);
newAtoms.cols(ind, newAtoms.n_cols - 1) = currentAtoms.cols(ind + 1, currentAtoms.n_cols - 1);

I think that is correct, you should probably double-check my logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed it to be this way. Only need to add an 'if' to check the bound of ind here.

// Recalculate the coefficients.
arma::vec newCoeffs =
solve(function.MatrixA() * newAtoms, function.Vectorb());
// Evaluate the function again.
double Fnew = function.Evaluate(newAtoms * newCoeffs);

if (Fnew > F)
// Should not delete the atom.
break;
else
{
// Delete the atom from current atoms.
currentAtoms = newAtoms;
currentCoeffs = newCoeffs;
atomSqTerm.shed_row(ind);
} // else
} // while
}


/**
* Enhance the solution in the convex hull of current atoms with atom norm
* constraint tau. Used in UpdateFullCorrection class for update step.
*
* Minimize the function in the atom domain defined by current atoms,
* where the solution still need to have atom norm (defined by current atoms)
* less than or equal to tau. We use projected gradient method to solve it,
* see the "Enhancement step" of the following paper:
* @code
* @article{RaoShaWri:2015Forward--backward,
* Author = {Rao, Nikhil and Shah, Parikshit and Wright, Stephen},
* Journal = {IEEE Transactions on Signal Processing},
* Number = {21},
* Pages = {5798--5811},
* Publisher = {IEEE},
* Title = {Forward--backward greedy algorithms for atomic norm regularization},
* Volume = {63},
* Year = {2015}
* }
* @endcode
*
* @param function function to be minimized.
* @param tau atom norm constraint.
* @param stepSize step size for projected gradient method.
* @param maxIteration maximum iteration number.
* @param tolerance tolerance for projected gradient method.
*/
template<typename FunctionType>
void ProjectedGradientEnhancement(FunctionType& function,
double tau,
double stepSize,
size_t maxIteration = 100,
double tolerance = 1e-3)
{
arma::mat x;
RecoverVector(x);
double value = function.Evaluate(x);

Proximal proximal(tau);
for (size_t iter = 1; iter<maxIteration; iter++)
{
// Update currentCoeffs with gradient descent method.
arma::mat g;
function.Gradient(x, g);
g = currentAtoms.t() * g;
currentCoeffs = currentCoeffs - stepSize * g;

// Projection of currentCoeffs to satisfy the atom norm constraint.
proximal.ProjectToL1Ball(currentCoeffs);

RecoverVector(x);
double valueNew = function.Evaluate(x);

if ((value - valueNew) < tolerance)
break;

value = valueNew;
}
}


//! Get the current atom coefficients.
const arma::vec& CurrentCoeffs() const { return currentCoeffs; }
//! Modify the current atom coefficients.
arma::vec& CurrentCoeffs() { return currentCoeffs; }

//! Get the current atoms.
const arma::mat& CurrentAtoms() const { return currentAtoms; }
//! Modify the current atoms.
arma::mat& CurrentAtoms() { return currentAtoms; }

private:
//! Coefficients of current atoms.
arma::vec currentCoeffs;

//! Current atoms in the solution space.
arma::mat currentAtoms;
}; // class Atoms
} // namespace optimization
} // namespace mlpack

#endif
85 changes: 73 additions & 12 deletions src/mlpack/core/optimizers/fw/constr_lpball.hpp
Expand Up @@ -20,28 +20,37 @@ namespace optimization {
/**
* LinearConstrSolver for FrankWolfe algorithm. Constraint domain given in the
* form of lp ball. That is, given \f$ v \f$, solve
* \f[
* \f$
* s:=arg\min_{s\in D} <s, v>
* \f$
* when \f$ D \f$ is a regularized lp ball. That is,
* \f[
* D = \{ x: (\sum_j|\lambda_j x_j|^p)^{1/p}\leq 1 \}.
* \f]
* when \f$ D \f$ is an lp ball.
* If \f$ \lambda \f$ is not given in the constructor, default is using all
* \f$ \lambda_j = 1 \f$ for all \f$ j \f$.
*
* In applications such as Orthogonal Matching Pursuit (OMP), \f$ \lambda \f$
* could be ideally set to the norm of the elements in the dictionary.
*
* For \f$ p=1 \f$: take (one) \f$ k = arg\max_j |v_j|\f$, then the solution is:
* For \f$ p=1 \f$: take (one) \f$ k = arg\max_j |v_j/\lambda_j|\f$, then the
* solution is:
* \f[
* s_k = -sign(v_k), \qquad s_j = 0, j\neq k.
* s_k = -sign(v_k)/\lambda_k, \qquad s_j = 0, \quad j\neq k.
* \f]
*
* For \f$ 1<p<\infty \f$: the solution is
* \f[
* s_j = -sign(v_j) |v_j|^{p-1}
* t_j = -sign(v_j) |v_j/\lambda_j|^{q-1}, \qquad
* s_j = \frac{t_j}{||t||_p\cdot\lambda_j}, \quad
* 1/p + 1/q = 1.
* \f]
*
* For \f$ p=\infty \f$: the solution is
* \f[
* s_j = -sign(v_j)
* s_j = -sign(v_j)/\lambda_j
* \f]
*
* where \f$ \alpha \f$ is a parameter which specifies the step size. \f$ i \f$
* is chosen according to \f$ j \f$ (the iteration number).
*/
class ConstrLpBallSolver
{
Expand All @@ -55,6 +64,18 @@ class ConstrLpBallSolver
ConstrLpBallSolver(const double p) : p(p)
{ /* Do nothing. */ }

/**
* Construct the solver of constrained problem, with regularization parameter
* lambda here.
*
* @param p The constraint is unit lp ball.
* @param lambda Regularization parameter.
*/
ConstrLpBallSolver(const double p, const arma::vec lambda) :
p(p), regFlag(true), lambda(lambda)
{ /* Do nothing. */ }


/**
* Optimizer of Linear Constrained Problem for FrankWolfe.
*
Expand All @@ -68,20 +89,39 @@ class ConstrLpBallSolver
{
// l-inf ball.
s = -sign(v);
if (regFlag)
s = s / lambda; // element-wise division.
}
else if (p > 1.0)
{
// lp ball with 1<p<inf.
s = -sign(v) % pow(abs(v), p-1);
if (regFlag)
s = v / lambda;
else
s = v;

double q = 1 / (1.0 - 1.0 / p);
s = - sign(v) % pow(abs(s), q - 1); // element-wise multiplication.
s = arma::normalise(s, p);

if (regFlag)
s = s / lambda;
}
else if (p == 1.0)
{
// l1 ball, also used in OMP.
arma::mat tmp = arma::abs(v);
if (regFlag)
s = arma::abs(v / lambda);
else
s = arma::abs(v);

arma::uword k;
tmp.max(k); // k is the linear index of the largest element.
s.zeros(v.n_rows, v.n_cols);
s.max(k); // k is the linear index of the largest element.
s.zeros();
s(k) = - mlpack::math::Sign(v(k));

if (regFlag)
s = s / lambda;
}
else
{
Expand All @@ -91,10 +131,31 @@ class ConstrLpBallSolver
return;
}

//! Get the p-norm.
double P() const { return p; }
//! Modify the p-norm.
double& P() { return p;}

//! Get regularization flag.
bool RegFlag() const {return regFlag;}
//! Modify regularization flag.
bool& RegFlag() {return regFlag;}

//! Get the regularization parameter.
arma::vec Lambda() const {return lambda;}
//! Modify the regularization parameter.
arma::vec& Lambda() {return lambda;}

private:
//! lp norm, 1<=p<=inf;
//! use std::numeric_limits<double>::infinity() for inf norm.
double p;

//! Regularization flag.
bool regFlag = false;

//! Regularization parameter.
arma::vec lambda;
};

} // namespace optimization
Expand Down