From 115f56d5b2e2e84b6e3679c95ca40ce2858aae7c Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 1 Jun 2023 23:19:14 +0000 Subject: [PATCH 1/9] Add op (hsatck) | feat (torchlib) --- .../function_libs/torch_lib/ops/core.py | 19 +++++++++++++++++-- .../function_libs/torch_lib/ops_test_data.py | 5 +++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d9b576a3f6..cf71cc4495 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2811,10 +2811,25 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() -def aten_hstack(tensors: Sequence[TensorType]) -> TensorType: +@torch_op("aten::hstack") +def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" - raise NotImplementedError() + # TODO: Due to lack of for loop, we couldn't use at::atleast_1d + # and examine the first tensor dim like torch did in their implementation. + + # In PyTorch: + # Tensor hstack(TensorList tensors) { + # TORCH_CHECK(!tensors.empty(), + # "hstack expects a non-empty TensorList"); + # auto rep = at::atleast_1d(tensors); + # if (rep[0].dim() == 1) { + # return at::cat(rep, 0); + # } + # return at::cat(rep, 1); + # } + + return op.ConcatFromSequence(tensors, axis=-1, new_axis=0) def aten_hypot(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index fc9c1eda26..c401159612 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -382,6 +382,7 @@ def _where_input_wrangler( # "greater_equal": core_ops.aten_greater_equal, # no test case in OPS_DB # "greater": core_ops.aten_greater, # no test case in OPS_DB "gt": core_ops.aten_gt, + "hstack": core_ops.aten_hstack, # "is_same_size": core_ops.aten_is_same_size, # no test case in OPS_DB # "is_nonzero": core_ops.aten_is_nonzero, # no test case in OPS_DB "index_put_bool": core_ops.aten_index_put_bool, @@ -1647,6 +1648,10 @@ def _where_input_wrangler( torch.float32, torch.float16, ), + "hstack": ( + torch.float32, + torch.float16, + ), # "is_same_size": core_ops.aten_is_same_size, # no test case in OPS_DB # "is_nonzero": core_ops.aten_is_nonzero, # no test case in OPS_DB "index_put_bool": ( From 10a1c1adae62ec7fb2d922e7041423a0cff507b0 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Fri, 2 Jun 2023 19:27:08 +0000 Subject: [PATCH 2/9] add logic to check first element --- onnxscript/function_libs/torch_lib/ops/core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cf71cc4495..900108516e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2829,7 +2829,12 @@ def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: # return at::cat(rep, 1); # } - return op.ConcatFromSequence(tensors, axis=-1, new_axis=0) + first_tensor = op.SequenceAt(tensors, 0) + if op.Size(op.Shape(first_tensor)) == 1: + result = op.ConcatFromSequence(tensors, axis=0, new_axis=0) + else: + result = op.ConcatFromSequence(tensors, axis=1, new_axis=0) + return result def aten_hypot(self: TensorType, other: TensorType) -> TensorType: From 5c311b50e50752180cb3a7fa86bda94dd2d16f81 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Fri, 2 Jun 2023 19:58:27 +0000 Subject: [PATCH 3/9] add vstack --- onnxscript/function_libs/torch_lib/ops/core.py | 14 ++++++++++++-- .../function_libs/torch_lib/ops_test_common.py | 2 ++ .../tests/function_libs/torch_lib/ops_test_data.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 900108516e..bd6cb8d12d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6552,10 +6552,20 @@ def aten_view_copy(self: TensorType, size: INT64) -> TensorType: raise NotImplementedError() -def aten_vstack(tensors: Sequence[TensorType]) -> TensorType: +@torch_op("aten::vstack") +def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: """vstack(Tensor[] tensors) -> Tensor""" - raise NotImplementedError() + # TODO: Support at_least2d + + # Tensor vstack(TensorList tensors) { + # TORCH_CHECK(!tensors.empty(), + # "vstack expects a non-empty TensorList"); + # auto rep = at::atleast_2d(tensors); + # return at::cat(rep, 0); + # } + + return op.ConcatFromSequence(tensors, axis=0) @torch_op("aten::where") diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index ba2b2f9695..4e776309be 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -485,6 +485,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, else: onnxscript_kwargs[key] = value + # print("onnxscript_args", onnxscript_args) + with onnxscript.evaluator.default_as(tracer): symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs) if not isinstance(symbolic_outputs, Sequence): diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c401159612..f3c81f6629 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -505,6 +505,7 @@ def _where_input_wrangler( "unflatten": (core_ops.aten_unflatten, _unflatten_input_wrangler), "unsqueeze": core_ops.aten_unsqueeze, "view": core_ops.aten_view, + "vstack": core_ops.aten_vstack, "where": (core_ops.aten_where, _where_input_wrangler), "xlogy": special_ops.aten_special_xlogy, "zeros": core_ops.aten_zeros, @@ -1152,6 +1153,11 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, reason="this Aten overload only support when correction attribute exists", ), + xfail( + "vstack", + matcher=lambda sample: len(sample.input[0].shape) < 2, + reason="fixme: Need aten::at_least2d supported", + ), ) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",)) @@ -2231,6 +2237,10 @@ def _where_input_wrangler( torch.float32, torch.float16, ), + "vstack": ( + torch.float32, + torch.float16, + ), "where": ( torch.float32, torch.float16, From 7a70e3241bfbbfcc26dca3b13e9403bd27b86594 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Fri, 2 Jun 2023 19:59:53 +0000 Subject: [PATCH 4/9] remove print --- onnxscript/tests/function_libs/torch_lib/ops_test_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 4e776309be..ba2b2f9695 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -485,8 +485,6 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, else: onnxscript_kwargs[key] = value - # print("onnxscript_args", onnxscript_args) - with onnxscript.evaluator.default_as(tracer): symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs) if not isinstance(symbolic_outputs, Sequence): From 879ef75cf9dceb70f4c413fd17d6a7f4260c01c0 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Mon, 5 Jun 2023 15:46:25 +0000 Subject: [PATCH 5/9] trace only the hstack --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index bd6cb8d12d..cfbaa8345d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2811,7 +2811,7 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::hstack") +@torch_op("aten::hstack", trace_only=True) def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" @@ -2829,8 +2829,7 @@ def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: # return at::cat(rep, 1); # } - first_tensor = op.SequenceAt(tensors, 0) - if op.Size(op.Shape(first_tensor)) == 1: + if len(tensors[0].shape) == 1: result = op.ConcatFromSequence(tensors, axis=0, new_axis=0) else: result = op.ConcatFromSequence(tensors, axis=1, new_axis=0) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index f3c81f6629..53b1e133c9 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -382,7 +382,6 @@ def _where_input_wrangler( # "greater_equal": core_ops.aten_greater_equal, # no test case in OPS_DB # "greater": core_ops.aten_greater, # no test case in OPS_DB "gt": core_ops.aten_gt, - "hstack": core_ops.aten_hstack, # "is_same_size": core_ops.aten_is_same_size, # no test case in OPS_DB # "is_nonzero": core_ops.aten_is_nonzero, # no test case in OPS_DB "index_put_bool": core_ops.aten_index_put_bool, @@ -529,6 +528,7 @@ def _where_input_wrangler( "convolution": core_ops.aten_convolution, "empty_like": core_ops.aten_empty_like, "grid_sampler_2d": core_ops.aten_grid_sampler_2d, + "hstack": core_ops.aten_hstack, "nn.functional.grid_sample": (core_ops.aten_grid_sampler, _grid_sample_input_wrangler), "index_select": core_ops.aten_index_select, "layer_norm": core_ops.aten_layer_norm, From 06b3de23b092e8a0b3928adc8d5601548c0421ab Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 7 Jun 2023 20:33:09 +0000 Subject: [PATCH 6/9] add atleast_Nd into hstack/vstack --- .../function_libs/torch_lib/ops/core.py | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d79853ba01..7129092ee3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2879,20 +2879,11 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType: def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" - # TODO: Due to lack of for loop, we couldn't use at::atleast_1d - # and examine the first tensor dim like torch did in their implementation. - - # In PyTorch: - # Tensor hstack(TensorList tensors) { - # TORCH_CHECK(!tensors.empty(), - # "hstack expects a non-empty TensorList"); - # auto rep = at::atleast_1d(tensors); - # if (rep[0].dim() == 1) { - # return at::cat(rep, 0); - # } - # return at::cat(rep, 1); - # } + # Use another onnx function + tensors = aten_atleast_1d(tensors) + # NOTE: The if/else graph has different shape/type which breaks the + # graph matching. We need to use trace_only. if len(tensors[0].shape) == 1: result = op.ConcatFromSequence(tensors, axis=0, new_axis=0) else: @@ -6619,16 +6610,18 @@ def aten_view_copy(self: TensorType, size: INT64) -> TensorType: def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: """vstack(Tensor[] tensors) -> Tensor""" - # TODO: Support at_least2d - - # Tensor vstack(TensorList tensors) { - # TORCH_CHECK(!tensors.empty(), - # "vstack expects a non-empty TensorList"); - # auto rep = at::atleast_2d(tensors); - # return at::cat(rep, 0); - # } + # TODO: This is exactly from aten_atleast_2d + @graph() + def reshape_to_2d(tensor): + shape = op.Shape(tensor) + rank = op.Size(shape) + if rank <= 1: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1])) + return tensor + + new_tensors = op.SequenceMap(tensors, body=reshape_to_2d) - return op.ConcatFromSequence(tensors, axis=0) + return op.ConcatFromSequence(new_tensors, axis=0) @torch_op("aten::where") From 0480fbd8f0926c2ba6d0695bdcc8d214cd3184d4 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 7 Jun 2023 20:43:13 +0000 Subject: [PATCH 7/9] add atleast_Nd --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7129092ee3..fc4957f07e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6618,7 +6618,7 @@ def reshape_to_2d(tensor): if rank <= 1: tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1])) return tensor - + new_tensors = op.SequenceMap(tensors, body=reshape_to_2d) return op.ConcatFromSequence(new_tensors, axis=0) From 2d25896ff006f70a75008a3d1b143e90c7737f08 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 8 Jun 2023 17:09:58 +0000 Subject: [PATCH 8/9] add xfail and shared func --- .../function_libs/torch_lib/ops/core.py | 34 +++++++++---------- .../function_libs/torch_lib/ops_test_data.py | 13 ++++--- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index fc4957f07e..1912c48ff8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -666,8 +666,8 @@ def aten_atanh(self: TFloat) -> TFloat: return op.Atanh(self) -@torch_op("aten::atleast_1d") -def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: +@torch_op("aten::atleast_1d", private=True) +def _aten_atleast_1d_onnx(self: Sequence[TTensor]) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" @graph() @@ -681,6 +681,11 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) +@torch_op("aten::atleast_1d", trace_only=True) +def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: + return _aten_atleast_1d_onnx(self) + + @torch_op("aten::atleast_1d") def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" @@ -692,8 +697,8 @@ def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor: return self -@torch_op("aten::atleast_2d") -def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: +@torch_op("aten::atleast_2d", private=True) +def _aten_atleast_2d_onnx(self: Sequence[TTensor]) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" @graph() @@ -707,6 +712,11 @@ def reshape_to_2d(tensor): return op.SequenceMap(self, body=reshape_to_2d) +@torch_op("aten::atleast_2d", trace_only=True) +def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: + return _aten_atleast_2d_onnx(self) + + @torch_op("aten::atleast_2d") def aten_atleast_2d_single_tensor(self: TTensor) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" @@ -2880,7 +2890,7 @@ def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" # Use another onnx function - tensors = aten_atleast_1d(tensors) + tensors = _aten_atleast_1d_onnx(tensors) # NOTE: The if/else graph has different shape/type which breaks the # graph matching. We need to use trace_only. @@ -6610,18 +6620,8 @@ def aten_view_copy(self: TensorType, size: INT64) -> TensorType: def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: """vstack(Tensor[] tensors) -> Tensor""" - # TODO: This is exactly from aten_atleast_2d - @graph() - def reshape_to_2d(tensor): - shape = op.Shape(tensor) - rank = op.Size(shape) - if rank <= 1: - tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1])) - return tensor - - new_tensors = op.SequenceMap(tensors, body=reshape_to_2d) - - return op.ConcatFromSequence(new_tensors, axis=0) + tensors = _aten_atleast_2d_onnx(tensors) + return op.ConcatFromSequence(tensors, axis=0) @torch_op("aten::where") diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index e15d128484..6280c8c526 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -622,6 +622,10 @@ def _where_input_wrangler( variant_name="partial_views", reason="ONNX doesn't have partial view for tensor", ), + xfail( + "hstack", + reason="fixme: A bug of constant-propagation optimization within the subgraph, we can avoid it by turning off graph-optimizations in session options", + ), xfail("logcumsumexp", reason="naive implementation not numerically stable"), xfail( "max", @@ -762,6 +766,10 @@ def _where_input_wrangler( reason="fixme: ORT fails with invalid model: 'INVALID_ARGUMENT : Failed to load model with error: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)'", test_class_name="TestOutputConsistencyFullGraph", ), + xfail( + "vstack", + reason="fixme: A bug of constant-propagation optimization within the subgraph, we can avoid it by turning off graph-optimizations in session options", + ), ) @@ -1174,11 +1182,6 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, reason="this Aten overload only support when correction attribute exists", ), - xfail( - "vstack", - matcher=lambda sample: len(sample.input[0].shape) < 2, - reason="fixme: Need aten::at_least2d supported", - ), ) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",)) From a24e6208940e6b59c439b22c2963759f8567c1ce Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 8 Jun 2023 18:50:30 +0000 Subject: [PATCH 9/9] remove trace only --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1912c48ff8..17e86897ef 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -681,7 +681,7 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) -@torch_op("aten::atleast_1d", trace_only=True) +@torch_op("aten::atleast_1d") def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: return _aten_atleast_1d_onnx(self) @@ -712,7 +712,7 @@ def reshape_to_2d(tensor): return op.SequenceMap(self, body=reshape_to_2d) -@torch_op("aten::atleast_2d", trace_only=True) +@torch_op("aten::atleast_2d") def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: return _aten_atleast_2d_onnx(self)