diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp index 51215e7fe65..f9359a9c136 100644 --- a/src/mlpack/methods/ann/cnn.hpp +++ b/src/mlpack/methods/ann/cnn.hpp @@ -326,7 +326,7 @@ class CNN Update(T& t, P& /* unused */, D& delta) { t.Gradient(delta, t.Gradient()); - t.Optimizer().Update(); + t.UpdateOptimizer(); } template @@ -374,8 +374,8 @@ class CNN HasGradientCheck::value, void>::type Apply(T& t, P& /* unused */, D& /* unused */) { - t.Optimizer().Optimize(); - t.Optimizer().Reset(); + t.Optimize(); + t.ResetOptimizer(); } template @@ -401,7 +401,7 @@ class CNN LayerTypes network; //! The outputlayer used to evaluate the network - OutputLayerType& outputLayer; + OutputLayerType outputLayer; //! The class used to evaluate the performance of the network PerformanceFunction performanceFunction; diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp index a6121087f75..f0f3e164516 100644 --- a/src/mlpack/methods/ann/ffn.hpp +++ b/src/mlpack/methods/ann/ffn.hpp @@ -344,7 +344,7 @@ class FFN Update(T& t, P& /* unused */, D& delta) { t.Gradient(delta, t.Gradient()); - t.Optimizer().Update(); + t.UpdateOptimizer(); } template @@ -392,8 +392,8 @@ class FFN HasGradientCheck::value, void>::type Apply(T& t, P& /* unused */, D& /* unused */) { - t.Optimizer().Optimize(); - t.Optimizer().Reset(); + t.Optimize(); + t.ResetOptimizer(); } template diff --git a/src/mlpack/methods/ann/layer/base_layer.hpp b/src/mlpack/methods/ann/layer/base_layer.hpp index 3a6b491e7a3..0cd6b2e64ac 100644 --- a/src/mlpack/methods/ann/layer/base_layer.hpp +++ b/src/mlpack/methods/ann/layer/base_layer.hpp @@ -51,19 +51,7 @@ class BaseLayer // Nothing to do here. } - BaseLayer(BaseLayer &&layer) noexcept - { - *this = std::move(layer); - } - BaseLayer& operator=(BaseLayer &&layer) noexcept - { - delta.swap(layer.delta); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } /** * Ordinary feed forward pass of a neural network, evaluating the function diff --git a/src/mlpack/methods/ann/layer/bias_layer.hpp b/src/mlpack/methods/ann/layer/bias_layer.hpp index 59061d92b93..f42ae86c8df 100644 --- a/src/mlpack/methods/ann/layer/bias_layer.hpp +++ b/src/mlpack/methods/ann/layer/bias_layer.hpp @@ -31,7 +31,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template< typename> class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = NguyenWidrowInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::mat @@ -52,48 +52,11 @@ class BiasLayer const double bias = 1, WeightInitRule weightInitRule = WeightInitRule()) : outSize(outSize), - bias(bias), - optimizer(new OptimizerType, - InputDataType>(*this)), - ownsOptimizer(true) + bias(bias) { weightInitRule.Initialize(weights, outSize, 1); } - BiasLayer(BiasLayer &&layer) noexcept - { - *this = std::move(layer); - } - - BiasLayer& operator=(BiasLayer &&layer) noexcept - { - optimizer = layer.optimizer; - layer.optimizer = nullptr; - ownsOptimizer = layer.ownsOptimizer; - layer.ownsOptimizer = false; - - outSize = layer.outSize; - bias = layer.bias; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - - /** - * Delete the bias layer object and its optimizer. - */ - ~BiasLayer() - { - if (ownsOptimizer) - delete optimizer; - } /** * Ordinary feed forward pass of a neural network, evaluating the function @@ -169,22 +132,17 @@ class BiasLayer { g = d * bias; } - - //! Get the optimizer. - OptimizerType, InputDataType>& Optimizer() const + void UpdateOptimizer() + { + optimizer.Update(Gradient()); + } + void Optimize() { - return *optimizer; + optimizer.Optimize(Weights()); } - //! Modify the optimizer. - OptimizerType, InputDataType>& Optimizer() + void ResetOptimizer() { - return *optimizer; + optimizer.Reset(); } //! Get the weights. @@ -234,19 +192,14 @@ class BiasLayer //! Locally-stored output parameter object. OutputDataType outputParameter; - //! Locally-stored pointer to the optimzer object. - OptimizerType, InputDataType>* optimizer; + //! Locally-stored optimzer object. + OptimizerType optimizer; - //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; }; // class BiasLayer //! Layer traits for the bias layer. template< - template class OptimizerType, + template< typename> class OptimizerType, typename WeightInitRule, typename InputDataType, typename OutputDataType @@ -266,7 +219,7 @@ class LayerTraits class OptimizerType = mlpack::ann::RMSPROP, + template< typename> class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = NguyenWidrowInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::cube diff --git a/src/mlpack/methods/ann/layer/conv_layer.hpp b/src/mlpack/methods/ann/layer/conv_layer.hpp index 26386f066b9..55d720a7bef 100644 --- a/src/mlpack/methods/ann/layer/conv_layer.hpp +++ b/src/mlpack/methods/ann/layer/conv_layer.hpp @@ -32,7 +32,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = NguyenWidrowInitialization, typename ForwardConvolutionRule = NaiveConvolution, typename BackwardConvolutionRule = NaiveConvolution, @@ -74,57 +74,11 @@ class ConvLayer xStride(xStride), yStride(yStride), wPad(wPad), - hPad(hPad), - optimizer(new OptimizerType, - OutputDataType>(*this)), - ownsOptimizer(true) + hPad(hPad) { weightInitRule.Initialize(weights, wfilter, hfilter, inMaps * outMaps); } - ConvLayer(ConvLayer &&layer) noexcept - { - *this = std::move(layer); - } - - ConvLayer& operator=(ConvLayer &&layer) noexcept - { - optimizer = layer.optimizer; - ownsOptimizer = layer.ownsOptimizer; - layer.optimizer = nullptr; - layer.ownsOptimizer = false; - - wfilter = layer.wfilter; - hfilter = layer.hfilter; - inMaps = layer.inMaps; - outMaps = layer.outMaps; - xStride = layer.xStride; - yStride = layer.yStride; - wPad = layer.wPad; - hPad = layer.hPad; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - - /** - * Delete the convolution layer object and its optimizer. - */ - ~ConvLayer() - { - if (ownsOptimizer) - delete optimizer; - } /** * Ordinary feed forward pass of a neural network, evaluating the function @@ -215,27 +169,17 @@ class ConvLayer } } - //! Get the optimizer. - OptimizerType, OutputDataType>& Optimizer() const + void UpdateOptimizer() + { + optimizer.Update(Gradient()); + } + void Optimize() { - return *optimizer; + optimizer.Optimize(Weights()); } - //! Modify the optimizer. - OptimizerType, OutputDataType>& Optimizer() + void ResetOptimizer() { - return *optimizer; + optimizer.Reset(); } //! Get the weights. @@ -350,21 +294,12 @@ class ConvLayer OutputDataType outputParameter; //! Locally-stored pointer to the optimzer object. - OptimizerType, OutputDataType>* optimizer; - - //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; + OptimizerType optimizer; }; // class ConvLayer //! Layer traits for the convolution layer. template< - template class OptimizerType, + template class OptimizerType, typename WeightInitRule, typename ForwardConvolutionRule, typename BackwardConvolutionRule, diff --git a/src/mlpack/methods/ann/layer/dropout_layer.hpp b/src/mlpack/methods/ann/layer/dropout_layer.hpp index 2b8b85a0756..f422492fa28 100644 --- a/src/mlpack/methods/ann/layer/dropout_layer.hpp +++ b/src/mlpack/methods/ann/layer/dropout_layer.hpp @@ -66,24 +66,8 @@ class DropoutLayer // Nothing to do here. } - DropoutLayer(DropoutLayer &&layer) noexcept - { - *this = std::move(layer); - } - DropoutLayer& operator=(DropoutLayer &&layer) noexcept - { - delta.swap(layer.delta); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(outputParameter); - mask.swap(layer.mask); - ratio = layer.ratio; - scale = layer.scale; - deterministic = layer.deterministic; - rescale = layer.rescale; - return *this; - } /** * Ordinary feed forward pass of the dropout layer. diff --git a/src/mlpack/methods/ann/layer/linear_layer.hpp b/src/mlpack/methods/ann/layer/linear_layer.hpp index a5568626dbd..54647040a94 100644 --- a/src/mlpack/methods/ann/layer/linear_layer.hpp +++ b/src/mlpack/methods/ann/layer/linear_layer.hpp @@ -28,7 +28,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = NguyenWidrowInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::mat @@ -48,48 +48,13 @@ class LinearLayer const size_t outSize, WeightInitRule weightInitRule = WeightInitRule()) : inSize(inSize), - outSize(outSize), - optimizer(new OptimizerType, - OutputDataType>(*this)), - ownsOptimizer(true) + outSize(outSize) { weightInitRule.Initialize(weights, outSize, inSize); } - LinearLayer(LinearLayer &&layer) noexcept - { - *this = std::move(layer); - } - LinearLayer& operator=(LinearLayer &&layer) noexcept - { - ownsOptimizer = layer.ownsOptimizer; - layer.ownsOptimizer = false; - optimizer = layer.optimizer; - layer.optimizer = nullptr; - - inSize = layer.inSize; - outSize = layer.outSize; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - /** - * Delete the linear layer object and its optimizer. - */ - ~LinearLayer() - { - if (ownsOptimizer) - delete optimizer; - } /** * Ordinary feed forward pass of a neural network, evaluating the function @@ -157,21 +122,17 @@ class LinearLayer GradientDelta(inputParameter, d, g); } - //! Get the optimizer. - OptimizerType, OutputDataType>& Optimizer() const + void UpdateOptimizer() { - return *optimizer; - } - //! Modify the optimizer. - OptimizerType, OutputDataType>& Optimizer() + optimizer.Update(Gradient()); + } + void Optimize() { - return *optimizer; + optimizer.Optimize(Weights()); + } + void ResetOptimizer() + { + optimizer.Reset(); } //! Get the weights. @@ -306,20 +267,17 @@ class LinearLayer OutputDataType outputParameter; //! Locally-stored pointer to the optimzer object. - OptimizerType, OutputDataType>* optimizer; + OptimizerType optimizer; //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; + }; // class LinearLayer /** * Linear Mapping layer to map between 3rd order tensors and dense matrices. */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template< typename> class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = NguyenWidrowInitialization, typename InputDataType = arma::cube, typename OutputDataType = arma::mat @@ -329,7 +287,7 @@ using LinearMappingLayer = LinearLayer< //! Layer traits for the linear layer. template< - template class OptimizerType, + template< typename> class OptimizerType, typename WeightInitRule, typename InputDataType, typename OutputDataType diff --git a/src/mlpack/methods/ann/layer/pooling_layer.hpp b/src/mlpack/methods/ann/layer/pooling_layer.hpp index c8eca2e56c4..f5b77774d0b 100644 --- a/src/mlpack/methods/ann/layer/pooling_layer.hpp +++ b/src/mlpack/methods/ann/layer/pooling_layer.hpp @@ -45,22 +45,6 @@ class PoolingLayer // Nothing to do here. } - PoolingLayer(PoolingLayer &&layer) noexcept - { - *this = std::move(layer); - } - - PoolingLayer& operator=(PoolingLayer &&layer) noexcept - { - kSize = layer.kSize; - delta.swap(layer.delta); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - pooling = std::move(layer.pooling); - - return *this; - } - /** * Ordinary feed forward pass of a neural network, evaluating the function * f(x) by propagating the activity forward through f. diff --git a/src/mlpack/methods/ann/layer/softmax_layer.hpp b/src/mlpack/methods/ann/layer/softmax_layer.hpp index 860d57a9479..91f5944a302 100644 --- a/src/mlpack/methods/ann/layer/softmax_layer.hpp +++ b/src/mlpack/methods/ann/layer/softmax_layer.hpp @@ -36,20 +36,6 @@ class SoftmaxLayer // Nothing to do here. } - SoftmaxLayer(SoftmaxLayer &&layer) noexcept - { - *this = std::move(layer); - } - - SoftmaxLayer& operator=(SoftmaxLayer &&layer) noexcept - { - delta.swap(layer.delta); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - /** * Ordinary feed forward pass of a neural network, evaluating the function * f(x) by propagating the activity forward through f. @@ -88,7 +74,7 @@ class SoftmaxLayer InputDataType& InputParameter() { return inputParameter; } //! Get the output parameter. - OutputDataType& OutputParameter() const {return outputParameter; } + const OutputDataType& OutputParameter() const {return outputParameter; } //! Modify the output parameter. OutputDataType& OutputParameter() { return outputParameter; } diff --git a/src/mlpack/methods/ann/layer/sparse_bias_layer.hpp b/src/mlpack/methods/ann/layer/sparse_bias_layer.hpp index e08d94a9f43..c9b5a314eec 100644 --- a/src/mlpack/methods/ann/layer/sparse_bias_layer.hpp +++ b/src/mlpack/methods/ann/layer/sparse_bias_layer.hpp @@ -27,7 +27,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = ZeroInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::mat @@ -49,52 +49,12 @@ class SparseBiasLayer const size_t batchSize, WeightInitRule weightInitRule = WeightInitRule()) : outSize(outSize), - batchSize(batchSize), - optimizer(new OptimizerType, - InputDataType>(*this)), - ownsOptimizer(true) + batchSize(batchSize) { weightInitRule.Initialize(weights, outSize, 1); } - SparseBiasLayer(SparseBiasLayer &&layer) noexcept - { - *this = std::move(layer); - } - - SparseBiasLayer& operator=(SparseBiasLayer &&layer) noexcept - { - optimizer = new OptimizerType, - InputDataType>(*this); - ownsOptimizer = layer.ownsOptimizer; - layer.optimizer = nullptr; - layer.ownsOptimizer = false; - - outSize = layer.outSize; - batchSize = layer.batchSize; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - - /** - * Delete the bias layer object and its optimizer. - */ - ~SparseBiasLayer() - { - if (ownsOptimizer) - delete optimizer; - } + /** * Ordinary feed forward pass of a neural network, evaluating the function @@ -138,22 +98,6 @@ class SparseBiasLayer g = arma::sum(d, 1) / static_cast(batchSize); } - //! Get the optimizer. - OptimizerType, InputDataType>& Optimizer() const - { - return *optimizer; - } - //! Modify the optimizer. - OptimizerType, InputDataType>& Optimizer() - { - return *optimizer; - } //! Get the batch size size_t BatchSize() const { return batchSize; } @@ -208,18 +152,13 @@ class SparseBiasLayer OutputDataType outputParameter; //! Locally-stored pointer to the optimzer object. - OptimizerType, InputDataType>* optimizer; + OptimizerType< InputDataType> optimizer; - //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; }; // class SparseBiasLayer //! Layer traits for the bias layer. template< - template class OptimizerType, + template class OptimizerType, typename WeightInitRule, typename InputDataType, typename OutputDataType diff --git a/src/mlpack/methods/ann/layer/sparse_input_layer.hpp b/src/mlpack/methods/ann/layer/sparse_input_layer.hpp index 4faf2d1fd55..2894ee65662 100644 --- a/src/mlpack/methods/ann/layer/sparse_input_layer.hpp +++ b/src/mlpack/methods/ann/layer/sparse_input_layer.hpp @@ -30,7 +30,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = RandomInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::mat @@ -53,55 +53,14 @@ class SparseInputLayer const double lambda = 0.0001) : inSize(inSize), outSize(outSize), - lambda(lambda), - optimizer(new OptimizerType, - OutputDataType>(*this)), - ownsOptimizer(true) + lambda(lambda) { weightInitRule.Initialize(weights, outSize, inSize); } - SparseInputLayer(SparseInputLayer &&layer) noexcept - { - *this = std::move(layer); - } + - SparseInputLayer& operator=(SparseInputLayer &&layer) noexcept - { - ownsOptimizer = layer.ownsOptimizer; - layer.ownsOptimizer = false; - - optimizer = new OptimizerType, - OutputDataType>(*this); - - layer.optimizer = nullptr; - - inSize = layer.inSize; - outSize = layer.outSize; - lambda = layer.lambda; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - - return *this; - } - - /** - * Delete the linear layer object and its optimizer. - */ - ~SparseInputLayer() - { - if (ownsOptimizer) - delete optimizer; - } + /** * Ordinary feed forward pass of a neural network, evaluating the function @@ -147,22 +106,7 @@ class SparseInputLayer lambda * weights; } - //! Get the optimizer. - OptimizerType, OutputDataType>& Optimizer() const - { - return *optimizer; - } - //! Modify the optimizer. - OptimizerType, OutputDataType>& Optimizer() - { - return *optimizer; - } + //! Get the weights. OutputDataType const& Weights() const { return weights; } @@ -215,18 +159,13 @@ class SparseInputLayer OutputDataType outputParameter; //! Locally-stored pointer to the optimzer object. - OptimizerType, OutputDataType>* optimizer; + OptimizerType< OutputDataType> optimizer; - //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; }; // class SparseInputLayer //! Layer traits for the SparseInputLayer. template< - template class OptimizerType, + template class OptimizerType, typename WeightInitRule, typename InputDataType, typename OutputDataType diff --git a/src/mlpack/methods/ann/layer/sparse_output_layer.hpp b/src/mlpack/methods/ann/layer/sparse_output_layer.hpp index a9f97f0c1b1..35121336774 100644 --- a/src/mlpack/methods/ann/layer/sparse_output_layer.hpp +++ b/src/mlpack/methods/ann/layer/sparse_output_layer.hpp @@ -27,7 +27,7 @@ namespace ann /** Artificial Neural Network. */ { * arma::sp_mat or arma::cube). */ template < - template class OptimizerType = mlpack::ann::RMSPROP, + template class OptimizerType = mlpack::ann::RMSPROP, class WeightInitRule = RandomInitialization, typename InputDataType = arma::mat, typename OutputDataType = arma::mat @@ -53,53 +53,11 @@ class SparseOutputLayer outSize(outSize), lambda(lambda), beta(beta), - rho(rho), - optimizer(new OptimizerType, - OutputDataType>(*this)), - ownsOptimizer(true) + rho(rho) { weightInitRule.Initialize(weights, outSize, inSize); } - SparseOutputLayer(SparseOutputLayer &&layer) noexcept - { - *this = std::move(layer); - } - - SparseOutputLayer& operator=(SparseOutputLayer &&layer) noexcept - { - ownsOptimizer = layer.ownsOptimizer; - optimizer = layer.optimizer; - layer.ownsOptimizer = false; - layer.optimizer = nullptr; - - beta = layer.beta; - rho = layer.rho; - lambda = layer.lambda; - inSize = layer.inSize; - outSize = layer.outSize; - weights.swap(layer.weights); - delta.swap(layer.delta); - gradient.swap(layer.gradient); - inputParameter.swap(layer.inputParameter); - outputParameter.swap(layer.outputParameter); - rhoCap.swap(layer.rhoCap); - - return *this; - } - - /** - * Delete the linear layer object and its optimizer. - */ - ~SparseOutputLayer() - { - if (ownsOptimizer) - delete optimizer; - } - /** * Ordinary feed forward pass of a neural network, evaluating the function * f(x) by propagating the activity forward through f. @@ -152,23 +110,7 @@ class SparseOutputLayer lambda * weights; } - //! Get the optimizer. - OptimizerType, OutputDataType>& Optimizer() const - { - return *optimizer; - } - //! Modify the optimizer. - OptimizerType, OutputDataType>& Optimizer() - { - return *optimizer; - } - + //! Sets the KL divergence parameter. void Beta(const double b) { @@ -258,18 +200,14 @@ class SparseOutputLayer OutputDataType outputParameter; //! Locally-stored pointer to the optimzer object. - OptimizerType, OutputDataType>* optimizer; + OptimizerType< OutputDataType> optimizer; - //! Parameter that indicates if the class owns a optimizer object. - bool ownsOptimizer; + }; // class SparseOutputLayer //! Layer traits for the SparseOutputLayer. template< - template class OptimizerType, + template class OptimizerType, typename WeightInitRule, typename InputDataType, typename OutputDataType diff --git a/src/mlpack/methods/ann/optimizer/ada_delta.hpp b/src/mlpack/methods/ann/optimizer/ada_delta.hpp index 76dc19530fd..800586f6565 100644 --- a/src/mlpack/methods/ann/optimizer/ada_delta.hpp +++ b/src/mlpack/methods/ann/optimizer/ada_delta.hpp @@ -32,7 +32,7 @@ namespace ann /** Artificial Neural Network. */ { * } * @endcode */ -template +template< typename DataType> class AdaDelta { public: @@ -45,10 +45,9 @@ class AdaDelta * @param eps The eps coefficient to avoid division by zero (numerical * stability). */ - AdaDelta(DecomposableFunctionType& function, + AdaDelta( const double rho = 0.95, const double eps = 1e-6) : - function(function), rho(rho), eps(eps) { @@ -58,32 +57,32 @@ class AdaDelta /** * Optimize the given function using AdaDelta. */ - void Optimize() + void Optimize(DataType & weights) { if (meanSquaredGradient.n_elem == 0) { - meanSquaredGradient = function.Weights(); + meanSquaredGradient = weights; meanSquaredGradient.zeros(); meanSquaredGradientDx = meanSquaredGradient; } - Optimize(function.Weights(), gradient, meanSquaredGradient, + Optimize(weights, gradient, meanSquaredGradient, meanSquaredGradientDx); } /* * Sum up all gradients and store the results in the gradients storage. */ - void Update() + void Update(DataType const& function_Gradient) { if (gradient.n_elem != 0) { - gradient += function.Gradient(); + gradient += function_Gradient; } else { - gradient = function.Gradient(); + gradient = function_Gradient; } } @@ -150,8 +149,6 @@ class AdaDelta weights -= dx; } - //! The instantiated function. - DecomposableFunctionType& function; //! The value used as interpolation parameter. const double rho; diff --git a/src/mlpack/methods/ann/optimizer/adam.hpp b/src/mlpack/methods/ann/optimizer/adam.hpp index 324020771c9..87dcc11beb7 100644 --- a/src/mlpack/methods/ann/optimizer/adam.hpp +++ b/src/mlpack/methods/ann/optimizer/adam.hpp @@ -30,7 +30,7 @@ namespace ann /** Artificial Neural Network. */ { * } * @endcode */ -template +template< typename DataType> class Adam { public: @@ -44,12 +44,11 @@ class Adam * @param eps The eps coefficient to avoid division by zero (numerical * stability). */ - Adam(DecomposableFunctionType& function, + Adam( const double lr = 0.001, const double beta1 = 0.9, const double beta2 = 0.999, const double eps = 1e-8) : - function(function), lr(lr), beta1(beta1), beta2(beta2), @@ -61,31 +60,31 @@ class Adam /** * Optimize the given function using Adam. */ - void Optimize() + void Optimize(DataType& weights) { if (mean.n_elem == 0) { - mean = function.Weights(); + mean = weights; mean.zeros(); variance = mean; } - Optimize(function.Weights(), gradient, mean, variance); + Optimize(weights, gradient, mean, variance); } /* * Sum up all gradients and store the results in the gradients storage. */ - void Update() + void Update(DataType const& function_gradients) { if (gradient.n_elem != 0) { - gradient += function.Gradient(); + gradient += function_gradients; } else { - gradient = function.Gradient(); + gradient = function_gradients; } } @@ -146,8 +145,6 @@ class Adam weights -= lr * mean / (arma::sqrt(variance) + eps); } - //! The instantiated function. - DecomposableFunctionType& function; //! The value used as learning rate. const double lr; diff --git a/src/mlpack/methods/ann/optimizer/rmsprop.hpp b/src/mlpack/methods/ann/optimizer/rmsprop.hpp index 6fcde445ebd..2fe49644f85 100644 --- a/src/mlpack/methods/ann/optimizer/rmsprop.hpp +++ b/src/mlpack/methods/ann/optimizer/rmsprop.hpp @@ -34,7 +34,7 @@ namespace ann /** Artificial Neural Network. */ { * } * @endcode */ -template +template class RMSPROP { public: @@ -47,11 +47,9 @@ class RMSPROP * @param eps The eps coefficient to avoid division by zero (numerical * stability). */ - RMSPROP(DecomposableFunctionType& function, - const double lr = 0.01, + RMSPROP(const double lr = 0.01, const double alpha = 0.99, const double eps = 1e-8) : - function(function), lr(lr), alpha(alpha), eps(eps) @@ -62,30 +60,30 @@ class RMSPROP /** * Optimize the given function using RmsProp. */ - void Optimize() + void Optimize(DataType& weights) { if (meanSquaredGad.n_elem == 0) { - meanSquaredGad = function.Weights(); + meanSquaredGad = weights; meanSquaredGad.zeros(); } - Optimize(function.Weights(), gradient, meanSquaredGad); + Optimize(weights, gradient, meanSquaredGad); } /* * Sum up all gradients and store the results in the gradients storage. */ - void Update() + void Update(DataType const& function_gradients) { if (gradient.n_elem != 0) { - DataType outputGradient = function.Gradient(); + DataType outputGradient = function_gradients; // function.Gradient(); gradient += outputGradient; } else { - gradient = function.Gradient(); + gradient = function_gradients; } } @@ -138,8 +136,6 @@ class RMSPROP weights -= lr * gradient / (arma::sqrt(meanSquaredGradient) + eps); } - //! The instantiated function. - DecomposableFunctionType& function; //! The value used as learning rate. const double lr; diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp index 316b9613489..a0036782c27 100644 --- a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp +++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp @@ -67,7 +67,7 @@ template< typename HiddenActivate = BaseLayer, typename OutputActivate = HiddenActivate, typename MatType = arma::mat, - template class Optimizer = RMSPROP, + template class Optimizer = RMSPROP, typename HiddenLayer = SparseInputLayer< Optimizer, RandomInitialization, MatType, MatType>, typename OutputLayer = SparseOutputLayer< @@ -296,7 +296,7 @@ class SparseAutoencoder */ template< typename MatType = arma::mat, - template class Optimizer = RMSPROP + template class Optimizer = RMSPROP > using LogisticSparseAutoencoder = SparseAutoencoder< BaseLayer, diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp index bd20116e8b5..9a4a3bcc8e2 100644 --- a/src/mlpack/tests/convolutional_network_test.cpp +++ b/src/mlpack/tests/convolutional_network_test.cpp @@ -34,58 +34,117 @@ using namespace mlpack::ann; BOOST_AUTO_TEST_SUITE(ConvolutionalNetworkTest); /** - * Train and evaluate a vanilla network with the specified structure. - */ +* Train and evaluate a vanilla network with the specified structure. +*/ template< - typename PerformanceFunction + typename PerformanceFunction > void BuildVanillaNetwork() { - arma::mat X; - X.load("mnist_first250_training_4s_and_9s.arm"); - - // Normalize each point since these are images. - arma::uword nPoints = X.n_cols; - for (arma::uword i = 0; i < nPoints; i++) - { - X.col(i) /= norm(X.col(i), 2); - } - - // Build the target matrix. - arma::mat Y = arma::zeros(10, nPoints); - for (size_t i = 0; i < nPoints; i++) - { - if (i < nPoints / 2) - { - Y.col(i)(5) = 1; - } - else - { - Y.col(i)(8) = 1; - } - } - - arma::cube input = arma::cube(28, 28, nPoints); - for (size_t i = 0; i < nPoints; i++) - input.slice(i) = arma::mat(X.colptr(i), 28, 28); + arma::mat X; + X.load("mnist_first250_training_4s_and_9s.arm"); + + // Normalize each point since these are images. + arma::uword nPoints = X.n_cols; + for (arma::uword i = 0; i < nPoints; i++) + { + X.col(i) /= norm(X.col(i), 2); + } + + // Build the target matrix. + arma::mat Y = arma::zeros(10, nPoints); + for (size_t i = 0; i < nPoints; i++) + { + if (i < nPoints / 2) + { + Y.col(i)(5) = 1; + } + else + { + Y.col(i)(8) = 1; + } + } + + arma::cube input = arma::cube(28, 28, nPoints); + for (size_t i = 0; i < nPoints; i++) + input.slice(i) = arma::mat(X.colptr(i), 28, 28); + + /* + * Construct a convolutional neural network with a 28x28x1 input layer, + * 24x24x8 convolution layer, 12x12x8 pooling layer, 8x8x12 convolution layer + * and a 4x4x12 pooling layer which is fully connected with the output layer. + * The network structure looks like: + * + * Input Convolution Pooling Convolution Pooling Output + * Layer Layer Layer Layer Layer Layer + * + * +---+ +---+ +---+ +---+ + * | +---+ | +---+ | +---+ | +---+ + * +---+ | | +---+ | | +---+ | | +---+ | | +---+ +---+ + * | | | | | | | | | | | | | | | | | | | | + * | +--> +-+ | +--> +-+ | +--> +-+ | +--> +-+ | +--> | | + * | | +-+ | +-+ | +-+ | +-+ | | | + * +---+ +---+ +---+ +---+ +---+ +---+ + */ + + ConvLayer convLayer0(1, 8, 5, 5); + BiasLayer2D biasLayer0(8); + BaseLayer2D baseLayer0; + PoolingLayer<> poolingLayer0(2); + + + + + ConvLayer convLayer1(8, 12, 5, 5); + BiasLayer2D biasLayer1(12); + BaseLayer2D baseLayer1; + PoolingLayer<> poolingLayer1(2); + + LinearMappingLayer linearLayer0(192, 10); + BiasLayer biasLayer2(10); + SoftmaxLayer<> softmaxLayer0; + + OneHotLayer outputLayer; + + auto modules = std::tie(convLayer0, biasLayer0, baseLayer0, poolingLayer0, + convLayer1, biasLayer1, baseLayer1, poolingLayer1, + linearLayer0, biasLayer2, softmaxLayer0); + + CNN + net(modules, outputLayer); + + Trainer trainer(net, 50, 1, 0.7); + trainer.Train(input, Y, input, Y); + + BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7); +} +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) +/** +* Train and evaluate a vanilla network with the specified structure. +*/ +template< + typename PerformanceFunction +> +auto GetVanillaNetwork() +{ /* - * Construct a convolutional neural network with a 28x28x1 input layer, - * 24x24x8 convolution layer, 12x12x8 pooling layer, 8x8x12 convolution layer - * and a 4x4x12 pooling layer which is fully connected with the output layer. - * The network structure looks like: - * - * Input Convolution Pooling Convolution Pooling Output - * Layer Layer Layer Layer Layer Layer - * - * +---+ +---+ +---+ +---+ - * | +---+ | +---+ | +---+ | +---+ - * +---+ | | +---+ | | +---+ | | +---+ | | +---+ +---+ - * | | | | | | | | | | | | | | | | | | | | - * | +--> +-+ | +--> +-+ | +--> +-+ | +--> +-+ | +--> | | - * | | +-+ | +-+ | +-+ | +-+ | | | - * +---+ +---+ +---+ +---+ +---+ +---+ - */ + * Construct a convolutional neural network with a 28x28x1 input layer, + * 24x24x8 convolution layer, 12x12x8 pooling layer, 8x8x12 convolution layer + * and a 4x4x12 pooling layer which is fully connected with the output layer. + * The network structure looks like: + * + * Input Convolution Pooling Convolution Pooling Output + * Layer Layer Layer Layer Layer Layer + * + * +---+ +---+ +---+ +---+ + * | +---+ | +---+ | +---+ | +---+ + * +---+ | | +---+ | | +---+ | | +---+ | | +---+ +---+ + * | | | | | | | | | | | | | | | | | | | | + * | +--> +-+ | +--> +-+ | +--> +-+ | +--> +-+ | +--> | | + * | | +-+ | +-+ | +-+ | +-+ | | | + * +---+ +---+ +---+ +---+ +---+ +---+ + */ ConvLayer convLayer0(1, 8, 5, 5); BiasLayer2D biasLayer0(8); @@ -106,19 +165,55 @@ void BuildVanillaNetwork() OneHotLayer outputLayer; - auto modules = std::tie(convLayer0, biasLayer0, baseLayer0, poolingLayer0, - convLayer1, biasLayer1, baseLayer1, poolingLayer1, - linearLayer0, biasLayer2, softmaxLayer0); + auto modules = std::make_tuple(convLayer0, biasLayer0, baseLayer0, poolingLayer0, + convLayer1, biasLayer1, baseLayer1, poolingLayer1, + linearLayer0, biasLayer2, softmaxLayer0); CNN - net(modules, outputLayer); + net(modules, outputLayer); + return net; - Trainer trainer(net, 50, 1, 0.7); +} +/** +* Train the vanilla network on a larger dataset. +*/ +BOOST_AUTO_TEST_CASE(GetVanillaNetworkTest) +{ + arma::mat X; + X.load("mnist_first250_training_4s_and_9s.arm"); + + // Normalize each point since these are images. + arma::uword nPoints = X.n_cols; + for (arma::uword i = 0; i < nPoints; i++) + { + X.col(i) /= norm(X.col(i), 2); + } + // Build the target matrix. + arma::mat Y = arma::zeros(10, nPoints); + for (size_t i = 0; i < nPoints; i++) + { + if (i < nPoints / 2) + { + Y.col(i)(5) = 1; + } + else + { + Y.col(i)(8) = 1; + } + } + + arma::cube input = arma::cube(28, 28, nPoints); + for (size_t i = 0; i < nPoints; i++) + input.slice(i) = arma::mat(X.colptr(i), 28, 28); + BOOST_TEST_MESSAGE("OK here 1"); + + auto net = GetVanillaNetwork(); + Trainer trainer(net,50,1,0.7); trainer.Train(input, Y, input, Y); BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7); } - +#endif /** * Train the vanilla network on a larger dataset. */ diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp index 4c54e77e359..662873e1f68 100644 --- a/src/mlpack/tests/feedforward_network_test.cpp +++ b/src/mlpack/tests/feedforward_network_test.cpp @@ -30,7 +30,7 @@ using namespace mlpack::ann; BOOST_AUTO_TEST_SUITE(FeedForwardNetworkTest); - +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) /** * Train and evaluate a vanilla network with the specified structure. */ @@ -40,6 +40,108 @@ template< typename PerformanceFunctionType, typename MatType = arma::mat > +auto GetVanillaNetwork( + const size_t inputDataSize, + const size_t hiddenLayerSize, + const size_t outputLabelSize +) + { + /* + * Construct a feed forward network with trainData.n_rows input nodes, + * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The + * network structure looks like: + * + * Input Hidden Output + * Layer Layer Layer + * +-----+ +-----+ +-----+ + * | | | | | | + * | +------>| +------>| | + * | | +>| | +>| | + * +-----+ | +--+--+ | +-----+ + * | | + * Bias | Bias | + * Layer | Layer | + * +-----+ | +-----+ | + * | | | | | | + * | +-----+ | +-----+ + * | | | | + * +-----+ +-----+ + */ + + LinearLayer<> inputLayer(inputDataSize, hiddenLayerSize); + BiasLayer<> inputBiasLayer(hiddenLayerSize); + BaseLayer inputBaseLayer; + + LinearLayer<> hiddenLayer1(hiddenLayerSize, outputLabelSize); + BiasLayer<> hiddenBiasLayer1(outputLabelSize); + BaseLayer outputLayer; + + OutputLayerType classOutputLayer; + + auto modules = std::make_tuple(inputLayer, inputBiasLayer, inputBaseLayer, + hiddenLayer1, hiddenBiasLayer1, outputLayer); + + return FFN + (modules, classOutputLayer); + + + } + +/** +* Train and evaluate a vanilla network with the specified structure. +*/ +template< + typename PerformanceFunction, + typename OutputLayerType, + typename PerformanceFunctionType, + typename MatType = arma::mat +> +void CopyCtorTest(MatType& trainData, + MatType& trainLabels, + MatType& testData, + MatType& testLabels, + const size_t hiddenLayerSize, + const size_t maxEpochs, + const double classificationErrorThreshold, + const double ValidationErrorThreshold) + { + + auto net = GetVanillaNetwork(trainData.n_rows, hiddenLayerSize, trainLabels.n_rows); + + Trainer trainer(net, maxEpochs, 1, 0.01); + trainer.Train(trainData, trainLabels, testData, testLabels); + + MatType prediction; + size_t error = 0; + + for (size_t i = 0; i < testData.n_cols; i++) + { + MatType predictionInput = testData.unsafe_col(i); + MatType targetOutput = testLabels.unsafe_col(i); + + net.Predict(predictionInput, prediction); + + if (arma::sum(arma::sum(arma::abs(prediction - targetOutput))) == 0) + error++; + } + + double classificationError = 1 - double(error) / testData.n_cols; + + BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold); + BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold); + } +#endif +/** +* Train and evaluate a vanilla network with the specified structure. +*/ +template< + typename PerformanceFunction, + typename OutputLayerType, + typename PerformanceFunctionType, + typename MatType = arma::mat +> void BuildVanillaNetwork(MatType& trainData, MatType& trainLabels, MatType& testData, @@ -110,10 +212,11 @@ void BuildVanillaNetwork(MatType& trainData, BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold); } + /** * Train the vanilla network on a larger dataset. */ -BOOST_AUTO_TEST_CASE(VanillaNetworkTest) +BOOST_AUTO_TEST_CASE(VanillaNetworkThyRoidTest) { // Load the dataset. arma::mat dataset; @@ -139,6 +242,14 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest) MeanSquaredErrorFunction> (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60); + } + +BOOST_AUTO_TEST_CASE(VanillaNetworkMnistTest) + { + // Load the dataset. + arma::mat dataset; + + dataset.load("mnist_first250_training_4s_and_9s.arm"); // Normalize each point since these are images. @@ -161,6 +272,68 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest) (dataset, labels, dataset, labels, 10, 200, 0.6, 20); } +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) + +/** +* Train the vanilla network on a larger dataset. +*/ +BOOST_AUTO_TEST_CASE(CopyCtorThyRoidTest) +{ + // Load the dataset. + arma::mat dataset; + data::Load("thyroid_train.csv", dataset, true); + + arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + data::Load("thyroid_test.csv", dataset, true); + + arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + // Vanilla neural net with logistic activation function. + // Because 92 percent of the patients are not hyperthyroid the neural + // network must be significant better than 92%. + CopyCtorTest + (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60); + +} +#endif +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) +BOOST_AUTO_TEST_CASE(CopyCtorMnistTest) +{ + // Load the dataset. + arma::mat dataset; + + + dataset.load("mnist_first250_training_4s_and_9s.arm"); + + // Normalize each point since these are images. + for (size_t i = 0; i < dataset.n_cols; ++i) + dataset.col(i) /= norm(dataset.col(i), 2); + + arma::mat labels = arma::zeros(1, dataset.n_cols); + labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1); + + // Vanilla neural net with logistic activation function. + CopyCtorTest + (dataset, labels, dataset, labels, 30, 100, 0.6, 10); + + // Vanilla neural net with tanh activation function. + CopyCtorTest + (dataset, labels, dataset, labels, 10, 200, 0.6, 20); +} +#endif /** * Train and evaluate a Dropout network with the specified structure. */ @@ -238,10 +411,57 @@ void BuildDropoutNetwork(MatType& trainData, BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold); } +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) + +/** +* Train and evaluate a Dropout network with the specified structure. +*/ +template< + typename PerformanceFunction, + typename OutputLayerType, + typename PerformanceFunctionType, + typename MatType = arma::mat +> +void CopyCtorDropoutTest(MatType& trainData, + MatType& trainLabels, + MatType& testData, + MatType& testLabels, + const size_t hiddenLayerSize, + const size_t maxEpochs, + const double classificationErrorThreshold, + const double ValidationErrorThreshold) +{ + + auto net = GetVanillaNetwork(trainData.n_rows, hiddenLayerSize, trainLabels.n_rows); + + Trainer trainer(net, maxEpochs, 1, 0.001); + trainer.Train(trainData, trainLabels, testData, testLabels); + + MatType prediction; + size_t error = 0; + + for (size_t i = 0; i < testData.n_cols; i++) + { + MatType input = testData.unsafe_col(i); + net.Predict(input, prediction); + if (arma::sum(arma::sum(arma::abs( + prediction - testLabels.unsafe_col(i)))) == 0) + error++; + } + + double classificationError = 1 - double(error) / testData.n_cols; + + BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold); + BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold); +} + +#endif /** * Train the dropout network on a larger dataset. */ -BOOST_AUTO_TEST_CASE(DropoutNetworkTest) +BOOST_AUTO_TEST_CASE(DropoutNetworkThyroidTest) { // Load the dataset. arma::mat dataset; @@ -267,6 +487,47 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest) MeanSquaredErrorFunction> (trainData, trainLabels, testData, testLabels, 4, 100, 0.1, 60); + } + +#if (__cplusplus >= 201402L) || (defined(_MSC_VER) && _MSC_VER >= 1900) +/** +* Train the dropout network on a larger dataset. +*/ +BOOST_AUTO_TEST_CASE(DropoutNetworkCtorTest) +{ + // Load the dataset. + arma::mat dataset; + data::Load("thyroid_train.csv", dataset, true); + + arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + data::Load("thyroid_test.csv", dataset, true); + + arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + // Vanilla neural net with logistic activation function. + // Because 92 percent of the patients are not hyperthyroid the neural + // network must be significant better than 92%. + CopyCtorDropoutTest + (trainData, trainLabels, testData, testLabels, 4, 100, 0.1, 60); + +} +#endif + +BOOST_AUTO_TEST_CASE(DropoutNetworkMnistTest) +{ + // Load the dataset. + arma::mat dataset; + + dataset.load("mnist_first250_training_4s_and_9s.arm"); // Normalize each point since these are images.