diff --git a/chainerx_cc/chainerx/CMakeLists.txt b/chainerx_cc/chainerx/CMakeLists.txt index 2bde04fed8ca..b73fc779957f 100644 --- a/chainerx_cc/chainerx/CMakeLists.txt +++ b/chainerx_cc/chainerx/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(kernels) add_subdirectory(routines) add_subdirectory(native) add_subdirectory(testing) @@ -42,13 +43,13 @@ install(FILES index_iterator.h indexable_array.h indexer.h + kernel.h + kernel_registry.h macro.h numerical_gradient.h numeric.h numeric_limits.h - op.h op_node.h - op_registry.h optional_container_arg.h platform.h reduction_kernel_arg.h @@ -166,10 +167,10 @@ if(${CHAINERX_BUILD_TEST}) index_iterator_test.cc indexable_array_test.cc indexer_test.cc + kernel_registry_test.cc numeric_limits_test.cc numerical_gradient_test.cc numeric_test.cc - op_registry_test.cc optional_container_arg_test.cc scalar_test.cc shape_test.cc diff --git a/chainerx_cc/chainerx/array.cc b/chainerx_cc/chainerx/array.cc index c3024045939e..c74a00c33688 100644 --- a/chainerx_cc/chainerx/array.cc +++ b/chainerx_cc/chainerx/array.cc @@ -31,6 +31,7 @@ #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/native/native_backend.h" #include "chainerx/op_node.h" @@ -40,7 +41,6 @@ #include "chainerx/routines/logic.h" #include "chainerx/routines/manipulation.h" #include "chainerx/routines/math.h" -#include "chainerx/routines/misc.h" #include "chainerx/routines/routines_util.h" #include "chainerx/routines/sorting.h" #include "chainerx/routines/statistics.h" @@ -313,7 +313,7 @@ Array Array::AsType(Dtype dtype, bool copy) const { } Array out = Empty(shape(), dtype, device()); - device().backend().CallOp(*this, out); + device().backend().CallKernel(*this, out); if (GetKind(dtype) == DtypeKind::kFloat) { BackwardBuilder bb{"astype", *this, out}; @@ -329,7 +329,7 @@ Array Array::AsType(Dtype dtype, bool copy) const { void Array::Fill(Scalar value) const { internal::CheckNoUnsafeInplace(*this, {}); - device().backend().CallOp(*this, value); + device().backend().CallKernel(*this, value); } const nonstd::optional& Array::GetGrad(const nonstd::optional& backprop_id) const { diff --git a/chainerx_cc/chainerx/backend.cc b/chainerx_cc/chainerx/backend.cc index db9d055c01c4..9b45c494db29 100644 --- a/chainerx_cc/chainerx/backend.cc +++ b/chainerx_cc/chainerx/backend.cc @@ -4,7 +4,7 @@ #include #include "chainerx/device.h" -#include "chainerx/op_registry.h" +#include "chainerx/kernel_registry.h" namespace chainerx { @@ -12,7 +12,7 @@ Backend::~Backend() = default; Backend::Backend(Context& context) : context_{context} {} -void Backend::Initialize() { op_registry_ = OpRegistry{&GetParentOpRegistry()}; } +void Backend::Initialize() { kernel_registry_ = KernelRegistry{&GetParentKernelRegistry()}; } Device& Backend::GetDevice(int index) { if (index < 0) { diff --git a/chainerx_cc/chainerx/backend.h b/chainerx_cc/chainerx/backend.h index 1a88c3a26876..5cc39f1bac33 100644 --- a/chainerx_cc/chainerx/backend.h +++ b/chainerx_cc/chainerx/backend.h @@ -6,8 +6,8 @@ #include #include -#include "chainerx/op.h" -#include "chainerx/op_registry.h" +#include "chainerx/kernel.h" +#include "chainerx/kernel_registry.h" namespace chainerx { @@ -40,7 +40,7 @@ class Backend { Context& context() const { return context_; } // Returns the op registry. - OpRegistry& op_registry() { return op_registry_; } + KernelRegistry& kernel_registry() { return kernel_registry_; } // Returns the device for the given index. // @@ -50,16 +50,16 @@ class Backend { // Queries if the backend supports data transfer between two devices. virtual bool SupportsTransfer(Device& src_device, Device& dst_device) = 0; - // Calls the op implementation. - template - auto CallOp(Args&&... args) { - Op& op = op_registry_.GetOp(); - return dynamic_cast(op).Call(std::forward(args)...); + // Calls the kernel implementation. + template + auto CallKernel(Args&&... args) { + Kernel& kernel = kernel_registry_.GetKernel(); + return dynamic_cast(kernel).Call(std::forward(args)...); } protected: - // Returns a backend-specific global op registry. - virtual OpRegistry& GetParentOpRegistry() = 0; + // Returns a backend-specific global kernel registry. + virtual KernelRegistry& GetParentKernelRegistry() = 0; private: // Creates a new device. @@ -72,7 +72,7 @@ class Backend { std::mutex devices_mutex_; - OpRegistry op_registry_; + KernelRegistry kernel_registry_; }; } // namespace chainerx diff --git a/chainerx_cc/chainerx/backend_testdata/backend0.cc b/chainerx_cc/chainerx/backend_testdata/backend0.cc index cbb51d1e29a7..7396ca3a4953 100644 --- a/chainerx_cc/chainerx/backend_testdata/backend0.cc +++ b/chainerx_cc/chainerx/backend_testdata/backend0.cc @@ -2,8 +2,8 @@ #include #include "chainerx/context.h" +#include "chainerx/kernel_registry.h" #include "chainerx/native/native_backend.h" -#include "chainerx/op_registry.h" namespace { @@ -14,9 +14,9 @@ class Backend0 : public chainerx::native::NativeBackend { std::string GetName() const override { return "backend0"; } protected: - chainerx::OpRegistry& GetParentOpRegistry() override { - static chainerx::OpRegistry op_registry{&chainerx::native::NativeBackend::GetGlobalOpRegistry()}; - return op_registry; + chainerx::KernelRegistry& GetParentKernelRegistry() override { + static chainerx::KernelRegistry kernel_registry{&chainerx::native::NativeBackend::GetGlobalKernelRegistry()}; + return kernel_registry; } }; diff --git a/chainerx_cc/chainerx/backend_testdata/backend1.cc b/chainerx_cc/chainerx/backend_testdata/backend1.cc index 2264213f0682..0c0d1cd0cd7a 100644 --- a/chainerx_cc/chainerx/backend_testdata/backend1.cc +++ b/chainerx_cc/chainerx/backend_testdata/backend1.cc @@ -15,10 +15,10 @@ class Backend1 : public chainerx::native::NativeBackend { std::string GetName() const override { return "backend1"; } protected: - chainerx::OpRegistry& GetParentOpRegistry() override { - static gsl::owner op_registry = - new chainerx::OpRegistry{&chainerx::native::NativeBackend::GetGlobalOpRegistry()}; - return *op_registry; + chainerx::KernelRegistry& GetParentKernelRegistry() override { + static gsl::owner kernel_registry = + new chainerx::KernelRegistry{&chainerx::native::NativeBackend::GetGlobalKernelRegistry()}; + return *kernel_registry; } }; diff --git a/chainerx_cc/chainerx/cuda/CMakeLists.txt b/chainerx_cc/chainerx/cuda/CMakeLists.txt index 28a36eefc202..4f4620065110 100644 --- a/chainerx_cc/chainerx/cuda/CMakeLists.txt +++ b/chainerx_cc/chainerx/cuda/CMakeLists.txt @@ -10,10 +10,10 @@ install(FILES data_type.cuh elementwise.cuh float16.cuh + kernel_regist.h memory_pool.h numeric.cuh numeric_limits.cuh - op_regist.h reduce.cuh DESTINATION include/chainerx/cuda ) diff --git a/chainerx_cc/chainerx/cuda/cuda_backend.h b/chainerx_cc/chainerx/cuda/cuda_backend.h index d798683dace3..b6e9e1561fb4 100644 --- a/chainerx_cc/chainerx/cuda/cuda_backend.h +++ b/chainerx_cc/chainerx/cuda/cuda_backend.h @@ -9,7 +9,7 @@ #include "chainerx/backend.h" #include "chainerx/context.h" #include "chainerx/device.h" -#include "chainerx/op_registry.h" +#include "chainerx/kernel_registry.h" namespace chainerx { namespace cuda { @@ -50,13 +50,13 @@ class CudaBackend : public Backend { // Gets maximum cuDNN workspace size. size_t GetCudnnMaxWorkspaceSize(); - static OpRegistry& GetGlobalOpRegistry() { - static OpRegistry* global_op_registry = new OpRegistry{}; - return *global_op_registry; + static KernelRegistry& GetGlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry{}; + return *global_kernel_registry; } protected: - OpRegistry& GetParentOpRegistry() override { return GetGlobalOpRegistry(); } + KernelRegistry& GetParentKernelRegistry() override { return GetGlobalKernelRegistry(); } private: std::unique_ptr CreateDevice(int index) override; diff --git a/chainerx_cc/chainerx/cuda/cuda_conv.cc b/chainerx_cc/chainerx/cuda/cuda_conv.cc index e85c43f8712c..5ae2d422e1a7 100644 --- a/chainerx_cc/chainerx/cuda/cuda_conv.cc +++ b/chainerx_cc/chainerx/cuda/cuda_conv.cc @@ -20,6 +20,7 @@ #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/hash_combine.h" +#include "chainerx/kernels/connection.h" #include "chainerx/macro.h" #include "chainerx/routines/connection.h" #include "chainerx/routines/creation.h" diff --git a/chainerx_cc/chainerx/cuda/cuda_conv_test.cc b/chainerx_cc/chainerx/cuda/cuda_conv_test.cc index 9577e00da0b8..47285927ea4c 100644 --- a/chainerx_cc/chainerx/cuda/cuda_conv_test.cc +++ b/chainerx_cc/chainerx/cuda/cuda_conv_test.cc @@ -9,6 +9,7 @@ #include "chainerx/constant.h" #include "chainerx/cuda/cuda_device.h" #include "chainerx/device_id.h" +#include "chainerx/kernels/connection.h" #include "chainerx/routines/connection.h" #include "chainerx/shape.h" #include "chainerx/stack_vector.h" @@ -155,9 +156,9 @@ TEST(CudaConvTest, BwdFilterAlgoCache) { Array gy = testing::BuildArray(out_shape).WithLinearData(-0.3f, 0.1f).WithPadding(1); EXPECT_EQ(size_t{0}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); - device.backend().CallOp(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); + device.backend().CallKernel(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); - device.backend().CallOp(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); + device.backend().CallKernel(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); } { @@ -171,9 +172,9 @@ TEST(CudaConvTest, BwdFilterAlgoCache) { Array gy = testing::BuildArray(out_shape).WithLinearData(-0.3f, 0.1f).WithPadding(1); EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); - device.backend().CallOp(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); + device.backend().CallKernel(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); EXPECT_EQ(size_t{2}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); - device.backend().CallOp(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); + device.backend().CallKernel(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); EXPECT_EQ(size_t{2}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv)); } } diff --git a/chainerx_cc/chainerx/cuda/cuda_device.h b/chainerx_cc/chainerx/cuda/cuda_device.h index ea7244e62bb2..cf0fe659152e 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device.h +++ b/chainerx_cc/chainerx/cuda/cuda_device.h @@ -21,6 +21,8 @@ #include "chainerx/cuda/memory_pool.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/normalization.h" +#include "chainerx/kernels/pooling.h" #include "chainerx/routines/normalization.h" #include "chainerx/routines/pooling.h" #include "chainerx/scalar.h" diff --git a/chainerx_cc/chainerx/cuda/cuda_device/activation.cu b/chainerx_cc/chainerx/cuda/cuda_device/activation.cu index 0234583b4e7a..1db2076919d1 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/activation.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/activation.cu @@ -9,10 +9,11 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/numeric.cuh" -#include "chainerx/cuda/op_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/numeric.h" #include "chainerx/routines/math.h" #include "chainerx/routines/type_util.h" @@ -31,7 +32,7 @@ struct IfLessElseASSAImpl { OutCudaType pos; }; -class CudaIfLessElseASSAOp : public IfLessElseASSAOp { +class CudaIfLessElseASSAKernel : public IfLessElseASSAKernel { public: void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override { Device& device = x1.device(); @@ -53,7 +54,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IfLessElseASSAOp, CudaIfLessElseASSAOp); +CHAINERX_CUDA_REGISTER_KERNEL(IfLessElseASSAKernel, CudaIfLessElseASSAKernel); template struct IfGreaterElseASSAImpl { @@ -64,7 +65,7 @@ struct IfGreaterElseASSAImpl { OutCudaType pos; }; -class CudaIfGreaterElseASSAOp : public IfGreaterElseASSAOp { +class CudaIfGreaterElseASSAKernel : public IfGreaterElseASSAKernel { public: void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override { Device& device = x1.device(); @@ -86,7 +87,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IfGreaterElseASSAOp, CudaIfGreaterElseASSAOp); +CHAINERX_CUDA_REGISTER_KERNEL(IfGreaterElseASSAKernel, CudaIfGreaterElseASSAKernel); template struct IfGreaterElseAAAAImpl { @@ -97,7 +98,7 @@ struct IfGreaterElseAAAAImpl { } }; -class CudaIfGreaterElseAAAAOp : public IfGreaterElseAAAAOp { +class CudaIfGreaterElseAAAAKernel : public IfGreaterElseAAAAKernel { public: void Call(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) override { Device& device = x1.device(); @@ -119,7 +120,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IfGreaterElseAAAAOp, CudaIfGreaterElseAAAAOp); +CHAINERX_CUDA_REGISTER_KERNEL(IfGreaterElseAAAAKernel, CudaIfGreaterElseAAAAKernel); template struct TanhImpl { @@ -127,7 +128,7 @@ struct TanhImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Tanh(x); } }; -class CudaTanhOp : public TanhOp { +class CudaTanhKernel : public TanhKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -141,7 +142,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(TanhOp, CudaTanhOp); +CHAINERX_CUDA_REGISTER_KERNEL(TanhKernel, CudaTanhKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/arithmetic.cu b/chainerx_cc/chainerx/cuda/cuda_device/arithmetic.cu index ccc378cbd16a..f0bea40cd8c1 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/arithmetic.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/arithmetic.cu @@ -10,9 +10,10 @@ #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" #include "chainerx/cuda/float16.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/routines/math.h" #include "chainerx/scalar.h" @@ -20,7 +21,7 @@ namespace chainerx { namespace cuda { namespace { -CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_OP(AddOp, { out = ArithmeticOps::Add(x1, x2); }); +CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(AddKernel, { out = ArithmeticOps::Add(x1, x2); }); template struct AddASImpl { @@ -29,7 +30,7 @@ struct AddASImpl { CudaType x2; }; -class CudaAddASOp : public AddASOp { +class CudaAddASKernel : public AddASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -44,9 +45,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AddASOp, CudaAddASOp); +CHAINERX_CUDA_REGISTER_KERNEL(AddASKernel, CudaAddASKernel); -CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_OP(SubtractOp, { out = ArithmeticOps::Subtract(x1, x2); }, VisitNumericDtype); +CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(SubtractKernel, { out = ArithmeticOps::Subtract(x1, x2); }, VisitNumericDtype); template struct SubtractASImpl { @@ -55,7 +56,7 @@ struct SubtractASImpl { CudaType x2; }; -class CudaSubtractASOp : public SubtractASOp { +class CudaSubtractASKernel : public SubtractASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -70,10 +71,10 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(SubtractASOp, CudaSubtractASOp); +CHAINERX_CUDA_REGISTER_KERNEL(SubtractASKernel, CudaSubtractASKernel); // TODO(sonots): support stream -CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_OP(MultiplyOp, { out = ArithmeticOps::Multiply(x1, x2); }); +CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(MultiplyKernel, { out = ArithmeticOps::Multiply(x1, x2); }); template struct MultiplyASImpl { @@ -82,7 +83,7 @@ struct MultiplyASImpl { CudaType x2; }; -class CudaMultiplyASOp : public MultiplyASOp { +class CudaMultiplyASKernel : public MultiplyASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -97,7 +98,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(MultiplyASOp, CudaMultiplyASOp); +CHAINERX_CUDA_REGISTER_KERNEL(MultiplyASKernel, CudaMultiplyASKernel); // CUDA does not have std::div. __device__ int8_t FloorDivide(int8_t x, int8_t y) { return x / y - ((y >= 0 ? x % y : -(x % y)) < 0 ? 1 : 0); } @@ -117,7 +118,7 @@ __device__ cuda::Float16 FloorDivide(cuda::Float16 x, cuda::Float16 y) { return cuda::Float16{FloorDivide(static_cast(x), static_cast(y))}; } -CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_OP(FloorDivideOp, { out = cuda::FloorDivide(x1, x2); }, VisitNumericDtype); +CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(FloorDivideKernel, { out = cuda::FloorDivide(x1, x2); }, VisitNumericDtype); template struct FloorDivideASImpl { @@ -126,7 +127,7 @@ struct FloorDivideASImpl { CudaType x2; }; -class CudaFloorDivideASOp : public FloorDivideASOp { +class CudaFloorDivideASKernel : public FloorDivideASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -141,9 +142,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(FloorDivideASOp, CudaFloorDivideASOp); +CHAINERX_CUDA_REGISTER_KERNEL(FloorDivideASKernel, CudaFloorDivideASKernel); -CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_OP(DivideOp, { out = ArithmeticOps::Divide(x1, x2); }); +CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(DivideKernel, { out = ArithmeticOps::Divide(x1, x2); }); template struct DivideASImpl { @@ -152,7 +153,7 @@ struct DivideASImpl { CudaType x2; }; -class CudaDivideASOp : public DivideASOp { +class CudaDivideASKernel : public DivideASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -167,7 +168,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(DivideASOp, CudaDivideASOp); +CHAINERX_CUDA_REGISTER_KERNEL(DivideASKernel, CudaDivideASKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/batch_norm.cc b/chainerx_cc/chainerx/cuda/cuda_device/batch_norm.cc index 8b9be1dabb24..98bcdf40afe8 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/batch_norm.cc +++ b/chainerx_cc/chainerx/cuda/cuda_device/batch_norm.cc @@ -14,10 +14,11 @@ #include "chainerx/backend_util.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/cudnn.h" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/error.h" +#include "chainerx/kernels/normalization.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/normalization.h" @@ -70,7 +71,7 @@ cuda_internal::CudnnTensorDescriptor DeriveBatchNormTensorDescriptor( return derive_desc; } -class CudaBatchNormOp : public BatchNormOp { +class CudaBatchNormKernel : public BatchNormKernel { public: std::tuple> Call( const Array& x, @@ -183,9 +184,9 @@ class CudaBatchNormOp : public BatchNormOp { } }; -CHAINERX_CUDA_REGISTER_OP(BatchNormOp, CudaBatchNormOp); +CHAINERX_CUDA_REGISTER_KERNEL(BatchNormKernel, CudaBatchNormKernel); -class CudaBatchNormGradOp : public BatchNormGradOp { +class CudaBatchNormGradKernel : public BatchNormGradKernel { public: std::tuple Call( const Array& x, @@ -294,9 +295,9 @@ class CudaBatchNormGradOp : public BatchNormGradOp { } }; -CHAINERX_CUDA_REGISTER_OP(BatchNormGradOp, CudaBatchNormGradOp); +CHAINERX_CUDA_REGISTER_KERNEL(BatchNormGradKernel, CudaBatchNormGradKernel); -class CudaFixedBatchNormOp : public FixedBatchNormOp { +class CudaFixedBatchNormKernel : public FixedBatchNormKernel { public: Array Call( const Array& x, @@ -376,7 +377,7 @@ class CudaFixedBatchNormOp : public FixedBatchNormOp { } }; -CHAINERX_CUDA_REGISTER_OP(FixedBatchNormOp, CudaFixedBatchNormOp); +CHAINERX_CUDA_REGISTER_KERNEL(FixedBatchNormKernel, CudaFixedBatchNormKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/comparison.cu b/chainerx_cc/chainerx/cuda/cuda_device/comparison.cu index 6aec59f3695e..97db6bbf0ced 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/comparison.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/comparison.cu @@ -9,10 +9,11 @@ #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/data_type.cuh" #include "chainerx/cuda/elementwise.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/reduce.cuh" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/logic.h" #include "chainerx/routines/logic.h" namespace chainerx { @@ -25,7 +26,7 @@ struct EqualImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 == x2; } }; -class CudaEqualOp : public EqualOp { +class CudaEqualKernel : public EqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -41,7 +42,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(EqualOp, CudaEqualOp); +CHAINERX_CUDA_REGISTER_KERNEL(EqualKernel, CudaEqualKernel); template struct NotEqualImpl { @@ -49,7 +50,7 @@ struct NotEqualImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 != x2; } }; -class CudaNotEqualOp : public NotEqualOp { +class CudaNotEqualKernel : public NotEqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -65,7 +66,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(NotEqualOp, CudaNotEqualOp); +CHAINERX_CUDA_REGISTER_KERNEL(NotEqualKernel, CudaNotEqualKernel); template struct GreaterImpl { @@ -73,7 +74,7 @@ struct GreaterImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 > x2; } }; -class CudaGreaterOp : public GreaterOp { +class CudaGreaterKernel : public GreaterKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -89,7 +90,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(GreaterOp, CudaGreaterOp); +CHAINERX_CUDA_REGISTER_KERNEL(GreaterKernel, CudaGreaterKernel); template struct GreaterEqualImpl { @@ -97,7 +98,7 @@ struct GreaterEqualImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 >= x2; } }; -class CudaGreaterEqualOp : public GreaterEqualOp { +class CudaGreaterEqualKernel : public GreaterEqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -113,7 +114,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(GreaterEqualOp, CudaGreaterEqualOp); +CHAINERX_CUDA_REGISTER_KERNEL(GreaterEqualKernel, CudaGreaterEqualKernel); template struct LogicalNotImpl { @@ -121,7 +122,7 @@ struct LogicalNotImpl { __device__ void operator()(int64_t /*i*/, CudaType x, bool& out) { out = !x; } }; -class CudaLogicalNotOp : public LogicalNotOp { +class CudaLogicalNotKernel : public LogicalNotKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -134,7 +135,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(LogicalNotOp, CudaLogicalNotOp); +CHAINERX_CUDA_REGISTER_KERNEL(LogicalNotKernel, CudaLogicalNotKernel); template struct LogicalAndImpl { @@ -142,7 +143,7 @@ struct LogicalAndImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 && x2; } }; -class CudaLogicalAndOp : public LogicalAndOp { +class CudaLogicalAndKernel : public LogicalAndKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -158,7 +159,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(LogicalAndOp, CudaLogicalAndOp); +CHAINERX_CUDA_REGISTER_KERNEL(LogicalAndKernel, CudaLogicalAndKernel); template struct LogicalOrImpl { @@ -166,7 +167,7 @@ struct LogicalOrImpl { __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 || x2; } }; -class CudaLogicalOrOp : public LogicalOrOp { +class CudaLogicalOrKernel : public LogicalOrKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -182,7 +183,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(LogicalOrOp, CudaLogicalOrOp); +CHAINERX_CUDA_REGISTER_KERNEL(LogicalOrKernel, CudaLogicalOrKernel); template struct AllImpl { @@ -193,7 +194,7 @@ struct AllImpl { __device__ bool MapOut(bool accum) { return accum; } }; -class CudaAllOp : public AllOp { +class CudaAllKernel : public AllKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -210,7 +211,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AllOp, CudaAllOp); +CHAINERX_CUDA_REGISTER_KERNEL(AllKernel, CudaAllKernel); template struct AnyImpl { @@ -221,7 +222,7 @@ struct AnyImpl { __device__ bool MapOut(bool accum) { return accum; } }; -class CudaAnyOp : public AnyOp { +class CudaAnyKernel : public AnyKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -238,7 +239,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AnyOp, CudaAnyOp); +CHAINERX_CUDA_REGISTER_KERNEL(AnyKernel, CudaAnyKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/conv.cc b/chainerx_cc/chainerx/cuda/cuda_device/conv.cc index c2aea073dec2..eafa687a369e 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/conv.cc +++ b/chainerx_cc/chainerx/cuda/cuda_device/conv.cc @@ -5,11 +5,11 @@ #include "chainerx/array.h" #include "chainerx/constant.h" #include "chainerx/cuda/cuda_conv.h" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/dtype.h" #include "chainerx/error.h" -#include "chainerx/native/op_regist.h" -#include "chainerx/routines/connection.h" +#include "chainerx/kernels/connection.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/shape.h" #include "chainerx/stack_vector.h" @@ -17,7 +17,7 @@ namespace chainerx { namespace cuda { namespace { -class CudaConvOp : public ConvOp { +class CudaConvKernel : public ConvKernel { public: Array Call( const Array& x, @@ -39,9 +39,9 @@ class CudaConvOp : public ConvOp { } }; -CHAINERX_CUDA_REGISTER_OP(ConvOp, CudaConvOp); +CHAINERX_CUDA_REGISTER_KERNEL(ConvKernel, CudaConvKernel); -class CudaConvTransposeOp : public ConvTransposeOp { +class CudaConvTransposeKernel : public ConvTransposeKernel { public: Array Call( const Array& x, @@ -62,9 +62,9 @@ class CudaConvTransposeOp : public ConvTransposeOp { } }; -CHAINERX_CUDA_REGISTER_OP(ConvTransposeOp, CudaConvTransposeOp); +CHAINERX_CUDA_REGISTER_KERNEL(ConvTransposeKernel, CudaConvTransposeKernel); -class CudaConvGradWeightOp : public ConvGradWeightOp { +class CudaConvGradWeightKernel : public ConvGradWeightKernel { public: Array Call( Dtype w_dtype, @@ -85,7 +85,7 @@ class CudaConvGradWeightOp : public ConvGradWeightOp { } }; -CHAINERX_CUDA_REGISTER_OP(ConvGradWeightOp, CudaConvGradWeightOp); +CHAINERX_CUDA_REGISTER_KERNEL(ConvGradWeightKernel, CudaConvGradWeightKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/copy.cu b/chainerx_cc/chainerx/cuda/cuda_device/copy.cu index 8da44b2af3a9..fa02dacffc33 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/copy.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/copy.cu @@ -8,17 +8,18 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/misc.h" namespace chainerx { namespace cuda { namespace { -CHAINERX_CUDA_REGISTER_ELTWISE_UNARY_OP(CopyOp, { out = x; }); +CHAINERX_CUDA_REGISTER_ELTWISE_UNARY_KERNEL(CopyKernel, { out = x; }); template struct AsTypeImpl { @@ -27,7 +28,7 @@ struct AsTypeImpl { __device__ void operator()(int64_t /*i*/, InCudaType a, OutCudaType& out) { out = static_cast(a); } }; -class CudaAsTypeOp : public AsTypeOp { +class CudaAsTypeKernel : public AsTypeKernel { public: void Call(const Array& a, const Array& out) override { Device& device = a.device(); @@ -42,7 +43,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AsTypeOp, CudaAsTypeOp); +CHAINERX_CUDA_REGISTER_KERNEL(AsTypeKernel, CudaAsTypeKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/dot.cu b/chainerx_cc/chainerx/cuda/cuda_device/dot.cu index ab423ee2e280..0b880408d2df 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/dot.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/dot.cu @@ -17,16 +17,18 @@ #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/data_type.cuh" #include "chainerx/cuda/float16.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/float16.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/linalg.h" +#include "chainerx/kernels/math.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/linalg.h" #include "chainerx/routines/math.h" -#include "chainerx/routines/misc.h" namespace chainerx { namespace cuda { @@ -83,7 +85,7 @@ struct GemmInputLayout { } // namespace -class CudaDotOp : public DotOp { +class CudaDotKernel : public DotKernel { public: void Call(const Array& a, const Array& b, const Array& out) override { Device& device = a.device(); @@ -110,15 +112,15 @@ public: // TODO(hvy): Avoid unnecessary cast here when multiplication supports mixed dtypes. const Array& a_cast = a.dtype() == out.dtype() ? a : a.AsType(out.dtype()); const Array& b_cast = b.dtype() == out.dtype() ? b : b.AsType(out.dtype()); - device.backend().CallOp(a_cast.Reshape({k}) * b_cast.Reshape({k}), Axes{0}, out.Reshape({})); + device.backend().CallKernel(a_cast.Reshape({k}) * b_cast.Reshape({k}), Axes{0}, out.Reshape({})); return; } if (out.dtype() == Dtype::kFloat16) { // TODO(imanishi): Use cublasHgemm Array out_float32 = Empty(out.shape(), Dtype::kFloat32, device); - device.backend().CallOp(a.AsType(Dtype::kFloat32), b.AsType(Dtype::kFloat32), out_float32); - device.backend().CallOp(out_float32, out); + device.backend().CallKernel(a.AsType(Dtype::kFloat32), b.AsType(Dtype::kFloat32), out_float32); + device.backend().CallKernel(out_float32, out); return; } @@ -184,12 +186,12 @@ public: } if (!is_out_contiguous) { - device.backend().CallOp(out_contiguous, out); + device.backend().CallKernel(out_contiguous, out); } } }; -CHAINERX_CUDA_REGISTER_OP(DotOp, CudaDotOp); +CHAINERX_CUDA_REGISTER_KERNEL(DotKernel, CudaDotKernel); } // namespace cuda } // namespace chainerx diff --git a/chainerx_cc/chainerx/cuda/cuda_device/exp_log.cu b/chainerx_cc/chainerx/cuda/cuda_device/exp_log.cu index 5d97637f3571..8f417a3dcd4e 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/exp_log.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/exp_log.cu @@ -9,11 +9,11 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/numeric.cuh" -#include "chainerx/cuda/op_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" -#include "chainerx/routines/math.h" +#include "chainerx/kernels/math.h" namespace chainerx { namespace cuda { @@ -25,7 +25,7 @@ struct ExpImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Exp(x); } }; -class CudaExpOp : public ExpOp { +class CudaExpKernel : public ExpKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -39,7 +39,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(ExpOp, CudaExpOp); +CHAINERX_CUDA_REGISTER_KERNEL(ExpKernel, CudaExpKernel); template struct LogImpl { @@ -47,7 +47,7 @@ struct LogImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Log(x); } }; -class CudaLogOp : public LogOp { +class CudaLogKernel : public LogKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -61,7 +61,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(LogOp, CudaLogOp); +CHAINERX_CUDA_REGISTER_KERNEL(LogKernel, CudaLogKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/fill.cu b/chainerx_cc/chainerx/cuda/cuda_device/fill.cu index cbdbe2d0b66c..d0aaf32704e9 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/fill.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/fill.cu @@ -12,14 +12,15 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/misc.h" #include "chainerx/scalar.h" #include "chainerx/shape.h" @@ -35,7 +36,7 @@ struct ArangeImpl { CudaType step; }; -class CudaArangeOp : public ArangeOp { +class CudaArangeKernel : public ArangeKernel { public: void Call(Scalar start, Scalar step, const Array& out) override { Device& device = out.device(); @@ -48,7 +49,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(ArangeOp, CudaArangeOp); +CHAINERX_CUDA_REGISTER_KERNEL(ArangeKernel, CudaArangeKernel); template struct IdentityImpl { @@ -58,7 +59,7 @@ struct IdentityImpl { int64_t n_plus_one; }; -class CudaIdentityOp : public IdentityOp { +class CudaIdentityKernel : public IdentityKernel { public: void Call(const Array& out) override { CHAINERX_ASSERT(out.ndim() == 2); @@ -73,7 +74,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IdentityOp, CudaIdentityOp); +CHAINERX_CUDA_REGISTER_KERNEL(IdentityKernel, CudaIdentityKernel); template struct EyeImpl { @@ -87,7 +88,7 @@ struct EyeImpl { int64_t step; }; -class CudaEyeOp : public EyeOp { +class CudaEyeKernel : public EyeKernel { public: void Call(int64_t k, const Array& out) override { Device& device = out.device(); @@ -99,7 +100,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(EyeOp, CudaEyeOp); +CHAINERX_CUDA_REGISTER_KERNEL(EyeKernel, CudaEyeKernel); template __global__ void SetVecInMat( @@ -117,7 +118,7 @@ __global__ void SetVecInMat( } } -class CudaDiagflatOp : public DiagflatOp { +class CudaDiagflatKernel : public DiagflatKernel { public: void Call(const Array& v, int64_t k, const Array& out) override { CHAINERX_ASSERT(v.ndim() == 1); @@ -139,7 +140,7 @@ public: } // Initialize all elements to 0 first instead of conditionally filling in the diagonal. - device.backend().CallOp(out, T{0}); + device.backend().CallKernel(out, T{0}); IndexableArray v_iarray{v}; IndexableArray out_iarray{out}; @@ -158,7 +159,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(DiagflatOp, CudaDiagflatOp); +CHAINERX_CUDA_REGISTER_KERNEL(DiagflatKernel, CudaDiagflatKernel); template struct LinspaceImpl { @@ -172,7 +173,7 @@ struct LinspaceImpl { double stop; }; -class CudaLinspaceOp : public LinspaceOp { +class CudaLinspaceKernel : public LinspaceKernel { public: void Call(double start, double stop, const Array& out) override { CHAINERX_ASSERT(out.ndim() == 1); @@ -188,7 +189,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(LinspaceOp, CudaLinspaceOp); +CHAINERX_CUDA_REGISTER_KERNEL(LinspaceKernel, CudaLinspaceKernel); template struct FillImpl { @@ -197,7 +198,7 @@ struct FillImpl { CudaType value; }; -class CudaFillOp : public FillOp { +class CudaFillKernel : public FillKernel { public: void Call(const Array& out, Scalar value) override { CudaSetDeviceScope scope{out.device().index()}; @@ -209,7 +210,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(FillOp, CudaFillOp); +CHAINERX_CUDA_REGISTER_KERNEL(FillKernel, CudaFillKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/indexing.cu b/chainerx_cc/chainerx/cuda/cuda_device/indexing.cu index 3323ce4ca917..d184bccf97ad 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/indexing.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/indexing.cu @@ -17,18 +17,18 @@ #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/data_type.cuh" #include "chainerx/cuda/elementwise.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/indexing.h" #include "chainerx/macro.h" #include "chainerx/routines/indexing.h" #include "chainerx/shape.h" namespace chainerx { namespace cuda { - namespace { // Makes axes for permutation that moves [first_axis, last_axis) to the head. @@ -48,7 +48,7 @@ Axes MakeRollingPermutation(int8_t first_axis, int8_t last_axis, int8_t ndim) { } template -__global__ void TakeKernel( +__global__ void TakeCudaKernel( IndexableArray a_iarray, IndexableArray out_iarray, IndexableArray indices_iarray, @@ -76,7 +76,7 @@ __global__ void TakeKernel( } template -__global__ void AddAtKernel( +__global__ void AddAtCudaKernel( IndexableArray a_iarray, IndexableArray b_iarray, IndexableArray out_iarray, @@ -156,12 +156,12 @@ void TakeImpl(Device& device, const Array& a, const Array& indices, int8_t axis, // TODO(niboshi): Calculate kMaxBlockSize per device std::lock_guard lock{*cuda_internal::g_mutex}; - static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&TakeKernel).block_size; + static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&TakeCudaKernel).block_size; int64_t total_size = out_indexer.total_size(); int64_t grid_size = (total_size + kMaxBlockSize - 1) / kMaxBlockSize; int64_t block_size = std::min(total_size, kMaxBlockSize); - TakeKernel<<>>( + TakeCudaKernel<<>>( a_iarray, out_iarray, indices_iarray, a_indexer, out_indexer, indices_indexer, common_total_size, axis_dim); }); } @@ -217,17 +217,17 @@ void AddAtImpl(Device& device, const Array& a, const Array& indices, int8_t axis TIndex axis_dim = gsl::narrow(a_shape[0]); - static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&AddAtKernel).block_size; + static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&AddAtCudaKernel).block_size; int64_t total_size = out_indexer.total_size(); int64_t grid_size = (total_size + kMaxBlockSize - 1) / kMaxBlockSize; int64_t block_size = std::min(total_size, kMaxBlockSize); - AddAtKernel<<>>( + AddAtCudaKernel<<>>( a_iarray, b_iarray, out_iarray, indices_iarray, b_indexer, out_indexer, indices_indexer, common_total_size, axis_dim); }); } -class CudaTakeOp : public TakeOp { +class CudaTakeKernel : public TakeKernel { public: void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) override { Device& device = a.device(); @@ -245,9 +245,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(TakeOp, CudaTakeOp); +CHAINERX_CUDA_REGISTER_KERNEL(TakeKernel, CudaTakeKernel); -class CudaAddAtOp : public AddAtOp { +class CudaAddAtKernel : public AddAtKernel { public: void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) override { Device& device = a.device(); @@ -265,7 +265,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AddAtOp, CudaAddAtOp); +CHAINERX_CUDA_REGISTER_KERNEL(AddAtKernel, CudaAddAtKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/misc.cu b/chainerx_cc/chainerx/cuda/cuda_device/misc.cu index 712331d3754e..ecab538e99ae 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/misc.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/misc.cu @@ -9,11 +9,11 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/numeric.cuh" -#include "chainerx/cuda/op_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" -#include "chainerx/routines/math.h" +#include "chainerx/kernels/math.h" namespace chainerx { namespace cuda { @@ -25,7 +25,7 @@ struct SquareImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = x * x; } }; -class CudaSquareOp : public SquareOp { +class CudaSquareKernel : public SquareKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -38,7 +38,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(SquareOp, CudaSquareOp); +CHAINERX_CUDA_REGISTER_KERNEL(SquareKernel, CudaSquareKernel); template struct SqrtImpl { @@ -46,7 +46,7 @@ struct SqrtImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Sqrt(x); } }; -class CudaSqrtOp : public SqrtOp { +class CudaSqrtKernel : public SqrtKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -60,7 +60,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(SqrtOp, CudaSqrtOp); +CHAINERX_CUDA_REGISTER_KERNEL(SqrtKernel, CudaSqrtKernel); template struct IsNanImpl { @@ -68,7 +68,7 @@ struct IsNanImpl { __device__ void operator()(int64_t /*i*/, CudaType x, bool& out) { out = cuda::IsNan(x); } }; -class CudaIsNanOp : public IsNanOp { +class CudaIsNanKernel : public IsNanKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -81,7 +81,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IsNanOp, CudaIsNanOp); +CHAINERX_CUDA_REGISTER_KERNEL(IsNanKernel, CudaIsNanKernel); template struct IsInfImpl { @@ -89,7 +89,7 @@ struct IsInfImpl { __device__ void operator()(int64_t /*i*/, CudaType x, bool& out) { out = cuda::IsInf(x); } }; -class CudaIsInfOp : public IsInfOp { +class CudaIsInfKernel : public IsInfKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -102,7 +102,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(IsInfOp, CudaIsInfOp); +CHAINERX_CUDA_REGISTER_KERNEL(IsInfKernel, CudaIsInfKernel); template struct CeilImpl { @@ -110,7 +110,7 @@ struct CeilImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Ceil(x); } }; -class CudaCeilOp : public CeilOp { +class CudaCeilKernel : public CeilKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -124,7 +124,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(CeilOp, CudaCeilOp); +CHAINERX_CUDA_REGISTER_KERNEL(CeilKernel, CudaCeilKernel); template struct FloorImpl { @@ -132,7 +132,7 @@ struct FloorImpl { __device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Floor(x); } }; -class CudaFloorOp : public FloorOp { +class CudaFloorKernel : public FloorKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -146,7 +146,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(FloorOp, CudaFloorOp); +CHAINERX_CUDA_REGISTER_KERNEL(FloorKernel, CudaFloorKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/pool.cu b/chainerx_cc/chainerx/cuda/cuda_device/pool.cu index f4d9c4879230..aa222df7bbbf 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/pool.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/pool.cu @@ -17,12 +17,13 @@ #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/cudnn.h" #include "chainerx/cuda/data_type.cuh" -#include "chainerx/cuda/op_regist.h" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/pooling.h" #include "chainerx/macro.h" #include "chainerx/numeric_limits.h" #include "chainerx/routines/connection.h" @@ -258,7 +259,7 @@ Array MaxPoolGradGrad( return actual_ggout; } -class CudaMaxPoolOp : public MaxPoolOp { +class CudaMaxPoolKernel : public MaxPoolKernel { public: std::tuple> Call( const Array& x, @@ -278,9 +279,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(MaxPoolOp, CudaMaxPoolOp); +CHAINERX_CUDA_REGISTER_KERNEL(MaxPoolKernel, CudaMaxPoolKernel); -class CudaMaxPoolGradOp : public MaxPoolGradOp { +class CudaMaxPoolGradKernel : public MaxPoolGradKernel { public: std::tuple> Call( const Array& gout, @@ -305,9 +306,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(MaxPoolGradOp, CudaMaxPoolGradOp); +CHAINERX_CUDA_REGISTER_KERNEL(MaxPoolGradKernel, CudaMaxPoolGradKernel); -class CudaMaxPoolGradGradOp : public MaxPoolGradGradOp { +class CudaMaxPoolGradGradKernel : public MaxPoolGradGradKernel { public: Array Call( const Array& ggx, @@ -328,7 +329,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(MaxPoolGradGradOp, CudaMaxPoolGradGradOp); +CHAINERX_CUDA_REGISTER_KERNEL(MaxPoolGradGradKernel, CudaMaxPoolGradGradKernel); cudnnPoolingMode_t GetCudnnPoolingMode(AveragePoolPadMode pad_mode) { switch (pad_mode) { @@ -341,7 +342,7 @@ cudnnPoolingMode_t GetCudnnPoolingMode(AveragePoolPadMode pad_mode) { } } -class CudaAveragePoolOp : public AveragePoolOp { +class CudaAveragePoolKernel : public AveragePoolKernel { public: std::tuple> Call( const Array& x, @@ -361,9 +362,9 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AveragePoolOp, CudaAveragePoolOp); +CHAINERX_CUDA_REGISTER_KERNEL(AveragePoolKernel, CudaAveragePoolKernel); -class CudaAveragePoolGradOp : public AveragePoolGradOp { +class CudaAveragePoolGradKernel : public AveragePoolGradKernel { public: Array Call( const Array& gout, @@ -384,7 +385,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AveragePoolGradOp, CudaAveragePoolGradOp); +CHAINERX_CUDA_REGISTER_KERNEL(AveragePoolGradKernel, CudaAveragePoolGradKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/reduction.cu b/chainerx_cc/chainerx/cuda/cuda_device/reduction.cu index 7750656d2a87..4009b77143f7 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/reduction.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/reduction.cu @@ -10,17 +10,18 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/data_type.cuh" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/numeric.cuh" #include "chainerx/cuda/numeric_limits.cuh" -#include "chainerx/cuda/op_regist.h" #include "chainerx/cuda/reduce.cuh" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" +#include "chainerx/kernels/sorting.h" #include "chainerx/macro.h" #include "chainerx/numeric_limits.h" #include "chainerx/reduction_kernel_arg.h" #include "chainerx/routines/math.h" -#include "chainerx/routines/sorting.h" #include "chainerx/shape.h" namespace chainerx { @@ -45,7 +46,7 @@ struct ArgMaxImpl { __device__ int64_t MapOut(MaxAndArgMax accum) { return accum.argmax; } }; -class CudaArgMaxOp : public ArgMaxOp { +class CudaArgMaxKernel : public ArgMaxKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { Device& device = a.device(); @@ -58,7 +59,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(ArgMaxOp, CudaArgMaxOp); +CHAINERX_CUDA_REGISTER_KERNEL(ArgMaxKernel, CudaArgMaxKernel); template struct SumImpl { @@ -70,7 +71,7 @@ struct SumImpl { __device__ OutCudaType MapOut(OutCudaType accum) { return accum; } }; -class CudaSumOp : public SumOp { +class CudaSumKernel : public SumKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { Device& device = a.device(); @@ -88,7 +89,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(SumOp, CudaSumOp); +CHAINERX_CUDA_REGISTER_KERNEL(SumKernel, CudaSumKernel); template struct AMaxImpl { @@ -103,7 +104,7 @@ struct AMaxImpl { __device__ CudaType MapOut(CudaType accum) { return accum; } }; -class CudaAMaxOp : public AMaxOp { +class CudaAMaxKernel : public AMaxKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { Device& device = a.device(); @@ -117,7 +118,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AMaxOp, CudaAMaxOp); +CHAINERX_CUDA_REGISTER_KERNEL(AMaxKernel, CudaAMaxKernel); template struct AMinImpl { @@ -132,7 +133,7 @@ struct AMinImpl { __device__ CudaType MapOut(CudaType accum) { return accum; } }; -class CudaAMinOp : public AMinOp { +class CudaAMinKernel : public AMinKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { Device& device = a.device(); @@ -146,7 +147,7 @@ public: } }; -CHAINERX_CUDA_REGISTER_OP(AMinOp, CudaAMinOp); +CHAINERX_CUDA_REGISTER_KERNEL(AMinKernel, CudaAMinKernel); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device/trigonometric.cu b/chainerx_cc/chainerx/cuda/cuda_device/trigonometric.cu index 139a5af67b60..0c1beaeb5cfe 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device/trigonometric.cu +++ b/chainerx_cc/chainerx/cuda/cuda_device/trigonometric.cu @@ -9,37 +9,37 @@ #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/cuda/cuda_set_device_scope.h" #include "chainerx/cuda/elementwise.cuh" +#include "chainerx/cuda/kernel_regist.h" #include "chainerx/cuda/numeric.cuh" -#include "chainerx/cuda/op_regist.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/numeric.h" -#include "chainerx/routines/math.h" #include "chainerx/scalar.h" namespace chainerx { namespace cuda { namespace { -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(SinOp, { out = cuda::Sin(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SinKernel, { out = cuda::Sin(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(CosOp, { out = cuda::Cos(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(CosKernel, { out = cuda::Cos(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(TanOp, { out = cuda::Tan(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(TanKernel, { out = cuda::Tan(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArcsinOp, { out = cuda::Arcsin(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArcsinKernel, { out = cuda::Arcsin(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArccosOp, { out = cuda::Arccos(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArccosKernel, { out = cuda::Arccos(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArctanOp, { out = cuda::Arctan(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArctanKernel, { out = cuda::Arctan(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(SinhOp, { out = cuda::Sinh(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SinhKernel, { out = cuda::Sinh(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(CoshOp, { out = cuda::Cosh(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(CoshKernel, { out = cuda::Cosh(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArcsinhOp, { out = cuda::Arcsinh(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArcsinhKernel, { out = cuda::Arcsinh(x); }); -CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArccoshOp, { out = cuda::Arccosh(x); }); +CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArccoshKernel, { out = cuda::Arccosh(x); }); } // namespace } // namespace cuda diff --git a/chainerx_cc/chainerx/cuda/cuda_device_test.cc b/chainerx_cc/chainerx/cuda/cuda_device_test.cc index b723aa408d02..c313a57cd847 100644 --- a/chainerx_cc/chainerx/cuda/cuda_device_test.cc +++ b/chainerx_cc/chainerx/cuda/cuda_device_test.cc @@ -13,7 +13,7 @@ #include "chainerx/cuda/cuda_backend.h" #include "chainerx/cuda/cuda_runtime.h" #include "chainerx/device.h" -#include "chainerx/routines/linalg.h" +#include "chainerx/kernels/linalg.h" #include "chainerx/testing/array.h" #include "chainerx/testing/array_check.h" #include "chainerx/testing/device_session.h" @@ -161,7 +161,7 @@ TEST(CudaDeviceTest, DotNonContiguousOut) { Array a = testing::BuildArray({2, 3}).WithLinearData(1.f); Array b = testing::BuildArray({3, 2}).WithData({1.f, 2.f, -1.f, -3.f, 2.f, 4.f}); Array c = testing::BuildArray({2, 2}).WithData({0.f, 0.f, 0.f, 0.f}).WithPadding(1); - a.device().backend().CallOp(a, b, c); + a.device().backend().CallKernel(a, b, c); Array e = testing::BuildArray({2, 2}).WithData({5.f, 8.f, 11.f, 17.f}); EXPECT_ARRAY_EQ(e, c); diff --git a/chainerx_cc/chainerx/cuda/kernel_regist.h b/chainerx_cc/chainerx/cuda/kernel_regist.h new file mode 100644 index 000000000000..1009b654ccdb --- /dev/null +++ b/chainerx_cc/chainerx/cuda/kernel_regist.h @@ -0,0 +1,76 @@ +#pragma once + +#include "chainerx/cuda/cuda_backend.h" +#include "chainerx/kernel_registry.h" + +// Register an kernel statically in CudaBackend. +#define CHAINERX_CUDA_REGISTER_KERNEL(key_kernel_cls, kernel_cls) \ + static chainerx::internal::KernelRegistrar \ + s_cuda_backend_kernel_##kernel_cls{}; + +#define CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, visit_dtype) \ + \ + template \ + struct Cuda##key_kernel_cls##Impl { \ + using CudaType = cuda_internal::DataType; \ + __device__ void operator()(int64_t i, CudaType x, CudaType& out) { \ + (void)i; \ + kernel_body \ + } \ + }; \ + \ + class Cuda##key_kernel_cls : public key_kernel_cls { \ + public: \ + void Call(const Array& x, const Array& out) override { \ + Device& device = x.device(); \ + device.CheckDevicesCompatible(x, out); \ + CudaSetDeviceScope scope{device.index()}; \ + const Array& x_cast = x.dtype() == out.dtype() ? x : x.AsType(out.dtype()); \ + visit_dtype(out.dtype(), [&](auto pt) { \ + using T = typename decltype(pt)::type; \ + Elementwise(Cuda##key_kernel_cls##Impl{}, x_cast, out); \ + }); \ + } \ + }; \ + \ + CHAINERX_CUDA_REGISTER_KERNEL(key_kernel_cls, Cuda##key_kernel_cls) + +#define CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, VisitFloatingPointDtype) + +#define CHAINERX_CUDA_REGISTER_ELTWISE_UNARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, VisitDtype) + +#define CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, visit_dtype) \ + \ + template \ + struct Cuda##key_kernel_cls##Impl { \ + using CudaType = cuda_internal::DataType; \ + __device__ void operator()(int64_t i, CudaType x1, CudaType x2, CudaType& out) { \ + (void)i; \ + kernel_body \ + } \ + }; \ + \ + class Cuda##key_kernel_cls : public key_kernel_cls { \ + public: \ + void Call(const Array& x1, const Array& x2, const Array& out) override { \ + Device& device = x1.device(); \ + device.CheckDevicesCompatible(x1, x2, out); \ + const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype()); \ + const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype()); \ + CudaSetDeviceScope scope{device.index()}; \ + visit_dtype(out.dtype(), [&](auto pt) { \ + using T = typename decltype(pt)::type; \ + Elementwise(Cuda##key_kernel_cls##Impl{}, x1_cast, x2_cast, out); \ + }); \ + } \ + }; \ + \ + CHAINERX_CUDA_REGISTER_KERNEL(key_kernel_cls, Cuda##key_kernel_cls); + +#define CHAINERX_CUDA_REGISTER_ELTWISE_FLOAT_BINARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, VisitFloatingPointDtype) + +#define CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, VisitDtype) diff --git a/chainerx_cc/chainerx/op.h b/chainerx_cc/chainerx/kernel.h similarity index 60% rename from chainerx_cc/chainerx/op.h rename to chainerx_cc/chainerx/kernel.h index 18962fc897f7..07bbcbfc474b 100644 --- a/chainerx_cc/chainerx/op.h +++ b/chainerx_cc/chainerx/kernel.h @@ -2,9 +2,9 @@ namespace chainerx { -class Op { +class Kernel { public: - virtual ~Op() = default; + virtual ~Kernel() = default; }; } // namespace chainerx diff --git a/chainerx_cc/chainerx/kernel_registry.h b/chainerx_cc/chainerx/kernel_registry.h new file mode 100644 index 000000000000..e8be718a2c54 --- /dev/null +++ b/chainerx_cc/chainerx/kernel_registry.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include +#include + +#include "chainerx/error.h" +#include "chainerx/kernel.h" + +namespace chainerx { + +// Manages dynamic registration and dispatch of kernels. +// This class is hierarchical: it has an optional pointer to a parent KernelRegistry and falls back if a kernel is not found in this +// instance. +class KernelRegistry { +public: + KernelRegistry() {} + + explicit KernelRegistry(KernelRegistry* parent) : parent_{parent} {} + + // Registers a kernel. + // Registers an instance of KernelType with the type_index of KeyKernelType as the key. + // KernelType must be a subclass of KeyKernelType. + template + void RegisterKernel() { + static_assert(std::is_base_of::value, "KernelType must be a subclass of KeyKernelType."); + std::lock_guard lock{*mutex_}; + auto pair = kernels_.emplace(std::type_index{typeid(KeyKernelType)}, std::make_unique()); + if (!pair.second) { + throw ChainerxError{"Duplicate kernel: ", KeyKernelType::name()}; + } + } + + // Looks up a kernel. + template + Kernel& GetKernel() { + std::type_index key{typeid(KeyKernelType)}; + { + std::lock_guard lock{*mutex_}; + auto it = kernels_.find(key); + if (it != kernels_.end()) { + return *it->second; + } + } + if (parent_ != nullptr) { + return parent_->GetKernel(); + } + throw ChainerxError{"Kernel not found: ", KeyKernelType::name()}; + } + +private: + std::unique_ptr mutex_{std::make_unique()}; + + KernelRegistry* parent_{}; + + std::unordered_map> kernels_{}; +}; + +namespace internal { + +// A facility to register kernels statically. +template +class KernelRegistrar { +public: + KernelRegistrar() noexcept { + KernelRegistry& kernel_registry = BackendType::GetGlobalKernelRegistry(); + kernel_registry.RegisterKernel(); + } +}; + +} // namespace internal +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernel_registry_test.cc b/chainerx_cc/chainerx/kernel_registry_test.cc new file mode 100644 index 000000000000..c32d84ea30aa --- /dev/null +++ b/chainerx_cc/chainerx/kernel_registry_test.cc @@ -0,0 +1,161 @@ +#include "chainerx/kernel_registry.h" + +#include +#include + +#include + +#include "chainerx/backend.h" +#include "chainerx/context.h" +#include "chainerx/error.h" +#include "chainerx/kernel.h" +#include "chainerx/macro.h" +#include "chainerx/native/native_backend.h" +#include "chainerx/testing/threading.h" +#include "chainerx/util.h" + +namespace chainerx { +namespace { + +TEST(KernelRegistryTest, KernelRegistry) { + KernelRegistry kernel_registry{}; + + class MyKernel : public Kernel { + public: + static const char* name() { return "mykernel"; } + virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } + }; + + kernel_registry.RegisterKernel(); + + Kernel& kernel = kernel_registry.GetKernel(); + + // no throw + MyKernel& mykernel = dynamic_cast(kernel); + + EXPECT_EQ(mykernel.Call(3, " is 3"), "3 is 3"); +} + +TEST(KernelRegistryTest, KernelRegistryHierarchy) { + KernelRegistry parent_kernel_registry{}; + KernelRegistry kernel_registry1{&parent_kernel_registry}; + KernelRegistry kernel_registry2{&parent_kernel_registry}; + + class MyKernel1 : public Kernel { + public: + static const char* name() { return "mykernel1"; } + virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } + }; + + class MyParentKernel : public Kernel { + public: + static const char* name() { return "myparentkernel"; } + virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } + }; + + kernel_registry1.RegisterKernel(); + parent_kernel_registry.RegisterKernel(); + + EXPECT_THROW({ kernel_registry2.GetKernel(); }, ChainerxError); + EXPECT_THROW({ parent_kernel_registry.GetKernel(); }, ChainerxError); + // no throw + Kernel& kernel1p = kernel_registry1.GetKernel(); + Kernel& kernel2p = kernel_registry2.GetKernel(); + Kernel& kernelpp = parent_kernel_registry.GetKernel(); + + Kernel& kernel = kernel_registry1.GetKernel(); + EXPECT_EQ(&kernel1p, &kernel2p); + EXPECT_EQ(&kernel1p, &kernelpp); + EXPECT_NE(&kernel1p, &kernel); + + MyKernel1& mykernel = dynamic_cast(kernel); + EXPECT_EQ(mykernel.Call(3, " is 3"), "3 is 3"); +} + +TEST(KernelRegistryTest, KernelRegistryWithBackend) { + // TODO(imanishi): Restore the environment variable after this test. + SetEnv("CHAINERX_PATH", CHAINERX_TEST_DIR "/backend_testdata"); + Context ctx; + Backend& backend0 = ctx.GetBackend("backend0"); + + KernelRegistry& kernel_registry = backend0.kernel_registry(); + + class MyKernel : public Kernel { + public: + static const char* name() { return "mykernel"; } + virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } + }; + + kernel_registry.RegisterKernel(); + + Kernel& kernel = kernel_registry.GetKernel(); + + // no throw + MyKernel& mykernel = dynamic_cast(kernel); + + EXPECT_EQ(mykernel.Call(3, " is 3"), "3 is 3"); + + // MyKernel should not be regsitered to the base backend: NativeBackend. + EXPECT_THROW({ native::NativeBackend::GetGlobalKernelRegistry().GetKernel(); }, ChainerxError); + + // MyKernel should not be regsitered to another Backend0 instance. + { + Context ctx_another; + Backend& backend_another = ctx_another.GetBackend("backend0"); + EXPECT_THROW({ backend_another.kernel_registry().GetKernel(); }, ChainerxError); + } +} + +TEST(KernelRegistryTest, KernelRegistryThreadSafe) { + KernelRegistry parent_kernel_registry{}; + KernelRegistry kernel_registry1{&parent_kernel_registry}; + + class MyKernel1 : public Kernel { + public: + static const char* name() { return "mykernel1"; } + virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } + }; + + class MyParentKernel : public Kernel { + public: + static const char* name() { return "myparentkernel"; } + virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } + }; + + class MyKernel2 : public Kernel { + public: + static const char* name() { return "mykernel2"; } + virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } + }; + + class MyParentKernel2 : public Kernel { + public: + static const char* name() { return "myparentkernel2"; } + virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } + }; + + kernel_registry1.RegisterKernel(); + parent_kernel_registry.RegisterKernel(); + + testing::RunThreads(4U, [&parent_kernel_registry, &kernel_registry1](size_t thread_index) { + switch (thread_index) { + case 0: + kernel_registry1.GetKernel(); + break; + case 1: + kernel_registry1.GetKernel(); + break; + case 2: + kernel_registry1.RegisterKernel(); + break; + case 3: + parent_kernel_registry.RegisterKernel(); + break; + default: + CHAINERX_NEVER_REACH(); + } + }); +} + +} // namespace +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/CMakeLists.txt b/chainerx_cc/chainerx/kernels/CMakeLists.txt new file mode 100644 index 000000000000..8679f8c033d0 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/CMakeLists.txt @@ -0,0 +1,13 @@ +install(FILES + connection.h + creation.h + indexing.h + linalg.h + logic.h + math.h + misc.h + normalization.h + pooling.h + sorting.h + DESTINATION include/chainerx/kernels + ) diff --git a/chainerx_cc/chainerx/kernels/connection.h b/chainerx_cc/chainerx/kernels/connection.h new file mode 100644 index 000000000000..fc0b4c98dabb --- /dev/null +++ b/chainerx_cc/chainerx/kernels/connection.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include + +#include "chainerx/array.h" +#include "chainerx/constant.h" +#include "chainerx/kernel.h" +#include "chainerx/stack_vector.h" + +namespace chainerx { + +// Computes the n-dimensional convolution. +// +// x: (batch_size, in_channels, in_1, in_2, ..., in_n) +// w: (out_channels, in_channels, k_1, k_2, ..., k_n) +// b: (out_channels) +// +// Returns an array of shape (batch_size, out_channels, out_1, out_2, ..., out_n). +class ConvKernel : public Kernel { +public: + static const char* name() { return "Conv"; } + + virtual Array Call( + const Array& x, + const Array& w, + const nonstd::optional& b, + const StackVector& stride, + const StackVector& pad, + bool cover_all, + Dtype out_dtype, + const nonstd::optional& out) = 0; +}; + +// Computes the n-dimensional transposed convolution. +// +// x: (batch_size, in_channels, in_1, in_2, ..., in_n) +// w: (in_channels, out_channels, k_1, k_2, ..., k_n) +// b: (out_channels) +// +// Returns an array of shape (batch_size, out_channels, out_1, out_2, ..., out_n). +class ConvTransposeKernel : public Kernel { +public: + static const char* name() { return "ConvTranspose"; } + + virtual Array Call( + const Array& x, + const Array& w, + const nonstd::optional& b, + const StackVector& stride, + const StackVector& pad, + const StackVector& out_size, + Dtype out_dtype, + const nonstd::optional& out) = 0; +}; + +class ConvGradWeightKernel : public Kernel { +public: + static const char* name() { return "ConvGradWeight"; } + + virtual Array Call( + Dtype w_dtype, + const Shape& w_shape, + const Array& x, + const Array& gy, + const StackVector& stride, + const StackVector& pad, + bool cover_all, + const nonstd::optional& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/creation.h b/chainerx_cc/chainerx/kernels/creation.h new file mode 100644 index 000000000000..fb4d315e02e6 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/creation.h @@ -0,0 +1,62 @@ +#pragma once + +#include + +#include "chainerx/array.h" +#include "chainerx/kernel.h" +#include "chainerx/scalar.h" + +namespace chainerx { + +class ArangeKernel : public Kernel { +public: + static const char* name() { return "Arange"; } + + virtual void Call(Scalar start, Scalar step, const Array& out) = 0; +}; + +class CopyKernel : public Kernel { +public: + static const char* name() { return "Copy"; } + + // Copies the elements from one array to the other. + // + // The arrays must match in shape and dtype and need to reside on this device. + virtual void Call(const Array& a, const Array& out) = 0; +}; + +class IdentityKernel : public Kernel { +public: + static const char* name() { return "Identity"; } + + // Creates the identity array. + // out must be a square 2-dim array. + virtual void Call(const Array& out) = 0; +}; + +class EyeKernel : public Kernel { +public: + static const char* name() { return "Eye"; } + + // Creates a 2-dimensional array with ones along the k-th diagonal and zeros elsewhere. + // out must be a square 2-dim array. + virtual void Call(int64_t k, const Array& out) = 0; +}; + +class DiagflatKernel : public Kernel { +public: + static const char* name() { return "Diagflat"; } + + virtual void Call(const Array& v, int64_t k, const Array& out) = 0; +}; + +class LinspaceKernel : public Kernel { +public: + static const char* name() { return "Linspace"; } + + // Creates an evenly spaced 1-d array. + // `out.ndim()` must be 1 with at least 1 elements. + virtual void Call(double start, double stop, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/indexing.h b/chainerx_cc/chainerx/kernels/indexing.h new file mode 100644 index 000000000000..3e29ef63f737 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/indexing.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "chainerx/array.h" +#include "chainerx/array_index.h" +#include "chainerx/kernel.h" + +namespace chainerx { + +class AddAtKernel : public Kernel { +public: + static const char* name() { return "AddAt"; } + + virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) = 0; +}; + +class TakeKernel : public Kernel { +public: + static const char* name() { return "Take"; } + + virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/linalg.h b/chainerx_cc/chainerx/kernels/linalg.h new file mode 100644 index 000000000000..f6d61052b63c --- /dev/null +++ b/chainerx_cc/chainerx/kernels/linalg.h @@ -0,0 +1,19 @@ +#pragma once + +#include "chainerx/array.h" +#include "chainerx/kernel.h" + +namespace chainerx { + +// Matrix multiplication. All the operands are matrices (i.e., two-dimensional arrays). +// Let the shapes of `a` and `b` be `(M, K)` and `(L, N)`, respectively. +// Then, it must hold that `K == L` and the shape of `out` must be `(M, N)`. +// Otherwise, the behavior is undefined. +class DotKernel : public Kernel { +public: + static const char* name() { return "Dot"; } + + virtual void Call(const Array& a, const Array& b, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/logic.h b/chainerx_cc/chainerx/kernels/logic.h new file mode 100644 index 000000000000..776d3297bde6 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/logic.h @@ -0,0 +1,72 @@ +#pragma once + +#include "chainerx/array.h" +#include "chainerx/axes.h" +#include "chainerx/kernel.h" + +namespace chainerx { + +class EqualKernel : public Kernel { +public: + static const char* name() { return "Equal"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class NotEqualKernel : public Kernel { +public: + static const char* name() { return "NotEqual"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class GreaterKernel : public Kernel { +public: + static const char* name() { return "Greater"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class GreaterEqualKernel : public Kernel { +public: + static const char* name() { return "GreaterEqual"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class LogicalNotKernel : public Kernel { +public: + static const char* name() { return "LogicalNot"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class LogicalAndKernel : public Kernel { +public: + static const char* name() { return "LogicalAnd"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class LogicalOrKernel : public Kernel { +public: + static const char* name() { return "LogicalOr"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class AllKernel : public Kernel { +public: + static const char* name() { return "All"; } + + virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; +}; + +class AnyKernel : public Kernel { +public: + static const char* name() { return "Any"; } + + virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/math.h b/chainerx_cc/chainerx/kernels/math.h new file mode 100644 index 000000000000..3fcd5c280582 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/math.h @@ -0,0 +1,274 @@ +#pragma once + +#include + +#include + +#include "chainerx/array.h" +#include "chainerx/axes.h" +#include "chainerx/kernel.h" +#include "chainerx/scalar.h" + +namespace chainerx { + +class AddKernel : public Kernel { +public: + static const char* name() { return "Add"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class AddASKernel : public Kernel { +public: + static const char* name() { return "AddAS"; } + + virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; +}; + +class SubtractKernel : public Kernel { +public: + static const char* name() { return "Subtract"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class SubtractASKernel : public Kernel { +public: + static const char* name() { return "SubtractAS"; } + + virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; +}; + +class MultiplyKernel : public Kernel { +public: + static const char* name() { return "Multiply"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class MultiplyASKernel : public Kernel { +public: + static const char* name() { return "MultiplyAS"; } + + virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; +}; + +class FloorDivideKernel : public Kernel { +public: + static const char* name() { return "FloorDivide"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class FloorDivideASKernel : public Kernel { +public: + static const char* name() { return "FloorDivideAS"; } + + virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; +}; + +class DivideKernel : public Kernel { +public: + static const char* name() { return "Divide"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; +}; + +class DivideASKernel : public Kernel { +public: + static const char* name() { return "DivideAS"; } + + virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; +}; + +class ExpKernel : public Kernel { +public: + static const char* name() { return "Exp"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class LogKernel : public Kernel { +public: + static const char* name() { return "Log"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class SquareKernel : public Kernel { +public: + static const char* name() { return "Square"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class SqrtKernel : public Kernel { +public: + static const char* name() { return "Sqrt"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class SinKernel : public Kernel { +public: + static const char* name() { return "Sin"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class CosKernel : public Kernel { +public: + static const char* name() { return "Cos"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class TanKernel : public Kernel { +public: + static const char* name() { return "Tan"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class SinhKernel : public Kernel { +public: + static const char* name() { return "Sinh"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class CoshKernel : public Kernel { +public: + static const char* name() { return "Cosh"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class TanhKernel : public Kernel { +public: + static const char* name() { return "Tanh"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class ArcsinKernel : public Kernel { +public: + static const char* name() { return "Arcsin"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class ArccosKernel : public Kernel { +public: + static const char* name() { return "Arccos"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class ArctanKernel : public Kernel { +public: + static const char* name() { return "Arctan"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class ArcsinhKernel : public Kernel { +public: + static const char* name() { return "Archsinh"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class ArccoshKernel : public Kernel { +public: + static const char* name() { return "Arccosh"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class CeilKernel : public Kernel { +public: + static const char* name() { return "Ceil"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class FloorKernel : public Kernel { +public: + static const char* name() { return "Floor"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class IsNanKernel : public Kernel { +public: + static const char* name() { return "IsNan"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +class IsInfKernel : public Kernel { +public: + static const char* name() { return "IsInf"; } + + virtual void Call(const Array& x, const Array& out) = 0; +}; + +// Calculate the sum of an array. +// It will be summed over the specified axes. +// `axis` must be normalized so that +// - it has only positive values, +// - it is sorted, and +// - it has no duplicated values. +// Otherwise, the behavior is undefined. +class SumKernel : public Kernel { +public: + static const char* name() { return "Sum"; } + + virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; +}; + +// Calculates the maximum along specified axes. +// See Sum() for the explanation of arguments. +class AMaxKernel : public Kernel { +public: + static const char* name() { return "AMax"; } + + virtual void Call(const Array& src, const Axes& axis, const Array& out) = 0; +}; + +// Calculates the minimum along specified axes. +// See Sum() for the explanation of arguments. +class AMinKernel : public Kernel { +public: + static const char* name() { return "AMin"; } + + virtual void Call(const Array& src, const Axes& axis, const Array& out) = 0; +}; + +// Compares x1 and x2 and assign either pos or neg according to the result. +// Formally, it calculates: out = x1 < x2 ? pos : neg +class IfLessElseASSAKernel : public Kernel { +public: + static const char* name() { return "IfLessElseASSA"; } + + virtual void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) = 0; +}; + +// Compares x1 and x2 and assign either pos or neg according to the result. +// Formally, it calculates: out = x1 > x2 ? pos : neg +class IfGreaterElseASSAKernel : public Kernel { +public: + static const char* name() { return "IfGreaterElseASSA"; } + + virtual void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) = 0; +}; + +class IfGreaterElseAAAAKernel : public Kernel { +public: + static const char* name() { return "IfGreaterElseAAAA"; } + + virtual void Call(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/misc.h b/chainerx_cc/chainerx/kernels/misc.h similarity index 66% rename from chainerx_cc/chainerx/routines/misc.h rename to chainerx_cc/chainerx/kernels/misc.h index 9ede28d672cb..8b848fbf64cf 100644 --- a/chainerx_cc/chainerx/routines/misc.h +++ b/chainerx_cc/chainerx/kernels/misc.h @@ -1,12 +1,14 @@ #pragma once +// TODO(hvy): Consider moving the content in this file to e.g. kernels/creation.h, in which case this file can be removed. + #include "chainerx/array.h" -#include "chainerx/op.h" +#include "chainerx/kernel.h" #include "chainerx/scalar.h" namespace chainerx { -class FillOp : public Op { +class FillKernel : public Kernel { public: static const char* name() { return "Fill"; } @@ -14,7 +16,7 @@ class FillOp : public Op { }; // Casts the elements from one array to the other dtype, and store into the other. -class AsTypeOp : public Op { +class AsTypeKernel : public Kernel { public: static const char* name() { return "AsType"; } diff --git a/chainerx_cc/chainerx/kernels/normalization.h b/chainerx_cc/chainerx/kernels/normalization.h new file mode 100644 index 000000000000..506c41397f89 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/normalization.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include + +#include + +#include "chainerx/array.h" +#include "chainerx/axes.h" +#include "chainerx/dtype.h" +#include "chainerx/kernel.h" +#include "chainerx/scalar.h" + +namespace chainerx { + +// Intermediate results from `BatchNormKernel::Call` can be stored in this construct and be reused in `BatchNormGradKernel::Call`. +// The objects to store may vary depending on backend so each backend should derive this class to define the actual set of intermediate +// results. +class BatchNormGradState { +public: + virtual ~BatchNormGradState() = default; +}; + +class BatchNormKernel : public Kernel { +public: + static const char* name() { return "BatchNorm"; } + + // The returned state should be a `nullptr` if `return_state` is `false`. + virtual std::tuple> Call( + const Array& x, + const Array& gamma, + const Array& beta, + const Array& running_mean, + const Array& running_var, + Scalar eps, + Scalar decay, + const Axes& axis, + bool return_state, + const nonstd::optional& out) = 0; +}; + +class BatchNormGradKernel : public Kernel { +public: + static const char* name() { return "BatchNormGrad"; } + + // Returns gx, ggamma, gbeta. + virtual std::tuple Call( + const Array& x, + const Array& gamma, + const Array& gout, + Scalar eps, + const Axes& axis, + const std::shared_ptr& state, + const nonstd::optional& gx, + const nonstd::optional& ggamma, + const nonstd::optional& gbeta) = 0; +}; + +class GenericBatchNormGradState : public BatchNormGradState { +public: + GenericBatchNormGradState(Array x_mean, Array x_inv_std, Dtype beta_dtype) + : x_mean_{std::move(x_mean)}, x_inv_std_{std::move(x_inv_std)}, beta_dtype_{beta_dtype} {} + + const Array& x_mean() const { return x_mean_; } + const Array& x_inv_std() const { return x_inv_std_; } + Dtype beta_dtype() const { return beta_dtype_; } + +private: + Array x_mean_; + Array x_inv_std_; + Dtype beta_dtype_; +}; + +class GenericBatchNormKernel : public BatchNormKernel { +public: + std::tuple> Call( + const Array& x, + const Array& gamma, + const Array& beta, + const Array& running_mean, + const Array& running_var, + Scalar eps, + Scalar decay, + const Axes& axis, + bool return_state, + const nonstd::optional& out) override; +}; + +class GenericBatchNormGradKernel : public BatchNormGradKernel { +public: + std::tuple Call( + const Array& x, + const Array& gamma, + const Array& gout, + Scalar eps, + const Axes& axis, + const std::shared_ptr& state, + const nonstd::optional& gx, + const nonstd::optional& ggamma, + const nonstd::optional& gbeta) override; +}; + +class FixedBatchNormKernel : public Kernel { +public: + static const char* name() { return "FixedBatchNorm"; } + + virtual Array Call( + const Array& x, + const Array& gamma, + const Array& beta, + const Array& mean, + const Array& var, + Scalar eps, + const Axes& axis, + const nonstd::optional& out) = 0; +}; + +class GenericFixedBatchNormKernel : public FixedBatchNormKernel { +public: + Array Call( + const Array& x, + const Array& gamma, + const Array& beta, + const Array& mean, + const Array& var, + Scalar eps, + const Axes& axis, + const nonstd::optional& out) override; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/pooling.h b/chainerx_cc/chainerx/kernels/pooling.h new file mode 100644 index 000000000000..7fdd28538e40 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/pooling.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include + +#include + +#include "chainerx/array.h" +#include "chainerx/constant.h" +#include "chainerx/kernel.h" +#include "chainerx/stack_vector.h" + +namespace chainerx { + +class MaxPoolGradState { +public: + virtual ~MaxPoolGradState() = default; +}; + +class MaxPoolKernel : public Kernel { +public: + static const char* name() { return "MaxPool"; } + + virtual std::tuple> Call( + const Array& x, + StackVector kernel_size, + StackVector stride, + StackVector pad, + bool cover_all, + bool return_state, + const nonstd::optional& out) = 0; +}; + +class MaxPoolGradGradState { +public: + virtual ~MaxPoolGradGradState() = default; +}; + +class MaxPoolGradKernel : public Kernel { +public: + static const char* name() { return "MaxPoolGrad"; } + + virtual std::tuple> Call( + const Array& gout, + StackVector kernel_size, + StackVector stride, + StackVector pad, + const std::shared_ptr& state, + bool return_state, + const nonstd::optional& gx) = 0; +}; + +class MaxPoolGradGradKernel : public Kernel { +public: + static const char* name() { return "MaxPoolGradGrad"; } + + virtual Array Call( + const Array& ggx, + StackVector kernel_size, + StackVector stride, + StackVector pad, + bool cover_all, + const std::shared_ptr& state, + const nonstd::optional& ggout) = 0; +}; + +class AveragePoolGradState { +public: + virtual ~AveragePoolGradState() = default; +}; + +class AveragePoolKernel : public Kernel { +public: + static const char* name() { return "AveragePool"; } + + virtual std::tuple> Call( + const Array& x, + StackVector kernel_size, + StackVector stride, + StackVector pad, + AveragePoolPadMode pad_mode, + bool return_state, + const nonstd::optional& out) = 0; +}; + +class AveragePoolGradKernel : public Kernel { +public: + static const char* name() { return "AveragePoolGrad"; } + + virtual Array Call( + const Array& gout, + StackVector kernel_size, + StackVector stride, + StackVector pad, + AveragePoolPadMode pad_mode, + const std::shared_ptr& state, + const nonstd::optional& gx) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/kernels/sorting.h b/chainerx_cc/chainerx/kernels/sorting.h new file mode 100644 index 000000000000..563d56e48d69 --- /dev/null +++ b/chainerx_cc/chainerx/kernels/sorting.h @@ -0,0 +1,16 @@ +#pragma once + +#include "chainerx/array.h" +#include "chainerx/axes.h" +#include "chainerx/kernel.h" + +namespace chainerx { + +class ArgMaxKernel : public Kernel { +public: + static const char* name() { return "ArgMax"; } + + virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; +}; + +} // namespace chainerx diff --git a/chainerx_cc/chainerx/native/CMakeLists.txt b/chainerx_cc/chainerx/native/CMakeLists.txt index dd5df19199d2..8f91287c88cb 100644 --- a/chainerx_cc/chainerx/native/CMakeLists.txt +++ b/chainerx_cc/chainerx/native/CMakeLists.txt @@ -3,7 +3,7 @@ install(FILES native_backend.h data_type.h elementwise.h - op_regist.h + kernel_regist.h reduce.h col2im.h im2col.h diff --git a/chainerx_cc/chainerx/native/im2col.cc b/chainerx_cc/chainerx/native/im2col.cc index f0aa472c7733..56404012d48c 100644 --- a/chainerx_cc/chainerx/native/im2col.cc +++ b/chainerx_cc/chainerx/native/im2col.cc @@ -11,6 +11,7 @@ #include "chainerx/device.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/creation.h" #include "chainerx/macro.h" #include "chainerx/routines/connection.h" #include "chainerx/routines/creation.h" @@ -102,7 +103,7 @@ Array Im2Col( } Array padded_x = static_cast(pad_value) == int64_t{0} ? Zeros(padded_shape, x.dtype(), device) : Full(padded_shape, pad_value, x.dtype(), device); - device.backend().CallOp(x, padded_x.At(unpadded_slice)); + device.backend().CallKernel(x, padded_x.At(unpadded_slice)); CHAINERX_ASSERT(ndim + 2 == padded_x.ndim()); // Create the output array. diff --git a/chainerx_cc/chainerx/native/kernel_regist.h b/chainerx_cc/chainerx/native/kernel_regist.h new file mode 100644 index 000000000000..0ba320d1f0b1 --- /dev/null +++ b/chainerx_cc/chainerx/native/kernel_regist.h @@ -0,0 +1,66 @@ +#pragma once + +#include "chainerx/kernel_registry.h" +#include "chainerx/native/native_backend.h" + +// Register an kernel statically in NativeBackend. +#define CHAINERX_NATIVE_REGISTER_KERNEL(key_kernel_cls, kernel_cls) \ + static chainerx::internal::KernelRegistrar \ + s_native_backend_kernel_##kernel_cls{}; // NOLINT(cert-err58-cpp) + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, visit_dtype) \ + class Native##key_kernel_cls : public key_kernel_cls { \ + public: \ + void Call(const Array& x, const Array& out) override { \ + Device& device = x.device(); \ + device.CheckDevicesCompatible(x, out); \ + const Array& x_cast = x.dtype() == out.dtype() ? x : x.AsType(out.dtype()); \ + visit_dtype(out.dtype(), [&](auto pt) { \ + using T = typename decltype(pt)::type; \ + struct Impl { \ + void operator()(int64_t i, T x, T& out) { \ + (void)i; \ + kernel_body \ + } \ + }; \ + Elementwise(Impl{}, x_cast, out); \ + }); \ + } \ + }; \ + \ + CHAINERX_NATIVE_REGISTER_KERNEL(key_kernel_cls, Native##key_kernel_cls); + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, VisitFloatingPointDtype) + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_UNARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_UNARY_KERNEL(key_kernel_cls, kernel_body, VisitDtype) + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, visit_dtype) \ + class Native##key_kernel_cls : public key_kernel_cls { \ + public: \ + void Call(const Array& x1, const Array& x2, const Array& out) override { \ + Device& device = x1.device(); \ + device.CheckDevicesCompatible(x1, x2, out); \ + const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype()); \ + const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype()); \ + visit_dtype(out.dtype(), [&](auto pt) { \ + using T = typename decltype(pt)::type; \ + struct Impl { \ + void operator()(int64_t i, T x1, T x2, T& out) { \ + (void)i; \ + kernel_body \ + } \ + }; \ + Elementwise(Impl{}, x1_cast, x2_cast, out); \ + }); \ + } \ + }; \ + \ + CHAINERX_NATIVE_REGISTER_KERNEL(key_kernel_cls, Native##key_kernel_cls); + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_BINARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, VisitFloatingPointDtype) + +#define CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_KERNEL(key_kernel_cls, kernel_body) \ + CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(key_kernel_cls, kernel_body, VisitDtype) diff --git a/chainerx_cc/chainerx/native/native_backend.h b/chainerx_cc/chainerx/native/native_backend.h index 58a43692dcb0..a82d74246ffb 100644 --- a/chainerx_cc/chainerx/native/native_backend.h +++ b/chainerx_cc/chainerx/native/native_backend.h @@ -5,7 +5,7 @@ #include "chainerx/backend.h" #include "chainerx/device.h" -#include "chainerx/op_registry.h" +#include "chainerx/kernel_registry.h" namespace chainerx { namespace native { @@ -35,13 +35,13 @@ class NativeBackend : public Backend { bool SupportsTransfer(Device& src_device, Device& dst_device) override; - static OpRegistry& GetGlobalOpRegistry() { - static OpRegistry* global_op_registry = new OpRegistry{}; - return *global_op_registry; + static KernelRegistry& GetGlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry{}; + return *global_kernel_registry; } protected: - OpRegistry& GetParentOpRegistry() override { return GetGlobalOpRegistry(); } + KernelRegistry& GetParentKernelRegistry() override { return GetGlobalKernelRegistry(); } private: std::unique_ptr CreateDevice(int index) override; diff --git a/chainerx_cc/chainerx/native/native_device.h b/chainerx_cc/chainerx/native/native_device.h index ce3c673b9e00..bfb177e901c8 100644 --- a/chainerx_cc/chainerx/native/native_device.h +++ b/chainerx_cc/chainerx/native/native_device.h @@ -13,6 +13,7 @@ #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/pooling.h" #include "chainerx/native/native_backend.h" #include "chainerx/routines/pooling.h" #include "chainerx/scalar.h" diff --git a/chainerx_cc/chainerx/native/native_device/activation.cc b/chainerx_cc/chainerx/native/native_device/activation.cc index 66792e3330b7..2dc3c88a5ed2 100644 --- a/chainerx_cc/chainerx/native/native_device/activation.cc +++ b/chainerx_cc/chainerx/native/native_device/activation.cc @@ -6,8 +6,9 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/numeric.h" #include "chainerx/routines/math.h" #include "chainerx/routines/type_util.h" @@ -17,7 +18,7 @@ namespace chainerx { namespace native { namespace { -class NativeIfLessElseASSAOp : public IfLessElseASSAOp { +class NativeIfLessElseASSAKernel : public IfLessElseASSAKernel { public: void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override { x1.device().CheckDevicesCompatible(x1, neg, out); @@ -39,9 +40,9 @@ class NativeIfLessElseASSAOp : public IfLessElseASSAOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IfLessElseASSAOp, NativeIfLessElseASSAOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IfLessElseASSAKernel, NativeIfLessElseASSAKernel); -class NativeIfGreaterElseASSAOp : public IfGreaterElseASSAOp { +class NativeIfGreaterElseASSAKernel : public IfGreaterElseASSAKernel { public: void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override { x1.device().CheckDevicesCompatible(x1, neg, out); @@ -63,9 +64,9 @@ class NativeIfGreaterElseASSAOp : public IfGreaterElseASSAOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IfGreaterElseASSAOp, NativeIfGreaterElseASSAOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IfGreaterElseASSAKernel, NativeIfGreaterElseASSAKernel); -class NativeIfGreaterElseAAAAOp : public IfGreaterElseAAAAOp { +class NativeIfGreaterElseAAAAKernel : public IfGreaterElseAAAAKernel { public: void Call(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) override { x1.device().CheckDevicesCompatible(x1, x2, pos, neg, out); @@ -87,9 +88,9 @@ class NativeIfGreaterElseAAAAOp : public IfGreaterElseAAAAOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IfGreaterElseAAAAOp, NativeIfGreaterElseAAAAOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IfGreaterElseAAAAKernel, NativeIfGreaterElseAAAAKernel); -class NativeTanhOp : public TanhOp { +class NativeTanhKernel : public TanhKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -104,7 +105,7 @@ class NativeTanhOp : public TanhOp { } }; -CHAINERX_NATIVE_REGISTER_OP(TanhOp, NativeTanhOp); +CHAINERX_NATIVE_REGISTER_KERNEL(TanhKernel, NativeTanhKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/arithmetic.cc b/chainerx_cc/chainerx/native/native_device/arithmetic.cc index 5d6afb05d49e..d0e26af102f2 100644 --- a/chainerx_cc/chainerx/native/native_device/arithmetic.cc +++ b/chainerx_cc/chainerx/native/native_device/arithmetic.cc @@ -7,8 +7,9 @@ #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/float16.h" +#include "chainerx/kernels/math.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/routines/math.h" #include "chainerx/scalar.h" @@ -16,9 +17,9 @@ namespace chainerx { namespace native { namespace { -CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_OP(AddOp, { out = ArithmeticOps::Add(x1, x2); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_KERNEL(AddKernel, { out = ArithmeticOps::Add(x1, x2); }); -class NativeAddASOp : public AddASOp { +class NativeAddASKernel : public AddASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -35,11 +36,11 @@ class NativeAddASOp : public AddASOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AddASOp, NativeAddASOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AddASKernel, NativeAddASKernel); -CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_OP(SubtractOp, { out = ArithmeticOps::Subtract(x1, x2); }, VisitNumericDtype); +CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(SubtractKernel, { out = ArithmeticOps::Subtract(x1, x2); }, VisitNumericDtype); -class NativeSubtractASOp : public SubtractASOp { +class NativeSubtractASKernel : public SubtractASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -56,11 +57,11 @@ class NativeSubtractASOp : public SubtractASOp { } }; -CHAINERX_NATIVE_REGISTER_OP(SubtractASOp, NativeSubtractASOp); +CHAINERX_NATIVE_REGISTER_KERNEL(SubtractASKernel, NativeSubtractASKernel); -CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_OP(MultiplyOp, { out = ArithmeticOps::Multiply(x1, x2); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_KERNEL(MultiplyKernel, { out = ArithmeticOps::Multiply(x1, x2); }); -class NativeMultiplyASOp : public MultiplyASOp { +class NativeMultiplyASKernel : public MultiplyASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -77,7 +78,7 @@ class NativeMultiplyASOp : public MultiplyASOp { } }; -CHAINERX_NATIVE_REGISTER_OP(MultiplyASOp, NativeMultiplyASOp); +CHAINERX_NATIVE_REGISTER_KERNEL(MultiplyASKernel, NativeMultiplyASKernel); int32_t FloorDivide(int32_t x, int32_t y) { auto div = std::div(x, y); @@ -102,9 +103,9 @@ chainerx::Float16 FloorDivide(chainerx::Float16 x, chainerx::Float16 y) { return chainerx::Float16{FloorDivide(static_cast(x), static_cast(y))}; } -CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_OP(FloorDivideOp, { out = native::FloorDivide(x1, x2); }, VisitNumericDtype); +CHAINERX_NATIVE_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(FloorDivideKernel, { out = native::FloorDivide(x1, x2); }, VisitNumericDtype); -class NativeFloorDivideASOp : public FloorDivideASOp { +class NativeFloorDivideASKernel : public FloorDivideASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -121,11 +122,11 @@ class NativeFloorDivideASOp : public FloorDivideASOp { } }; -CHAINERX_NATIVE_REGISTER_OP(FloorDivideASOp, NativeFloorDivideASOp); +CHAINERX_NATIVE_REGISTER_KERNEL(FloorDivideASKernel, NativeFloorDivideASKernel); -CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_OP(DivideOp, { out = ArithmeticOps::Divide(x1, x2); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_BINARY_KERNEL(DivideKernel, { out = ArithmeticOps::Divide(x1, x2); }); -class NativeDivideASOp : public DivideASOp { +class NativeDivideASKernel : public DivideASKernel { public: void Call(const Array& x1, Scalar x2, const Array& out) override { Device& device = x1.device(); @@ -142,7 +143,7 @@ class NativeDivideASOp : public DivideASOp { } }; -CHAINERX_NATIVE_REGISTER_OP(DivideASOp, NativeDivideASOp); +CHAINERX_NATIVE_REGISTER_KERNEL(DivideASKernel, NativeDivideASKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/batch_norm.cc b/chainerx_cc/chainerx/native/native_device/batch_norm.cc index f37d53458fae..5fd005faf4f3 100644 --- a/chainerx_cc/chainerx/native/native_device/batch_norm.cc +++ b/chainerx_cc/chainerx/native/native_device/batch_norm.cc @@ -1,12 +1,12 @@ -#include "chainerx/native/op_regist.h" -#include "chainerx/routines/normalization.h" +#include "chainerx/kernels/normalization.h" +#include "chainerx/native/kernel_regist.h" namespace chainerx { namespace native { -CHAINERX_NATIVE_REGISTER_OP(BatchNormOp, GenericBatchNormOp); -CHAINERX_NATIVE_REGISTER_OP(BatchNormGradOp, GenericBatchNormGradOp); -CHAINERX_NATIVE_REGISTER_OP(FixedBatchNormOp, GenericFixedBatchNormOp); +CHAINERX_NATIVE_REGISTER_KERNEL(BatchNormKernel, GenericBatchNormKernel); +CHAINERX_NATIVE_REGISTER_KERNEL(BatchNormGradKernel, GenericBatchNormGradKernel); +CHAINERX_NATIVE_REGISTER_KERNEL(FixedBatchNormKernel, GenericFixedBatchNormKernel); } // namespace native } // namespace chainerx diff --git a/chainerx_cc/chainerx/native/native_device/comparison.cc b/chainerx_cc/chainerx/native/native_device/comparison.cc index 0cccae4fe57e..e48748dd5998 100644 --- a/chainerx_cc/chainerx/native/native_device/comparison.cc +++ b/chainerx_cc/chainerx/native/native_device/comparison.cc @@ -5,8 +5,9 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/logic.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/native/reduce.h" #include "chainerx/routines/logic.h" @@ -14,7 +15,7 @@ namespace chainerx { namespace native { namespace { -class NativeEqualOp : public EqualOp { +class NativeEqualKernel : public EqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -32,9 +33,9 @@ class NativeEqualOp : public EqualOp { } }; -CHAINERX_NATIVE_REGISTER_OP(EqualOp, NativeEqualOp); +CHAINERX_NATIVE_REGISTER_KERNEL(EqualKernel, NativeEqualKernel); -class NativeNotEqualOp : public NotEqualOp { +class NativeNotEqualKernel : public NotEqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -52,9 +53,9 @@ class NativeNotEqualOp : public NotEqualOp { } }; -CHAINERX_NATIVE_REGISTER_OP(NotEqualOp, NativeNotEqualOp); +CHAINERX_NATIVE_REGISTER_KERNEL(NotEqualKernel, NativeNotEqualKernel); -class NativeGreaterOp : public GreaterOp { +class NativeGreaterKernel : public GreaterKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -72,9 +73,9 @@ class NativeGreaterOp : public GreaterOp { } }; -CHAINERX_NATIVE_REGISTER_OP(GreaterOp, NativeGreaterOp); +CHAINERX_NATIVE_REGISTER_KERNEL(GreaterKernel, NativeGreaterKernel); -class NativeGreaterEqualOp : public GreaterEqualOp { +class NativeGreaterEqualKernel : public GreaterEqualKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -92,9 +93,9 @@ class NativeGreaterEqualOp : public GreaterEqualOp { } }; -CHAINERX_NATIVE_REGISTER_OP(GreaterEqualOp, NativeGreaterEqualOp); +CHAINERX_NATIVE_REGISTER_KERNEL(GreaterEqualKernel, NativeGreaterEqualKernel); -class NativeLogicalNotOp : public LogicalNotOp { +class NativeLogicalNotKernel : public LogicalNotKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -109,9 +110,9 @@ class NativeLogicalNotOp : public LogicalNotOp { } }; -CHAINERX_NATIVE_REGISTER_OP(LogicalNotOp, NativeLogicalNotOp); +CHAINERX_NATIVE_REGISTER_KERNEL(LogicalNotKernel, NativeLogicalNotKernel); -class NativeLogicalAndOp : public LogicalAndOp { +class NativeLogicalAndKernel : public LogicalAndKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -129,9 +130,9 @@ class NativeLogicalAndOp : public LogicalAndOp { } }; -CHAINERX_NATIVE_REGISTER_OP(LogicalAndOp, NativeLogicalAndOp); +CHAINERX_NATIVE_REGISTER_KERNEL(LogicalAndKernel, NativeLogicalAndKernel); -class NativeLogicalOrOp : public LogicalOrOp { +class NativeLogicalOrKernel : public LogicalOrKernel { public: void Call(const Array& x1, const Array& x2, const Array& out) override { Device& device = x1.device(); @@ -149,9 +150,9 @@ class NativeLogicalOrOp : public LogicalOrOp { } }; -CHAINERX_NATIVE_REGISTER_OP(LogicalOrOp, NativeLogicalOrOp); +CHAINERX_NATIVE_REGISTER_KERNEL(LogicalOrKernel, NativeLogicalOrKernel); -class NativeAllOp : public AllOp { +class NativeAllKernel : public AllKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -171,9 +172,9 @@ class NativeAllOp : public AllOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AllOp, NativeAllOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AllKernel, NativeAllKernel); -class NativeAnyOp : public AnyOp { +class NativeAnyKernel : public AnyKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -194,7 +195,7 @@ class NativeAnyOp : public AnyOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AnyOp, NativeAnyOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AnyKernel, NativeAnyKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/conv.cc b/chainerx_cc/chainerx/native/native_device/conv.cc index 647eddc26459..ca6198280e37 100644 --- a/chainerx_cc/chainerx/native/native_device/conv.cc +++ b/chainerx_cc/chainerx/native/native_device/conv.cc @@ -16,13 +16,13 @@ #include "chainerx/error.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/connection.h" +#include "chainerx/kernels/creation.h" #include "chainerx/macro.h" #include "chainerx/native/col2im.h" #include "chainerx/native/im2col.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/native/tensor_dot.h" -#include "chainerx/routines/connection.h" -#include "chainerx/routines/creation.h" #include "chainerx/routines/manipulation.h" #include "chainerx/shape.h" #include "chainerx/stack_vector.h" @@ -31,7 +31,7 @@ namespace chainerx { namespace native { namespace { -class NativeConvOp : public ConvOp { +class NativeConvKernel : public ConvKernel { public: Array Call( const Array& x, @@ -81,9 +81,9 @@ class NativeConvOp : public ConvOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ConvOp, NativeConvOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ConvKernel, NativeConvKernel); -class NativeConvGradWeightOp : public ConvGradWeightOp { +class NativeConvGradWeightKernel : public ConvGradWeightKernel { public: Array Call( Dtype w_dtype, @@ -120,9 +120,9 @@ class NativeConvGradWeightOp : public ConvGradWeightOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ConvGradWeightOp, NativeConvGradWeightOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ConvGradWeightKernel, NativeConvGradWeightKernel); -class NativeConvTransposeOp : public ConvTransposeOp { +class NativeConvTransposeKernel : public ConvTransposeKernel { public: Array Call( const Array& x, @@ -158,7 +158,7 @@ class NativeConvTransposeOp : public ConvTransposeOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ConvTransposeOp, NativeConvTransposeOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ConvTransposeKernel, NativeConvTransposeKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/copy.cc b/chainerx_cc/chainerx/native/native_device/copy.cc index 6563398927a3..c0ca11a32728 100644 --- a/chainerx_cc/chainerx/native/native_device/copy.cc +++ b/chainerx_cc/chainerx/native/native_device/copy.cc @@ -5,18 +5,19 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/misc.h" namespace chainerx { namespace native { namespace { -CHAINERX_NATIVE_REGISTER_ELTWISE_UNARY_OP(CopyOp, { out = x; }); +CHAINERX_NATIVE_REGISTER_ELTWISE_UNARY_KERNEL(CopyKernel, { out = x; }); -class NativeAsTypeOp : public AsTypeOp { +class NativeAsTypeKernel : public AsTypeKernel { public: void Call(const Array& a, const Array& out) override { a.device().CheckDevicesCompatible(a, out); @@ -32,7 +33,7 @@ class NativeAsTypeOp : public AsTypeOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AsTypeOp, NativeAsTypeOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AsTypeKernel, NativeAsTypeKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/dot.cc b/chainerx_cc/chainerx/native/native_device/dot.cc index cd7f51cd0b7b..8667cf8a44dc 100644 --- a/chainerx_cc/chainerx/native/native_device/dot.cc +++ b/chainerx_cc/chainerx/native/native_device/dot.cc @@ -13,12 +13,13 @@ #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/linalg.h" #include "chainerx/macro.h" #include "chainerx/native/data_type.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/linalg.h" #include "chainerx/shape.h" namespace chainerx { @@ -116,7 +117,7 @@ void Gemm(const Array& a, const Array& b, const Array& out) { } if (!is_out_contiguous) { - out.device().backend().CallOp(out_contiguous, out); + out.device().backend().CallKernel(out_contiguous, out); } } @@ -140,7 +141,7 @@ double MultiplyAdd(double x, double y, double z) { return std::fma(x, y, z); } } // namespace -class NativeDotOp : public DotOp { +class NativeDotKernel : public DotKernel { public: void Call(const Array& a, const Array& b, const Array& out) override { Device& device = a.device(); @@ -223,7 +224,7 @@ class NativeDotOp : public DotOp { } }; -CHAINERX_NATIVE_REGISTER_OP(DotOp, NativeDotOp); +CHAINERX_NATIVE_REGISTER_KERNEL(DotKernel, NativeDotKernel); } // namespace native } // namespace chainerx diff --git a/chainerx_cc/chainerx/native/native_device/exp_log.cc b/chainerx_cc/chainerx/native/native_device/exp_log.cc index b2962a173654..028c651b1b4b 100644 --- a/chainerx_cc/chainerx/native/native_device/exp_log.cc +++ b/chainerx_cc/chainerx/native/native_device/exp_log.cc @@ -5,16 +5,16 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/numeric.h" -#include "chainerx/routines/math.h" namespace chainerx { namespace native { namespace { -class NativeExpOp : public ExpOp { +class NativeExpKernel : public ExpKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -29,9 +29,9 @@ class NativeExpOp : public ExpOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ExpOp, NativeExpOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ExpKernel, NativeExpKernel); -class NativeLogOp : public LogOp { +class NativeLogKernel : public LogKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -46,7 +46,7 @@ class NativeLogOp : public LogOp { } }; -CHAINERX_NATIVE_REGISTER_OP(LogOp, NativeLogOp); +CHAINERX_NATIVE_REGISTER_KERNEL(LogKernel, NativeLogKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/fill.cc b/chainerx_cc/chainerx/native/native_device/fill.cc index 077e7b9bcfb9..76fc580215dd 100644 --- a/chainerx_cc/chainerx/native/native_device/fill.cc +++ b/chainerx_cc/chainerx/native/native_device/fill.cc @@ -7,12 +7,12 @@ #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/native/data_type.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" -#include "chainerx/routines/creation.h" -#include "chainerx/routines/misc.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/scalar.h" #include "chainerx/shape.h" @@ -20,7 +20,7 @@ namespace chainerx { namespace native { namespace { -class NativeArangeOp : public ArangeOp { +class NativeArangeKernel : public ArangeKernel { public: void Call(Scalar start, Scalar step, const Array& out) override { VisitDtype(out.dtype(), [&](auto pt) { @@ -35,9 +35,9 @@ class NativeArangeOp : public ArangeOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ArangeOp, NativeArangeOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ArangeKernel, NativeArangeKernel); -class NativeIdentityOp : public IdentityOp { +class NativeIdentityKernel : public IdentityKernel { public: void Call(const Array& out) override { CHAINERX_ASSERT(out.ndim() == 2); @@ -55,9 +55,9 @@ class NativeIdentityOp : public IdentityOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IdentityOp, NativeIdentityOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IdentityKernel, NativeIdentityKernel); -class NativeEyeOp : public EyeOp { +class NativeEyeKernel : public EyeKernel { public: void Call(int64_t k, const Array& out) override { VisitDtype(out.dtype(), [k, &out](auto pt) { @@ -74,9 +74,9 @@ class NativeEyeOp : public EyeOp { } }; -CHAINERX_NATIVE_REGISTER_OP(EyeOp, NativeEyeOp); +CHAINERX_NATIVE_REGISTER_KERNEL(EyeKernel, NativeEyeKernel); -class NativeDiagflatOp : public DiagflatOp { +class NativeDiagflatKernel : public DiagflatKernel { public: void Call(const Array& v, int64_t k, const Array& out) override { CHAINERX_ASSERT(v.ndim() == 1); @@ -115,9 +115,9 @@ class NativeDiagflatOp : public DiagflatOp { } }; -CHAINERX_NATIVE_REGISTER_OP(DiagflatOp, NativeDiagflatOp); +CHAINERX_NATIVE_REGISTER_KERNEL(DiagflatKernel, NativeDiagflatKernel); -class NativeLinspaceOp : public LinspaceOp { +class NativeLinspaceKernel : public LinspaceKernel { public: void Call(double start, double stop, const Array& out) override { CHAINERX_ASSERT(out.ndim() == 1); @@ -141,9 +141,9 @@ class NativeLinspaceOp : public LinspaceOp { } }; -CHAINERX_NATIVE_REGISTER_OP(LinspaceOp, NativeLinspaceOp); +CHAINERX_NATIVE_REGISTER_KERNEL(LinspaceKernel, NativeLinspaceKernel); -class NativeFillOp : public FillOp { +class NativeFillKernel : public FillKernel { public: void Call(const Array& out, Scalar value) override { VisitDtype(out.dtype(), [&](auto pt) { @@ -157,7 +157,7 @@ class NativeFillOp : public FillOp { } }; -CHAINERX_NATIVE_REGISTER_OP(FillOp, NativeFillOp); +CHAINERX_NATIVE_REGISTER_KERNEL(FillKernel, NativeFillKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/indexing.cc b/chainerx_cc/chainerx/native/native_device/indexing.cc index e596130f3f7c..12f7aa9378f0 100644 --- a/chainerx_cc/chainerx/native/native_device/indexing.cc +++ b/chainerx_cc/chainerx/native/native_device/indexing.cc @@ -7,9 +7,10 @@ #include "chainerx/dtype.h" #include "chainerx/indexable_array.h" #include "chainerx/indexer.h" +#include "chainerx/kernels/indexing.h" #include "chainerx/macro.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/routines/indexing.h" #include "chainerx/shape.h" @@ -17,7 +18,7 @@ namespace chainerx { namespace native { namespace { -class NativeTakeOp : public TakeOp { +class NativeTakeKernel : public TakeKernel { public: void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) override { CHAINERX_ASSERT(GetKind(indices.dtype()) == DtypeKind::kInt || GetKind(indices.dtype()) == DtypeKind::kUInt); @@ -81,9 +82,9 @@ class NativeTakeOp : public TakeOp { } }; -CHAINERX_NATIVE_REGISTER_OP(TakeOp, NativeTakeOp); +CHAINERX_NATIVE_REGISTER_KERNEL(TakeKernel, NativeTakeKernel); -class NativeAddAtOp : public AddAtOp { +class NativeAddAtKernel : public AddAtKernel { public: void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) override { CHAINERX_ASSERT(a.shape() == out.shape()); @@ -155,7 +156,7 @@ class NativeAddAtOp : public AddAtOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AddAtOp, NativeAddAtOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AddAtKernel, NativeAddAtKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/misc.cc b/chainerx_cc/chainerx/native/native_device/misc.cc index 5a7f661904b3..5930b8a43a76 100644 --- a/chainerx_cc/chainerx/native/native_device/misc.cc +++ b/chainerx_cc/chainerx/native/native_device/misc.cc @@ -6,8 +6,9 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/numeric.h" #include "chainerx/routines/math.h" @@ -15,7 +16,7 @@ namespace chainerx { namespace native { namespace { -class NativeSquareOp : public SquareOp { +class NativeSquareKernel : public SquareKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -29,9 +30,9 @@ class NativeSquareOp : public SquareOp { } }; -CHAINERX_NATIVE_REGISTER_OP(SquareOp, NativeSquareOp); +CHAINERX_NATIVE_REGISTER_KERNEL(SquareKernel, NativeSquareKernel); -class NativeSqrtOp : public SqrtOp { +class NativeSqrtKernel : public SqrtKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -46,9 +47,9 @@ class NativeSqrtOp : public SqrtOp { } }; -CHAINERX_NATIVE_REGISTER_OP(SqrtOp, NativeSqrtOp); +CHAINERX_NATIVE_REGISTER_KERNEL(SqrtKernel, NativeSqrtKernel); -class NativeIsNanOp : public IsNanOp { +class NativeIsNanKernel : public IsNanKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -62,9 +63,9 @@ class NativeIsNanOp : public IsNanOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IsNanOp, NativeIsNanOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IsNanKernel, NativeIsNanKernel); -class NativeIsInfOp : public IsInfOp { +class NativeIsInfKernel : public IsInfKernel { public: void Call(const Array& x, const Array& out) override { x.device().CheckDevicesCompatible(x, out); @@ -78,9 +79,9 @@ class NativeIsInfOp : public IsInfOp { } }; -CHAINERX_NATIVE_REGISTER_OP(IsInfOp, NativeIsInfOp); +CHAINERX_NATIVE_REGISTER_KERNEL(IsInfKernel, NativeIsInfKernel); -class NativeCeilOp : public CeilOp { +class NativeCeilKernel : public CeilKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -96,9 +97,9 @@ class NativeCeilOp : public CeilOp { } }; -CHAINERX_NATIVE_REGISTER_OP(CeilOp, NativeCeilOp); +CHAINERX_NATIVE_REGISTER_KERNEL(CeilKernel, NativeCeilKernel); -class NativeFloorOp : public FloorOp { +class NativeFloorKernel : public FloorKernel { public: void Call(const Array& x, const Array& out) override { Device& device = x.device(); @@ -114,7 +115,7 @@ class NativeFloorOp : public FloorOp { } }; -CHAINERX_NATIVE_REGISTER_OP(FloorOp, NativeFloorOp); +CHAINERX_NATIVE_REGISTER_KERNEL(FloorKernel, NativeFloorKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/pool.cc b/chainerx_cc/chainerx/native/native_device/pool.cc index af72e2b6f361..a9a850b7ef7b 100644 --- a/chainerx_cc/chainerx/native/native_device/pool.cc +++ b/chainerx_cc/chainerx/native/native_device/pool.cc @@ -13,11 +13,14 @@ #include "chainerx/constant.h" #include "chainerx/dtype.h" #include "chainerx/error.h" +#include "chainerx/kernels/indexing.h" +#include "chainerx/kernels/math.h" +#include "chainerx/kernels/pooling.h" #include "chainerx/macro.h" #include "chainerx/native/col2im.h" #include "chainerx/native/elementwise.h" #include "chainerx/native/im2col.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/native/tensor_dot.h" #include "chainerx/numeric_limits.h" #include "chainerx/routines/connection.h" @@ -54,7 +57,7 @@ Axes GetSwapSpatialDimensionsAxes(size_t n) { return axes; } -class NativeMaxPoolOp : public MaxPoolOp { +class NativeMaxPoolKernel : public MaxPoolKernel { public: std::tuple> Call( const Array& x, @@ -86,9 +89,9 @@ class NativeMaxPoolOp : public MaxPoolOp { } }; -CHAINERX_NATIVE_REGISTER_OP(MaxPoolOp, NativeMaxPoolOp); +CHAINERX_NATIVE_REGISTER_KERNEL(MaxPoolKernel, NativeMaxPoolKernel); -class NativeMaxPoolGradOp : public MaxPoolGradOp { +class NativeMaxPoolGradKernel : public MaxPoolGradKernel { public: std::tuple> Call( const Array& gout, @@ -122,7 +125,7 @@ class NativeMaxPoolGradOp : public MaxPoolGradOp { Device& device = x.device(); Array gcol = Zeros({out_total_size * kernel_total_size}, x.dtype(), device); Array offset = Arange(0, out_total_size * kernel_total_size, kernel_total_size, indices.dtype(), device); - device.backend().CallOp(gcol, indices.Reshape(out_flat) + offset, 0, gout.Reshape(out_flat), gcol); + device.backend().CallKernel(gcol, indices.Reshape(out_flat) + offset, 0, gout.Reshape(out_flat), gcol); // Reshape col gradients to (batch_size, channel, out_1, out_2, ..., out_n, k_1, k_2, ..., k_n). Shape out_shape_with_kernel = gout.shape(); @@ -142,9 +145,9 @@ class NativeMaxPoolGradOp : public MaxPoolGradOp { } }; -CHAINERX_NATIVE_REGISTER_OP(MaxPoolGradOp, NativeMaxPoolGradOp); +CHAINERX_NATIVE_REGISTER_KERNEL(MaxPoolGradKernel, NativeMaxPoolGradKernel); -class NativeMaxPoolGradGradOp : public MaxPoolGradGradOp { +class NativeMaxPoolGradGradKernel : public MaxPoolGradGradKernel { public: Array Call( const Array& ggx, @@ -176,13 +179,13 @@ class NativeMaxPoolGradGradOp : public MaxPoolGradGradOp { } }; -CHAINERX_NATIVE_REGISTER_OP(MaxPoolGradGradOp, NativeMaxPoolGradGradOp); +CHAINERX_NATIVE_REGISTER_KERNEL(MaxPoolGradGradKernel, NativeMaxPoolGradGradKernel); // TODO(hvy): Use Device::Mean when implemented. void Mean(const Array& a, const Axes& axis, const Array& out) { Device& device = a.device(); - device.backend().CallOp(a, axis, out); - device.backend().CallOp(out, internal::CountItemsAlongAxes(a.shape(), axis), out); + device.backend().CallKernel(a, axis, out); + device.backend().CallKernel(out, internal::CountItemsAlongAxes(a.shape(), axis), out); } Array GetPadModeIgnorePoolingWidths( @@ -245,7 +248,7 @@ Array GetPadModeIgnorePoolingWidths( return widths; } -class NativeAveragePoolOp : public AveragePoolOp { +class NativeAveragePoolKernel : public AveragePoolKernel { public: std::tuple> Call( const Array& x, @@ -280,10 +283,10 @@ class NativeAveragePoolOp : public AveragePoolOp { break; case AveragePoolPadMode::kIgnore: { Device& device = x.device(); - device.backend().CallOp(col, kernel_axes, actual_out); + device.backend().CallKernel(col, kernel_axes, actual_out); width_ignore = GetPadModeIgnorePoolingWidths(x.shape(), kernel_size, stride, pad, x.dtype()).BroadcastTo(actual_out.shape()); - device.backend().CallOp(actual_out, *width_ignore, actual_out); + device.backend().CallKernel(actual_out, *width_ignore, actual_out); break; } default: @@ -297,9 +300,9 @@ class NativeAveragePoolOp : public AveragePoolOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AveragePoolOp, NativeAveragePoolOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AveragePoolKernel, NativeAveragePoolKernel); -class NativeAveragePoolGradOp : public AveragePoolGradOp { +class NativeAveragePoolGradKernel : public AveragePoolGradKernel { public: Array Call( const Array& gout, @@ -347,7 +350,7 @@ class NativeAveragePoolGradOp : public AveragePoolGradOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AveragePoolGradOp, NativeAveragePoolGradOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AveragePoolGradKernel, NativeAveragePoolGradKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/reduction.cc b/chainerx_cc/chainerx/native/native_device/reduction.cc index 852610dad66f..f303ab259fac 100644 --- a/chainerx_cc/chainerx/native/native_device/reduction.cc +++ b/chainerx_cc/chainerx/native/native_device/reduction.cc @@ -6,20 +6,21 @@ #include "chainerx/axes.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" +#include "chainerx/kernels/sorting.h" #include "chainerx/macro.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/native/reduce.h" #include "chainerx/numeric.h" #include "chainerx/numeric_limits.h" #include "chainerx/routines/math.h" -#include "chainerx/routines/sorting.h" #include "chainerx/shape.h" namespace chainerx { namespace native { namespace { -class NativeArgMaxOp : public ArgMaxOp { +class NativeArgMaxKernel : public ArgMaxKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(std::all_of(axis.begin(), axis.end(), [&a](int8_t i) { return a.shape()[i] > 0; })); @@ -48,9 +49,9 @@ class NativeArgMaxOp : public ArgMaxOp { } }; -CHAINERX_NATIVE_REGISTER_OP(ArgMaxOp, NativeArgMaxOp); +CHAINERX_NATIVE_REGISTER_KERNEL(ArgMaxKernel, NativeArgMaxKernel); -class NativeSumOp : public SumOp { +class NativeSumKernel : public SumKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -73,9 +74,9 @@ class NativeSumOp : public SumOp { } }; -CHAINERX_NATIVE_REGISTER_OP(SumOp, NativeSumOp); +CHAINERX_NATIVE_REGISTER_KERNEL(SumKernel, NativeSumKernel); -class NativeAMaxOp : public AMaxOp { +class NativeAMaxKernel : public AMaxKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -98,9 +99,9 @@ class NativeAMaxOp : public AMaxOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AMaxOp, NativeAMaxOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AMaxKernel, NativeAMaxKernel); -class NativeAMinOp : public AMinOp { +class NativeAMinKernel : public AMinKernel { public: void Call(const Array& a, const Axes& axis, const Array& out) override { CHAINERX_ASSERT(internal::IsValidReductionShape(a.shape(), axis, out.shape(), true)); @@ -123,7 +124,7 @@ class NativeAMinOp : public AMinOp { } }; -CHAINERX_NATIVE_REGISTER_OP(AMinOp, NativeAMinOp); +CHAINERX_NATIVE_REGISTER_KERNEL(AMinKernel, NativeAMinKernel); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/native_device/trigonometric.cc b/chainerx_cc/chainerx/native/native_device/trigonometric.cc index 52b72d062af2..5ee6ecc132a4 100644 --- a/chainerx_cc/chainerx/native/native_device/trigonometric.cc +++ b/chainerx_cc/chainerx/native/native_device/trigonometric.cc @@ -6,35 +6,35 @@ #include "chainerx/array.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/native/elementwise.h" -#include "chainerx/native/op_regist.h" +#include "chainerx/native/kernel_regist.h" #include "chainerx/numeric.h" -#include "chainerx/routines/math.h" #include "chainerx/scalar.h" namespace chainerx { namespace native { namespace { -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(SinOp, { out = chainerx::Sin(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SinKernel, { out = chainerx::Sin(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(CosOp, { out = chainerx::Cos(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(CosKernel, { out = chainerx::Cos(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(TanOp, { out = chainerx::Tan(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(TanKernel, { out = chainerx::Tan(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArcsinOp, { out = chainerx::Arcsin(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArcsinKernel, { out = chainerx::Arcsin(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArccosOp, { out = chainerx::Arccos(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArccosKernel, { out = chainerx::Arccos(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArctanOp, { out = chainerx::Arctan(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArctanKernel, { out = chainerx::Arctan(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(SinhOp, { out = chainerx::Sinh(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(SinhKernel, { out = chainerx::Sinh(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(CoshOp, { out = chainerx::Cosh(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(CoshKernel, { out = chainerx::Cosh(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArcsinhOp, { out = chainerx::Arcsinh(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArcsinhKernel, { out = chainerx::Arcsinh(x); }); -CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_OP(ArccoshOp, { out = chainerx::Arccosh(x); }); +CHAINERX_NATIVE_REGISTER_ELTWISE_FLOAT_UNARY_KERNEL(ArccoshKernel, { out = chainerx::Arccosh(x); }); } // namespace } // namespace native diff --git a/chainerx_cc/chainerx/native/tensor_dot.cc b/chainerx_cc/chainerx/native/tensor_dot.cc index 8d30714df7f7..80d05c1946fb 100644 --- a/chainerx_cc/chainerx/native/tensor_dot.cc +++ b/chainerx_cc/chainerx/native/tensor_dot.cc @@ -9,9 +9,9 @@ #include "chainerx/array.h" #include "chainerx/axes.h" #include "chainerx/device.h" +#include "chainerx/kernels/linalg.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/linalg.h" #include "chainerx/shape.h" namespace chainerx { @@ -89,7 +89,8 @@ Array TensorDot(const Array& a, const Array& b, const Axes& a_axis, const Axes& // Compute the dot product between a and b reshaped to 2-dimensions. Shape dot_shape{a_remain_total_size, b_remain_total_size}; Array dot_out = Empty(dot_shape, out_dtype, a.device()); - a.device().backend().CallOp(a.Transpose(a_roll_axes).Reshape(a_shape), b.Transpose(b_roll_axes).Reshape(b_shape), dot_out); + a.device().backend().CallKernel( + a.Transpose(a_roll_axes).Reshape(a_shape), b.Transpose(b_roll_axes).Reshape(b_shape), dot_out); // Reshape and return the output array. Shape out_shape = a_remain_dims; diff --git a/chainerx_cc/chainerx/op_registry.h b/chainerx_cc/chainerx/op_registry.h deleted file mode 100644 index 6ee85787dfc3..000000000000 --- a/chainerx_cc/chainerx/op_registry.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "chainerx/error.h" -#include "chainerx/op.h" - -namespace chainerx { - -// Manages dynamic registration and dispatch of ops. -// This class is hierarchical: it has an optional pointer to a parent OpRegistry and falls back if an op is not found in this instance. -class OpRegistry { -public: - OpRegistry() {} - - explicit OpRegistry(OpRegistry* parent) : parent_{parent} {} - - // Registers an op. - // Registers an instance of OpType with the type_index of KeyOpType as the key. - // OpType must be a subclass of KeyOpType. - template - void RegisterOp() { - static_assert(std::is_base_of::value, "OpType must be a subclass of KeyOpType."); - std::lock_guard lock{*mutex_}; - auto pair = ops_.emplace(std::type_index{typeid(KeyOpType)}, std::make_unique()); - if (!pair.second) { - throw ChainerxError{"Duplicate op: ", KeyOpType::name()}; - } - } - - // Looks up an op. - template - Op& GetOp() { - std::type_index key{typeid(KeyOpType)}; - { - std::lock_guard lock{*mutex_}; - auto it = ops_.find(key); - if (it != ops_.end()) { - return *it->second; - } - } - if (parent_ != nullptr) { - return parent_->GetOp(); - } - throw ChainerxError{"Op not found: ", KeyOpType::name()}; - } - -private: - std::unique_ptr mutex_{std::make_unique()}; - - OpRegistry* parent_{}; - - std::unordered_map> ops_{}; -}; - -namespace internal { - -// A facility to register ops statically. -template -class OpRegistrar { -public: - OpRegistrar() noexcept { - OpRegistry& op_registry = BackendType::GetGlobalOpRegistry(); - op_registry.RegisterOp(); - } -}; - -} // namespace internal -} // namespace chainerx diff --git a/chainerx_cc/chainerx/op_registry_test.cc b/chainerx_cc/chainerx/op_registry_test.cc deleted file mode 100644 index 40d4ba6e2fc2..000000000000 --- a/chainerx_cc/chainerx/op_registry_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -#include "chainerx/op_registry.h" - -#include -#include - -#include - -#include "chainerx/backend.h" -#include "chainerx/context.h" -#include "chainerx/error.h" -#include "chainerx/macro.h" -#include "chainerx/native/native_backend.h" -#include "chainerx/op.h" -#include "chainerx/testing/threading.h" -#include "chainerx/util.h" - -namespace chainerx { -namespace { - -TEST(OpRegistryTest, OpRegistry) { - OpRegistry op_registry{}; - - class MyOp : public Op { - public: - static const char* name() { return "myop"; } - virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } - }; - - op_registry.RegisterOp(); - - Op& op = op_registry.GetOp(); - - // no throw - MyOp& myop = dynamic_cast(op); - - EXPECT_EQ(myop.Call(3, " is 3"), "3 is 3"); -} - -TEST(OpRegistryTest, OpRegistryHierarchy) { - OpRegistry parent_op_registry{}; - OpRegistry op_registry1{&parent_op_registry}; - OpRegistry op_registry2{&parent_op_registry}; - - class MyOp1 : public Op { - public: - static const char* name() { return "myop1"; } - virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } - }; - - class MyParentOp : public Op { - public: - static const char* name() { return "myparentop"; } - virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } - }; - - op_registry1.RegisterOp(); - parent_op_registry.RegisterOp(); - - EXPECT_THROW({ op_registry2.GetOp(); }, ChainerxError); - EXPECT_THROW({ parent_op_registry.GetOp(); }, ChainerxError); - // no throw - Op& op1p = op_registry1.GetOp(); - Op& op2p = op_registry2.GetOp(); - Op& oppp = parent_op_registry.GetOp(); - - Op& op = op_registry1.GetOp(); - EXPECT_EQ(&op1p, &op2p); - EXPECT_EQ(&op1p, &oppp); - EXPECT_NE(&op1p, &op); - - MyOp1& myop = dynamic_cast(op); - EXPECT_EQ(myop.Call(3, " is 3"), "3 is 3"); -} - -TEST(OpRegistryTest, OpRegistryWithBackend) { - // TODO(imanishi): Restore the environment variable after this test. - SetEnv("CHAINERX_PATH", CHAINERX_TEST_DIR "/backend_testdata"); - Context ctx; - Backend& backend0 = ctx.GetBackend("backend0"); - - OpRegistry& op_registry = backend0.op_registry(); - - class MyOp : public Op { - public: - static const char* name() { return "myop"; } - virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } - }; - - op_registry.RegisterOp(); - - Op& op = op_registry.GetOp(); - - // no throw - MyOp& myop = dynamic_cast(op); - - EXPECT_EQ(myop.Call(3, " is 3"), "3 is 3"); - - // MyOp should not be regsitered to the base backend: NativeBackend. - EXPECT_THROW({ native::NativeBackend::GetGlobalOpRegistry().GetOp(); }, ChainerxError); - - // MyOp should not be regsitered to another Backend0 instance. - { - Context ctx_another; - Backend& backend_another = ctx_another.GetBackend("backend0"); - EXPECT_THROW({ backend_another.op_registry().GetOp(); }, ChainerxError); - } -} - -TEST(OpRegistryTest, OpRegistryThreadSafe) { - OpRegistry parent_op_registry{}; - OpRegistry op_registry1{&parent_op_registry}; - - class MyOp1 : public Op { - public: - static const char* name() { return "myop1"; } - virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } - }; - - class MyParentOp : public Op { - public: - static const char* name() { return "myparentop"; } - virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } - }; - - class MyOp2 : public Op { - public: - static const char* name() { return "myop2"; } - virtual std::string Call(int a, const std::string& b) { return std::to_string(a) + b; } - }; - - class MyParentOp2 : public Op { - public: - static const char* name() { return "myparentop2"; } - virtual std::string Call(const std::string& a, float b) { return a + std::to_string(b); } - }; - - op_registry1.RegisterOp(); - parent_op_registry.RegisterOp(); - - testing::RunThreads(4U, [&parent_op_registry, &op_registry1](size_t thread_index) { - switch (thread_index) { - case 0: - op_registry1.GetOp(); - break; - case 1: - op_registry1.GetOp(); - break; - case 2: - op_registry1.RegisterOp(); - break; - case 3: - parent_op_registry.RegisterOp(); - break; - default: - CHAINERX_NEVER_REACH(); - } - }); -} - -} // namespace -} // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/CMakeLists.txt b/chainerx_cc/chainerx/routines/CMakeLists.txt index d478a1230d0f..5171009995cf 100644 --- a/chainerx_cc/chainerx/routines/CMakeLists.txt +++ b/chainerx_cc/chainerx/routines/CMakeLists.txt @@ -21,7 +21,6 @@ install(FILES logic.h manipulation.h math.h - misc.h normalization.h pooling.h routines_util.h diff --git a/chainerx_cc/chainerx/routines/connection.cc b/chainerx_cc/chainerx/routines/connection.cc index 1f8fc95e4db3..440a9f461f95 100644 --- a/chainerx_cc/chainerx/routines/connection.cc +++ b/chainerx_cc/chainerx/routines/connection.cc @@ -17,8 +17,11 @@ #include "chainerx/dims.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernel_registry.h" +#include "chainerx/kernels/connection.h" +#include "chainerx/kernels/linalg.h" +#include "chainerx/kernels/math.h" #include "chainerx/macro.h" -#include "chainerx/op_registry.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/linalg.h" #include "chainerx/routines/math.h" @@ -69,7 +72,7 @@ Array ConvGradWeight( Array out{}; { NoBackpropModeScope scope{}; - out = x.device().backend().CallOp(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); + out = x.device().backend().CallKernel(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt); CHAINERX_ASSERT(out.dtype() == w_dtype); } @@ -142,7 +145,7 @@ Array Conv( Array out{}; { NoBackpropModeScope scope{}; - out = x.device().backend().CallOp(x, w, b, stride, pad, cover_all, real_out_dtype, nonstd::nullopt); + out = x.device().backend().CallKernel(x, w, b, stride, pad, cover_all, real_out_dtype, nonstd::nullopt); } { @@ -254,7 +257,7 @@ Array ConvTranspose( Array out{}; { NoBackpropModeScope scope{}; - out = x.device().backend().CallOp(x, w, b, stride, pad, real_out_size, real_out_dtype, nonstd::nullopt); + out = x.device().backend().CallKernel(x, w, b, stride, pad, real_out_size, real_out_dtype, nonstd::nullopt); } { @@ -339,10 +342,10 @@ Array Linear(const Array& x, const Array& w, const nonstd::optional& b, u { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x_matrix, w.Transpose(), out_matrix); + x.device().backend().CallKernel(x_matrix, w.Transpose(), out_matrix); if (has_bias) { - x.device().backend().CallOp(out_matrix, b_matrix.AsType(out_dtype, false), out_matrix); + x.device().backend().CallKernel(out_matrix, b_matrix.AsType(out_dtype, false), out_matrix); } } diff --git a/chainerx_cc/chainerx/routines/connection.h b/chainerx_cc/chainerx/routines/connection.h index 22308ee8d773..f51b9920cee0 100644 --- a/chainerx_cc/chainerx/routines/connection.h +++ b/chainerx_cc/chainerx/routines/connection.h @@ -6,7 +6,6 @@ #include "chainerx/array.h" #include "chainerx/constant.h" -#include "chainerx/op.h" #include "chainerx/stack_vector.h" namespace chainerx { @@ -21,63 +20,6 @@ int64_t GetConvTransposeOutDim(int64_t in_dim, int64_t kernel_size, int64_t stri } // namespace internal -// Computes the n-dimensional convolution. -// -// x: (batch_size, in_channels, in_1, in_2, ..., in_n) -// w: (out_channels, in_channels, k_1, k_2, ..., k_n) -// b: (out_channels) -// -// Returns an array of shape (batch_size, out_channels, out_1, out_2, ..., out_n). -class ConvOp : public Op { -public: - static const char* name() { return "Conv"; } - - virtual Array Call( - const Array& x, - const Array& w, - const nonstd::optional& b, - const StackVector& stride, - const StackVector& pad, - bool cover_all, - Dtype out_dtype, - const nonstd::optional& out) = 0; -}; - -// Computes the n-dimensional transposed convolution. -// -// x: (batch_size, in_channels, in_1, in_2, ..., in_n) -// w: (in_channels, out_channels, k_1, k_2, ..., k_n) -// b: (out_channels) -// -// Returns an array of shape (batch_size, out_channels, out_1, out_2, ..., out_n). -class ConvTransposeOp : public Op { -public: - static const char* name() { return "ConvTranspose"; } - virtual Array Call( - const Array& x, - const Array& w, - const nonstd::optional& b, - const StackVector& stride, - const StackVector& pad, - const StackVector& out_size, - Dtype out_dtype, - const nonstd::optional& out) = 0; -}; - -class ConvGradWeightOp : public Op { -public: - static const char* name() { return "ConvGradWeight"; } - virtual Array Call( - Dtype w_dtype, - const Shape& w_shape, - const Array& x, - const Array& gy, - const StackVector& stride, - const StackVector& pad, - bool cover_all, - const nonstd::optional& out) = 0; -}; - // Computes the n-dimensional convolution. // // x: (batch_size, in_channels, in_1, in_2, ..., in_n) diff --git a/chainerx_cc/chainerx/routines/creation.cc b/chainerx_cc/chainerx/routines/creation.cc index 38b4c501321e..0d0374e09e51 100644 --- a/chainerx_cc/chainerx/routines/creation.cc +++ b/chainerx_cc/chainerx/routines/creation.cc @@ -18,8 +18,9 @@ #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/graph.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" -#include "chainerx/routines/misc.h" #include "chainerx/routines/type_util.h" #include "chainerx/scalar.h" #include "chainerx/shape.h" @@ -127,7 +128,7 @@ Array Arange(Scalar start, Scalar stop, Scalar step, Dtype dtype, Device& device } Array out = Empty({size}, dtype, device); - device.backend().CallOp(start, step, out); + device.backend().CallKernel(start, step, out); return out; } @@ -159,7 +160,7 @@ Array Copy(const Array& a) { Array out = EmptyLike(a, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, out); + a.device().backend().CallKernel(a, out); } BackwardBuilder bb{"copy", a, out}; @@ -181,7 +182,7 @@ Array Identity(int64_t n, Dtype dtype, Device& device) { Array out = Empty(Shape{n, n}, dtype, device); { NoBackpropModeScope scope{}; - device.backend().CallOp(out); + device.backend().CallKernel(out); } return out; } @@ -203,7 +204,7 @@ Array Eye(int64_t n, nonstd::optional m, nonstd::optional k, n Array out = Empty({n, m.value()}, dtype.value(), device); { NoBackpropModeScope scope{}; - device.backend().CallOp(k.value(), out); + device.backend().CallKernel(k.value(), out); } return out; } @@ -218,7 +219,7 @@ Array AsContiguous(const Array& a, Dtype dtype) { Array out = Empty(a.shape(), dtype, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a.AsGradStopped(), out); + a.device().backend().CallKernel(a.AsGradStopped(), out); } if (GetKind(dtype) == DtypeKind::kFloat) { @@ -268,7 +269,7 @@ Array Diag(const Array& v, int64_t k, Device& device) { out = Empty(Shape{n, n}, v.dtype(), device); { NoBackpropModeScope scope{}; - device.backend().CallOp(v, k, out); + device.backend().CallKernel(v, k, out); } } else if (ndim == 2) { // Return the diagonal as a 1D array. @@ -338,7 +339,7 @@ Array Linspace( } { NoBackpropModeScope scope{}; - device.backend().CallOp(start_value, stop_value, out); + device.backend().CallKernel(start_value, stop_value, out); } } return out; diff --git a/chainerx_cc/chainerx/routines/creation.h b/chainerx_cc/chainerx/routines/creation.h index 72981ccb5864..d60d949cafa9 100644 --- a/chainerx_cc/chainerx/routines/creation.h +++ b/chainerx_cc/chainerx/routines/creation.h @@ -12,7 +12,6 @@ #include "chainerx/device.h" #include "chainerx/dtype.h" #include "chainerx/graph.h" -#include "chainerx/op.h" #include "chainerx/scalar.h" #include "chainerx/shape.h" @@ -55,57 +54,6 @@ Array FromData( int64_t offset = 0, Device& device = GetDefaultDevice()); -class ArangeOp : public Op { -public: - static const char* name() { return "Arange"; } - - virtual void Call(Scalar start, Scalar step, const Array& out) = 0; -}; - -class CopyOp : public Op { -public: - static const char* name() { return "Copy"; } - - // Copies the elements from one array to the other. - // - // The arrays must match in shape and dtype and need to reside on this device. - virtual void Call(const Array& a, const Array& out) = 0; -}; - -class IdentityOp : public Op { -public: - static const char* name() { return "Identity"; } - - // Creates the identity array. - // out must be a square 2-dim array. - virtual void Call(const Array& out) = 0; -}; - -class EyeOp : public Op { -public: - static const char* name() { return "Eye"; } - - // Creates a 2-dimensional array with ones along the k-th diagonal and zeros elsewhere. - // out must be a square 2-dim array. - virtual void Call(int64_t k, const Array& out) = 0; -}; - -class DiagflatOp : public Op { -public: - static const char* name() { return "Diagflat"; } - - virtual void Call(const Array& v, int64_t k, const Array& out) = 0; -}; - -class LinspaceOp : public Op { -public: - static const char* name() { return "Linspace"; } - - // Creates an evenly spaced 1-d array. - // `out.ndim()` must be 1 with at least 1 elements. - virtual void Call(double start, double stop, const Array& out) = 0; -}; - Array Empty(const Shape& shape, Dtype dtype, Device& device = GetDefaultDevice()); Array Full(const Shape& shape, Scalar fill_value, Dtype dtype, Device& device = GetDefaultDevice()); Array Full(const Shape& shape, Scalar fill_value, Device& device = GetDefaultDevice()); diff --git a/chainerx_cc/chainerx/routines/indexing.cc b/chainerx_cc/chainerx/routines/indexing.cc index eb722f8d0e94..de88e9274cb3 100644 --- a/chainerx_cc/chainerx/routines/indexing.cc +++ b/chainerx_cc/chainerx/routines/indexing.cc @@ -17,6 +17,8 @@ #include "chainerx/constant.h" #include "chainerx/dtype.h" #include "chainerx/graph.h" +#include "chainerx/kernels/indexing.h" +#include "chainerx/kernels/math.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/math.h" @@ -44,7 +46,7 @@ Array AddAt(const Array& a, const std::vector& indices, const Array& { NoBackpropModeScope scope{}; - a.device().backend().CallOp(b, out_view, out_view); + a.device().backend().CallKernel(b, out_view, out_view); } { @@ -134,7 +136,7 @@ Array AddAt(const Array& a, const Array& indices, int8_t axis, const Array& b) { { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, indices, axis, b, out); + a.device().backend().CallKernel(a, indices, axis, b, out); } { @@ -170,7 +172,7 @@ Array Take(const Array& a, const Array& indices, int8_t axis) { { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, indices, axis_norm, out); + a.device().backend().CallKernel(a, indices, axis_norm, out); } BackwardBuilder bb{"take", a, out}; diff --git a/chainerx_cc/chainerx/routines/indexing.h b/chainerx_cc/chainerx/routines/indexing.h index 0a35734c7a8c..93c41a6f3aa6 100644 --- a/chainerx_cc/chainerx/routines/indexing.h +++ b/chainerx_cc/chainerx/routines/indexing.h @@ -3,30 +3,10 @@ #include #include -#include "nonstd/optional.hpp" - #include "chainerx/array.h" #include "chainerx/array_index.h" -#include "chainerx/backend.h" -#include "chainerx/device.h" -#include "chainerx/op.h" namespace chainerx { - -class AddAtOp : public Op { -public: - static const char* name() { return "AddAt"; } - - virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) = 0; -}; - -class TakeOp : public Op { -public: - static const char* name() { return "Take"; } - - virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) = 0; -}; - namespace internal { // Returns a view selected with the indices. diff --git a/chainerx_cc/chainerx/routines/linalg.cc b/chainerx_cc/chainerx/routines/linalg.cc index c62c50b27473..6c807f83f434 100644 --- a/chainerx_cc/chainerx/routines/linalg.cc +++ b/chainerx_cc/chainerx/routines/linalg.cc @@ -17,6 +17,7 @@ #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/linalg.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/type_util.h" #include "chainerx/shape.h" @@ -72,7 +73,7 @@ Array Dot(const Array& a, const Array& b, nonstd::optional out_dtype) { Array out_matrix = Empty({m, n}, real_out_dtype, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a_matrix, b_matrix, out_matrix); + a.device().backend().CallKernel(a_matrix, b_matrix, out_matrix); } { diff --git a/chainerx_cc/chainerx/routines/linalg.h b/chainerx_cc/chainerx/routines/linalg.h index 99f6d2844785..479e7462fd99 100644 --- a/chainerx_cc/chainerx/routines/linalg.h +++ b/chainerx_cc/chainerx/routines/linalg.h @@ -4,21 +4,9 @@ #include "chainerx/array.h" #include "chainerx/dtype.h" -#include "chainerx/op.h" namespace chainerx { -// Matrix multiplication. All the operands are matrices (i.e., two-dimensional arrays). -// Let the shapes of `a` and `b` be `(M, K)` and `(L, N)`, respectively. -// Then, it must hold that `K == L` and the shape of `out` must be `(M, N)`. -// Otherwise, the behavior is undefined. -class DotOp : public Op { -public: - static const char* name() { return "Dot"; } - - virtual void Call(const Array& a, const Array& b, const Array& out) = 0; -}; - Array Dot(const Array& a, const Array& b, nonstd::optional out_dtype = nonstd::nullopt); } // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/logic.cc b/chainerx_cc/chainerx/routines/logic.cc index b6c3647b187f..b5b021847846 100644 --- a/chainerx_cc/chainerx/routines/logic.cc +++ b/chainerx_cc/chainerx/routines/logic.cc @@ -4,6 +4,7 @@ #include "chainerx/backprop_mode.h" #include "chainerx/device.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/logic.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/manipulation.h" #include "chainerx/routines/type_util.h" @@ -47,25 +48,27 @@ void CheckLogicDtypes(const Array& x1, const Array& x2) { Array Equal(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallKernel(x1, x2, out); }; return BroadcastComparison(func, x1, x2); } Array NotEqual(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallKernel(x1, x2, out); }; return BroadcastComparison(func, x1, x2); } Array Greater(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { x1.device().backend().CallKernel(x1, x2, out); }; return BroadcastComparison(func, x1, x2); } Array GreaterEqual(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { + return x1.device().backend().CallKernel(x1, x2, out); + }; return BroadcastComparison(func, x1, x2); } @@ -73,20 +76,22 @@ Array LogicalNot(const Array& x) { Array out = Empty(x.shape(), Dtype::kBool, x.device()); { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } return out; } Array LogicalAnd(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { + return x1.device().backend().CallKernel(x1, x2, out); + }; return BroadcastComparison(func, x1, x2); } Array LogicalOr(const Array& x1, const Array& x2) { CheckLogicDtypes(x1, x2); - auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallOp(x1, x2, out); }; + auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallKernel(x1, x2, out); }; return BroadcastComparison(func, x1, x2); } @@ -95,7 +100,7 @@ Array All(const Array& a, const OptionalAxes& axis, bool keepdims) { Array out = internal::EmptyReduced(a.shape(), Dtype::kBool, sorted_axis, keepdims, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } return out; } @@ -105,7 +110,7 @@ Array Any(const Array& a, const OptionalAxes& axis, bool keepdims) { Array out = internal::EmptyReduced(a.shape(), Dtype::kBool, sorted_axis, keepdims, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } return out; } diff --git a/chainerx_cc/chainerx/routines/logic.h b/chainerx_cc/chainerx/routines/logic.h index 65a354981b78..961a047ce77f 100644 --- a/chainerx_cc/chainerx/routines/logic.h +++ b/chainerx_cc/chainerx/routines/logic.h @@ -1,75 +1,12 @@ #pragma once +#include + #include "chainerx/array.h" -#include "chainerx/backend.h" -#include "chainerx/device.h" -#include "chainerx/op.h" +#include "chainerx/axes.h" namespace chainerx { -class EqualOp : public Op { -public: - static const char* name() { return "Equal"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class NotEqualOp : public Op { -public: - static const char* name() { return "NotEqual"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class GreaterOp : public Op { -public: - static const char* name() { return "Greater"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class GreaterEqualOp : public Op { -public: - static const char* name() { return "GreaterEqual"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class LogicalNotOp : public Op { -public: - static const char* name() { return "LogicalNot"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class LogicalAndOp : public Op { -public: - static const char* name() { return "LogicalAnd"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class LogicalOrOp : public Op { -public: - static const char* name() { return "LogicalOr"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class AllOp : public Op { -public: - static const char* name() { return "All"; } - - virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; -}; - -class AnyOp : public Op { -public: - static const char* name() { return "Any"; } - - virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; -}; - // Returns an elementwise equality array. // // Dtype casting is not supported: if x1 and x2 have different types, DtypeError is thrown. diff --git a/chainerx_cc/chainerx/routines/manipulation.cc b/chainerx_cc/chainerx/routines/manipulation.cc index f4c33f68e444..c85cd89a9953 100644 --- a/chainerx_cc/chainerx/routines/manipulation.cc +++ b/chainerx_cc/chainerx/routines/manipulation.cc @@ -22,9 +22,10 @@ #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/creation.h" +#include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" -#include "chainerx/routines/misc.h" #include "chainerx/routines/type_util.h" #include "chainerx/shape.h" #include "chainerx/strides.h" @@ -445,7 +446,7 @@ Array ConcatenateImpl(const std::vector& arrays, int8_t axis) { Array sliced_out = internal::MakeArray(shape, strides, out_dtype, device, out.data(), out_offset); Dtype in_dtype = array.dtype(); in_dtypes.emplace_back(in_dtype); - device.backend().CallOp(array, sliced_out); + device.backend().CallKernel(array, sliced_out); array_refs.emplace_back(ConstArrayRef{array}); out_offset += strides[axis] * shape[axis]; } @@ -570,7 +571,7 @@ Array Stack(const std::vector& arrays, int8_t axis) { int64_t out_offset = 0; for (const Array& array : arrays) { Array sliced_out = internal::MakeArray(array.shape(), strides, dtype, device, out.data(), out_offset); - device.backend().CallOp(array, sliced_out); + device.backend().CallKernel(array, sliced_out); out_offset += step; } } diff --git a/chainerx_cc/chainerx/routines/math.cc b/chainerx_cc/chainerx/routines/math.cc index 6e3f65bfbfe1..7a6f52479732 100644 --- a/chainerx_cc/chainerx/routines/math.cc +++ b/chainerx_cc/chainerx/routines/math.cc @@ -16,6 +16,7 @@ #include "chainerx/enum.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/math.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/logic.h" @@ -123,7 +124,7 @@ void AddImpl(const Array& x1, const Array& x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } { @@ -147,7 +148,7 @@ void AddImpl(const Array& x1, const Array& x2, const Array& out) { void AddASImpl(const Array& x1, Scalar x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } BackwardBuilder bb{"add_scalar", x1, out}; @@ -182,7 +183,7 @@ void SubtractImpl(const Array& x1, const Array& x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } { @@ -206,7 +207,7 @@ void SubtractImpl(const Array& x1, const Array& x2, const Array& out) { void SubtractASImpl(const Array& x1, Scalar x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } BackwardBuilder bb{"subtract_scalar", x1, out}; @@ -249,7 +250,7 @@ void MultiplyImpl(const Array& x1, const Array& x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } { @@ -275,7 +276,7 @@ void MultiplyImpl(const Array& x1, const Array& x2, const Array& out) { void MultiplyASImpl(const Array& x1, Scalar x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } BackwardBuilder bb{"multiply_scalar", x1, out}; @@ -309,12 +310,12 @@ void FloorDivideImpl(const Array& x1, const Array& x2, const Array& out) { CheckEqual(x1.shape(), x2.shape()); NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } void FloorDivideASImpl(const Array& x1, Scalar x2, const Array& out) { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } namespace internal { @@ -342,7 +343,7 @@ void DivideImpl(const Array& x1, const Array& x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } { @@ -369,7 +370,7 @@ void DivideImpl(const Array& x1, const Array& x2, const Array& out) { void DivideASImpl(const Array& x1, Scalar x2, const Array& out) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, out); + x1.device().backend().CallKernel(x1, x2, out); } BackwardBuilder bb{"divide_scalar", x1, out}; @@ -452,7 +453,7 @@ Array Sum(const Array& a, const OptionalAxes& axis, bool keepdims) { Array out = internal::EmptyReduced(a.shape(), out_dtype, sorted_axis, keepdims, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } BackwardBuilder bb{"sum", a, out}; @@ -488,7 +489,7 @@ Array AMax(const Array& a, const OptionalAxes& axis, bool keepdims) { { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } BackwardBuilder bb{"amax", a, out}; @@ -533,7 +534,7 @@ Array AMin(const Array& a, const OptionalAxes& axis, bool keepdims) { { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } BackwardBuilder bb{"amin", a, out}; @@ -577,7 +578,7 @@ Array IfLessElse(const Array& x1, Scalar x2, Scalar pos, const Array& neg) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, pos, neg, out); + x1.device().backend().CallKernel(x1, x2, pos, neg, out); } BackwardBuilder bb{"if_less_else", neg, out}; @@ -612,7 +613,7 @@ Array IfGreaterElse(const Array& x1, Scalar x2, Scalar pos, const Array& neg) { { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, pos, neg, out); + x1.device().backend().CallKernel(x1, x2, pos, neg, out); } BackwardBuilder bb{"if_greater_else", neg, out}; @@ -635,7 +636,7 @@ void IfGreaterElseImpl(const Array& x1, const Array& x2, const Array& pos, const CheckEqual(x1.shape(), x2.shape()); { NoBackpropModeScope scope{}; - x1.device().backend().CallOp(x1, x2, pos, neg, out); + x1.device().backend().CallKernel(x1, x2, pos, neg, out); } { BackwardBuilder bb{"if_greater_else", {pos, neg}, out}; @@ -703,7 +704,7 @@ Array Exp(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"exp", x, out}; @@ -724,7 +725,7 @@ Array Log(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"log", x, out}; @@ -771,7 +772,7 @@ Array Square(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"square", x, out}; @@ -794,7 +795,7 @@ Array Sqrt(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"sqrt", x, out}; @@ -816,7 +817,7 @@ Array Tanh(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"tanh", x, out}; @@ -838,7 +839,7 @@ Array Sin(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"sin", x, out}; @@ -860,7 +861,7 @@ Array Cos(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"cos", x, out}; @@ -882,7 +883,7 @@ Array Tan(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"tan", x, out}; @@ -905,7 +906,7 @@ Array Arcsin(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"arcsin", x, out}; @@ -927,7 +928,7 @@ Array Arccos(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"arccos", x, out}; @@ -949,7 +950,7 @@ Array Arctan(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"arctan", x, out}; @@ -971,7 +972,7 @@ Array Sinh(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"sinh", x, out}; @@ -993,7 +994,7 @@ Array Cosh(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"cosh", x, out}; @@ -1015,7 +1016,7 @@ Array Arcsinh(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"arcsinh", x, out}; @@ -1037,7 +1038,7 @@ Array Arccosh(const Array& x) { { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } BackwardBuilder bb{"arccosh", x, out}; @@ -1058,7 +1059,7 @@ Array Ceil(const Array& x) { Array out = Empty(x.shape(), dtype, x.device()); { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } return out; } @@ -1068,7 +1069,7 @@ Array Floor(const Array& x) { Array out = Empty(x.shape(), dtype, x.device()); { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } return out; } @@ -1077,7 +1078,7 @@ Array IsNan(const Array& x) { Array out = Empty(x.shape(), Dtype::kBool, x.device()); { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } return out; } @@ -1086,7 +1087,7 @@ Array IsInf(const Array& x) { Array out = Empty(x.shape(), Dtype::kBool, x.device()); { NoBackpropModeScope scope{}; - x.device().backend().CallOp(x, out); + x.device().backend().CallKernel(x, out); } return out; } diff --git a/chainerx_cc/chainerx/routines/math.h b/chainerx_cc/chainerx/routines/math.h index b30d0118b6a3..82ac3a9cf491 100644 --- a/chainerx_cc/chainerx/routines/math.h +++ b/chainerx_cc/chainerx/routines/math.h @@ -10,266 +10,6 @@ namespace chainerx { -class AddOp : public Op { -public: - static const char* name() { return "Add"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class AddASOp : public Op { -public: - static const char* name() { return "AddAS"; } - - virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; -}; - -class SubtractOp : public Op { -public: - static const char* name() { return "Subtract"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class SubtractASOp : public Op { -public: - static const char* name() { return "SubtractAS"; } - - virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; -}; - -class MultiplyOp : public Op { -public: - static const char* name() { return "Multiply"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class MultiplyASOp : public Op { -public: - static const char* name() { return "MultiplyAS"; } - - virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; -}; - -class FloorDivideOp : public Op { -public: - static const char* name() { return "FloorDivide"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class FloorDivideASOp : public Op { -public: - static const char* name() { return "FloorDivideAS"; } - - virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; -}; - -class DivideOp : public Op { -public: - static const char* name() { return "Divide"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0; -}; - -class DivideASOp : public Op { -public: - static const char* name() { return "DivideAS"; } - - virtual void Call(const Array& x1, Scalar x2, const Array& out) = 0; -}; - -// Calculate the sum of an array. -// It will be summed over the specified axes. -// `axis` must be normalized so that -// - it has only positive values, -// - it is sorted, and -// - it has no duplicated values. -// Otherwise, the behavior is undefined. -class SumOp : public Op { -public: - static const char* name() { return "Sum"; } - - virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; -}; - -// Calculates the maximum along specified axes. -// See Sum() for the explanation of arguments. -class AMaxOp : public Op { -public: - static const char* name() { return "AMax"; } - - virtual void Call(const Array& src, const Axes& axis, const Array& out) = 0; -}; - -// Calculates the minimum along specified axes. -// See Sum() for the explanation of arguments. -class AMinOp : public Op { -public: - static const char* name() { return "AMin"; } - - virtual void Call(const Array& src, const Axes& axis, const Array& out) = 0; -}; - -// Compares x1 and x2 and assign either pos or neg according to the result. -// Formally, it calculates: out = x1 < x2 ? pos : neg -class IfLessElseASSAOp : public Op { -public: - static const char* name() { return "IfLessElseASSA"; } - - virtual void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) = 0; -}; - -// Compares x1 and x2 and assign either pos or neg according to the result. -// Formally, it calculates: out = x1 > x2 ? pos : neg -class IfGreaterElseASSAOp : public Op { -public: - static const char* name() { return "IfGreaterElseASSA"; } - - virtual void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) = 0; -}; - -class IfGreaterElseAAAAOp : public Op { -public: - static const char* name() { return "IfGreaterElseAAAA"; } - - virtual void Call(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) = 0; -}; - -class SinhOp : public Op { -public: - static const char* name() { return "Sinh"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class CoshOp : public Op { -public: - static const char* name() { return "Cosh"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class TanhOp : public Op { -public: - static const char* name() { return "Tanh"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ArcsinhOp : public Op { -public: - static const char* name() { return "Archsinh"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ArccoshOp : public Op { -public: - static const char* name() { return "Arccosh"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ExpOp : public Op { -public: - static const char* name() { return "Exp"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class LogOp : public Op { -public: - static const char* name() { return "Log"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class SquareOp : public Op { -public: - static const char* name() { return "Square"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class SqrtOp : public Op { -public: - static const char* name() { return "Sqrt"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class IsNanOp : public Op { -public: - static const char* name() { return "IsNan"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class IsInfOp : public Op { -public: - static const char* name() { return "IsInf"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class SinOp : public Op { -public: - static const char* name() { return "Sin"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class CosOp : public Op { -public: - static const char* name() { return "Cos"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class TanOp : public Op { -public: - static const char* name() { return "Tan"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ArcsinOp : public Op { -public: - static const char* name() { return "Arcsin"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ArccosOp : public Op { -public: - static const char* name() { return "Arccos"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class ArctanOp : public Op { -public: - static const char* name() { return "Arctan"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class CeilOp : public Op { -public: - static const char* name() { return "Ceil"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - -class FloorOp : public Op { -public: - static const char* name() { return "Floor"; } - - virtual void Call(const Array& x, const Array& out) = 0; -}; - Array Negative(const Array& x); namespace internal { diff --git a/chainerx_cc/chainerx/routines/normalization.cc b/chainerx_cc/chainerx/routines/normalization.cc index 8e804541a1d3..4521cbde2427 100644 --- a/chainerx_cc/chainerx/routines/normalization.cc +++ b/chainerx_cc/chainerx/routines/normalization.cc @@ -17,6 +17,7 @@ #include "chainerx/dtype.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/normalization.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/math.h" @@ -152,7 +153,7 @@ std::tuple> ApplyGenericBatchNorm( } // namespace -std::tuple> GenericBatchNormOp::Call( +std::tuple> GenericBatchNormKernel::Call( const Array& x, const Array& gamma, const Array& beta, @@ -192,7 +193,7 @@ std::tuple> GenericBatchNormOp::Call( return result; } -std::tuple GenericBatchNormGradOp::Call( +std::tuple GenericBatchNormGradKernel::Call( const Array& x, const Array& gamma, const Array& gout, @@ -246,7 +247,7 @@ std::tuple GenericBatchNormGradOp::Call( return std::make_tuple(std::move(actual_gx), std::move(actual_ggamma), std::move(actual_gbeta)); } -Array GenericFixedBatchNormOp::Call( +Array GenericFixedBatchNormKernel::Call( const Array& x, const Array& gamma, const Array& beta, @@ -285,7 +286,7 @@ Array BatchNorm( std::shared_ptr state{}; { NoBackpropModeScope scope{}; - std::tie(out, state) = device.backend().CallOp( + std::tie(out, state) = device.backend().CallKernel( x.AsGradStopped(), gamma_reshaped.AsGradStopped(), beta_reshaped.AsGradStopped(), @@ -322,7 +323,7 @@ Array BatchNorm( Array gbeta{}; { NoBackpropModeScope scope{}; - std::tie(gx, ggamma, gbeta) = device.backend().CallOp( + std::tie(gx, ggamma, gbeta) = device.backend().CallKernel( x, gamma_reshaped, gout, eps, sorted_axis, state, nonstd::nullopt, nonstd::nullopt, nonstd::nullopt); } CHAINERX_ASSERT(internal::GetArrayBody(gx)->nodes().empty()); @@ -416,7 +417,7 @@ Array FixedBatchNorm( { NoBackpropModeScope scope{}; - return x.device().backend().CallOp( + return x.device().backend().CallKernel( x.AsGradStopped(), result.gamma, result.beta, result.mean, result.var, eps, result.sorted_axis, nonstd::nullopt); } } diff --git a/chainerx_cc/chainerx/routines/normalization.h b/chainerx_cc/chainerx/routines/normalization.h index e1a0fc5b1390..65d780e4794e 100644 --- a/chainerx_cc/chainerx/routines/normalization.h +++ b/chainerx_cc/chainerx/routines/normalization.h @@ -1,136 +1,13 @@ #pragma once -#include -#include -#include - #include #include "chainerx/array.h" #include "chainerx/axes.h" -#include "chainerx/constant.h" -#include "chainerx/dtype.h" -#include "chainerx/op.h" #include "chainerx/scalar.h" -#include "chainerx/stack_vector.h" namespace chainerx { -// Intermediate results from `BatchNormOp::Call` can be stored in this construct and be reused in `BatchNormGradOp::Call`. -// The objects to store may vary depending on backend so each backend should derive this class to define the actual set of intermediate -// results. -class BatchNormGradState { -public: - virtual ~BatchNormGradState() = default; -}; - -class BatchNormOp : public Op { -public: - static const char* name() { return "BatchNorm"; } - - // The returned state should be a `nullptr` if `return_state` is `false`. - virtual std::tuple> Call( - const Array& x, - const Array& gamma, - const Array& beta, - const Array& running_mean, - const Array& running_var, - Scalar eps, - Scalar decay, - const Axes& axis, - bool return_state, - const nonstd::optional& out) = 0; -}; - -class BatchNormGradOp : public Op { -public: - static const char* name() { return "BatchNormGrad"; } - - // Returns gx, ggamma, gbeta. - virtual std::tuple Call( - const Array& x, - const Array& gamma, - const Array& gout, - Scalar eps, - const Axes& axis, - const std::shared_ptr& state, - const nonstd::optional& gx, - const nonstd::optional& ggamma, - const nonstd::optional& gbeta) = 0; -}; - -class GenericBatchNormGradState : public BatchNormGradState { -public: - GenericBatchNormGradState(Array x_mean, Array x_inv_std, Dtype beta_dtype) - : x_mean_{std::move(x_mean)}, x_inv_std_{std::move(x_inv_std)}, beta_dtype_{beta_dtype} {} - - const Array& x_mean() const { return x_mean_; } - const Array& x_inv_std() const { return x_inv_std_; } - Dtype beta_dtype() const { return beta_dtype_; } - -private: - Array x_mean_; - Array x_inv_std_; - Dtype beta_dtype_; -}; - -class GenericBatchNormOp : public BatchNormOp { -public: - std::tuple> Call( - const Array& x, - const Array& gamma, - const Array& beta, - const Array& running_mean, - const Array& running_var, - Scalar eps, - Scalar decay, - const Axes& axis, - bool return_state, - const nonstd::optional& out) override; -}; - -class GenericBatchNormGradOp : public BatchNormGradOp { -public: - std::tuple Call( - const Array& x, - const Array& gamma, - const Array& gout, - Scalar eps, - const Axes& axis, - const std::shared_ptr& state, - const nonstd::optional& gx, - const nonstd::optional& ggamma, - const nonstd::optional& gbeta) override; -}; - -class FixedBatchNormOp : public Op { -public: - static const char* name() { return "FixedBatchNorm"; } - - virtual Array Call( - const Array& x, - const Array& gamma, - const Array& beta, - const Array& mean, - const Array& var, - Scalar eps, - const Axes& axis, - const nonstd::optional& out) = 0; -}; - -class GenericFixedBatchNormOp : public FixedBatchNormOp { -public: - Array Call( - const Array& x, - const Array& gamma, - const Array& beta, - const Array& mean, - const Array& var, - Scalar eps, - const Axes& axis, - const nonstd::optional& out) override; -}; - // Computes the batch normalization along the given axis. // If axis is omitted, the first axis is treated as the batch axis and will be reduced during normalization. // Running mean and running variance that are passed as arguments will be updated in-place. diff --git a/chainerx_cc/chainerx/routines/pooling.cc b/chainerx_cc/chainerx/routines/pooling.cc index 8611b2f9f8c5..6c2ba0130acb 100644 --- a/chainerx_cc/chainerx/routines/pooling.cc +++ b/chainerx_cc/chainerx/routines/pooling.cc @@ -18,6 +18,7 @@ #include "chainerx/dims.h" #include "chainerx/error.h" #include "chainerx/graph.h" +#include "chainerx/kernels/pooling.h" #include "chainerx/routines/math.h" #include "chainerx/routines/routines_util.h" #include "chainerx/stack_vector.h" @@ -65,8 +66,8 @@ Array MaxPool( std::shared_ptr state{}; { NoBackpropModeScope scope{}; - std::tie(out, state) = - x.device().backend().CallOp(x.AsGradStopped(), kernel_size, stride, pad, cover_all, true, nonstd::nullopt); + std::tie(out, state) = x.device().backend().CallKernel( + x.AsGradStopped(), kernel_size, stride, pad, cover_all, true, nonstd::nullopt); } internal::MakeViewForForwardBackwardOutput(out); @@ -80,7 +81,7 @@ Array MaxPool( std::shared_ptr grad_grad_state{}; { NoBackpropModeScope scope{}; - std::tie(gx, grad_grad_state) = gout.device().backend().CallOp( + std::tie(gx, grad_grad_state) = gout.device().backend().CallKernel( gout.AsGradStopped(), kernel_size, stride, pad, state, true, nonstd::nullopt); } internal::MakeViewForForwardBackwardOutput(gx); @@ -94,7 +95,7 @@ Array MaxPool( Array ggout{}; { NoBackpropModeScope scope{}; - ggout = ggx.device().backend().CallOp( + ggout = ggx.device().backend().CallKernel( ggx.AsGradStopped(), st.kernel_size, st.stride, st.pad, st.cover_all, grad_grad_state, nonstd::nullopt); } @@ -146,8 +147,8 @@ Array AveragePool( std::shared_ptr state{}; { NoBackpropModeScope scope{}; - std::tie(out, state) = - x.device().backend().CallOp(x.AsGradStopped(), kernel_size, stride, pad, pad_mode, true, nonstd::nullopt); + std::tie(out, state) = x.device().backend().CallKernel( + x.AsGradStopped(), kernel_size, stride, pad, pad_mode, true, nonstd::nullopt); } internal::MakeViewForForwardBackwardOutput(out); @@ -160,7 +161,7 @@ Array AveragePool( Array gx{}; { NoBackpropModeScope scope{}; - gx = gout.device().backend().CallOp( + gx = gout.device().backend().CallKernel( gout.AsGradStopped(), kernel_size, stride, pad, pad_mode, state, nonstd::nullopt); } diff --git a/chainerx_cc/chainerx/routines/pooling.h b/chainerx_cc/chainerx/routines/pooling.h index 96717b8757c1..e2b6424e6fa9 100644 --- a/chainerx_cc/chainerx/routines/pooling.h +++ b/chainerx_cc/chainerx/routines/pooling.h @@ -1,103 +1,13 @@ #pragma once #include -#include -#include - -#include #include "chainerx/array.h" #include "chainerx/constant.h" -#include "chainerx/op.h" #include "chainerx/stack_vector.h" namespace chainerx { -class MaxPoolGradState { -public: - virtual ~MaxPoolGradState() = default; -}; - -class MaxPoolOp : public Op { -public: - static const char* name() { return "MaxPool"; } - - virtual std::tuple> Call( - const Array& x, - StackVector kernel_size, - StackVector stride, - StackVector pad, - bool cover_all, - bool return_state, - const nonstd::optional& out) = 0; -}; - -class MaxPoolGradGradState { -public: - virtual ~MaxPoolGradGradState() = default; -}; - -class MaxPoolGradOp : public Op { -public: - static const char* name() { return "MaxPoolGrad"; } - - virtual std::tuple> Call( - const Array& gout, - StackVector kernel_size, - StackVector stride, - StackVector pad, - const std::shared_ptr& state, - bool return_state, - const nonstd::optional& gx) = 0; -}; - -class MaxPoolGradGradOp : public Op { -public: - static const char* name() { return "MaxPoolGradGrad"; } - - virtual Array Call( - const Array& ggx, - StackVector kernel_size, - StackVector stride, - StackVector pad, - bool cover_all, - const std::shared_ptr& state, - const nonstd::optional& ggout) = 0; -}; - -class AveragePoolGradState { -public: - virtual ~AveragePoolGradState() = default; -}; - -class AveragePoolOp : public Op { -public: - static const char* name() { return "AveragePool"; } - - virtual std::tuple> Call( - const Array& x, - StackVector kernel_size, - StackVector stride, - StackVector pad, - AveragePoolPadMode pad_mode, - bool return_state, - const nonstd::optional& out) = 0; -}; - -class AveragePoolGradOp : public Op { -public: - static const char* name() { return "AveragePoolGrad"; } - - virtual Array Call( - const Array& gout, - StackVector kernel_size, - StackVector stride, - StackVector pad, - AveragePoolPadMode pad_mode, - const std::shared_ptr& state, - const nonstd::optional& gx) = 0; -}; - Array MaxPool( const Array& x, const StackVector& kernel_size, diff --git a/chainerx_cc/chainerx/routines/sorting.cc b/chainerx_cc/chainerx/routines/sorting.cc index 73afdeff1614..7dc77589ef8a 100644 --- a/chainerx_cc/chainerx/routines/sorting.cc +++ b/chainerx_cc/chainerx/routines/sorting.cc @@ -9,6 +9,7 @@ #include "chainerx/backprop_mode.h" #include "chainerx/dtype.h" #include "chainerx/error.h" +#include "chainerx/kernels/sorting.h" #include "chainerx/routines/creation.h" #include "chainerx/shape.h" @@ -42,7 +43,7 @@ Array ArgMax(const Array& a, const OptionalAxes& axis) { Array out = Empty(out_shape, Dtype::kInt64, a.device()); { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); + a.device().backend().CallKernel(a, sorted_axis, out); } return out; } diff --git a/chainerx_cc/chainerx/routines/sorting.h b/chainerx_cc/chainerx/routines/sorting.h index 852aaea2e585..21ff2f8be631 100644 --- a/chainerx_cc/chainerx/routines/sorting.h +++ b/chainerx_cc/chainerx/routines/sorting.h @@ -1,22 +1,12 @@ #pragma once -#include +#include #include "chainerx/array.h" #include "chainerx/axes.h" -#include "chainerx/backend.h" -#include "chainerx/device.h" -#include "chainerx/op.h" namespace chainerx { -class ArgMaxOp : public Op { -public: - static const char* name() { return "ArgMax"; } - - virtual void Call(const Array& a, const Axes& axis, const Array& out) = 0; -}; - Array ArgMax(const Array& a, const OptionalAxes& axis = nonstd::nullopt); } // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/statistics.cc b/chainerx_cc/chainerx/routines/statistics.cc index 5fce5e280fa3..87ca29762c0b 100644 --- a/chainerx_cc/chainerx/routines/statistics.cc +++ b/chainerx_cc/chainerx/routines/statistics.cc @@ -8,6 +8,7 @@ #include "chainerx/backward_builder.h" #include "chainerx/backward_context.h" #include "chainerx/dtype.h" +#include "chainerx/kernels/math.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/math.h" @@ -25,8 +26,8 @@ Array Mean(const Array& a, const OptionalAxes& axis, bool keepdims) { { NoBackpropModeScope scope{}; - a.device().backend().CallOp(a, sorted_axis, out); - a.device().backend().CallOp(out, n, out); + a.device().backend().CallKernel(a, sorted_axis, out); + a.device().backend().CallKernel(out, n, out); } BackwardBuilder bb{"mean", a, out};