diff --git a/src/mlpack/methods/ann/layer/pooling_layer.hpp b/src/mlpack/methods/ann/layer/pooling_layer.hpp index 9eb57aa2b6c..a33ffe9ac5d 100644 --- a/src/mlpack/methods/ann/layer/pooling_layer.hpp +++ b/src/mlpack/methods/ann/layer/pooling_layer.hpp @@ -10,6 +10,7 @@ #include #include +#include #include namespace mlpack { @@ -39,8 +40,9 @@ class PoolingLayer * @param kSize Size of the pooling window. * @param pooling The pooling strategy. */ - PoolingLayer(const size_t kSize, PoolingRule pooling = PoolingRule()) : - kSize(kSize), pooling(pooling) + PoolingLayer(const size_t kSize, const size_t stride = 1, + PoolingRule pooling = PoolingRule()) : + kSize(kSize), pooling(pooling), stride(stride) { // Nothing to do here. } @@ -68,8 +70,8 @@ class PoolingLayer template void Forward(const arma::Cube& input, arma::Cube& output) { - output = arma::zeros >(input.n_rows / kSize, - input.n_cols / kSize, input.n_slices); + output = arma::zeros >((input.n_rows - kSize) / stride + 1, + (input.n_cols - kSize) / stride + 1, input.n_slices); for (size_t s = 0; s < input.n_slices; s++) Pooling(input.slice(s), output.slice(s)); @@ -155,6 +157,7 @@ class PoolingLayer { ar & data::CreateNVP(kSize, "kSize"); ar & data::CreateNVP(pooling, "pooling"); + ar & data::CreateNVP(stride, "stride"); } private: @@ -167,16 +170,16 @@ class PoolingLayer template void Pooling(const arma::Mat& input, arma::Mat& output) { - const size_t rStep = kSize; const size_t cStep = kSize; - for (size_t j = 0; j < input.n_cols; j += cStep) + for (size_t j = 0, colidx = 0; j < output.n_cols; ++j, colidx += stride) { - for (size_t i = 0; i < input.n_rows; i += rStep) + for (size_t i = 0, rowidx = 0; i < output.n_rows; ++i, rowidx += stride) { - output(i / rStep, j / cStep) += pooling.Pooling( - input(arma::span(i, i + rStep - 1), arma::span(j, j + cStep - 1))); + output(i, j) += pooling.Pooling( + input(arma::span(rowidx, rowidx + rStep - 1), + arma::span(colidx, colidx + cStep - 1))); } } } @@ -215,6 +218,9 @@ class PoolingLayer //! Locally-stored size of the pooling window. size_t kSize; + //! Locally-stored stride value by which we move filter. + size_t stride; + //! Locally-stored delta object. OutputDataType delta; @@ -249,3 +255,4 @@ class LayerTraits > } // namespace mlpack #endif +