From 47447c1cd74dae18cdcb47cea237d50fabff833d Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 30 Mar 2023 19:43:54 +0800 Subject: [PATCH 01/20] update --- .../function_libs/torch_aten/ops/core.py | 18 ++++++++++-------- .../torch_aten/ops_correctness_test.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 5cd22079b9..c1ec01482a 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4173,19 +4173,21 @@ def aten_native_dropout_backward( raise NotImplementedError() +@torch_op("aten::native_group_norm", trace_only=True) def aten_native_group_norm( - input: TensorType, + input: TReal, weight: Optional[TensorType], bias: Optional[TensorType], - N: INT64, - C: INT64, - HxW: INT64, - group: int, - eps: float, -) -> tuple[TensorType, TensorType, TensorType]: + N: INT64 = None, + C: INT64 = None, + HxW: INT64 = None, + group: int = None, + eps: float = None, +) -> tuple[TReal, TReal, TReal]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" - raise NotImplementedError() + result = op.GroupNormalization(input, weight, bias, epsilon=eps, num_groups=group) + return result def aten_native_group_norm_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index c916283099..ebeebf30e1 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -314,6 +314,17 @@ def _mse_loss_input_wrangler( return args, kwargs +def _native_group_norm_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + kwargs["group"] = args.pop(1) + args.append(kwargs["weight"]) + args.append(kwargs["bias"]) + del kwargs["weight"] + del kwargs["bias"] + return args, kwargs + + def _nll_loss_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -498,6 +509,7 @@ def _where_input_wrangler( "mul": core_ops.aten_mul, "narrow": core_ops.aten_narrow, # "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB + "native_group_norm": (core_ops.aten_native_group_norm, _native_group_norm_input_wrangler), "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_full": core_ops.aten_new_full, @@ -1006,6 +1018,8 @@ def _where_input_wrangler( duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) +duplicate_opinfo(OPS_DB, "nn.functional.group_norm", ("native_group_norm",)) + duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",)) duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",)) @@ -1348,6 +1362,10 @@ def run_test_output_match( ), kwargs=repr(cpu_sample.kwargs), ): + if i == 0: + print(i) + else: + continue skip_reason = _should_skip_test_sample(op.name, cpu_sample) if skip_reason is not None: # Cannot use self.skip because pytest would skip the entire test From 081c55cc9c91c9a7e9cc2b2abe2327621e92b07b Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 30 Mar 2023 19:55:23 +0800 Subject: [PATCH 02/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index c1ec01482a..cba3eaf286 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4186,6 +4186,9 @@ def aten_native_group_norm( ) -> tuple[TReal, TReal, TReal]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" + # input_len = len(input.shape) + # if input_len == 3: + # input = op.Unsqueeze(input, axes=0) result = op.GroupNormalization(input, weight, bias, epsilon=eps, num_groups=group) return result From f056fd7869a7f41256d9dffdc825c56569d111f5 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 3 Apr 2023 12:17:12 +0800 Subject: [PATCH 03/20] Update ops_correctness_test.py --- .../tests/function_libs/torch_aten/ops_correctness_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index ebeebf30e1..b5aa40c1f9 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -317,9 +317,9 @@ def _mse_loss_input_wrangler( def _native_group_norm_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - kwargs["group"] = args.pop(1) - args.append(kwargs["weight"]) - args.append(kwargs["bias"]) + kwargs["group"] = args.pop(1) # move group(int) to kwargs as attribute + args.append(kwargs["weight"]) # move weight(tensor) to args as input + args.append(kwargs["bias"]) # move bias(tensor) to args as input del kwargs["weight"] del kwargs["bias"] return args, kwargs From 580218871c00b57aaa013a8b6c8c62d2ea748f95 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 17:00:52 +0800 Subject: [PATCH 04/20] update --- .../function_libs/torch_aten/ops/core.py | 35 +++++++++++++++---- .../torch_aten/ops_correctness_test.py | 9 ++--- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index cba3eaf286..e8e6125b24 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4176,8 +4176,8 @@ def aten_native_dropout_backward( @torch_op("aten::native_group_norm", trace_only=True) def aten_native_group_norm( input: TReal, - weight: Optional[TensorType], - bias: Optional[TensorType], + weight: Optional[TReal], + bias: Optional[TReal], N: INT64 = None, C: INT64 = None, HxW: INT64 = None, @@ -4186,11 +4186,32 @@ def aten_native_group_norm( ) -> tuple[TReal, TReal, TReal]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" - # input_len = len(input.shape) - # if input_len == 3: - # input = op.Unsqueeze(input, axes=0) - result = op.GroupNormalization(input, weight, bias, epsilon=eps, num_groups=group) - return result + # Create weight_instance_norm and bias_instance_norm + weight_inst = op.Constant(value_floats=[1.0] * group) + bias_inst = op.Constant(value_floats=[0.0] * group) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need + shape = op.Constant(value_ints=[0, group, -1]) + + return _aten_native_group_norm_onnx(input, weight, bias, weight_inst, bias_inst, shape, eps) + + +@torch_op("aten::native_group_norm", private=True) +def _aten_native_group_norm_onnx( + input: TReal, + weight: Optional[TReal], + bias: Optional[TReal], + w1: TReal, + b1: TReal, + shape: INT64, + eps: float = None, +) -> TReal: # We can only return one TReal instead of [x,y,z] + input_reshaped = op.Reshape(input, shape) + norm_reshaped = op.InstanceNormalization(input_reshaped, w1, b1, epsilon=eps) + norm = op.Reshape(norm_reshaped, op.Shape(input)) + input_rank = op.Size(op.Shape(input)) + axes = op.Range(1, input_rank - 1, 1) + # Using the real weight and bias to computer again + return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes)) def aten_native_group_norm_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index b5aa40c1f9..011880c842 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -850,6 +850,11 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), + skip( + "native_group_norm", + matcher=lambda sample: len(sample.input.shape) == 2, + reason="ONNX only support input shape >= 3", + ), skip( "new_ones", matcher=lambda sample: sample.kwargs.get("dtype") is not None, @@ -1362,10 +1367,6 @@ def run_test_output_match( ), kwargs=repr(cpu_sample.kwargs), ): - if i == 0: - print(i) - else: - continue skip_reason = _should_skip_test_sample(op.name, cpu_sample) if skip_reason is not None: # Cannot use self.skip because pytest would skip the entire test From 4fb749d80d0142904d16b785e7e4a91861c4382a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 17:01:40 +0800 Subject: [PATCH 05/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index e8e6125b24..d87615e76f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4180,7 +4180,7 @@ def aten_native_group_norm( bias: Optional[TReal], N: INT64 = None, C: INT64 = None, - HxW: INT64 = None, + HxW: INT64 = None, # pylint: disable=unused-argument group: int = None, eps: float = None, ) -> tuple[TReal, TReal, TReal]: @@ -4192,7 +4192,9 @@ def aten_native_group_norm( # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need shape = op.Constant(value_ints=[0, group, -1]) - return _aten_native_group_norm_onnx(input, weight, bias, weight_inst, bias_inst, shape, eps) + return _aten_native_group_norm_onnx( + input, weight, bias, weight_inst, bias_inst, shape, eps + ) @torch_op("aten::native_group_norm", private=True) From 4f49db6b4bd35e487510cc445d53ed78abfb92a0 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 17:05:16 +0800 Subject: [PATCH 06/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index d87615e76f..f165499875 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4206,7 +4206,10 @@ def _aten_native_group_norm_onnx( b1: TReal, shape: INT64, eps: float = None, -) -> TReal: # We can only return one TReal instead of [x,y,z] +# FIXME: We can only return one TReal instead of [x,y,z] +# Because we don't how to computer the running_var and running_mean +# No native_group_norm test case, and the group_norm function in torch only return one output +) -> TReal: input_reshaped = op.Reshape(input, shape) norm_reshaped = op.InstanceNormalization(input_reshaped, w1, b1, epsilon=eps) norm = op.Reshape(norm_reshaped, op.Shape(input)) From 5500bd112a8792f47ccce6eec366bb18c16cd6c7 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 17:05:48 +0800 Subject: [PATCH 07/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index f165499875..a9836f02ba 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4183,7 +4183,10 @@ def aten_native_group_norm( HxW: INT64 = None, # pylint: disable=unused-argument group: int = None, eps: float = None, -) -> tuple[TReal, TReal, TReal]: +# FIXME: We can only return one TReal instead of [x,y,z] +# Because we don't how to computer the running_var and running_mean +# No native_group_norm test case, and the group_norm function in torch only return one output +) -> TReal: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Create weight_instance_norm and bias_instance_norm @@ -4206,9 +4209,6 @@ def _aten_native_group_norm_onnx( b1: TReal, shape: INT64, eps: float = None, -# FIXME: We can only return one TReal instead of [x,y,z] -# Because we don't how to computer the running_var and running_mean -# No native_group_norm test case, and the group_norm function in torch only return one output ) -> TReal: input_reshaped = op.Reshape(input, shape) norm_reshaped = op.InstanceNormalization(input_reshaped, w1, b1, epsilon=eps) From 264de391fac36a9adf5fc2d67319faae63dc67a3 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 17:39:34 +0800 Subject: [PATCH 08/20] update --- onnxscript/function_libs/torch_aten/ops/core.py | 11 ++++++----- .../function_libs/torch_aten/ops_correctness_test.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 446d97cae3..4cf2336aab 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4213,15 +4213,16 @@ def aten_native_group_norm( input: TReal, weight: Optional[TReal], bias: Optional[TReal], - N: INT64 = None, - C: INT64 = None, + N: INT64 = None, # pylint: disable=unused-argument + C: INT64 = None, # pylint: disable=unused-argument HxW: INT64 = None, # pylint: disable=unused-argument group: int = None, eps: float = None, -# FIXME: We can only return one TReal instead of [x,y,z] -# Because we don't how to computer the running_var and running_mean -# No native_group_norm test case, and the group_norm function in torch only return one output ) -> TReal: + # FIXME: for the return, we can only return one TReal instead of [x,y,z] + # Because we don't how to computer the running_var and running_mean + # No native_group_norm test case, and the group_norm function in torch only return one output + """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Create weight_instance_norm and bias_instance_norm diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index c6755b56c2..8a96b7dcbd 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -602,7 +602,6 @@ def _where_input_wrangler( "mul": core_ops.aten_mul, "narrow": core_ops.aten_narrow, # "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB - "native_group_norm": (core_ops.aten_native_group_norm, _native_group_norm_input_wrangler), "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_full": core_ops.aten_new_full, @@ -713,6 +712,7 @@ def _where_input_wrangler( "index_select": core_ops.aten_index_select, "layer_norm": core_ops.aten_layer_norm, "max": core_ops.aten_max, + "native_group_norm": (core_ops.aten_native_group_norm, _native_group_norm_input_wrangler), "native_layer_norm": core_ops.aten_native_layer_norm, "new_empty": core_ops.aten_new_empty, "new_empty_strided": core_ops.aten_new_empty_strided, From 4761f0d0042231174a9ee09822364553da9b5d8c Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 18:27:22 +0800 Subject: [PATCH 09/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 4cf2336aab..0759119a14 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4222,7 +4222,6 @@ def aten_native_group_norm( # FIXME: for the return, we can only return one TReal instead of [x,y,z] # Because we don't how to computer the running_var and running_mean # No native_group_norm test case, and the group_norm function in torch only return one output - """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Create weight_instance_norm and bias_instance_norm From 27dce7eb19b24a758abbfec9ba222796edd536e2 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 17 Apr 2023 18:32:14 +0800 Subject: [PATCH 10/20] Update core.py --- .../function_libs/torch_aten/ops/core.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0759119a14..a0f11441c3 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4210,15 +4210,15 @@ def aten_native_dropout_backward( @torch_op("aten::native_group_norm", trace_only=True) def aten_native_group_norm( - input: TReal, - weight: Optional[TReal], - bias: Optional[TReal], + input: TFloat, + weight: Optional[TFloat], + bias: Optional[TFloat], N: INT64 = None, # pylint: disable=unused-argument C: INT64 = None, # pylint: disable=unused-argument HxW: INT64 = None, # pylint: disable=unused-argument group: int = None, eps: float = None, -) -> TReal: +) -> TFloat: # FIXME: for the return, we can only return one TReal instead of [x,y,z] # Because we don't how to computer the running_var and running_mean # No native_group_norm test case, and the group_norm function in torch only return one output @@ -4237,16 +4237,16 @@ def aten_native_group_norm( @torch_op("aten::native_group_norm", private=True) def _aten_native_group_norm_onnx( - input: TReal, - weight: Optional[TReal], - bias: Optional[TReal], - w1: TReal, - b1: TReal, + input: TFloat, + weight: Optional[TFloat], + bias: Optional[TFloat], + weight_inst: TFloat, + bias_inst: TFloat, shape: INT64, eps: float = None, ) -> TReal: input_reshaped = op.Reshape(input, shape) - norm_reshaped = op.InstanceNormalization(input_reshaped, w1, b1, epsilon=eps) + norm_reshaped = op.InstanceNormalization(input_reshaped, weight_inst, bias_inst, epsilon=eps) norm = op.Reshape(norm_reshaped, op.Shape(input)) input_rank = op.Size(op.Shape(input)) axes = op.Range(1, input_rank - 1, 1) From f40f0d393e34144a15d497930a121fa67ed46539 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 18 Apr 2023 18:29:36 +0800 Subject: [PATCH 11/20] update --- .../function_libs/torch_aten/ops/core.py | 24 ++++++++--- .../function_libs/torch_aten/extra_opinfo.py | 40 +++++++++++++++++++ .../torch_aten/ops_correctness_test.py | 15 +------ 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index a0f11441c3..7075ff4fbf 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4218,21 +4218,21 @@ def aten_native_group_norm( HxW: INT64 = None, # pylint: disable=unused-argument group: int = None, eps: float = None, -) -> TFloat: - # FIXME: for the return, we can only return one TReal instead of [x,y,z] - # Because we don't how to computer the running_var and running_mean - # No native_group_norm test case, and the group_norm function in torch only return one output +) -> Tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" + # Assert(weight is not None, and, bias is not None) # Create weight_instance_norm and bias_instance_norm weight_inst = op.Constant(value_floats=[1.0] * group) bias_inst = op.Constant(value_floats=[0.0] * group) # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need shape = op.Constant(value_ints=[0, group, -1]) - return _aten_native_group_norm_onnx( + norm = _aten_native_group_norm_onnx( input, weight, bias, weight_inst, bias_inst, shape, eps ) + # weight_inst, bias_inst are fake output, because we must return 3 outpurs + return norm, weight_inst, bias_inst @torch_op("aten::native_group_norm", private=True) @@ -4245,6 +4245,8 @@ def _aten_native_group_norm_onnx( shape: INT64, eps: float = None, ) -> TReal: + # Using InstanceNorm to simulate GroupNorm, because GroupNorm need weight[group] and bias[group] + # But the input is weight[channel] and bias[channel] input_reshaped = op.Reshape(input, shape) norm_reshaped = op.InstanceNormalization(input_reshaped, weight_inst, bias_inst, epsilon=eps) norm = op.Reshape(norm_reshaped, op.Shape(input)) @@ -4254,6 +4256,18 @@ def _aten_native_group_norm_onnx( return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes)) +# def test_aten_native_group_norm(): +# import numpy as np +# input = (np.arange(24).reshape(2,4,3)*1.0 + 1.0).astype(np.float32) +# weight = (np.ones((4,)) * 1.0).astype(np.float32) +# bias = (np.zeros((4,)) * 1.0).astype(np.float32) +# # import torch as t +# # r = t.ops.aten.native_batch_norm(input, weight, bias, run_m, run_b, True, 0.5, 0.1) +# r = aten_native_group_norm(input, weight, bias, None, None, None, 2, 0.0) +# print(r) +# test_aten_native_group_norm() +# exit(0) + def aten_native_group_norm_backward( grad_out: TensorType, input: TensorType, diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index 5aec4b48dc..f48c506451 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -4,6 +4,7 @@ """ import functools +import itertools from typing import Any, List import torch @@ -227,6 +228,37 @@ def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, yield opinfo_core.SampleInput(arg, kwargs=kwargs) +def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial(torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases = ( + ((1, 6, 3), (6,), (6,), 1, 6, 3, {'group': 2, 'eps' : 0.5}), + # ((2, 6, 3), 2, {'group': 2, 'eps' : -0.5}), + # ((1, 3), 1, {'group': 2, 'eps' : 1e-5}), + # ((0, 2), 1, {'group': 2, 'eps' : 1e-5}), + # ((5, 5, 5), 1, {'group': 2, 'eps' : 0.5}), + ) + + for input_shape, weight, bias, N, C, HxW, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None + + yield opinfo_core.SampleInput( + make_arg(input_shape), + args=( + weight, + bias, + N, + C, + HxW, + ), + kwargs=kwargs + ) + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -316,6 +348,14 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): skips=(), supports_out=False, ), + opinfo_core.OpInfo( + "native_group_norm", + op=torch.ops.aten.native_group_norm, + aten_name="native_group_norm", + dtypes=common_dtype.floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_native_group_norm, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.conv3d", aliases=("conv3d",), diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 8a96b7dcbd..79433b2f51 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -379,17 +379,6 @@ def _mse_loss_input_wrangler( return args, kwargs -def _native_group_norm_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - kwargs["group"] = args.pop(1) # move group(int) to kwargs as attribute - args.append(kwargs["weight"]) # move weight(tensor) to args as input - args.append(kwargs["bias"]) # move bias(tensor) to args as input - del kwargs["weight"] - del kwargs["bias"] - return args, kwargs - - def _nll_loss_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -712,7 +701,7 @@ def _where_input_wrangler( "index_select": core_ops.aten_index_select, "layer_norm": core_ops.aten_layer_norm, "max": core_ops.aten_max, - "native_group_norm": (core_ops.aten_native_group_norm, _native_group_norm_input_wrangler), + "native_group_norm": core_ops.aten_native_group_norm, "native_layer_norm": core_ops.aten_native_layer_norm, "new_empty": core_ops.aten_new_empty, "new_empty_strided": core_ops.aten_new_empty_strided, @@ -1230,8 +1219,6 @@ def _where_input_wrangler( duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -duplicate_opinfo(OPS_DB, "nn.functional.group_norm", ("native_group_norm",)) - duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",)) duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",)) From f90e0953b311c7ccbaa7cbda70c6816bce6c5ac6 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 18 Apr 2023 18:51:18 +0800 Subject: [PATCH 12/20] update --- .../function_libs/torch_aten/ops/core.py | 19 ++++++------------- .../function_libs/torch_aten/extra_opinfo.py | 14 ++++++-------- .../torch_aten/ops_correctness_test.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 6269f35574..6eb3bd1f7f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4248,7 +4248,12 @@ def _aten_native_group_norm_onnx( # Using InstanceNorm to simulate GroupNorm, because GroupNorm need weight[group] and bias[group] # But the input is weight[channel] and bias[channel] input_reshaped = op.Reshape(input, shape) - norm_reshaped = op.InstanceNormalization(input_reshaped, weight_inst, bias_inst, epsilon=eps) + norm_reshaped = op.InstanceNormalization( + input_reshaped, + weight_inst, + bias_inst, + epsilon=eps + ) norm = op.Reshape(norm_reshaped, op.Shape(input)) input_rank = op.Size(op.Shape(input)) axes = op.Range(1, input_rank - 1, 1) @@ -4256,18 +4261,6 @@ def _aten_native_group_norm_onnx( return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes)) -# def test_aten_native_group_norm(): -# import numpy as np -# input = (np.arange(24).reshape(2,4,3)*1.0 + 1.0).astype(np.float32) -# weight = (np.ones((4,)) * 1.0).astype(np.float32) -# bias = (np.zeros((4,)) * 1.0).astype(np.float32) -# # import torch as t -# # r = t.ops.aten.native_batch_norm(input, weight, bias, run_m, run_b, True, 0.5, 0.1) -# r = aten_native_group_norm(input, weight, bias, None, None, None, 2, 0.0) -# print(r) -# test_aten_native_group_norm() -# exit(0) - def aten_native_group_norm_backward( grad_out: TensorType, input: TensorType, diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index f48c506451..07f677ae0b 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -4,7 +4,6 @@ """ import functools -import itertools from typing import Any, List import torch @@ -232,13 +231,12 @@ def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwa del op_info make_arg = functools.partial(torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - # Ordered as input shape, num groups, and kwargs for eps + # Ordered as input shape, C,N,HxW, and kwargs for group and eps cases = ( - ((1, 6, 3), (6,), (6,), 1, 6, 3, {'group': 2, 'eps' : 0.5}), - # ((2, 6, 3), 2, {'group': 2, 'eps' : -0.5}), - # ((1, 3), 1, {'group': 2, 'eps' : 1e-5}), - # ((0, 2), 1, {'group': 2, 'eps' : 1e-5}), - # ((5, 5, 5), 1, {'group': 2, 'eps' : 0.5}), + ((1, 6, 3), (6,), (6,), 1, 6, 3, {"group": 2, "eps": 0.5}), + ((2, 6, 3), (6,), (6,), 2, 6, 3, {"group": 3, "eps": -0.5}), + ((5, 5, 5), (5,), (5,), 5, 5, 5, {"group": 1, "eps": 1e-5}), + ((5, 8, 10), (8,), (8,), 5, 8, 10, {"group": 4, "eps": 1e-5}), ) for input_shape, weight, bias, N, C, HxW, kwargs in cases: @@ -256,7 +254,7 @@ def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwa C, HxW, ), - kwargs=kwargs + kwargs=kwargs, ) def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 52ffef8b53..d74058051e 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -811,6 +811,16 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyFullGraph", enabled_if=version_utils.onnxruntime_older_than("1.15"), ), + xfail( + "native_group_norm", + reason="fixme: ONNX InstanceNorm only return 1 output, but Torch need 3", + test_class_name="TestOutputConsistencyFullGraph", + ), + xfail( + "native_group_norm", + reason="fixme: ONNX InstanceNorm only return 1 output, but Torch need 3", + test_class_name="TestOutputConsistencyEager", + ), xfail( "new_ones", reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", From e62482208683b3c2ffce95022404f8dc647fd688 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 18 Apr 2023 18:58:56 +0800 Subject: [PATCH 13/20] update --- onnxscript/function_libs/torch_aten/ops/core.py | 5 +---- onnxscript/tests/function_libs/torch_aten/extra_opinfo.py | 4 +++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 6eb3bd1f7f..befbe64e1b 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4249,10 +4249,7 @@ def _aten_native_group_norm_onnx( # But the input is weight[channel] and bias[channel] input_reshaped = op.Reshape(input, shape) norm_reshaped = op.InstanceNormalization( - input_reshaped, - weight_inst, - bias_inst, - epsilon=eps + input_reshaped, weight_inst, bias_inst, epsilon=eps ) norm = op.Reshape(norm_reshaped, op.Shape(input)) input_rank = op.Size(op.Shape(input)) diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index 07f677ae0b..1ce04b2754 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -229,7 +229,9 @@ def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs): del op_info - make_arg = functools.partial(torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) # Ordered as input shape, C,N,HxW, and kwargs for group and eps cases = ( From 497659a99023d9d7cb8fcd3d53b96740905a0876 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 18 Apr 2023 19:04:11 +0800 Subject: [PATCH 14/20] Update extra_opinfo.py --- onnxscript/tests/function_libs/torch_aten/extra_opinfo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index 1ce04b2754..eeed97bf34 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -259,6 +259,7 @@ def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwa kwargs=kwargs, ) + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride From c83905c767a8c9389308e0dfbafd485482b2dfbf Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 19 Apr 2023 08:39:28 +0800 Subject: [PATCH 15/20] update --- .../function_libs/torch_aten/ops/core.py | 72 +++++++++++-------- .../torch_aten/ops_correctness_test.py | 5 -- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index befbe64e1b..9670fcd528 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4213,49 +4213,65 @@ def aten_native_group_norm( input: TFloat, weight: Optional[TFloat], bias: Optional[TFloat], - N: INT64 = None, # pylint: disable=unused-argument - C: INT64 = None, # pylint: disable=unused-argument - HxW: INT64 = None, # pylint: disable=unused-argument - group: int = None, - eps: float = None, + N: Optional[INT64] = None, # pylint: disable=unused-argument + C: Optional[INT64] = None, # pylint: disable=unused-argument + HxW: Optional[INT64] = None, # pylint: disable=unused-argument + group: Optional[int] = None, + eps: Optional[float] = 1e-05, ) -> Tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" - # Assert(weight is not None, and, bias is not None) - # Create weight_instance_norm and bias_instance_norm - weight_inst = op.Constant(value_floats=[1.0] * group) - bias_inst = op.Constant(value_floats=[0.0] * group) - # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need - shape = op.Constant(value_ints=[0, group, -1]) + # Actually we don't need N,C,HxW value because the input tensor has that information + if group is None: + group = 1 # Equal to LayerNorm - norm = _aten_native_group_norm_onnx( - input, weight, bias, weight_inst, bias_inst, shape, eps - ) - # weight_inst, bias_inst are fake output, because we must return 3 outpurs - return norm, weight_inst, bias_inst + if weight is None: # Set to 1.0 as default, the shape is Channel size + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default, the shape is Channel size + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + norm, fake_mean, fake_var = _aten_native_group_norm_onnx(input, weight, bias, group, eps) + # FIXME: return fake value because we must return 3 outputs(norm, mean, var) + # We know how the 'mean' was computed in Torch, but don't know the 'var' + return norm, fake_mean, fake_var @torch_op("aten::native_group_norm", private=True) def _aten_native_group_norm_onnx( input: TFloat, - weight: Optional[TFloat], - bias: Optional[TFloat], - weight_inst: TFloat, - bias_inst: TFloat, - shape: INT64, - eps: float = None, -) -> TReal: - # Using InstanceNorm to simulate GroupNorm, because GroupNorm need weight[group] and bias[group] - # But the input is weight[channel] and bias[channel] - input_reshaped = op.Reshape(input, shape) + weight: TFloat, + bias: TFloat, + group: int, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat]: + # Using InstanceNorm to simulate op.GroupNorm, because op.GroupNorm need weight[group] and bias[group] + # But the input is weight[channel] and bias[channel], the size mismatched + # Create weight_instance_norm and bias_instance_norm + shape_group = op.Reshape(op.Constant(value_int=group), op.Constant(value_ints=[-1])) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need + shape_input = op.Concat( + op.Constant(value_ints=[0]), + shape_group, + op.Constant(value_ints=[-1]), + axis=0 + ) + input_reshaped = op.Reshape(input, shape_input) + weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), shape_group) + bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), shape_group) norm_reshaped = op.InstanceNormalization( - input_reshaped, weight_inst, bias_inst, epsilon=eps + input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps ) norm = op.Reshape(norm_reshaped, op.Shape(input)) input_rank = op.Size(op.Shape(input)) axes = op.Range(1, input_rank - 1, 1) # Using the real weight and bias to computer again - return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes)) + # But need to unsqueeze to the target shape for broading cast easy + weight_full_shape = op.Unsqueeze(weight, axes) + bias_full_shape = op.Unsqueeze(bias, axes) + norm_mul_weight = op.Mul(norm, weight_full_shape) + norm_result = op.Add(norm_mul_weight, bias_full_shape) + return norm_result, weight_inst_norm, bias_inst_norm def aten_native_group_norm_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index d74058051e..54d1fdc7fe 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -1011,11 +1011,6 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - skip( - "native_group_norm", - matcher=lambda sample: len(sample.input.shape) == 2, - reason="ONNX only support input shape >= 3", - ), skip( "new_ones", matcher=lambda sample: sample.kwargs.get("dtype") is not None, From 8f68565e9767e9e2ce1cfcbb937919181c2fd0d2 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 19 Apr 2023 09:30:43 +0800 Subject: [PATCH 16/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 9670fcd528..0676887123 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4216,15 +4216,12 @@ def aten_native_group_norm( N: Optional[INT64] = None, # pylint: disable=unused-argument C: Optional[INT64] = None, # pylint: disable=unused-argument HxW: Optional[INT64] = None, # pylint: disable=unused-argument - group: Optional[int] = None, - eps: Optional[float] = 1e-05, + group: int = 1, + eps: float = 1e-05, ) -> Tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Actually we don't need N,C,HxW value because the input tensor has that information - if group is None: - group = 1 # Equal to LayerNorm - if weight is None: # Set to 1.0 as default, the shape is Channel size weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) From a8fc8f7c4addc83a815fc87b236642b2e792a10e Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 19 Apr 2023 10:03:17 +0800 Subject: [PATCH 17/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0676887123..b7bc258d66 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4211,8 +4211,8 @@ def aten_native_dropout_backward( @torch_op("aten::native_group_norm", trace_only=True) def aten_native_group_norm( input: TFloat, - weight: Optional[TFloat], - bias: Optional[TFloat], + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, N: Optional[INT64] = None, # pylint: disable=unused-argument C: Optional[INT64] = None, # pylint: disable=unused-argument HxW: Optional[INT64] = None, # pylint: disable=unused-argument From 958820bbe96b930369abea1f9f422eb46ed85016 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 20 Apr 2023 14:42:27 +0800 Subject: [PATCH 18/20] update --- .../function_libs/torch_aten/ops/core.py | 69 +++++++++++++------ .../torch_aten/ops_correctness_test.py | 10 --- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index fe19687bb5..b5e4eb4fe9 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4308,10 +4308,9 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - norm, fake_mean, fake_var = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - # FIXME: return fake value because we must return 3 outputs(norm, mean, var) - # We know how the 'mean' was computed in Torch, but don't know the 'var' - return norm, fake_mean, fake_var + # Accoding to Torch, return rstd instead of var + norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) + return norm, mean, rstd @torch_op("aten::native_group_norm", private=True) @@ -4319,36 +4318,62 @@ def _aten_native_group_norm_onnx( input: TFloat, weight: TFloat, bias: TFloat, - group: int, + group: INT64, eps: float, ) -> Tuple[TFloat, TFloat, TFloat]: - # Using InstanceNorm to simulate op.GroupNorm, because op.GroupNorm need weight[group] and bias[group] - # But the input is weight[channel] and bias[channel], the size mismatched - # Create weight_instance_norm and bias_instance_norm - shape_group = op.Reshape(op.Constant(value_int=group), op.Constant(value_ints=[-1])) - # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need + # Because onnx.GroupNorm() need size=group for weight and bias + # But the torch's aten function's input need size=channel, the size mismatched + # So we have to use onnx.InstanceNorm() to simulate + neg_1 = op.Constant(value_ints=[-1]) + # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter + group_tensor = op.Reshape(group, neg_1) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] shape_input = op.Concat( op.Constant(value_ints=[0]), - shape_group, - op.Constant(value_ints=[-1]), + group_tensor, + neg_1, axis=0 ) input_reshaped = op.Reshape(input, shape_input) - weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), shape_group) - bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), shape_group) - norm_reshaped = op.InstanceNormalization( + weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), group_tensor) + bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), group_tensor) + norm = op.InstanceNormalization( input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps ) - norm = op.Reshape(norm_reshaped, op.Shape(input)) - input_rank = op.Size(op.Shape(input)) - axes = op.Range(1, input_rank - 1, 1) - # Using the real weight and bias to computer again + # Reshape back to input's shape + norm = op.Reshape(norm, op.Shape(input)) + # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - weight_full_shape = op.Unsqueeze(weight, axes) - bias_full_shape = op.Unsqueeze(bias, axes) + input_rank = op.Size(op.Shape(input)) + axes_unsqueeze = op.Range(1, input_rank - 1, 1) + weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) + bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) norm_mul_weight = op.Mul(norm, weight_full_shape) norm_result = op.Add(norm_mul_weight, bias_full_shape) - return norm_result, weight_inst_norm, bias_inst_norm + # Compute mean and rstd, but using Torch algorithm + # The returned shape for mean and vstd should be [N, group, -1] + N = op.Shape(input, start=0, end=1) + C = op.Shape(input, start=1, end=2) + HxW = op.ReduceProd(op.Shape(input, start=2)) + shape_N_group_neg1 = op.Concat( + N, + group_tensor, + neg_1, + axis=0 + ) + input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) + axes = op.Constant(value_ints=[2]) # output size is [N, group] + # Get mean which size is [N, group, 1], for broadcasting + mean = op.ReduceMean(input_N_group_neg1, axes) + input_sub_mean = op.Sub(input_N_group_neg1, mean) + sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) + sum = op.ReduceSum(sqr_input_sub_mean, axes, keepdims=0) + # In Pytorch, vstd = 1/(sqrt(sum/n + eps)) + n = op.Cast(HxW * C / group, to=FLOAT.dtype) + rstd = op.Div(1.0, op.Sqrt(sum / n + eps)) + # Get the correct shape [N, group] for mean again + mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=0) + return norm_result, mean, rstd def aten_native_group_norm_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 1364b37191..8450aec0c2 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -822,16 +822,6 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyFullGraph", enabled_if=version_utils.onnxruntime_older_than("1.15"), ), - xfail( - "native_group_norm", - reason="fixme: ONNX InstanceNorm only return 1 output, but Torch need 3", - test_class_name="TestOutputConsistencyFullGraph", - ), - xfail( - "native_group_norm", - reason="fixme: ONNX InstanceNorm only return 1 output, but Torch need 3", - test_class_name="TestOutputConsistencyEager", - ), xfail( "new_ones", reason="fixme: ORT fails with invalid model: 'ONNX Schema aten_new_full: failed validating the check: !(it.GetName().empty())'", From b38af29de310537dac3496d84866f6a4b9e3c42d Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 20 Apr 2023 14:44:14 +0800 Subject: [PATCH 19/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index b5e4eb4fe9..c10ed770e4 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4328,12 +4328,7 @@ def _aten_native_group_norm_onnx( # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] - shape_input = op.Concat( - op.Constant(value_ints=[0]), - group_tensor, - neg_1, - axis=0 - ) + shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) input_reshaped = op.Reshape(input, shape_input) weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), group_tensor) bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), group_tensor) @@ -4355,12 +4350,7 @@ def _aten_native_group_norm_onnx( N = op.Shape(input, start=0, end=1) C = op.Shape(input, start=1, end=2) HxW = op.ReduceProd(op.Shape(input, start=2)) - shape_N_group_neg1 = op.Concat( - N, - group_tensor, - neg_1, - axis=0 - ) + shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0) input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) axes = op.Constant(value_ints=[2]) # output size is [N, group] # Get mean which size is [N, group, 1], for broadcasting From c3e3b677ee90afe023a5d558e752874679565ac6 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 20 Apr 2023 18:39:59 +0800 Subject: [PATCH 20/20] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index c10ed770e4..1e8544b4cd 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4348,19 +4348,17 @@ def _aten_native_group_norm_onnx( # Compute mean and rstd, but using Torch algorithm # The returned shape for mean and vstd should be [N, group, -1] N = op.Shape(input, start=0, end=1) - C = op.Shape(input, start=1, end=2) - HxW = op.ReduceProd(op.Shape(input, start=2)) shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0) input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) - axes = op.Constant(value_ints=[2]) # output size is [N, group] + # The output size is [N, group], so dims = [2] + axes = op.Constant(value_ints=[2]) # Get mean which size is [N, group, 1], for broadcasting mean = op.ReduceMean(input_N_group_neg1, axes) input_sub_mean = op.Sub(input_N_group_neg1, mean) sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) - sum = op.ReduceSum(sqr_input_sub_mean, axes, keepdims=0) - # In Pytorch, vstd = 1/(sqrt(sum/n + eps)) - n = op.Cast(HxW * C / group, to=FLOAT.dtype) - rstd = op.Div(1.0, op.Sqrt(sum / n + eps)) + # In Pytorch, vstd = 1/(sqrt(var + eps)) + var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=0) + rstd = op.Div(1.0, op.Sqrt(var + eps)) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=0) return norm_result, mean, rstd