Skip to content

Commit

Permalink
[BUGFIX] Fix MKLDNN BatchNorm with even number of channels (apache#19150
Browse files Browse the repository at this point in the history
) apache#19299 (apache#19425)

* Fix MKLDNN BatchNorm with even number of channels (apache#19150)

Even number of channels results in data reordering before batch
norm operation. Therefore, if BatchNorm data array is view of
another array and the data is stored in MKLDNN format, the data
needs to be converted to the default format.

* Add or updated test to verify Batchnorm odd & even number of channels

* Fix for Batchnorm odd & even chnls number context
  • Loading branch information
akarbown committed Oct 30, 2020
1 parent 0514233 commit 0faecf0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
16 changes: 6 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
return it->second;
}

template<typename DType>
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
const OpContext &ctx, const NDArray &in_data,
mkldnn::normalization_flags flags) {
return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
}

template <typename DType>
void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -182,8 +175,11 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
aux_states,
ctx.is_train && !param.use_global_stats,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
NDArray &data = in_data[batchnorm::kData];
if (data.IsMKLDNNData() && data.IsView())
data = data.Reorder2Default();
auto data_mem = data.GetMKLDNNData();
auto &fwd = GetBNForward<DType>(param, ctx, data_mem, flags);

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -221,7 +217,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}

mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
net_args[MKLDNN_ARG_DST] = *out_mem;
if (fuse_relu) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_mkldnn_sum_inplace_with_cpu_layout():

def test_batchnorm():
def check_batchnorm_training(stype):
for shape in [(2, 3), (2, 3, 2, 2)]:
for shape in [(2, 3), (2, 4), (2, 3, 2, 2), (2, 4, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import mxnet as mx
from mxnet import gluon
from mxnet import init
from mxnet.gluon import nn
from mxnet.base import py_str, MXNetError
from mxnet.test_utils import assert_almost_equal, default_context, assert_allclose
Expand Down Expand Up @@ -1978,6 +1979,39 @@ def hybrid_forward(self, F, x):
check_layer_forward_withinput(net, x)


def test_batchnorm_chnls():
chn_list = [1024, 512, 256, 128, 64, 45, 32, 16, 3]
class Net(gluon.HybridBlock):
def __init__(self,
chn_num,
norm_kwargs=None,
in_channels=3,
**kwargs):
super(Net, self).__init__(**kwargs)
self.in_channels = in_channels
self.conv1 = gluon.nn.Conv3D(
in_channels=self.in_channels,
channels=chn_num,
kernel_size=(1, 7, 7),
strides=(1, 2, 2),
padding=(0, 3, 3),
use_bias=False,
)
self.bn1 = gluon.nn.BatchNorm(in_channels=chn_num, **({} if norm_kwargs is None else norm_kwargs))

def hybrid_forward(self, F, x):
"""Hybrid forward of R2+1D net"""
conv = self.conv1(x)
out = self.bn1(conv)
return out

for i in range(len(chn_list)):
net = Net(chn_list[i])
net.initialize(init=init.Constant(1))
x = mx.nd.zeros((1, 3, 8, 160, 160))
net(x).asnumpy()


def test_concat():
chn_list = [16, 64]
shapes = [1, 3, 5]
Expand Down

0 comments on commit 0faecf0

Please sign in to comment.