Skip to content

Commit

Permalink
Merge pull request #8464 from emcastillo/chx-bad-bn
Browse files Browse the repository at this point in the history
Support `chainerx.batch_norm` with 2D input on CUDA
  • Loading branch information
mergify[bot] committed Nov 21, 2019
2 parents b0f88cf + d190e92 commit 7c4c79c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
24 changes: 17 additions & 7 deletions chainerx_cc/chainerx/cuda/cuda_device/batch_norm.cc
Expand Up @@ -29,7 +29,6 @@ namespace chainerx {
namespace cuda {
namespace {

// TODO(sonots): Support other than 4, 5 dimensional arrays by reshaping into 4-dimensional arrays as Chainer does.
cudnnBatchNormMode_t GetBatchNormMode(const Axes& axis) {
if (axis.ndim() == 1 && axis[0] == 0) { // (1, channels, (depth, )height, width)
return CUDNN_BATCHNORM_PER_ACTIVATION;
Expand Down Expand Up @@ -63,6 +62,19 @@ void UpdateRunning(const Array& running, const Array& running_updated) {
internal::GetRawOffsetData(running), internal::GetRawOffsetData(running_casted_back), running.GetNBytes(), device);
}

// Appends singleton axes to make an array with at least 4 dimensions.
// Used for cuDNN BatchNorm, which only supports 4 or 5 dimension input.
Array ExpandToAtLeast4D(const Array& x) {
if (x.ndim() >= 4) {
return x;
}
Shape shape = x.shape();
while (shape.size() < 4) {
shape.push_back(1);
}
return x.Reshape(shape);
}

// Derives a secondary tensor descriptor for the batch normalization parameters.
cuda_internal::CudnnTensorDescriptor DeriveBatchNormTensorDescriptor(
const cuda_internal::CudnnTensorDescriptor& x_desc, cudnnBatchNormMode_t mode) {
Expand Down Expand Up @@ -122,7 +134,7 @@ class CudaBatchNormKernel : public BatchNormKernel {
CudaSetDeviceScope scope{device.index()};

Array x_cont = AsContiguous(x);
cuda_internal::CudnnTensorDescriptor x_desc{x_cont};
cuda_internal::CudnnTensorDescriptor x_desc{ExpandToAtLeast4D(x_cont)};

cudnnBatchNormMode_t mode = GetBatchNormMode(axis);
cuda_internal::CudnnTensorDescriptor gamma_beta_mean_var_desc = DeriveBatchNormTensorDescriptor(x_desc, mode);
Expand Down Expand Up @@ -199,6 +211,7 @@ class CudaBatchNormGradKernel : public BatchNormGradKernel {
const absl::optional<Array>& ggamma,
const absl::optional<Array>& gbeta) override {
CHAINERX_ASSERT(gamma.shape() == internal::ReduceShape(x.shape(), axis, true));
CHAINERX_ASSERT(x.dtype() == gout.dtype());
CHAINERX_ASSERT(x.shape() == gout.shape());
CHAINERX_ASSERT(&x.device() == &gamma.device());
CHAINERX_ASSERT(&x.device() == &gout.device());
Expand Down Expand Up @@ -239,10 +252,7 @@ class CudaBatchNormGradKernel : public BatchNormGradKernel {

Array gout_cont = AsContiguous(gout);
Array actual_gx = EmptyLike(x, device);
cuda_internal::CudnnTensorDescriptor x_desc{x_cont};

// The CudnnTensorDescriptor for `x_cont` can be reused for `gout_cont`.
CHAINERX_ASSERT(x_desc.GetDtype() == cuda_internal::CudnnTensorDescriptor{gout_cont}.GetDtype());
cuda_internal::CudnnTensorDescriptor x_desc{ExpandToAtLeast4D(x_cont)};

cudnnBatchNormMode_t mode = GetBatchNormMode(axis);

Expand Down Expand Up @@ -339,7 +349,7 @@ class CudaFixedBatchNormKernel : public FixedBatchNormKernel {
CudaSetDeviceScope scope{device.index()};

Array x_cont = AsContiguous(x);
cuda_internal::CudnnTensorDescriptor x_desc{x_cont};
cuda_internal::CudnnTensorDescriptor x_desc{ExpandToAtLeast4D(x_cont)};

cudnnBatchNormMode_t mode = GetBatchNormMode(axis);

Expand Down
Expand Up @@ -44,9 +44,11 @@ def _create_batch_norm_ndarray_args(


# Note that CUDA (cuDNN) only supports batch normalization with 4 or
# 5-dimenisional data.
# 5-dimensional data. Arrays with smaller dimensions are supported by the
# CUDA backend, while those with larger dimensions are not.
# x_shape,reduced_shape,axis
_batch_norm_params = [
((3, 2), (2,), None),
((5, 4, 3, 2), (4, 3, 2), None),
((5, 4, 3, 2), (4, 3, 2), (0,)),
((5, 4, 3, 2), (4,), (0, 2, 3)),
Expand Down

0 comments on commit 7c4c79c

Please sign in to comment.