From fa3290d28b0693121444c1b3b558eed8e737d9f9 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 30 May 2023 23:46:03 +0000 Subject: [PATCH 1/8] add draft --- .../function_libs/torch_lib/ops/core.py | 13 ---- onnxscript/function_libs/torch_lib/ops/nn.py | 66 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d9b576a3f6..66039b336d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -684,19 +684,6 @@ def aten_atleast_3d(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_avg_pool1d( - self: TensorType, - kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, - padding: Sequence[int] = (0,), - ceil_mode: bool = False, - count_include_pad: bool = True, -) -> TensorType: - """avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::baddbmm") def aten_baddbmm( self: TrealOrUInt8, diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index c9c5a3bf07..5968d02789 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -125,6 +125,17 @@ def aten_adaptive_max_pool3d_backward( raise NotImplementedError() +def aten_avg_pool1d( + self: TensorType, + kernel_size: Sequence[int], + stride: Optional[Sequence[int]] = None, + padding: Sequence[int] = (0,), + ceil_mode: bool = False, + count_include_pad: bool = True, +) -> TensorType: + """avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor""" + + raise NotImplementedError() @torch_op("aten::avg_pool2d", trace_only=True) def aten_avg_pool2d( @@ -206,7 +217,7 @@ def aten_avg_pool2d_backward( raise NotImplementedError() - +@torch_op("aten::avg_pool3d", trace_only=True) def aten_avg_pool3d( self: TensorType, kernel_size: Sequence[int], @@ -218,7 +229,58 @@ def aten_avg_pool3d( ) -> TensorType: """avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" - raise NotImplementedError() + # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly + # But ONNX needs pair number [x,y] to specify on each side explicitly + # For pool3d, this number should be 3 + expand_size = 3 + + # The kernel_shape should be [x, y] + if isinstance(kernel_size, int): # x -> [x, x] + kernel_shape = [kernel_size] * expand_size + else: # assert(len(kernel_size)==2), already [x, y] + kernel_shape = kernel_size + + # The pads should be [w, x, y, z] + if isinstance(padding, int): # w -> [w, w, w, w] + pads = [padding] * expand_size * 2 + elif len(padding) == 1: # [w] -> [w, w, w, w] + pads = padding * 4 + elif len(padding) == 2: # [w, x] -> [w, x, w, x] + pads = padding * 2 + else: # assert len(padding) == 4, already [w, x, y, z] + pads = padding + + # The strides should be [x, y] + if isinstance(stride, int): # x -> [x, x] + strides = [stride] * expand_size + elif stride is None: + strides = kernel_shape + else: + strides = stride + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) + # mask = [ + # 1, 2, 3, S,..3, 2, 1 + # 2, 4, 6, 2S, 6, 4, 2 + # 3, 6, 9, 3S, 9, 6, 3 + # S, 2S,3S,SS,3S,2S, S + # 3, 6, 9, 3S, 9, 6, 3 + # 2, 4, 6, 2S, 6, 4, 2 + # 1, 2, 3, S,..3, 2, 1 + # ] + # S is stride size, in this case S=4, + # S may dup lot of times according to the image size + + return result def aten_avg_pool3d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index bea709d4e1..c7e4c4cd2d 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -538,6 +538,7 @@ def _where_input_wrangler( "native_group_norm": core_ops.aten_native_group_norm, "native_layer_norm": core_ops.aten_native_layer_norm, "nn.functional.avg_pool2d": (nn_ops.aten_avg_pool2d, _avg_pool2d_input_wrangler), + "nn.functional.avg_pool3d": (nn_ops.aten_avg_pool3d, _avg_pool2d_input_wrangler), "nn.functional.conv1d": core_ops.aten_conv1d, "nn.functional.conv2d": core_ops.aten_conv2d, "nn.functional.conv3d": core_ops.aten_conv3d, From 251489dff4668d9245a052ba69ae87e0eedbb673 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 31 May 2023 01:10:09 +0000 Subject: [PATCH 2/8] Use avg_pool2d format --- onnxscript/function_libs/torch_lib/ops/nn.py | 135 +++++++++++------- .../function_libs/torch_lib/ops_test_data.py | 33 +++-- 2 files changed, 106 insertions(+), 62 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 5968d02789..8e6ddf5c01 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -125,17 +125,83 @@ def aten_adaptive_max_pool3d_backward( raise NotImplementedError() + +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: Sequence[int], +) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 + elif len(padding) == 2: + pads = padding * expand_size + else: + pads = padding + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride + + return (kernel_shape, strides, pads) + + +@torch_op("aten::avg_pool1d", trace_only=True) def aten_avg_pool1d( - self: TensorType, + self: TFloat, kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, + stride: Sequence[int] = (), padding: Sequence[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True, -) -> TensorType: +) -> TFloat: """avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor""" - raise NotImplementedError() + # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly + # But ONNX needs pair number [x,y] to specify on each side explicitly + # For pool3d, this number should be 3 + expand_size = 1 + + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) + # mask = [ + # 1, 2, 3, S,..3, 2, 1 + # 2, 4, 6, 2S, 6, 4, 2 + # 3, 6, 9, 3S, 9, 6, 3 + # S, 2S,3S,SS,3S,2S, S + # 3, 6, 9, 3S, 9, 6, 3 + # 2, 4, 6, 2S, 6, 4, 2 + # 1, 2, 3, S,..3, 2, 1 + # ] + # S is stride size, in this case S=4, + # S may dup lot of times according to the image size + + return result + @torch_op("aten::avg_pool2d", trace_only=True) def aten_avg_pool2d( @@ -154,29 +220,9 @@ def aten_avg_pool2d( # For pool3d, this number should be 3 expand_size = 2 - # The kernel_shape should be [x, y] - if isinstance(kernel_size, int): # x -> [x, x] - kernel_shape = [kernel_size] * expand_size - else: # assert(len(kernel_size)==2), already [x, y] - kernel_shape = kernel_size - - # The pads should be [w, x, y, z] - if isinstance(padding, int): # w -> [w, w, w, w] - pads = [padding] * expand_size * 2 - elif len(padding) == 1: # [w] -> [w, w, w, w] - pads = padding * 4 - elif len(padding) == 2: # [w, x] -> [w, x, w, x] - pads = padding * 2 - else: # assert len(padding) == 4, already [w, x, y, z] - pads = padding - - # The strides should be [x, y] - if isinstance(stride, int): # x -> [x, x] - strides = [stride] * expand_size - elif stride is None: - strides = kernel_shape - else: - strides = stride + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) result = op.AveragePool( self, @@ -217,16 +263,17 @@ def aten_avg_pool2d_backward( raise NotImplementedError() + @torch_op("aten::avg_pool3d", trace_only=True) def aten_avg_pool3d( - self: TensorType, + self: TFloat, kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, + stride: Sequence[int] = (), padding: Sequence[int] = (0, 0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, -) -> TensorType: + divisor_override: Optional[int] = None, # pylint: disable=unused-argument +) -> TFloat: """avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly @@ -234,29 +281,9 @@ def aten_avg_pool3d( # For pool3d, this number should be 3 expand_size = 3 - # The kernel_shape should be [x, y] - if isinstance(kernel_size, int): # x -> [x, x] - kernel_shape = [kernel_size] * expand_size - else: # assert(len(kernel_size)==2), already [x, y] - kernel_shape = kernel_size - - # The pads should be [w, x, y, z] - if isinstance(padding, int): # w -> [w, w, w, w] - pads = [padding] * expand_size * 2 - elif len(padding) == 1: # [w] -> [w, w, w, w] - pads = padding * 4 - elif len(padding) == 2: # [w, x] -> [w, x, w, x] - pads = padding * 2 - else: # assert len(padding) == 4, already [w, x, y, z] - pads = padding - - # The strides should be [x, y] - if isinstance(stride, int): # x -> [x, x] - strides = [stride] * expand_size - elif stride is None: - strides = kernel_shape - else: - strides = stride + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) result = op.AveragePool( self, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c7e4c4cd2d..f6773128d7 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -71,7 +71,7 @@ def _amin_amax_input_wrangler( return args, kwargs -def _avg_pool2d_input_wrangler( +def _avg_pool_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: if "dim" not in kwargs: @@ -87,10 +87,11 @@ def _avg_pool2d_input_wrangler( # Cannot using list(padding) here, because the element will be numpy.int64 instead of int padding = padding.tolist() kwargs["padding"] = padding - stride = args.pop(2) - if isinstance(stride, np.ndarray): - stride = stride.tolist() - kwargs["stride"] = stride + if len(args) > 2: + stride = args.pop(2) + if isinstance(stride, np.ndarray): + stride = stride.tolist() + kwargs["stride"] = stride kernel_size = args.pop(1) if isinstance(kernel_size, np.ndarray): kernel_size = kernel_size.tolist() @@ -537,8 +538,9 @@ def _where_input_wrangler( "native_batch_norm": core_ops.aten_native_batch_norm, "native_group_norm": core_ops.aten_native_group_norm, "native_layer_norm": core_ops.aten_native_layer_norm, - "nn.functional.avg_pool2d": (nn_ops.aten_avg_pool2d, _avg_pool2d_input_wrangler), - "nn.functional.avg_pool3d": (nn_ops.aten_avg_pool3d, _avg_pool2d_input_wrangler), + "nn.functional.avg_pool1d": (nn_ops.aten_avg_pool1d, _avg_pool_input_wrangler), + "nn.functional.avg_pool2d": (nn_ops.aten_avg_pool2d, _avg_pool_input_wrangler), + "nn.functional.avg_pool3d": (nn_ops.aten_avg_pool3d, _avg_pool_input_wrangler), "nn.functional.conv1d": core_ops.aten_conv1d, "nn.functional.conv2d": core_ops.aten_conv2d, "nn.functional.conv3d": core_ops.aten_conv3d, @@ -949,7 +951,14 @@ def _where_input_wrangler( ), xfail( "nn.functional.avg_pool2d", - matcher=lambda sample: len(sample.args) > 5 and sample.args[5] is not None, + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or sample.kwargs.get("divisor_override") is not None, + reason="ONNX doesn't support divisor_override argument", + ), + xfail( + "nn.functional.avg_pool3d", + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or sample.kwargs.get("divisor_override") is not None, reason="ONNX doesn't support divisor_override argument", ), xfail( @@ -1864,10 +1873,18 @@ def _where_input_wrangler( torch.float32, # torch.float16, # FIXME: ORT inference error GlobalAveragePool ), + "nn.functional.avg_pool1d": ( + torch.float32, + torch.float16, + ), "nn.functional.avg_pool2d": ( torch.float32, torch.float16, ), + "nn.functional.avg_pool3d": ( + torch.float32, + torch.float16, + ), "nn.functional.celu": ( torch.float32, torch.float16, From f98aaaa8630e25eb8602febe487f77452395a693 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 31 May 2023 19:31:50 +0000 Subject: [PATCH 3/8] adjust padding --- onnxscript/function_libs/torch_lib/ops/nn.py | 8 ++++---- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 7651e1141a..1c45680aaa 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -146,7 +146,7 @@ def _adjust_attributes_of_avg_pool( elif len(padding) == 2: pads = padding * expand_size else: - pads = padding + pads = padding * 2 if isinstance(stride, int): strides = [stride] * expand_size @@ -287,11 +287,11 @@ def aten_avg_pool3d( result = op.AveragePool( self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, kernel_shape=kernel_shape, - pads=pads, strides=strides, + pads=pads, + count_include_pad=count_include_pad, + ceil_mode=ceil_mode, ) # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index f6773128d7..c802ba53d5 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -949,6 +949,11 @@ def _where_input_wrangler( matcher=lambda sample: sample.args[0] != (1, 1, 1), reason="only global pooling is supported; only batched inputs are supported", ), + # xfail( + # "nn.functional.avg_pool1d", + # matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, + # reason="ONNX has different ciel mode strategy to PyTorch", + # ), xfail( "nn.functional.avg_pool2d", matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) From 4d8b5b8673ab7cbda935ad3ac715eabd5e9b67af Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 1 Jun 2023 21:14:03 +0000 Subject: [PATCH 4/8] address comment --- onnxscript/function_libs/torch_lib/ops/nn.py | 13 ------------- .../tests/function_libs/torch_lib/ops_test_data.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 1c45680aaa..37dd87509c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -187,19 +187,6 @@ def aten_avg_pool1d( strides=strides, ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) - # mask = [ - # 1, 2, 3, S,..3, 2, 1 - # 2, 4, 6, 2S, 6, 4, 2 - # 3, 6, 9, 3S, 9, 6, 3 - # S, 2S,3S,SS,3S,2S, S - # 3, 6, 9, 3S, 9, 6, 3 - # 2, 4, 6, 2S, 6, 4, 2 - # 1, 2, 3, S,..3, 2, 1 - # ] - # S is stride size, in this case S=4, - # S may dup lot of times according to the image size - return result diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c802ba53d5..0789250198 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -952,7 +952,7 @@ def _where_input_wrangler( # xfail( # "nn.functional.avg_pool1d", # matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, - # reason="ONNX has different ciel mode strategy to PyTorch", + # reason="ONNX has different ceil mode strategy to PyTorch", # ), xfail( "nn.functional.avg_pool2d", From d6bf99b70d46606fb326fa10f074a80c8d5945a7 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 13 Jul 2023 16:04:01 +0000 Subject: [PATCH 5/8] fix merged conflict --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c2d94b9124..e8ac4e0637 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1338,7 +1338,7 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.avg_pool2d", nn_ops.aten_avg_pool2d, - input_wrangler=_avg_pool2d_input_wrangler, + input_wrangler=_avg_pool_input_wrangler, trace_only=True, ).xfail( matcher=lambda sample: len(sample.args) > 5 and sample.args[5] is not None, From 4e698cab9478bdad5ee2fbd58227f0aa65fb67e2 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 18 Jul 2023 16:36:19 +0000 Subject: [PATCH 6/8] add tests --- .../function_libs/torch_lib/ops_test_data.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 998b8ee175..bdc83070ca 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1394,13 +1394,34 @@ def _where_input_wrangler( trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), + TorchLibOpInfo( + "nn.functional.avg_pool1d", + nn_ops.aten_avg_pool1d, + input_wrangler=_avg_pool_input_wrangler, + trace_only=True, + ).xfail( + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ), TorchLibOpInfo( "nn.functional.avg_pool2d", nn_ops.aten_avg_pool2d, input_wrangler=_avg_pool_input_wrangler, trace_only=True, ).xfail( - matcher=lambda sample: len(sample.args) > 5 and sample.args[5] is not None, + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ), + TorchLibOpInfo( + "nn.functional.avg_pool3d", + nn_ops.aten_avg_pool3d, + input_wrangler=_avg_pool_input_wrangler, + trace_only=True, + ).xfail( + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), reason="ONNX doesn't support divisor_override argument", ), TorchLibOpInfo( From 769675b88ac89184af1e86899d6b506b34bacfad Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 18 Jul 2023 22:18:54 +0000 Subject: [PATCH 7/8] add xfail on ceil_mode --- .../tests/function_libs/torch_lib/ops_test_data.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index bdc83070ca..88cb78bd7f 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1399,10 +1399,15 @@ def _where_input_wrangler( nn_ops.aten_avg_pool1d, input_wrangler=_avg_pool_input_wrangler, trace_only=True, - ).xfail( + ) + .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) or (sample.kwargs.get("divisor_override") is not None), reason="ONNX doesn't support divisor_override argument", + ) + .xfail( + matcher=lambda sample: sample.kwargs.get("ceil_mode", True), + reason="ONNXRUNTIME doesn't match PyTorch when ceil_mode=True until opset 19", ), TorchLibOpInfo( "nn.functional.avg_pool2d", @@ -1419,10 +1424,15 @@ def _where_input_wrangler( nn_ops.aten_avg_pool3d, input_wrangler=_avg_pool_input_wrangler, trace_only=True, - ).xfail( + ) + .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) or (sample.kwargs.get("divisor_override") is not None), reason="ONNX doesn't support divisor_override argument", + ) + .xfail( + matcher=lambda sample: sample.kwargs.get("ceil_mode", True), + reason="ONNXRUNTIME doesn't match PyTorch when ceil_mode=True until opset 19", ), TorchLibOpInfo( "nn.functional.conv1d", From a425eb901b81fab76445237f9dde5dc2d1881b4e Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 18 Jul 2023 23:04:59 +0000 Subject: [PATCH 8/8] update and have more specific xfail --- .../tests/function_libs/torch_lib/ops_test_data.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 88cb78bd7f..ffe3da3a23 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1406,8 +1406,14 @@ def _where_input_wrangler( reason="ONNX doesn't support divisor_override argument", ) .xfail( - matcher=lambda sample: sample.kwargs.get("ceil_mode", True), - reason="ONNXRUNTIME doesn't match PyTorch when ceil_mode=True until opset 19", + matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) + and ( + sample.kwargs.get("count_include_pad") is True + or sample.input.shape[2] + % (sample.args[0][0] if isinstance(sample.args[0], tuple) else sample.args[0]) + != 0 + ), + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", ), TorchLibOpInfo( "nn.functional.avg_pool2d", @@ -1431,8 +1437,8 @@ def _where_input_wrangler( reason="ONNX doesn't support divisor_override argument", ) .xfail( - matcher=lambda sample: sample.kwargs.get("ceil_mode", True), - reason="ONNXRUNTIME doesn't match PyTorch when ceil_mode=True until opset 19", + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", ), TorchLibOpInfo( "nn.functional.conv1d",