Skip to content

Commit

Permalink
[Relay][Convert Layout] Handling batch norm layout change. (apache#4600)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and zhiics committed Jan 11, 2020
1 parent 268bcbb commit 60c422d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
29 changes: 29 additions & 0 deletions src/relay/op/nn/nn.cc
Expand Up @@ -617,6 +617,34 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);

Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());

size_t axis =
param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);

Layout ret = Layout::Undef();

// If new_in_layouts are defined, this code tries to modify the layout.
if (new_in_layouts.defined() && old_in_layouts.defined()) {
// Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
const auto& bn_dim = old_in_layouts[0][axis];
auto new_index = new_in_layouts[0].IndexOf(bn_dim);
param->axis = new_index;
ret = new_in_layouts[0];
} else if (old_in_layouts.defined()) {
ret = old_in_layouts[0];
}
// BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
Layout c_layout = Layout("C");

return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout},
{ret, c_layout, c_layout}};
}

bool BatchNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -708,6 +736,7 @@ axis to be the last item in the input shape.
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/convert_layout.cc
Expand Up @@ -134,7 +134,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"),
{ir::StringImm::make("InferType"),
ir::StringImm::make("CanonicalizeOps")});
}

Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Expand Up @@ -349,6 +349,54 @@ def expected():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_conv_bn_convert_layout():
""" Check that layout transforms are propagated through bn. """
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 64))
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
data_layout='NHWC', kernel_layout='HWIO')

dtype = "float32"
beta = relay.var("beta", relay.TensorType((64,), dtype))
gamma = relay.var("gamma", relay.TensorType((64,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))

y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3)
y = relay.nn.relu(y[0])
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 56, 56, 64))
w = relay.var("weight", shape=(3, 3, 64, 64))
x = relay.layout_transform(x, 'NHWC', 'NCHW')
w = relay.layout_transform(w, 'HWIO', 'OIHW')
y = relay.nn.conv2d(x, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))

dtype = "float32"
beta = relay.var("beta", relay.TensorType((64,), dtype))
gamma = relay.var("gamma", relay.TensorType((64,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))

y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1)
y = relay.nn.relu(y[0])
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
test_no_convert_layout()
test_conv_convert_layout()
Expand All @@ -358,3 +406,4 @@ def expected():
test_bn_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()

0 comments on commit 60c422d

Please sign in to comment.