Skip to content

Commit

Permalink
Port threshold to structure (pytorch#57810)
Browse files Browse the repository at this point in the history
Summary:
Related pytorch#55070
Port threshold and threshold_backward to structure

Pull Request resolved: pytorch#57810

Reviewed By: agolynski

Differential Revision: D28382716

Pulled By: ezyang

fbshipit-source-id: 8d0702ad074b52e8512524d9807c93bfe04c51d6
  • Loading branch information
liuyuanqiang@bytedance authored and Kushashwa Shrimali committed May 18, 2021
1 parent 4082e13 commit 5122b9b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 93 deletions.
74 changes: 36 additions & 38 deletions aten/src/ATen/native/Activation.cpp
Expand Up @@ -13,7 +13,38 @@

#include <c10/util/irange.h>

namespace at { namespace native {
namespace at {
namespace meta {
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Scalar& value) {
const Tensor& result = maybe_get_output();
build(TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(result)
.add_input(self)
.add_input(self) // other
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true));
}
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold_backward)(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
const Tensor& gradInput = maybe_get_output();
build(TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(gradInput)
.add_input(self)
.add_input(grad) // other
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true));
}
} // namespace meta
namespace native {

static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
Expand Down Expand Up @@ -406,45 +437,12 @@ Tensor softplus_backward(
return iter.output();
}

// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
static Tensor threshold_out(
optional<Tensor> opt_result,
const Tensor& self,
const Scalar& threshold,
const Scalar& value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(result)
.add_input(self)
.add_input(other)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.build();
threshold_stub(iter.device_type(), iter, threshold, value);
return iter.output();
}

Tensor threshold(const Tensor& self, const Scalar& threshold, const Scalar& value) {
return threshold_out(nullopt, self, threshold, value, self);
}

Tensor& threshold_(Tensor& self, const Scalar& threshold, const Scalar& value) {
threshold_out(make_optional(self), self, threshold, value, self);
return self;
}

Tensor& threshold_out(const Tensor& self, const Scalar& threshold, const Scalar& value, Tensor& result) {
threshold_out(make_optional(result), self, threshold, value, self);
return result;
TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
threshold_stub(device_type(), *this, threshold, value);
}

Tensor threshold_backward(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
return threshold_out(nullopt, self, threshold, 0, grad);
TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, const Scalar& threshold, const Tensor& gradInput) {
threshold_stub(device_type(), *this, threshold, 0);
}

// -----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Activation.h
Expand Up @@ -14,7 +14,7 @@ using activation_fn = void (*)(TensorIterator&);
using activation_backward_fn = void (*)(TensorIterator&);
using softplus_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&);
using softplus_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&);
using threshold_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&);
using threshold_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&);
using hardtanh_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&);
using hardsigmoid_fn = void(*)(TensorIterator&);
using hardsigmoid_backward_fn = void(*)(TensorIterator&);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/Activation.cpp
Expand Up @@ -84,7 +84,7 @@ static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
}

static void threshold_kernel(
TensorIterator& iter,
TensorIteratorBase& iter,
const Scalar& threshold_scalar,
const Scalar& value_scalar) {
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "threshold_cpu", [&] {
Expand Down
46 changes: 3 additions & 43 deletions aten/src/ATen/native/cuda/Activation.cu
Expand Up @@ -305,13 +305,13 @@ void softplus_backward_kernel(TensorIterator& iter, const Scalar& beta_, const S
}

template <typename scalar_t>
void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t value) {
void threshold_kernel_impl(TensorIteratorBase& iter, scalar_t threshold, scalar_t value) {
gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
});
}

static void threshold_kernel(TensorIterator& iter, const Scalar& threshold, const Scalar& value) {
static void threshold_kernel_cuda(TensorIteratorBase& iter, const Scalar& threshold, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] {
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
Expand Down Expand Up @@ -524,47 +524,6 @@ Tensor gelu_backward_cuda(const Tensor& grad, const Tensor& self) {
return dX;
}

// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
static Tensor threshold_out_cuda(
optional<Tensor> opt_result,
const Tensor& self,
const Scalar& threshold,
const Scalar& value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
.add_output(result)
.add_input(self)
.add_input(other)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.build();
threshold_kernel(iter, threshold, value);
return iter.output();
}

Tensor threshold_cuda(const Tensor& self, const Scalar& threshold, const Scalar& value) {
return threshold_out_cuda(nullopt, self, threshold, value, self);
}

Tensor& threshold__cuda(Tensor& self, const Scalar& threshold, const Scalar& value) {
threshold_out_cuda(make_optional(self), self, threshold, value, self);
return self;
}

Tensor& threshold_out_cuda(const Tensor& self, const Scalar& threshold, const Scalar& value, Tensor& result) {
threshold_out_cuda(make_optional(result), self, threshold, value, self);
return result;
}

Tensor threshold_backward_cuda(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
return threshold_out_cuda(nullopt, self, threshold, 0, grad);
}

REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
Expand All @@ -581,6 +540,7 @@ REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
REGISTER_DISPATCH(silu_stub, &silu_kernel);
REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);

} // namespace native
} // namespace at
21 changes: 12 additions & 9 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4035,29 +4035,32 @@
- func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor
device_check: NoCheck # TensorIterator
variants: function
structured_delegate: threshold.out
dispatch:
CPU: threshold
CUDA: threshold_cuda
QuantizedCPU: threshold_quantized_cpu

- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
dispatch:
CPU: threshold_
CUDA: threshold__cuda
structured_delegate: threshold.out

- func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: threshold_out

- func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU: threshold_out
CUDA: threshold_out_cuda
CPU, CUDA: threshold_backward_out

- func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
variants: function
structured_delegate: threshold_backward.grad_input
dispatch:
CPU: threshold_backward
CUDA: threshold_backward_cuda
MkldnnCPU: mkldnn_relu_backward

- func: tile(Tensor self, int[] dims) -> Tensor
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/runtime/static/ops.cpp
Expand Up @@ -586,7 +586,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
fastResizeToZero(out_t);
at::native::threshold_out(in0_t, 0, 0, out_t);
at::cpu::threshold_out(out_t, in0_t, 0, 0);
} else {
at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
int64_t nn = in0_t.numel();
Expand Down

0 comments on commit 5122b9b

Please sign in to comment.