Skip to content

Commit

Permalink
Enable max.unary_out (pytorch#85926)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#85926
Approved by: https://github.com/bdhirsh
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Oct 10, 2022
1 parent e18d466 commit 16a0fa1
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 18 deletions.
13 changes: 10 additions & 3 deletions aten/src/ATen/native/ReduceAllOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ Tensor max(const Tensor &self) {
}

Tensor& max_unary_out(const Tensor &self, Tensor& out) {
Tensor tmp_output = at::max(self);
at::native::resize_output(out, tmp_output.sizes());
out.copy_(tmp_output);
// First check if the devices match (CPU vs GPU)
TORCH_CHECK(self.device() == out.device());

TORCH_CHECK(canCast(
typeMetaToScalarType(self.dtype()),
typeMetaToScalarType(out.dtype())));

at::native::resize_output(out, {});

max_all_stub(self.device().type(), out, self.contiguous());
return out;
}

Expand Down
13 changes: 6 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8711,13 +8711,6 @@
MPS: max_mps
QuantizedCPU: max_quantized_cpu

# Not to be confused with binary op `max.out`. Commented because of failed CI
# FIXME: enable this
#- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
# device_check: NoCheck # TensorIterator
# dispatch:
# CompositeExplicitAutograd: max_unary_out

- func: fmax(Tensor self, Tensor other) -> Tensor
structured_delegate: fmax.out
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -8752,6 +8745,12 @@
- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator

- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: max_unary_out
QuantizedCPU: max_quantized_unary_out

- func: minimum(Tensor self, Tensor other) -> Tensor
structured_delegate: minimum.out
device_check: NoCheck # TensorIterator
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/quantized/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ Tensor max_quantized_cpu(const Tensor& self) {
return std::get<0>(self.reshape({-1}).max(/*dim=*/0));
}

Tensor& max_quantized_unary_out(const Tensor& self, Tensor& out) {
// TODO this implementation is inefficient for now.
TORCH_CHECK(self.device() == out.device());

TORCH_CHECK(canCast(
typeMetaToScalarType(self.dtype()),
typeMetaToScalarType(out.dtype())));
Tensor temp = max_quantized_cpu(self);
at::native::resize_output(out, temp.sizes());
out.copy_(temp);
return out;
}

Tensor min_quantized_cpu(const Tensor& self) {
return std::get<0>(self.reshape({-1}).min(/*dim=*/0));
}
Expand Down
8 changes: 7 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ def meta_index_select_out(self, dim, index, out):
return out.copy_(torch.index_select(self, dim, index))


@register_meta([aten.max.default, aten.min.default])
@register_meta([aten.max.default, aten.max.unary_out])
@out_wrapper()
def meta_max(self):
return self.new_empty(())


@register_meta([aten.min.default])
def meta_min(self):
return self.new_empty(())


@register_meta(aten.angle.default)
def meta_angle(self):
if self.is_complex():
Expand Down
19 changes: 13 additions & 6 deletions torch/csrc/jit/frontend/schema_matching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,23 @@ static bool varargsCanBeUsedAsList(
!typevar_list;
}

// Note (@zasdfgbnm):
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
// Currently JIT does not distinguish ScalarType vs int, so there is really
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode
// the aten::view.dtype here to block this overload. This blocklist should be
// removed when JIT fully suports ScalarType as its own type.
bool isBlockListedSchema(const FunctionSchema& schema) {
// Note (@zasdfgbnm):
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
// Currently JIT does not distinguish ScalarType vs int, so there is really
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to
// hardcode the aten::view.dtype here to block this overload. This blocklist
// should be removed when JIT fully suports ScalarType as its own type.
if (schema.name() == "aten::view" && schema.overload_name() == "dtype") {
return true;
}
// Note (@tugsbayasgalan)
// TorchScript doesn't suport kwargs so this op collides with aten.max.others
// since both of them have 2 Tensor inputs. Since we don't expect users to
// use this op in TS, we just skip it
if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") {
return true;
}
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10238,7 +10238,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
OpInfo('max',
variant_test_name='reduction_no_dim',
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
supports_out=False,
supports_out=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
Expand Down

0 comments on commit 16a0fa1

Please sign in to comment.