Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,19 +812,6 @@ def aten_atleast_3d_single_tensor(self: TTensor) -> TTensor:
return self


def aten_avg_pool1d(
self: TensorType,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
padding: Sequence[int] = (0,),
ceil_mode: bool = False,
count_include_pad: bool = True,
) -> TensorType:
"""avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor"""

raise NotImplementedError()


@torch_op("aten::baddbmm")
def aten_baddbmm(
self: TRealOrUInt8,
Expand Down
133 changes: 104 additions & 29 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,70 @@ def aten_adaptive_max_pool3d_backward(
raise NotImplementedError()


def _adjust_attributes_of_avg_pool(
expand_size: int,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]:
"""Adjust attributes of avg_pool to match ONNX specification."""

if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else:
kernel_shape = kernel_size

if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1:
pads = padding * expand_size * 2
elif len(padding) == 2:
pads = padding * expand_size
else:
pads = padding * 2

if isinstance(stride, int):
strides = [stride] * expand_size
elif not stride:
strides = kernel_shape
else:
strides = stride

return (kernel_shape, strides, pads)


@torch_op("aten::avg_pool1d", trace_only=True)
def aten_avg_pool1d(
self: TFloat,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0,),
ceil_mode: bool = False,
count_include_pad: bool = True,
) -> TFloat:
"""avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor"""

# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 1

kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)

result = op.AveragePool(
self,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
kernel_shape=kernel_shape,
pads=pads,
strides=strides,
)

return result


@torch_op("aten::avg_pool2d", trace_only=True)
def aten_avg_pool2d(
self: TFloat,
Expand All @@ -144,30 +208,9 @@ def aten_avg_pool2d(
# For pool3d, this number should be 3
expand_size = 2

# The kernel_shape should be [x, y]
if isinstance(kernel_size, int): # x -> [x, x]
kernel_shape = [kernel_size] * expand_size
else: # assert(len(kernel_size)==2), already [x, y]
kernel_shape = kernel_size

# The pads should be [w, x, y, z]
if isinstance(padding, int): # w -> [w, w, w, w]
pads = [padding] * expand_size * 2
elif len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding

# The strides should be [x, y]
if isinstance(stride, int): # x -> [x, x]
strides = [stride] * expand_size
elif not stride:
# stride is empty
strides = kernel_shape
else:
strides = stride
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)

result = op.AveragePool(
self,
Expand Down Expand Up @@ -209,18 +252,50 @@ def aten_avg_pool2d_backward(
raise NotImplementedError()


@torch_op("aten::avg_pool3d", trace_only=True)
def aten_avg_pool3d(
self: TensorType,
self: TFloat,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TensorType:
divisor_override: Optional[int] = None, # pylint: disable=unused-argument
) -> TFloat:
"""avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"""

raise NotImplementedError()
# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 3

kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
expand_size, kernel_size, stride, padding
)

result = op.AveragePool(
self,
kernel_shape=kernel_shape,
strides=strides,
pads=pads,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)

# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
# mask = [
# 1, 2, 3, S,..3, 2, 1
# 2, 4, 6, 2S, 6, 4, 2
# 3, 6, 9, 3S, 9, 6, 3
# S, 2S,3S,SS,3S,2S, S
# 3, 6, 9, 3S, 9, 6, 3
# 2, 4, 6, 2S, 6, 4, 2
# 1, 2, 3, S,..3, 2, 1
# ]
# S is stride size, in this case S=4,
# S may dup lot of times according to the image size

return result


def aten_avg_pool3d_backward(
Expand Down
52 changes: 45 additions & 7 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _amin_amax_input_wrangler(
return args, kwargs


def _avg_pool2d_input_wrangler(
def _avg_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
Expand All @@ -197,10 +197,11 @@ def _avg_pool2d_input_wrangler(
# Cannot using list(padding) here, because the element will be numpy.int64 instead of int
padding = padding.tolist()
kwargs["padding"] = padding
stride = args.pop(2)
if isinstance(stride, np.ndarray):
stride = stride.tolist()
kwargs["stride"] = stride
if len(args) > 2:
stride = args.pop(2)
if isinstance(stride, np.ndarray):
stride = stride.tolist()
kwargs["stride"] = stride
kernel_size = args.pop(1)
if isinstance(kernel_size, np.ndarray):
kernel_size = kernel_size.tolist()
Expand Down Expand Up @@ -1393,15 +1394,52 @@ def _where_input_wrangler(
trace_only=True,
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
),
TorchLibOpInfo(
"nn.functional.avg_pool1d",
nn_ops.aten_avg_pool1d,
input_wrangler=_avg_pool_input_wrangler,
trace_only=True,
)
.xfail(
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
or (sample.kwargs.get("divisor_override") is not None),
reason="ONNX doesn't support divisor_override argument",
)
.xfail(
matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True)
and (
sample.kwargs.get("count_include_pad") is True
or sample.input.shape[2]
% (sample.args[0][0] if isinstance(sample.args[0], tuple) else sample.args[0])
!= 0
),
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
),
TorchLibOpInfo(
"nn.functional.avg_pool2d",
nn_ops.aten_avg_pool2d,
input_wrangler=_avg_pool2d_input_wrangler,
input_wrangler=_avg_pool_input_wrangler,
trace_only=True,
).xfail(
matcher=lambda sample: len(sample.args) > 5 and sample.args[5] is not None,
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
or (sample.kwargs.get("divisor_override") is not None),
reason="ONNX doesn't support divisor_override argument",
),
TorchLibOpInfo(
"nn.functional.avg_pool3d",
nn_ops.aten_avg_pool3d,
input_wrangler=_avg_pool_input_wrangler,
trace_only=True,
)
.xfail(
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
or (sample.kwargs.get("divisor_override") is not None),
reason="ONNX doesn't support divisor_override argument",
)
.xfail(
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True,
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
),
TorchLibOpInfo(
"nn.functional.conv1d",
core_ops.aten_conv1d,
Expand Down