From 16a0fa1204edb118800261a26281e624988eb239 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 7 Oct 2022 13:37:02 -0700 Subject: [PATCH] Enable max.unary_out (#85926) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85926 Approved by: https://github.com/bdhirsh --- aten/src/ATen/native/ReduceAllOps.cpp | 13 ++++++++++--- aten/src/ATen/native/native_functions.yaml | 13 ++++++------- .../ATen/native/quantized/TensorCompare.cpp | 13 +++++++++++++ torch/_meta_registrations.py | 8 +++++++- torch/csrc/jit/frontend/schema_matching.cpp | 19 +++++++++++++------ .../_internal/common_methods_invocations.py | 2 +- 6 files changed, 50 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/ReduceAllOps.cpp b/aten/src/ATen/native/ReduceAllOps.cpp index 31764734b67ab..1ef5e9b93733c 100644 --- a/aten/src/ATen/native/ReduceAllOps.cpp +++ b/aten/src/ATen/native/ReduceAllOps.cpp @@ -34,9 +34,16 @@ Tensor max(const Tensor &self) { } Tensor& max_unary_out(const Tensor &self, Tensor& out) { - Tensor tmp_output = at::max(self); - at::native::resize_output(out, tmp_output.sizes()); - out.copy_(tmp_output); + // First check if the devices match (CPU vs GPU) + TORCH_CHECK(self.device() == out.device()); + + TORCH_CHECK(canCast( + typeMetaToScalarType(self.dtype()), + typeMetaToScalarType(out.dtype()))); + + at::native::resize_output(out, {}); + + max_all_stub(self.device().type(), out, self.contiguous()); return out; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3994f4eb6a92b..ec16861d0c104 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8711,13 +8711,6 @@ MPS: max_mps QuantizedCPU: max_quantized_cpu -# Not to be confused with binary op `max.out`. Commented because of failed CI -# FIXME: enable this -#- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) -# device_check: NoCheck # TensorIterator -# dispatch: -# CompositeExplicitAutograd: max_unary_out - - func: fmax(Tensor self, Tensor other) -> Tensor structured_delegate: fmax.out device_check: NoCheck # TensorIterator @@ -8752,6 +8745,12 @@ - func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator +- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: max_unary_out + QuantizedCPU: max_quantized_unary_out + - func: minimum(Tensor self, Tensor other) -> Tensor structured_delegate: minimum.out device_check: NoCheck # TensorIterator diff --git a/aten/src/ATen/native/quantized/TensorCompare.cpp b/aten/src/ATen/native/quantized/TensorCompare.cpp index 08a104257f4eb..747f8bfe4d301 100644 --- a/aten/src/ATen/native/quantized/TensorCompare.cpp +++ b/aten/src/ATen/native/quantized/TensorCompare.cpp @@ -14,6 +14,19 @@ Tensor max_quantized_cpu(const Tensor& self) { return std::get<0>(self.reshape({-1}).max(/*dim=*/0)); } +Tensor& max_quantized_unary_out(const Tensor& self, Tensor& out) { + // TODO this implementation is inefficient for now. + TORCH_CHECK(self.device() == out.device()); + + TORCH_CHECK(canCast( + typeMetaToScalarType(self.dtype()), + typeMetaToScalarType(out.dtype()))); + Tensor temp = max_quantized_cpu(self); + at::native::resize_output(out, temp.sizes()); + out.copy_(temp); + return out; +} + Tensor min_quantized_cpu(const Tensor& self) { return std::get<0>(self.reshape({-1}).min(/*dim=*/0)); } diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ee500d74171c5..01bc58aaa2961 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -110,11 +110,17 @@ def meta_index_select_out(self, dim, index, out): return out.copy_(torch.index_select(self, dim, index)) -@register_meta([aten.max.default, aten.min.default]) +@register_meta([aten.max.default, aten.max.unary_out]) +@out_wrapper() def meta_max(self): return self.new_empty(()) +@register_meta([aten.min.default]) +def meta_min(self): + return self.new_empty(()) + + @register_meta(aten.angle.default) def meta_angle(self): if self.is_complex(): diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index b5e4c395672f3..0315d489fab5a 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -324,16 +324,23 @@ static bool varargsCanBeUsedAsList( !typevar_list; } -// Note (@zasdfgbnm): -// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 -// Currently JIT does not distinguish ScalarType vs int, so there is really -// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode -// the aten::view.dtype here to block this overload. This blocklist should be -// removed when JIT fully suports ScalarType as its own type. bool isBlockListedSchema(const FunctionSchema& schema) { + // Note (@zasdfgbnm): + // This is a workaround for https://github.com/pytorch/pytorch/issues/47964 + // Currently JIT does not distinguish ScalarType vs int, so there is really + // no way to distinguish x.view(1) vs x.view(torch.int8). So we have to + // hardcode the aten::view.dtype here to block this overload. This blocklist + // should be removed when JIT fully suports ScalarType as its own type. if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { return true; } + // Note (@tugsbayasgalan) + // TorchScript doesn't suport kwargs so this op collides with aten.max.others + // since both of them have 2 Tensor inputs. Since we don't expect users to + // use this op in TS, we just skip it + if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") { + return true; + } return false; } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 9bb79538e31f3..93386800fe893 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10238,7 +10238,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): OpInfo('max', variant_test_name='reduction_no_dim', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, + supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_max_min_reduction_no_dim,