Skip to content

Commit

Permalink
Support incomplete shape inference for slice_channel and convolution (a…
Browse files Browse the repository at this point in the history
…pache#4898)

* Support incomplete shape inference for slice_channel and convolution

* Try fix warning

* Fix shape infer

* fix test cases and unnecessary if

* Revise threshold

* Update submodules

* Update

* add docs to explain the behavior of squeeze_axis

* fix compile error

* revise doc
  • Loading branch information
sxjscience authored and piiswrong committed Feb 24, 2017
1 parent 3439f97 commit 7fc3db5
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ps-lite
Submodule ps-lite updated 1 files
+1 −1 docs/overview.md
53 changes: 46 additions & 7 deletions src/operator/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,6 @@ class ConvolutionProp : public OperatorProperty {
<< "incorrect stride size: " << param_.stride;
CHECK_GT(param_.dilate.Size(), 0) \
<< "incorrect dilate size: " << param_.dilate;
CHECK(ksize_y <= dshape[2] + 2 * param_.pad[0]
&& ksize_x <= dshape[3] + 2 * param_.pad[1])
<< "kernel size exceed input";
Shape<4> oshape;
oshape[0] = dshape[0];
oshape[1] = param_.num_filter;
Expand All @@ -417,6 +414,26 @@ class ConvolutionProp : public OperatorProperty {
oshape[3] = (dshape[3] + 2 * param_.pad[1] -
(param_.dilate[1] * (ksize_x - 1) + 1)) / param_.stride[1] + 1;
SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value()));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0].get<4>(), param_.layout.value(), kNCHW);
dshape[0] = oshape[0];
if (param_.stride[0] == 1) {
dshape[2] = oshape[2] + param_.dilate[0] * (ksize_y - 1) - 2 * param_.pad[0];
}
if (param_.stride[1] == 1) {
dshape[3] = oshape[3] + param_.dilate[1] * (ksize_x - 1) - 2 * param_.pad[1];
}
SHAPE_ASSIGN_CHECK(*in_shape, conv::kData,
ConvertLayout(dshape, kNCHW, param_.layout.value()));
// Check whether the kernel sizes are valid
if (dshape[2] != 0) {
CHECK_LE(ksize_y, dshape[2] + 2 * param_.pad[0]) << "kernel size exceed input";
}
if (dshape[3] != 0) {
CHECK_LE(ksize_x, dshape[3] + 2 * param_.pad[1]) << "kernel size exceed input";
}
return true;
} else if (param_.kernel.ndim() == 3) {
// 3d conv
Expand Down Expand Up @@ -445,10 +462,6 @@ class ConvolutionProp : public OperatorProperty {
<< "incorrect stride size: " << param_.stride;
CHECK_GT(param_.dilate.Size(), 0) \
<< "incorrect dilate size: " << param_.dilate;
CHECK(ksize_d <= dshape[2] + 2 * param_.pad[0]
&& ksize_y <= dshape[3] + 2 * param_.pad[1]
&& ksize_x <= dshape[4] + 2 * param_.pad[2])
<< "kernel size exceed input";
CHECK_EQ(param_.dilate.Size(), 1)
<< "Dilate is not supported in 3d convolution";
Shape<5> oshape;
Expand All @@ -461,6 +474,32 @@ class ConvolutionProp : public OperatorProperty {
oshape[4] = (dshape[4] + 2 * param_.pad[2] -
(1 * (ksize_x - 1) + 1)) / param_.stride[2] + 1;
SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCDHW, param_.layout.value()));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input depth/height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0].get<5>(), param_.layout.value(), kNCDHW);
dshape[0] = oshape[0];
if (param_.stride[0] == 1) {
dshape[2] = oshape[2] + 1 * (ksize_d - 1) - 2 * param_.pad[0];
}
if (param_.stride[1] == 1) {
dshape[3] = oshape[3] + 1 * (ksize_y - 1) - 2 * param_.pad[1];
}
if (param_.stride[2] == 1) {
dshape[4] = oshape[4] + 1 * (ksize_x - 1) - 2 * param_.pad[2];
}
SHAPE_ASSIGN_CHECK(*in_shape, conv::kData,
ConvertLayout(dshape, kNCDHW, param_.layout.value()));
// Check whether the kernel sizes are valid
if (dshape[2] != 0) {
CHECK_LT(ksize_d, dshape[2] + 2 * param_.pad[0]) << "kernel size exceed input";
}
if (dshape[3] != 0) {
CHECK_LE(ksize_y, dshape[3] + 2 * param_.pad[1]) << "kernel size exceed input";
}
if (dshape[4] != 0) {
CHECK_LE(ksize_x, dshape[4] + 2 * param_.pad[2]) << "kernel size exceed input";
}
return true;
} else {
LOG(FATAL) << "Unknown convolution type";
Expand Down
37 changes: 33 additions & 4 deletions src/operator/slice_channel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ struct SliceChannelParam : public dmlc::Parameter<SliceChannelParam> {
DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("Dimension along which to slice.");
DMLC_DECLARE_FIELD(squeeze_axis).set_default(0)
.describe("If true AND the sliced dimension becomes 1, squeeze that dimension.");
.describe("If true, the dimension will be squeezed."
" Also, input.shape[axis] must be the same as `num_outputs`"
" when squeeze_axis is turned on.");
}
}; // struct SliceChannelParam

