diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index c2db97849c..9e50e73c9a 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4299,19 +4299,80 @@ def aten_native_dropout_backward( raise NotImplementedError() +@torch_op("aten::native_group_norm", trace_only=True) def aten_native_group_norm( - input: TensorType, - weight: Optional[TensorType], - bias: Optional[TensorType], - N: INT64, - C: INT64, - HxW: INT64, - group: int, - eps: float, -) -> tuple[TensorType, TensorType, TensorType]: + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + N: Optional[INT64] = None, # pylint: disable=unused-argument + C: Optional[INT64] = None, # pylint: disable=unused-argument + HxW: Optional[INT64] = None, # pylint: disable=unused-argument + group: int = 1, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" - raise NotImplementedError() + # Actually we don't need N,C,HxW value because the input tensor has that information + if weight is None: # Set to 1.0 as default, the shape is Channel size + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default, the shape is Channel size + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + # Accoding to Torch, return rstd instead of var + norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) + return norm, mean, rstd + + +@torch_op("aten::native_group_norm", private=True) +def _aten_native_group_norm_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + group: INT64, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat]: + # Because onnx.GroupNorm() need size=group for weight and bias + # But the torch's aten function's input need size=channel, the size mismatched + # So we have to use onnx.InstanceNorm() to simulate + neg_1 = op.Constant(value_ints=[-1]) + # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter + group_tensor = op.Reshape(group, neg_1) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] + shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) + input_reshaped = op.Reshape(input, shape_input) + weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), group_tensor) + bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), group_tensor) + norm = op.InstanceNormalization( + input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps + ) + # Reshape back to input's shape + norm = op.Reshape(norm, op.Shape(input)) + # Using the input weight and bias to do affine + # But need to unsqueeze to the target shape for broading cast easy + input_rank = op.Size(op.Shape(input)) + axes_unsqueeze = op.Range(1, input_rank - 1, 1) + weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) + bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + norm_mul_weight = op.Mul(norm, weight_full_shape) + norm_result = op.Add(norm_mul_weight, bias_full_shape) + # Compute mean and rstd, but using Torch algorithm + # The returned shape for mean and vstd should be [N, group, -1] + N = op.Shape(input, start=0, end=1) + shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0) + input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1) + # The output size is [N, group], so dims = [2] + axes = op.Constant(value_ints=[2]) + # Get mean which size is [N, group, 1], for broadcasting + mean = op.ReduceMean(input_N_group_neg1, axes) + input_sub_mean = op.Sub(input_N_group_neg1, mean) + sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) + # In Pytorch, vstd = 1/(sqrt(var + eps)) + var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=0) + rstd = op.Div(1.0, op.Sqrt(var + eps)) + # Get the correct shape [N, group] for mean again + mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=0) + return norm_result, mean, rstd def aten_native_group_norm_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index 10f4ec2ed2..6d116aaf7c 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -317,6 +317,39 @@ def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, yield opinfo_core.SampleInput(arg, kwargs=kwargs) +def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # Ordered as input shape, C,N,HxW, and kwargs for group and eps + cases = ( + ((1, 6, 3), (6,), (6,), 1, 6, 3, {"group": 2, "eps": 0.5}), + ((2, 6, 3), (6,), (6,), 2, 6, 3, {"group": 3, "eps": -0.5}), + ((5, 5, 5), (5,), (5,), 5, 5, 5, {"group": 1, "eps": 1e-5}), + ((5, 8, 10), (8,), (8,), 5, 8, 10, {"group": 4, "eps": 1e-5}), + ) + + for input_shape, weight, bias, N, C, HxW, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None + + yield opinfo_core.SampleInput( + make_arg(input_shape), + args=( + weight, + bias, + N, + C, + HxW, + ), + kwargs=kwargs, + ) + + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -406,6 +439,14 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): skips=(), supports_out=False, ), + opinfo_core.OpInfo( + "native_group_norm", + op=torch.ops.aten.native_group_norm, + aten_name="native_group_norm", + dtypes=common_dtype.floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_native_group_norm, + supports_out=False, + ), opinfo_core.OpInfo( "max_pool2d", variant_test_name="empty_strides", diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index db42b0c90b..6258f61a4b 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -704,6 +704,7 @@ def _where_input_wrangler( "max_pool2d": nn_ops.aten_max_pool2d, # Custom from extra_opinfo "max_pool3d": nn_ops.aten_max_pool3d, # Custom from extra_opinfo "native_batch_norm": core_ops.aten_native_batch_norm, + "native_group_norm": core_ops.aten_native_group_norm, "native_layer_norm": core_ops.aten_native_layer_norm, "new_empty": core_ops.aten_new_empty, "new_empty_strided": core_ops.aten_new_empty_strided,