Skip to content

Commit

Permalink
Merge pull request #1390 from ShikharJ/AtrousConv
Browse files Browse the repository at this point in the history
Atrous (Dilated) Convolution Implementation.
  • Loading branch information
zoq committed May 18, 2018
2 parents a37c9ca + 5b215ec commit 24d564d
Show file tree
Hide file tree
Showing 8 changed files with 872 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace ann /** Artificial Neural Network. */ {
/**
* Computes the two-dimensional convolution through fft. This class allows
* specification of the type of the border type. The convolution can be
* computed with the valid border type of the full border type (default).
* computed with the valid border type or the full border type (default).
*
* FullConvolution: returns the full two-dimensional convolution.
* ValidConvolution: returns only those parts of the convolution that are
Expand Down
83 changes: 62 additions & 21 deletions src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class NaiveConvolution
* @param output Output data that contains the results of the convolution.
* @param dW Stride of filter application in the x direction.
* @param dH Stride of filter application in the y direction.
* @param dilationW The dilation factor in x direction.
* @param dilationH The dilation factor in y direction.
*/
template<typename eT, typename Border = BorderMode>
static typename std::enable_if<
Expand All @@ -51,10 +53,13 @@ class NaiveConvolution
const arma::Mat<eT>& filter,
arma::Mat<eT>& output,
const size_t dW = 1,
const size_t dH = 1)
const size_t dH = 1,
const size_t dilationW = 1,
const size_t dilationH = 1)
{
output = arma::zeros<arma::Mat<eT> >((input.n_rows - filter.n_rows + 1) /
dW, (input.n_cols - filter.n_cols + 1) / dH);
output = arma::zeros<arma::Mat<eT> >(
(input.n_rows - (filter.n_rows - 1) * dilationW - 1) / dW + 1,
(input.n_cols - (filter.n_cols - 1) * dilationH - 1) / dH + 1);

// It seems to be about 3.5 times faster to use pointers instead of
// filter(ki, kj) * input(leftInput + ki, topInput + kj) and output(i, j).
Expand All @@ -67,8 +72,9 @@ class NaiveConvolution
const eT* kernelPtr = filter.memptr();
for (size_t kj = 0; kj < filter.n_cols; ++kj)
{
const eT* inputPtr = input.colptr(kj + j * dW) + i * dH;
for (size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr, ++inputPtr)
const eT* inputPtr = input.colptr(kj * dilationW + j * dW) + i * dH;
for (size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr,
inputPtr += dilationH)
*outputPtr += *kernelPtr * (*inputPtr);
}
}
Expand All @@ -83,6 +89,8 @@ class NaiveConvolution
* @param output Output data that contains the results of the convolution.
* @param dW Stride of filter application in the x direction.
* @param dH Stride of filter application in the y direction.
* @param dilationW The dilation factor in x direction.
* @param dilationH The dilation factor in y direction.
*/
template<typename eT, typename Border = BorderMode>
static typename std::enable_if<
Expand All @@ -91,20 +99,41 @@ class NaiveConvolution
const arma::Mat<eT>& filter,
arma::Mat<eT>& output,
const size_t dW = 1,
const size_t dH = 1)
const size_t dH = 1,
const size_t dilationW = 1,
const size_t dilationH = 1)
{
const size_t outputRows = (input.n_rows + 2 * (filter.n_rows - 1)) * dW;
const size_t outputCols = (input.n_cols + 2 * (filter.n_cols - 1)) * dH;
size_t outputRows = (input.n_rows - 1) * dW + 2 * (filter.n_rows - 1)
* dilationW + 1;
size_t outputCols = (input.n_cols - 1) * dH + 2 * (filter.n_cols - 1)
* dilationH + 1;

for (size_t i = 0; i < dW; i++)
{
if (((((i + outputRows - 2 * (filter.n_rows - 1) * dilationW - 1) % dW)
+ dW) % dW) == i){
outputRows += i;
break;
}
}
for (size_t i = 0; i < dH; i++)
{
if (((((i + outputCols - 2 * (filter.n_cols - 1) * dilationH - 1) % dH)
+ dH) % dH) == i){
outputCols += i;
break;
}
}

// Pad filter and input to the working output shape.
arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
outputCols);
inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
filter.n_rows - 1 + input.n_rows - 1,
filter.n_cols - 1 + input.n_cols - 1) = input;
inputPadded.submat((filter.n_rows - 1) * dilationW, (filter.n_cols - 1)
* dilationH, (filter.n_rows - 1) * dilationW + input.n_rows - 1,
(filter.n_cols - 1) * dilationH + input.n_cols - 1) = input;