Expand Down Expand Up @@ -157,6 +159,7 @@ class SliceChannelProp : public OperatorProperty {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1);
TShape dshape = in_shape->at(slice_enum::kData);
TShape ishape = in_shape->at(slice_enum::kData);
if (dshape.ndim() == 0) return false;
if (param_.axis >= 0) {
CHECK_LT(static_cast<size_t>(param_.axis), dshape.ndim());
Expand All @@ -171,16 +174,42 @@ class SliceChannelProp : public OperatorProperty {
<< "num_outputs (" << param_.num_outputs
<< ") does not divide input dimension "
<< real_axis << " (" << dshape[real_axis] << ").";
if (param_.squeeze_axis && ishape[real_axis] != 0) {
CHECK_EQ(ishape[real_axis], param_.num_outputs)
<< "If squeeze axis is True, the size of the sliced axis must be the same as num_outputs."
<< " Input shape=" << ishape << ", axis=" << real_axis
<< ", num_outputs=" << param_.num_outputs << ".";
}
dshape[real_axis] /= param_.num_outputs;
if (param_.squeeze_axis && dshape[real_axis] == 1) {
if (param_.squeeze_axis && (dshape[real_axis] == 1 || ishape[real_axis] == 0)) {
for (int d = real_axis; d < static_cast<int>(dshape.ndim()) - 1; ++d) {
dshape[d] = dshape[d+1];
}
dshape = TShape(&dshape[0], &dshape[dshape.ndim()-1]);
}
out_shape->clear();
CHECK_EQ((*out_shape).size(), param_.num_outputs) << "Size of output shape mismatch!";
for (int i = 0; i < param_.num_outputs; ++i) {
out_shape->push_back(dshape);
SHAPE_ASSIGN_CHECK(*out_shape, i, dshape);
// Perform incomplete shape inference.
// We can back-calculate the inshape based on the out_shape.
TShape back_calculate_dshape = ishape;
if (param_.squeeze_axis && (dshape.ndim() == ishape.ndim() - 1)) {
for (int d = 0; d < real_axis; ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
}
back_calculate_dshape[real_axis] = param_.num_outputs;
for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d - 1];
}
} else {
for (int d = 0; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
if (d == real_axis) {
back_calculate_dshape[d] *= param_.num_outputs;
}
}
}
SHAPE_ASSIGN_CHECK(*in_shape, slice_enum::kData, back_calculate_dshape);
}
return true;
}
Expand Down
73 changes: 73 additions & 0 deletions tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,80 @@ def test_backward_infer():
for k, v in true_shapes.items():
assert arg_shape_dict[k] == v


def test_incomplete_infer_elewise():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.Variable('b', shape=(12, 0))
c = a + b
arg_shapes, _, _ = c.infer_shape()
arg_names = c.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (12, 10)
assert arg_shapes['b'] == (12, 10)


def test_incomplete_infer_mlp():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.FullyConnected(data=a, num_hidden=21)
c = mx.sym.Variable('c', shape=(5, 0))
d = b + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10)
assert arg_shapes['c'] == (5, 21)


def test_incomplete_infer_slicechannel():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.SliceChannel(data=a, num_outputs=10, axis=1, squeeze_axis=True)
c = mx.sym.Variable('c', shape=(5,))
d = b[1] + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10)

a = mx.sym.Variable('a', shape=(0, 15, 0))
b = mx.sym.SliceChannel(data=a, num_outputs=3, squeeze_axis=False)
c = mx.sym.Variable('c', shape=(3, 5, 2))
d = b[1] + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (3, 15, 2)


def test_incomplete_infer_convolution():
a = mx.sym.Variable('a', shape=(0, 10, 0, 0))
b = mx.sym.Convolution(data=a, num_filter=21, kernel=(3, 3), dilate=(1, 1), pad=(1, 1))
c = mx.sym.Variable('c', shape=(5, 21, 32, 32))
d = b + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10, 32, 32)


def test_incomplete_infer_concat():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.Variable('b', shape=(0, 5))
c = mx.sym.Concat(a, b, num_args=2, dim=1)
d = mx.sym.Variable('d', shape=(2, 0))
d = d + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (2, 10)
assert arg_shapes['b'] == (2, 5)
assert arg_shapes['d'] == (2, 15)

if __name__ == "__main__":
test_mlp2_infer_shape()
test_mlp2_infer_error()
test_backward_infer()
test_incomplete_infer_elewise()
test_incomplete_infer_mlp()
test_incomplete_infer_slicechannel()
test_incomplete_infer_convolution()
test_incomplete_infer_concat()

0 comments on commit 7fc3db5

Please sign in to comment.