Skip to content

Commit

Permalink
Merge pull request #7319 from asi1024/chx-abs
Browse files Browse the repository at this point in the history
Add `chainerx::Absolute` device implementation
  • Loading branch information
niboshi committed Aug 22, 2019
2 parents 6ee8337 + 0062d70 commit 03557d6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/cuda/cuda_device/misc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SqrtKernel, { out = cuda::Sqrt

CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(SquareKernel, { out = x * x; }, VisitNumericDtype);

CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(FabsKernel, { out = cuda::Fabs(x); });
CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(AbsKernel, { out = cuda::Abs(x); }, VisitNumericDtype);

CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(SignKernel, { out = cuda::Sign(x); }, VisitNumericDtype);

Expand Down
10 changes: 9 additions & 1 deletion chainerx_cc/chainerx/cuda/numeric.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ __device__ inline float Exp2(float x) { return std::exp2f(x); }

__device__ inline cuda::Float16 Exp2(cuda::Float16 x) { return cuda::Float16{std::exp2f(static_cast<float>(x))}; }

__device__ inline uint8_t Abs(uint8_t x) { return x; }
__device__ inline int8_t Abs(int8_t x) { return std::labs(x); }
__device__ inline int16_t Abs(int16_t x) { return std::labs(x); }
__device__ inline int32_t Abs(int32_t x) { return std::labs(x); }
__device__ inline int64_t Abs(int64_t x) { return std::llabs(x); }
__device__ inline double Abs(double x) { return std::fabs(x); }
__device__ inline float Abs(float x) { return std::fabs(x); }
__device__ inline cuda::Float16 Abs(cuda::Float16 x) { return static_cast<cuda::Float16>(std::fabs(static_cast<float>(x))); }

#define CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(name, func) \
template <typename T> \
__device__ inline T name(T x) { \
Expand All @@ -110,7 +119,6 @@ CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Exp, std::exp)
CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Log, std::log)
CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Log10, std::log10)
CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Sqrt, std::sqrt)
CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Fabs, std::fabs)

namespace numeric_detail {

Expand Down
4 changes: 2 additions & 2 deletions chainerx_cc/chainerx/kernels/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class SquareKernel : public Kernel {
virtual void Call(const Array& x, const Array& out) = 0;
};

class FabsKernel : public Kernel {
class AbsKernel : public Kernel {
public:
static const char* name() { return "Fabs"; }
static const char* name() { return "Abs"; }

virtual void Call(const Array& x, const Array& out) = 0;
};
Expand Down
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/native/native_device/misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SqrtKernel, { out = chainerx

CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(SquareKernel, { out = x * x; }, VisitNumericDtype);

CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(FabsKernel, { out = chainerx::Fabs(x); });
CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(AbsKernel, { out = chainerx::Abs(x); }, VisitNumericDtype);

CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(SignKernel, { out = chainerx::Sign(x); }, VisitNumericDtype);

Expand Down
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/numeric.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log10, std::log10)
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log2, std::log2)
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log1p, std::log1p)
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sqrt, std::sqrt)
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Fabs, std::fabs)
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Abs, std::abs)

namespace numeric_detail {

Expand Down
32 changes: 20 additions & 12 deletions chainerx_cc/chainerx/routines/misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,15 @@ Array Square(const Array& x) {
return out;
}

Array Absolute(const Array& x) {
Array x_flip_1 = IfGreaterElse(x, 0.0, 0.0, -x);
Array x_flip_2 = IfLessElse(x, 0.0, 0.0, x);

Array out = x_flip_1 + x_flip_2;
return out;
}
namespace {

Array Fabs(const Array& x) {
Dtype dtype = internal::GetMathResultDtype(x.dtype());
Array out = Empty(x.shape(), dtype, x.device());
void AbsoluteImpl(const Array& x, const Array& out) {
{
NoBackpropModeScope scope{};
x.device().backend().CallKernel<FabsKernel>(x, out);
x.device().backend().CallKernel<AbsKernel>(x, out);
}

BackwardBuilder bb{"fabs", x, out};
BackwardBuilder bb{"abs", x, out};
if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
bt.Define([inp_tok = bb.RetainInput(0)](BackwardContext& bctx) {
const Array& gout = *bctx.output_grad();
Expand All @@ -186,7 +178,23 @@ Array Fabs(const Array& x) {
});
}
bb.Finalize();
}

} // namespace

Array Absolute(const Array& x) {
if (x.dtype() == Dtype::kBool) {
throw DtypeError{"Absolute does not support boolean array"};
}
Array out = EmptyLike(x);
AbsoluteImpl(x, out);
return out;
}

Array Fabs(const Array& x) {
Dtype dtype = internal::GetMathResultDtype(x.dtype());
Array out = Empty(x.shape(), dtype, x.device());
AbsoluteImpl(x, out);
return out;
}

Expand Down

0 comments on commit 03557d6

Please sign in to comment.