NaiveConvolution<ValidConvolution>::Convolution(inputPadded, filter,
output, 1, 1);
output, 1, 1, dilationW, dilationH);
}

/*
Expand All @@ -115,17 +144,21 @@ class NaiveConvolution
* @param output Output data that contains the results of the convolution.
* @param dW Stride of filter application in the x direction.
* @param dH Stride of filter application in the y direction.
* @param dilationW The dilation factor in x direction.
* @param dilationH The dilation factor in y direction.
*/
template<typename eT>
static void Convolution(const arma::Cube<eT>& input,
const arma::Cube<eT>& filter,
arma::Cube<eT>& output,
const size_t dW = 1,
const size_t dH = 1)
const size_t dH = 1,
const size_t dilationW = 1,
const size_t dilationH = 1)
{
arma::Mat<eT> convOutput;
NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
convOutput, dW, dH);
convOutput, dW, dH, dilationW, dilationH);

output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
input.n_slices);
Expand All @@ -134,7 +167,7 @@ class NaiveConvolution
for (size_t i = 1; i < input.n_slices; i++)
{
NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
output.slice(i), dW, dH);
output.slice(i), dW, dH, dilationW, dilationH);
}
}

Expand All @@ -147,17 +180,21 @@ class NaiveConvolution
* @param output Output data that contains the results of the convolution.
* @param dW Stride of filter application in the x direction.
* @param dH Stride of filter application in the y direction.
* @param dilationW The dilation factor in x direction.
* @param dilationH The dilation factor in y direction.
*/
template<typename eT>
static void Convolution(const arma::Mat<eT>& input,
const arma::Cube<eT>& filter,
arma::Cube<eT>& output,
const size_t dW = 1,
const size_t dH = 1)
const size_t dH = 1,
const size_t dilationW = 1,
const size_t dilationH = 1)
{
arma::Mat<eT> convOutput;
NaiveConvolution<BorderMode>::Convolution(input, filter.slice(0),
convOutput, dW, dH);
convOutput, dW, dH, dilationW, dilationH);

output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
filter.n_slices);
Expand All @@ -166,7 +203,7 @@ class NaiveConvolution
for (size_t i = 1; i < filter.n_slices; i++)
{
NaiveConvolution<BorderMode>::Convolution(input, filter.slice(i),
output.slice(i), dW, dH);
output.slice(i), dW, dH, dilationW, dilationH);
}
}

Expand All @@ -179,17 +216,21 @@ class NaiveConvolution
* @param output Output data that contains the results of the convolution.
* @param dW Stride of filter application in the x direction.
* @param dH Stride of filter application in the y direction.
* @param dilationW The dilation factor in x direction.
* @param dilationH The dilation factor in y direction.
*/
template<typename eT>
static void Convolution(const arma::Cube<eT>& input,
const arma::Mat<eT>& filter,
arma::Cube<eT>& output,
const size_t dW = 1,
const size_t dH = 1)
const size_t dH = 1,
const size_t dilationW = 1,
const size_t dilationH = 1)
{
arma::Mat<eT> convOutput;
NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter,
convOutput, dW, dH);
convOutput, dW, dH, dilationW, dilationH);

output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
input.n_slices);
Expand All @@ -198,7 +239,7 @@ class NaiveConvolution
for (size_t i = 1; i < input.n_slices; i++)
{
NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter,
output.slice(i), dW, dH);
output.slice(i), dW, dH, dilationW, dilationH);
}
}
}; // class NaiveConvolution
Expand Down
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/layer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(SOURCES
add_merge_impl.hpp
alpha_dropout.hpp
alpha_dropout_impl.hpp
atrous_convolution.hpp
atrous_convolution_impl.hpp
base_layer.hpp
bilinear_interpolation.hpp
bilinear_interpolation_impl.hpp
Expand Down
Loading

0 comments on commit 24d564d

Please sign in to comment.