Skip to content

Commit

Permalink
[Frontend][MXNet] Add support for MXNet GroupNorm (apache#7409)
Browse files Browse the repository at this point in the history
* Add support for MXNet GroupNorm

* Fix python lint

* Fix lint
  • Loading branch information
Trevor Morris authored and trevor-m committed Mar 2, 2021
1 parent d7a72d2 commit 48c6099
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Expand Up @@ -495,6 +495,19 @@ def _mx_layer_norm(inputs, attrs):
return _op.nn.layer_norm(*inputs, **new_attrs)


def _mx_group_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "output_mean_var" is not supported for operator Group Norm.'
)
new_attrs = {}
new_attrs["axis"] = 1
new_attrs["num_groups"] = attrs.get_int("num_groups", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.group_norm(*inputs, **new_attrs)


def _mx_slice(inputs, attrs):
new_attrs = {}
begin = list(attrs.get_int_tuple("begin", None))
Expand Down Expand Up @@ -2599,6 +2612,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_contrib_SyncBatchNorm": _mx_batch_norm,
"InstanceNorm": _mx_instance_norm,
"LayerNorm": _mx_layer_norm,
"GroupNorm": _mx_group_norm,
"LRN": _mx_lrn,
"L2Normalization": _mx_l2_normalize,
"slice": _mx_slice,
Expand Down
32 changes: 32 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Expand Up @@ -1263,6 +1263,38 @@ def verify(shape, axis=-1):
verify((2, 5, 6))


@tvm.testing.uses_gpu
def test_forward_group_norm():
def verify(shape, num_groups=1):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[1])).astype("float32")
beta = np.random.uniform(size=(shape[1])).astype("float32")
ref_res = mx.nd.GroupNorm(
data=mx.nd.array(x),
gamma=mx.nd.array(gamma),
beta=mx.nd.array(beta),
num_groups=num_groups,
)
mx_sym = mx.sym.GroupNorm(
mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), num_groups=num_groups
)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(
op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5
)

verify((1, 4, 2), num_groups=4)
# TODO(trevmorr): MXNet GroupNorm implementation is bugged for cases when num_groups != num_channels
# https://github.com/apache/incubator-mxnet/pull/18199
# verify((1, 4, 2, 3), num_groups=2)
# verify((1, 4, 2, 3))


@tvm.testing.uses_gpu
def test_forward_one_hot():
def verify(indices_shape, depth, on_value, off_value, dtype):
Expand Down

0 comments on commit 48c6099

Please sign in to comment.