From 5a4798e285b4201f31fe6e400fc4a857a82b0c1a Mon Sep 17 00:00:00 2001 From: vasanth kalingeri Date: Mon, 14 Mar 2016 11:32:30 +0530 Subject: [PATCH 1/3] Adam implementation --- src/mlpack/core/optimizers/adam/adam_impl.hpp | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 src/mlpack/core/optimizers/adam/adam_impl.hpp diff --git a/src/mlpack/core/optimizers/adam/adam_impl.hpp b/src/mlpack/core/optimizers/adam/adam_impl.hpp new file mode 100644 index 00000000000..408e116e75f --- /dev/null +++ b/src/mlpack/core/optimizers/adam/adam_impl.hpp @@ -0,0 +1,129 @@ +#ifndef __MLPACK_CORE_OPTIMIZERS_ADAM_ADAM_IMPL_HPP +#define __MLPACK_CORE_OPTIMIZERS_ADAM_ADAM_IMPL_HPP + +//Have to write the outer wrapper +#include "adam.hpp" + +namespace mlpack { +namespace optimization { + +template +Adam::Adam(DecomposableFunctionType& function, + const double stepSize, + const double beta1, + const double beta2, + const double eps, + const size_t maxIterations, + const double tolerance, + const bool shuffle) : + function(function), + stepSize(stepSize), + beta1(beta1), + beta2(beta2), + eps(eps), + maxIterations(maxIterations), + tolerance(tolerance), + shuffle(shuffle) +{ /* Nothing to do. */ } + +//! Optimize the function (minimize). +template +double Adam::Optimize(arma::mat& iterate) +{ + // Find the number of functions to use. + const size_t numFunctions = function.NumFunctions(); + + // This is used only if shuffle is true. + arma::Col visitationOrder; + if (shuffle) + visitationOrder = arma::shuffle(arma::linspace>(0, + (numFunctions - 1), numFunctions)); + + // To keep track of where we are and how things are going. + size_t currentFunction = 0; + double overallObjective = 0; + double lastObjective = DBL_MAX; + + // Calculate the first objective function. + for (size_t i = 0; i < numFunctions; ++i) + overallObjective += function.Evaluate(iterate, i); + + // Now iterate! + arma::mat gradient(iterate.n_rows, iterate.n_cols); + + //1st moment vector + arma::mat mean = arma::zeros(iterate.n_rows, + iterate.n_cols); + + //2nd moment vector + arma::mat variance = arma::zeros(iterate.n_rows, + iterate.n_cols); + + for (size_t i = 1; i != maxIterations; ++i, ++currentFunction) + { + // Is this iteration the start of a sequence? + if ((currentFunction % numFunctions) == 0) + { + // Output current objective function. + Log::Info << "Adam: iteration " << i << ", objective " + << overallObjective << "." << std::endl; + + if (std::isnan(overallObjective) || std::isinf(overallObjective)) + { + Log::Warn << "Adam: converged to " << overallObjective + << "; terminating with failure. Try a smaller step size?" + << std::endl; + return overallObjective; + } + + if (std::abs(lastObjective - overallObjective) < tolerance) + { + Log::Info << "Adam: minimized within tolerance " << tolerance << "; " + << "terminating optimization." << std::endl; + return overallObjective; + } + + // Reset the counter variables. + lastObjective = overallObjective; + overallObjective = 0; + currentFunction = 0; + + if (shuffle) // Determine order of visitation. + visitationOrder = arma::shuffle(visitationOrder); + } + + // Evaluate the gradient for this iteration. + if (shuffle) + function.Gradient(iterate, visitationOrder[currentFunction], gradient); + else + function.Gradient(iterate, currentFunction, gradient); + + // And update the iterate. + // Accumulate updates. + mean += (1 - beta1) * (gradient - mean); + variance += (1 - beta2) * (gradient % gradient - variance); + + // Apply update. + iterate -= stepSize * mean / (arma::sqrt(variance) + eps); + + // Now add that to the overall objective function. + if (shuffle) + overallObjective += function.Evaluate(iterate, + visitationOrder[currentFunction]); + else + overallObjective += function.Evaluate(iterate, currentFunction); + } + + Log::Info << "Adam: maximum iterations (" << maxIterations << ") reached; " + << "terminating optimization." << std::endl; + // Calculate final objective. + overallObjective = 0; + for (size_t i = 0; i < numFunctions; ++i) + overallObjective += function.Evaluate(iterate, i); + return overallObjective; +} + +} // namespace optimization +} // namespace mlpack + +#endif From 7f54cee830bb1b22f8919033c423d8df32cfe4af Mon Sep 17 00:00:00 2001 From: vasanth kalingeri Date: Mon, 14 Mar 2016 12:26:30 +0530 Subject: [PATCH 2/3] Reimplementing Adadelta with tests --- .../core/optimizers/adadelta/CMakeLists.txt | 11 ++ .../core/optimizers/adadelta/ada_delta.hpp | 144 ++++++++++++++ .../optimizers/adadelta/ada_delta_impl.hpp | 129 +++++++++++++ src/mlpack/tests/ada_delta_test.cpp | 177 +++++++++++------- 4 files changed, 396 insertions(+), 65 deletions(-) create mode 100644 src/mlpack/core/optimizers/adadelta/CMakeLists.txt create mode 100644 src/mlpack/core/optimizers/adadelta/ada_delta.hpp create mode 100644 src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp diff --git a/src/mlpack/core/optimizers/adadelta/CMakeLists.txt b/src/mlpack/core/optimizers/adadelta/CMakeLists.txt new file mode 100644 index 00000000000..3cd516b87a7 --- /dev/null +++ b/src/mlpack/core/optimizers/adadelta/CMakeLists.txt @@ -0,0 +1,11 @@ +set(SOURCES + ada_delta.hpp + ada_delta_impl.hpp +) + +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/src/mlpack/core/optimizers/adadelta/ada_delta.hpp b/src/mlpack/core/optimizers/adadelta/ada_delta.hpp new file mode 100644 index 00000000000..dbe58868dd8 --- /dev/null +++ b/src/mlpack/core/optimizers/adadelta/ada_delta.hpp @@ -0,0 +1,144 @@ +#ifndef __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_HPP +#define __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * Adadelta is an optimizer that uses two ideas to improve upon the two main + * drawbacks of the Adagrad method: + * + * - Accumulate Over Window + * - Correct Units with Hessian Approximation + * + * For more information, see the following. + * + * @code + * @article{Zeiler2012, + * author = {Matthew D. Zeiler}, + * title = {{ADADELTA:} An Adaptive Learning Rate Method}, + * journal = {CoRR}, + * year = {2012} + * } + * @endcode + * + + * For AdaDelta to work, a DecomposableFunctionType template parameter is + * required. This class must implement the following function: + * + * size_t NumFunctions(); + * double Evaluate(const arma::mat& coordinates, const size_t i); + * void Gradient(const arma::mat& coordinates, + * const size_t i, + * arma::mat& gradient); + * + * NumFunctions() should return the number of functions (\f$n\f$), and in the + * other two functions, the parameter i refers to which individual function (or + * gradient) is being evaluated. So, for the case of a data-dependent function, + * such as NCA (see mlpack::nca::NCA), NumFunctions() should return the number + * of points in the dataset, and Evaluate(coordinates, 0) will evaluate the + * objective function on the first point in the dataset (presumably, the dataset + * is held internally in the DecomposableFunctionType). + * + * @tparam DecomposableFunctionType Decomposable objective function type to be + * minimized. + */ +template +class AdaDelta +{ + public: + /** + * Construct the AdaDelta optimizer with the given function and parameters. The + * defaults here are not necessarily good for the given problem, so it is + * suggested that the values used be tailored to the task at hand. The + * maximum number of iterations refers to the maximum number of points that + * are processed (i.e., one iteration equals one point; one iteration does not + * equal one pass over the dataset). + * + * @param function Function to be optimized (minimized). + * @param rho Smoothing constant + * @param eps Value used to initialise the mean squared gradient parameter. + * @param maxIterations Maximum number of iterations allowed (0 means no + * limit). + * @param tolerance Maximum absolute tolerance to terminate algorithm. + * @param shuffle If true, the function order is shuffled; otherwise, each + * function is visited in linear order. + */ + AdaDelta(DecomposableFunctionType& function, + const double rho = 0.95, + const double eps = 1e-6, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const bool shuffle = true); + + /** + * Optimize the given function using AdaDelta. The given starting point will be + * modified to store the finishing point of the algorithm, and the final + * objective value is returned. + * + * @param iterate Starting point (will be modified). + * @return Objective value of the final point. + */ + double Optimize(arma::mat& iterate); + + //! Get the instantiated function to be optimized. + const DecomposableFunctionType& Function() const { return function; } + //! Modify the instantiated function. + DecomposableFunctionType& Function() { return function; } + + //! Get the smoothing parameter. + double Rho() const { return rho; } + //! Modify the smoothing parameter. + double& Rho() { return rho; } + + //! Get the value used to initialise the mean squared gradient parameter. + double Epsilon() const { return eps; } + //! Modify the value used to initialise the mean squared gradient parameter. + double& Epsilon() { return eps; } + + //! Get the maximum number of iterations (0 indicates no limit). + size_t MaxIterations() const { return maxIterations; } + //! Modify the maximum number of iterations (0 indicates no limit). + size_t& MaxIterations() { return maxIterations; } + + //! Get the tolerance for termination. + double Tolerance() const { return tolerance; } + //! Modify the tolerance for termination. + double& Tolerance() { return tolerance; } + + //! Get whether or not the individual functions are shuffled. + bool Shuffle() const { return shuffle; } + //! Modify whether or not the individual functions are shuffled. + bool& Shuffle() { return shuffle; } + + private: + //! The instantiated function. + DecomposableFunctionType& function; + + //! The smoothing parameter. + double rho; + + //! The value used to initialise the mean squared gradient parameter. + double eps; + + //! The maximum number of allowed iterations. + size_t maxIterations; + + //! The tolerance for termination. + double tolerance; + + //! Controls whether or not the individual functions are shuffled when + //! iterating. + bool shuffle; +}; + +} // namespace optimization +} // namespace mlpack + +// Include implementation. +#include "ada_delta_impl.hpp" + +#endif + diff --git a/src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp b/src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp new file mode 100644 index 00000000000..ac08a62047e --- /dev/null +++ b/src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp @@ -0,0 +1,129 @@ +#ifndef __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_IMPL_HPP +#define __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_IMPL_HPP + +#include "ada_delta.hpp" + +namespace mlpack { +namespace optimization { + +template +AdaDelta::AdaDelta(DecomposableFunctionType& function, + const double rho, + const double eps, + const size_t maxIterations, + const double tolerance, + const bool shuffle) : + function(function), + rho(rho), + eps(eps), + maxIterations(maxIterations), + tolerance(tolerance), + shuffle(shuffle) +{ /* Nothing to do. */ } + +//! Optimize the function (minimize). +template +double AdaDelta::Optimize(arma::mat& iterate) +{ + // Find the number of functions to use. + const size_t numFunctions = function.NumFunctions(); + + // This is used only if shuffle is true. + arma::Col visitationOrder; + if (shuffle) + visitationOrder = arma::shuffle(arma::linspace>(0, + (numFunctions - 1), numFunctions)); + + // To keep track of where we are and how things are going. + size_t currentFunction = 0; + double overallObjective = 0; + double lastObjective = DBL_MAX; + + // Calculate the first objective function. + for (size_t i = 0; i < numFunctions; ++i) + overallObjective += function.Evaluate(iterate, i); + + // Now iterate! + arma::mat gradient(iterate.n_rows, iterate.n_cols); + + // Leaky sum of squares of parameter gradient. + arma::mat meanSquaredGradient = arma::zeros(iterate.n_rows, + iterate.n_cols); + + // Leaky sum of squares of parameter gradient. + arma::mat meanSquaredGradientDx = arma::zeros(iterate.n_rows, + iterate.n_cols); + + for (size_t i = 1; i != maxIterations; ++i, ++currentFunction) + { + // Is this iteration the start of a sequence? + if ((currentFunction % numFunctions) == 0) + { + // Output current objective function. + Log::Info << "AdaDelta: iteration " << i << ", objective " + << overallObjective << "." << std::endl; + + if (std::isnan(overallObjective) || std::isinf(overallObjective)) + { + Log::Warn << "AdaDelta: converged to " << overallObjective + << "; terminating with failure. Try a smaller step size?" + << std::endl; + return overallObjective; + } + + if (std::abs(lastObjective - overallObjective) < tolerance) + { + Log::Info << "AdaDelta: minimized within tolerance " << tolerance << "; " + << "terminating optimization." << std::endl; + return overallObjective; + } + + // Reset the counter variables. + lastObjective = overallObjective; + overallObjective = 0; + currentFunction = 0; + + if (shuffle) // Determine order of visitation. + visitationOrder = arma::shuffle(visitationOrder); + } + + // Evaluate the gradient for this iteration. + if (shuffle) + function.Gradient(iterate, visitationOrder[currentFunction], gradient); + else + function.Gradient(iterate, currentFunction, gradient); + + // Accumulate gradient. + meanSquaredGradient *= rho; + meanSquaredGradient += (1 - rho) * (gradient % gradient); + arma::mat dx = arma::sqrt((meanSquaredGradientDx + eps) / + (meanSquaredGradient + eps)) % gradient; + + // Accumulate updates. + meanSquaredGradientDx *= rho; + meanSquaredGradientDx += (1 - rho) * (dx % dx); + + // Apply update. + iterate -= dx; + + // Now add that to the overall objective function. + if (shuffle) + overallObjective += function.Evaluate(iterate, + visitationOrder[currentFunction]); + else + overallObjective += function.Evaluate(iterate, currentFunction); + } + + Log::Info << "AdaDelta: maximum iterations (" << maxIterations << ") reached; " + << "terminating optimization." << std::endl; + // Calculate final objective. + overallObjective = 0; + for (size_t i = 0; i < numFunctions; ++i) + overallObjective += function.Evaluate(iterate, i); + return overallObjective; +} + +} // namespace optimization +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/ada_delta_test.cpp b/src/mlpack/tests/ada_delta_test.cpp index 961fede49ff..01fdd63b215 100644 --- a/src/mlpack/tests/ada_delta_test.cpp +++ b/src/mlpack/tests/ada_delta_test.cpp @@ -2,28 +2,34 @@ * @file ada_delta_test.cpp * @author Marcus Edel * - * Tests the AdaDelta optimizer on a couple test models. + * Tests the AdaDelta optimizer */ #include -#include +#include +#include -#include +#include +#include +#include +#include +#include #include #include #include -#include - -#include -#include -#include -#include #include #include "old_boost_test_definitions.hpp" +using namespace arma; using namespace mlpack; +using namespace mlpack::optimization; +using namespace mlpack::optimization::test; + +using namespace mlpack::distribution; +using namespace mlpack::regression; + using namespace mlpack::ann; BOOST_AUTO_TEST_SUITE(AdaDeltaTest); @@ -33,81 +39,122 @@ BOOST_AUTO_TEST_SUITE(AdaDeltaTest); * iris data, the data set contains 3 classes. One class is linearly separable * from the other 2. The other two aren't linearly separable from each other. */ + BOOST_AUTO_TEST_CASE(SimpleAdaDeltaTestFunction) { - const size_t hiddenLayerSize = 10; - const size_t maxEpochs = 300; + SGDTestFunction f; + AdaDelta optimizer(f, 0.99, 1e-8, 5000000, 1e-9, true); + + arma::mat coordinates = f.GetInitialPoint(); + const double result = optimizer.Optimize(coordinates); - // Load the dataset. - arma::mat dataset, labels, labelsIdx; - data::Load("iris_train.csv", dataset, true); - data::Load("iris_train_labels.csv", labelsIdx, true); + BOOST_REQUIRE_LE(std::abs(result) - 1.0, 0.2); + BOOST_REQUIRE_SMALL(coordinates[0], 1e-3); + BOOST_REQUIRE_SMALL(coordinates[1], 1e-3); + BOOST_REQUIRE_SMALL(coordinates[2], 1e-3); +} - // Create target matrix. - labels = arma::zeros(labelsIdx.max() + 1, labelsIdx.n_cols); - for (size_t i = 0; i < labelsIdx.n_cols; i++) - labels(labelsIdx(0, i), i) = 1; +/** + * Run AdaDelta on logistic regression and make sure the results are acceptable. + */ +BOOST_AUTO_TEST_CASE(LogisticRegressionTest) +{ + // Generate a two-Gaussian dataset. + GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye(3, 3)); + GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye(3, 3)); + + arma::mat data(3, 1000); + arma::Row responses(1000); + for (size_t i = 0; i < 500; ++i) + { + data.col(i) = g1.Random(); + responses[i] = 0; + } + for (size_t i = 500; i < 1000; ++i) + { + data.col(i) = g2.Random(); + responses[i] = 1; + } - // Construct a feed forward network using the specified parameters. - RandomInitialization randInit(0.1, 0.1); + // Shuffle the dataset. + arma::uvec indices = arma::shuffle(arma::linspace(0, + data.n_cols - 1, data.n_cols)); + arma::mat shuffledData(3, 1000); + arma::Row shuffledResponses(1000); + for (size_t i = 0; i < data.n_cols; ++i) + { + shuffledData.col(i) = data.col(indices[i]); + shuffledResponses[i] = responses[indices[i]]; + } - LinearLayer inputLayer(dataset.n_rows, - hiddenLayerSize, randInit); - BiasLayer inputBiasLayer(hiddenLayerSize, - 1, randInit); - BaseLayer inputBaseLayer; + // Create a test set. + arma::mat testData(3, 1000); + arma::Row testResponses(1000); + for (size_t i = 0; i < 500; ++i) + { + testData.col(i) = g1.Random(); + testResponses[i] = 0; + } + for (size_t i = 500; i < 1000; ++i) + { + testData.col(i) = g2.Random(); + testResponses[i] = 1; + } - LinearLayer hiddenLayer1(hiddenLayerSize, - labels.n_rows, randInit); - BiasLayer hiddenBiasLayer1(labels.n_rows, - 1, randInit); - BaseLayer outputLayer; + LogisticRegression<> lr(shuffledData.n_rows, 0.5); - OneHotLayer classOutputLayer; + LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5); + AdaDelta > AdaDelta(lrf); + lr.Train(AdaDelta); - auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer, - hiddenLayer1, hiddenBiasLayer1, outputLayer); + // Ensure that the error is close to zero. + const double acc = lr.ComputeAccuracy(data, responses); + BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance. - FFN - net(modules, classOutputLayer); + const double testAcc = lr.ComputeAccuracy(testData, testResponses); + BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6); // 0.6% error tolerance. +} - arma::mat prediction; - size_t error = 0; +/** + * Run AdaDelta on a feedforward neural network and make sure the results are + * acceptable. + */ +BOOST_AUTO_TEST_CASE(FeedforwardTest) +{ + // Test on a non-linearly separable dataset (XOR). + arma::mat input, labels; + input << 0 << 1 << 1 << 0 << arma::endr + << 1 << 0 << 1 << 0 << arma::endr; + labels << 0 << 0 << 1 << 1; - // Evaluate the feed forward network. - for (size_t i = 0; i < dataset.n_cols; i++) - { - arma::mat input = dataset.unsafe_col(i); - net.Predict(input, prediction); + // Instantiate the first layer. + LinearLayer<> inputLayer(input.n_rows, 4); + BiasLayer<> biasLayer(4); + SigmoidLayer<> hiddenLayer0; - if (arma::sum(arma::sum(arma::abs( - prediction - labels.unsafe_col(i)))) == 0) - error++; - } + // Instantiate the second layer. + LinearLayer<> hiddenLayer1(4, labels.n_rows); + SigmoidLayer<> outputLayer; - // Check if the selected model isn't already optimized. - double classificationError = 1 - double(error) / dataset.n_cols; - BOOST_REQUIRE_GE(classificationError, 0.09); + // Instantiate the output layer. + BinaryClassificationLayer classOutputLayer; - // Train the feed forward network. - Trainer trainer(net, maxEpochs, 1, 0.01, false); - trainer.Train(dataset, labels, dataset, labels); + // Instantiate the feedforward network. + auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0, hiddenLayer1, + outputLayer); + FFN net(modules, classOutputLayer); - // Evaluate the feed forward network. - error = 0; - for (size_t i = 0; i < dataset.n_cols; i++) - { - arma::mat input = dataset.unsafe_col(i); - net.Predict(input, prediction); + AdaDelta opt(net, 0.88, 1e-15, + 300 * input.n_cols, 1e-18); - if (arma::sum(arma::sum(arma::abs( - prediction - labels.unsafe_col(i)))) == 0) - error++; - } + net.Train(input, labels, opt); - classificationError = 1 - double(error) / dataset.n_cols; + arma::mat prediction; + net.Predict(input, prediction); - BOOST_REQUIRE_LE(classificationError, 0.09); + const bool b = arma::accu(prediction - labels) == 0; + BOOST_REQUIRE_EQUAL(b, true); } BOOST_AUTO_TEST_SUITE_END(); From ce0daad9a7c2f390b24f7c21f89351df4c3f8884 Mon Sep 17 00:00:00 2001 From: vasanth kalingeri Date: Mon, 14 Mar 2016 21:53:35 +0530 Subject: [PATCH 3/3] Fixed std::logic_error When initialized, the class returned this error error: Mat::init(): requested size is not compatible with column vector layout terminate called after throwing an instance of 'std::logic_error' what(): Mat::init(): requested size is not compatible with column vector layout Aborted (core dumped) --- .../ann/init_rules/kathirvalavakumar_subavathi_init.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp b/src/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp index df6b6e3ed0b..0486756073a 100644 --- a/src/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +++ b/src/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp @@ -24,8 +24,8 @@ #include #include - -#include "random_init.hpp" +#include +#include namespace mlpack { namespace ann /** Artificial Neural Network. */ { @@ -61,7 +61,7 @@ class KathirvalavakumarSubavathiInitialization KathirvalavakumarSubavathiInitialization(const arma::Mat& data, const double s) : s(s) { - dataSum = arma::sum(data + data); + dataSum = arma::sum(data); } /** @@ -77,7 +77,6 @@ class KathirvalavakumarSubavathiInitialization { arma::Row b = s * arma::sqrt(3 / (rows * dataSum)); const double theta = b.min(); - RandomInitialization randomInit(-theta, theta); randomInit.Initialize(W, rows, cols); } @@ -104,7 +103,7 @@ class KathirvalavakumarSubavathiInitialization private: //! Parameter that defines the sum of elements in each column. - arma::colvec dataSum; + arma::rowvec dataSum; //! Parameter that defines the active region. const double s;