diff --git a/src/mlpack/methods/ann/layer/drop_connect_layer.hpp b/src/mlpack/methods/ann/layer/drop_connect_layer.hpp new file mode 100644 index 00000000000..f3208a43a10 --- /dev/null +++ b/src/mlpack/methods/ann/layer/drop_connect_layer.hpp @@ -0,0 +1,382 @@ +/** + * @file drop_connect_layer.hpp + * @author Abhinav Chanda + * + * Definition of the DropConnectLayer class, which implements a regulaizer + * that randomly sets weights to zero between two fully connected layers. + * It prevents units from co-adapting and is a generalisation of dropout. + */ +#ifndef __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP +#define __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP + +#include +#include + +namespace mlpack { +namespace ann /** Artificial Neural Network. */ { + +/** + * The dropconnect layer is a regularizer that randomly with probability ratio + * sets weights to zero and scales the remaining weights by factor 1 / + * (1 - ratio) between two fully connected layers. If rescale is true the input + * is scaled with 1 / (1-p) when deterministic is false. In the deterministic mode + * (during testing), the layer just scales the output. + * + * Note: During training you should set deterministic to false and during + * testing you should set deterministic to true. + * + * For more information, see the following. + * + * @code + * @article{Wan2013, + * author = {Li Wan, Matthew Zeiler, Sixin Zhang, Yann Le Cun, Rob Fergus}, + * title = {Regularization of Neural Networks using DropConnect}, + * journal = {JMLR}, + * volume = {28}, + * year = {2013}, + * } + * @endcode + * + * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + */ +template < + typename InputDataType = arma::mat, + typename OutputDataType = arma::mat +> +class DropConnectLayer +{ + public: + /** + * Create the DropConnectLayer object using the specified number of units. + * + * @param inSize The number of input units. + * @param outSize The number of output units. + */ + DropConnectLayer(const size_t inSize, + const size_t outSize, + const double ratio = 0.5, + const bool rescale = true) : + inSize(inSize), + outSize(outSize), + ratio(ratio), + scale(1.0 / (1.0 - ratio)), + rescale(rescale) + { + weights.set_size(outSize, inSize); + } + + /** + * Ordinary feed forward pass of a neural network, evaluating the function + * f(x) by propagating the activity forward through f. + * + * @param input Input data used for evaluating the specified function. + * @param output Resulting output activation. + */ + template + void Forward(const arma::Mat& input, arma::Mat& output) + { + // The dropout mask will not be multiplied in the deterministic mode + // (during testing). + if (deterministic) + { + if (!rescale) + { + output = weights * input; + } + else + { + output = weights * scale * input; + } + } + else + { + // Scale with input / (1 - ratio) and set values to zero with probability + // ratio. + mask = arma::randu >(weights.n_rows, weights.n_cols); + mask.transform( [&](double val) { return (val > ratio); } ); + output = (weights % mask) * scale * input; + } + } + + /** + * Ordinary feed forward pass of a neural network, evaluating the function + * f(x) by propagating the activity forward through f. + * + * @param input Input data used for evaluating the specified function. + * @param output Resulting output activation. + */ + template + void Forward(const arma::Cube& input, arma::Mat& output) + { + // The dropout mask will not be multiplied in the deterministic mode + // (during testing). + if (deterministic) + { + if (!rescale) + { + output = weights * input; + } + else + { + output = weights * scale * input; + } + } + else + { + arma::Mat data(input.n_elem, 1); + + for (size_t s = 0, c = 0; s < input.n_slices / data.n_cols; s++) + { + for (size_t i = 0; i < data.n_cols; i++, c++) + { + data.col(i).subvec(s * input.n_rows * input.n_cols, (s + 1) * + input.n_rows * input.n_cols - 1) = arma::vectorise(input.slice(c)); + } + } + // Scale with input / (1 - ratio) and set values to zero with probability + // ratio. + mask = arma::randu >(weights.n_rows, weights.n_cols); + mask.transform( [&](double val) { return (val > ratio); } ); + output = (weights % mask) * scale * data; + } + } + + /** + * Ordinary feed backward pass of a neural network, calculating the function + * f(x) by propagating x backwards trough f. Using the results from the feed + * forward pass. + * + * @param input The propagated input activation. + * @param gy The backpropagated error. + * @param g The calculated gradient. + */ + template + void Backward(const InputType& /* unused */, + const arma::Mat& gy, + arma::Mat& g) + { + g = ((weights % mask) * scale).t() * gy; + } + + + /* + * Calculate the gradient using the output delta and the input activation. + * + * @param d The calculated error. + * @param g The calculated gradient. + */ + template + void Gradient(const arma::Mat& d, GradientDataType& g) + { + GradientDelta(inputParameter, d, g); + } + + //! Get the weights. + OutputDataType& Weights() const { return weights; } + //! Modify the weights. + OutputDataType& Weights() { return weights; } + + //! Get the input parameter. + InputDataType& InputParameter() const { return inputParameter; } + //! Modify the input parameter. + InputDataType& InputParameter() { return inputParameter; } + + //! Get the output parameter. + OutputDataType& OutputParameter() const { return outputParameter; } + //! Modify the output parameter. + OutputDataType& OutputParameter() { return outputParameter; } + + //! Get the delta. + OutputDataType& Delta() const { return delta; } + //! Modify the delta. + OutputDataType& Delta() { return delta; } + + //! Get the gradient. + OutputDataType& Gradient() const { return gradient; } + //! Modify the gradient. + OutputDataType& Gradient() { return gradient; } + + //! The value of the deterministic parameter. + bool Deterministic() const { return deterministic; } + //! Modify the value of the deterministic parameter. + bool& Deterministic() { return deterministic; } + + //! The probability of setting a value to zero. + double Ratio() const { return ratio; } + + //! Modify the probability of setting a value to zero. + void Ratio(const double r) + { + ratio = r; + scale = 1.0 / (1.0 - ratio); + } + + //! The value of the rescale parameter. + bool Rescale() const {return rescale; } + //! Modify the value of the rescale parameter. + bool& Rescale() {return rescale; } + + /** + * Serialize the layer. + */ + template + void Serialize(Archive& ar, const unsigned int /* version */) + { + ar & data::CreateNVP(ratio, "ratio"); + ar & data::CreateNVP(rescale, "rescale"); + ar & data::CreateNVP(weights, "weights"); + } + + private: + /* + * Calculate the gradient using the output delta (3rd order tensor) and the + * input activation (3rd order tensor). + * + * @param input The input parameter used for calculating the gradient. + * @param d The output delta. + * @param g The calculated gradient. + */ + template + void GradientDelta(const arma::Cube& input, + const arma::Mat& d, + arma::Cube& g) + { + g = arma::Cube(weights.n_rows, weights.n_cols, 1); + arma::Mat data = arma::Mat(d.n_cols, + input.n_elem / d.n_cols); + + for (size_t s = 0, c = 0; s < input.n_slices / + data.n_rows; s++) + { + for (size_t i = 0; i < data.n_rows; i++, c++) + { + data.row(i).subvec(s * input.n_rows * + input.n_cols, (s + 1) * + input.n_rows * + input.n_cols - 1) = arma::vectorise( + input.slice(c), 1); + } + } + + g.slice(0) = d * data / d.n_cols; + } + + /* + * Calculate the gradient (3rd order tensor) using the output delta + * (dense matrix) and the input activation (dense matrix). + * + * @param input The input parameter used for calculating the gradient. + * @param d The output delta. + * @param g The calculated gradient. + */ + template + void GradientDelta(const arma::Mat& /* input unused */, + const arma::Mat& d, + arma::Cube& g) + { + g = arma::Cube(weights.n_rows, weights.n_cols, 1); + Gradient(d, g.slice(0)); + } + + /* + * Calculate the gradient (dense matrix) using the output delta + * (dense matrix) and the input activation (3rd order tensor). + * + * @param input The input parameter used for calculating the gradient. + * @param d The output delta. + * @param g The calculated gradient. + */ + template + void GradientDelta(const arma::Cube& /* input unused */, + const arma::Mat& d, + arma::Mat& g) + { + arma::Cube grad = arma::Cube(weights.n_rows, weights.n_cols, 1); + Gradient(d, grad); + g = grad.slice(0); + } + + /* + * Calculate the gradient (dense matrix) using the output delta + * (dense matrix) and the input activation (dense matrix). + * + * @param input The input parameter used for calculating the gradient. + * @param d The output delta. + * @param g The calculated gradient. + */ + template + void GradientDelta(const arma::Mat& input, + const arma::Mat& d, + arma::Mat& g) + { + g = d * input.t(); + } + + //! Locally-stored number of input units. + size_t inSize; + + //! Locally-stored number of output units. + size_t outSize; + + //! Locally-stored weight object. + OutputDataType weights; + + //! Locally-stored delta object. + OutputDataType delta; + + //! Locally-stored gradient object. + OutputDataType gradient; + + //! Locally-stored input parameter object. + InputDataType inputParameter; + + //! Locally-stored output parameter object. + OutputDataType outputParameter; + + //! Locally-stored mask object. + OutputDataType mask; + + //! The probability of setting a value to zero. + double ratio; + + //! The scale fraction. + double scale; + + //! If true dropout and scaling is disabled, see notes above. + bool deterministic; + + //! If true the input is rescaled when deterministic is False. + bool rescale; +}; // class DropConnectLayer + +/** + * Mapping layer to map between 3rd order tensors and dense matrices. + */ +template < + typename InputDataType = arma::cube, + typename OutputDataType = arma::mat +> +using DropConnectLinearMappingLayer = DropConnectLayer; + +//! Layer traits for the dropconnect layer. +template< + typename InputDataType, + typename OutputDataType +> +class LayerTraits > +{ + public: + static const bool IsBinary = false; + static const bool IsOutputLayer = false; + static const bool IsBiasLayer = false; + static const bool IsLSTMLayer = false; + static const bool IsConnection = true; +}; + +} // namespace ann +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp index e8b663682ca..7231f8a2558 100644 --- a/src/mlpack/tests/feedforward_network_test.cpp +++ b/src/mlpack/tests/feedforward_network_test.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -286,4 +287,131 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest) (dataset, labels, dataset, labels, 8, 30, 0.4); } +/** + * Train and evaluate a Dropout network with the specified structure. + */ +template< + typename PerformanceFunction, + typename OutputLayerType, + typename PerformanceFunctionType, + typename MatType = arma::mat +> +void BuildDropConnectNetwork(MatType& trainData, + MatType& trainLabels, + MatType& testData, + MatType& testLabels, + const size_t hiddenLayerSize, + const size_t maxEpochs, + const double classificationErrorThreshold) +{ + /* + * Construct a feed forward network with trainData.n_rows input nodes, + * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The + * network structure looks like: + * + * Input Hidden DropConnect Output + * Layer Layer Layer Layer + * +-----+ +-----+ +-----+ +-----+ + * | | | | | | | | + * | +------>| +------>| +------>| | + * | | +>| | | | | | + * +-----+ | +--+--+ +-----+ +-----+ + * | + * Bias | + * Layer | + * +-----+ | + * | | | + * | +-----+ + * | | + * +-----+ + */ + + LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize); + BiasLayer<> biasLayer(hiddenLayerSize); + BaseLayer hiddenLayer0; + DropConnectLayer<> dropconnectLayer0(hiddenLayerSize, trainLabels.n_rows); + + BaseLayer outputLayer; + + OutputLayerType classOutputLayer; + + auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0, dropconnectLayer0, + outputLayer); + + FFN net(modules, classOutputLayer); + + RMSprop opt(net, 0.01, 0.88, 1e-8, + maxEpochs * trainData.n_cols, 1e-18); + + net.Train(trainData, trainLabels, opt); + + MatType prediction; + net.Predict(testData, prediction); + + size_t error = 0; + for (size_t i = 0; i < testData.n_cols; i++) + { + if (arma::sum(arma::sum( + arma::abs(prediction.col(i) - testLabels.col(i)))) == 0) + { + error++; + } + } + + double classificationError = 1 - double(error) / testData.n_cols; + BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold); +} + +/** + * Train the dropconnect network on a larger dataset. + */ +BOOST_AUTO_TEST_CASE(DropConnectNetworkTest) +{ + // Load the dataset. + arma::mat dataset; + data::Load("thyroid_train.csv", dataset, true); + + arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + data::Load("thyroid_test.csv", dataset, true); + + arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4, + dataset.n_cols - 1); + arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0, + dataset.n_rows - 1, dataset.n_cols - 1); + + // Vanilla neural net with logistic activation function. + // Because 92 percent of the patients are not hyperthyroid the neural + // network must be significant better than 92%. + BuildDropConnectNetwork + (trainData, trainLabels, testData, testLabels, 4, 100, 0.1); + + dataset.load("mnist_first250_training_4s_and_9s.arm"); + + // Normalize each point since these are images. + for (size_t i = 0; i < dataset.n_cols; ++i) + dataset.col(i) /= norm(dataset.col(i), 2); + + arma::mat labels = arma::zeros(1, dataset.n_cols); + labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1); + + // Vanilla neural net with logistic activation function. + BuildDropConnectNetwork + (dataset, labels, dataset, labels, 8, 30, 0.4); + + // Vanilla neural net with tanh activation function. + BuildDropConnectNetwork + (dataset, labels, dataset, labels, 8, 30, 0.4); +} + BOOST_AUTO_TEST_SUITE_END();