Skip to content

Commit

Permalink
Merge pull request #6944 from hvy/kernels
Browse files Browse the repository at this point in the history
Introduce `chainerx/kernels/` and rename existing device "op"s to "kernel"s
  • Loading branch information
niboshi committed Apr 17, 2019
2 parents 8fc8e08 + 863d05c commit 299637b
Show file tree
Hide file tree
Showing 84 changed files with 1,649 additions and 1,359 deletions.
7 changes: 4 additions & 3 deletions chainerx_cc/chainerx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(kernels)
add_subdirectory(routines)
add_subdirectory(native)
add_subdirectory(testing)
Expand Down Expand Up @@ -42,13 +43,13 @@ install(FILES
index_iterator.h
indexable_array.h
indexer.h
kernel.h
kernel_registry.h
macro.h
numerical_gradient.h
numeric.h
numeric_limits.h
op.h
op_node.h
op_registry.h
optional_container_arg.h
platform.h
reduction_kernel_arg.h
Expand Down Expand Up @@ -166,10 +167,10 @@ if(${CHAINERX_BUILD_TEST})
index_iterator_test.cc
indexable_array_test.cc
indexer_test.cc
kernel_registry_test.cc
numeric_limits_test.cc
numerical_gradient_test.cc
numeric_test.cc
op_registry_test.cc
optional_container_arg_test.cc
scalar_test.cc
shape_test.cc
Expand Down
6 changes: 3 additions & 3 deletions chainerx_cc/chainerx/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "chainerx/dtype.h"
#include "chainerx/error.h"
#include "chainerx/graph.h"
#include "chainerx/kernels/misc.h"
#include "chainerx/macro.h"
#include "chainerx/native/native_backend.h"
#include "chainerx/op_node.h"
Expand All @@ -40,7 +41,6 @@
#include "chainerx/routines/logic.h"
#include "chainerx/routines/manipulation.h"
#include "chainerx/routines/math.h"
#include "chainerx/routines/misc.h"
#include "chainerx/routines/routines_util.h"
#include "chainerx/routines/sorting.h"
#include "chainerx/routines/statistics.h"
Expand Down Expand Up @@ -313,7 +313,7 @@ Array Array::AsType(Dtype dtype, bool copy) const {
}

Array out = Empty(shape(), dtype, device());
device().backend().CallOp<AsTypeOp>(*this, out);
device().backend().CallKernel<AsTypeKernel>(*this, out);

if (GetKind(dtype) == DtypeKind::kFloat) {
BackwardBuilder bb{"astype", *this, out};
Expand All @@ -329,7 +329,7 @@ Array Array::AsType(Dtype dtype, bool copy) const {

void Array::Fill(Scalar value) const {
internal::CheckNoUnsafeInplace(*this, {});
device().backend().CallOp<FillOp>(*this, value);
device().backend().CallKernel<FillKernel>(*this, value);
}

const nonstd::optional<Array>& Array::GetGrad(const nonstd::optional<BackpropId>& backprop_id) const {
Expand Down
4 changes: 2 additions & 2 deletions chainerx_cc/chainerx/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include <utility>

#include "chainerx/device.h"
#include "chainerx/op_registry.h"
#include "chainerx/kernel_registry.h"

namespace chainerx {

Backend::~Backend() = default;

Backend::Backend(Context& context) : context_{context} {}

void Backend::Initialize() { op_registry_ = OpRegistry{&GetParentOpRegistry()}; }
void Backend::Initialize() { kernel_registry_ = KernelRegistry{&GetParentKernelRegistry()}; }

Device& Backend::GetDevice(int index) {
if (index < 0) {
Expand Down
22 changes: 11 additions & 11 deletions chainerx_cc/chainerx/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <utility>
#include <vector>

#include "chainerx/op.h"
#include "chainerx/op_registry.h"
#include "chainerx/kernel.h"
#include "chainerx/kernel_registry.h"

namespace chainerx {

Expand Down Expand Up @@ -40,7 +40,7 @@ class Backend {
Context& context() const { return context_; }

// Returns the op registry.
OpRegistry& op_registry() { return op_registry_; }
KernelRegistry& kernel_registry() { return kernel_registry_; }

// Returns the device for the given index.
//
Expand All @@ -50,16 +50,16 @@ class Backend {
// Queries if the backend supports data transfer between two devices.
virtual bool SupportsTransfer(Device& src_device, Device& dst_device) = 0;

// Calls the op implementation.
template <typename OpType, typename... Args>
auto CallOp(Args&&... args) {
Op& op = op_registry_.GetOp<OpType>();
return dynamic_cast<OpType&>(op).Call(std::forward<Args>(args)...);
// Calls the kernel implementation.
template <typename KernelType, typename... Args>
auto CallKernel(Args&&... args) {
Kernel& kernel = kernel_registry_.GetKernel<KernelType>();
return dynamic_cast<KernelType&>(kernel).Call(std::forward<Args>(args)...);
}

protected:
// Returns a backend-specific global op registry.
virtual OpRegistry& GetParentOpRegistry() = 0;
// Returns a backend-specific global kernel registry.
virtual KernelRegistry& GetParentKernelRegistry() = 0;

private:
// Creates a new device.
Expand All @@ -72,7 +72,7 @@ class Backend {

std::mutex devices_mutex_;

OpRegistry op_registry_;
KernelRegistry kernel_registry_;
};

} // namespace chainerx
8 changes: 4 additions & 4 deletions chainerx_cc/chainerx/backend_testdata/backend0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#include <string>

#include "chainerx/context.h"
#include "chainerx/kernel_registry.h"
#include "chainerx/native/native_backend.h"
#include "chainerx/op_registry.h"

namespace {

Expand All @@ -14,9 +14,9 @@ class Backend0 : public chainerx::native::NativeBackend {
std::string GetName() const override { return "backend0"; }

protected:
chainerx::OpRegistry& GetParentOpRegistry() override {
static chainerx::OpRegistry op_registry{&chainerx::native::NativeBackend::GetGlobalOpRegistry()};
return op_registry;
chainerx::KernelRegistry& GetParentKernelRegistry() override {
static chainerx::KernelRegistry kernel_registry{&chainerx::native::NativeBackend::GetGlobalKernelRegistry()};
return kernel_registry;
}
};

Expand Down
8 changes: 4 additions & 4 deletions chainerx_cc/chainerx/backend_testdata/backend1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class Backend1 : public chainerx::native::NativeBackend {
std::string GetName() const override { return "backend1"; }

protected:
chainerx::OpRegistry& GetParentOpRegistry() override {
static gsl::owner<chainerx::OpRegistry*> op_registry =
new chainerx::OpRegistry{&chainerx::native::NativeBackend::GetGlobalOpRegistry()};
return *op_registry;
chainerx::KernelRegistry& GetParentKernelRegistry() override {
static gsl::owner<chainerx::KernelRegistry*> kernel_registry =
new chainerx::KernelRegistry{&chainerx::native::NativeBackend::GetGlobalKernelRegistry()};
return *kernel_registry;
}
};

Expand Down
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ install(FILES
data_type.cuh
elementwise.cuh
float16.cuh
kernel_regist.h
memory_pool.h
numeric.cuh
numeric_limits.cuh
op_regist.h
reduce.cuh
DESTINATION include/chainerx/cuda
)
Expand Down
10 changes: 5 additions & 5 deletions chainerx_cc/chainerx/cuda/cuda_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "chainerx/backend.h"
#include "chainerx/context.h"
#include "chainerx/device.h"
#include "chainerx/op_registry.h"
#include "chainerx/kernel_registry.h"

namespace chainerx {
namespace cuda {
Expand Down Expand Up @@ -50,13 +50,13 @@ class CudaBackend : public Backend {
// Gets maximum cuDNN workspace size.
size_t GetCudnnMaxWorkspaceSize();

static OpRegistry& GetGlobalOpRegistry() {
static OpRegistry* global_op_registry = new OpRegistry{};
return *global_op_registry;
static KernelRegistry& GetGlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = new KernelRegistry{};
return *global_kernel_registry;
}

protected:
OpRegistry& GetParentOpRegistry() override { return GetGlobalOpRegistry(); }
KernelRegistry& GetParentKernelRegistry() override { return GetGlobalKernelRegistry(); }

private:
std::unique_ptr<Device> CreateDevice(int index) override;
Expand Down
1 change: 1 addition & 0 deletions chainerx_cc/chainerx/cuda/cuda_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "chainerx/dtype.h"
#include "chainerx/error.h"
#include "chainerx/hash_combine.h"
#include "chainerx/kernels/connection.h"
#include "chainerx/macro.h"
#include "chainerx/routines/connection.h"
#include "chainerx/routines/creation.h"
Expand Down
9 changes: 5 additions & 4 deletions chainerx_cc/chainerx/cuda/cuda_conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "chainerx/constant.h"
#include "chainerx/cuda/cuda_device.h"
#include "chainerx/device_id.h"
#include "chainerx/kernels/connection.h"
#include "chainerx/routines/connection.h"
#include "chainerx/shape.h"
#include "chainerx/stack_vector.h"
Expand Down Expand Up @@ -155,9 +156,9 @@ TEST(CudaConvTest, BwdFilterAlgoCache) {
Array gy = testing::BuildArray(out_shape).WithLinearData(-0.3f, 0.1f).WithPadding(1);

EXPECT_EQ(size_t{0}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
device.backend().CallOp<ConvGradWeightOp>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
device.backend().CallKernel<ConvGradWeightKernel>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
device.backend().CallOp<ConvGradWeightOp>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
device.backend().CallKernel<ConvGradWeightKernel>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
}
{
Expand All @@ -171,9 +172,9 @@ TEST(CudaConvTest, BwdFilterAlgoCache) {
Array gy = testing::BuildArray(out_shape).WithLinearData(-0.3f, 0.1f).WithPadding(1);

EXPECT_EQ(size_t{1}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
device.backend().CallOp<ConvGradWeightOp>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
device.backend().CallKernel<ConvGradWeightKernel>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
EXPECT_EQ(size_t{2}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
device.backend().CallOp<ConvGradWeightOp>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
device.backend().CallKernel<ConvGradWeightKernel>(w_dtype, w_shape, x, gy, stride, pad, cover_all, nonstd::nullopt);
EXPECT_EQ(size_t{2}, cuda_internal::CudaConvTest::GetBwdFilterAlgoCacheMapSize(cuda_conv));
}
}
Expand Down
2 changes: 2 additions & 0 deletions chainerx_cc/chainerx/cuda/cuda_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "chainerx/cuda/memory_pool.h"
#include "chainerx/device.h"
#include "chainerx/dtype.h"
#include "chainerx/kernels/normalization.h"
#include "chainerx/kernels/pooling.h"
#include "chainerx/routines/normalization.h"
#include "chainerx/routines/pooling.h"
#include "chainerx/scalar.h"
Expand Down
19 changes: 10 additions & 9 deletions chainerx_cc/chainerx/cuda/cuda_device/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
#include "chainerx/cuda/cuda_runtime.h"
#include "chainerx/cuda/cuda_set_device_scope.h"
#include "chainerx/cuda/elementwise.cuh"
#include "chainerx/cuda/kernel_regist.h"
#include "chainerx/cuda/numeric.cuh"
#include "chainerx/cuda/op_regist.h"
#include "chainerx/device.h"
#include "chainerx/dtype.h"
#include "chainerx/kernels/math.h"
#include "chainerx/numeric.h"
#include "chainerx/routines/math.h"
#include "chainerx/routines/type_util.h"
Expand All @@ -31,7 +32,7 @@ struct IfLessElseASSAImpl {
OutCudaType pos;
};

class CudaIfLessElseASSAOp : public IfLessElseASSAOp {
class CudaIfLessElseASSAKernel : public IfLessElseASSAKernel {
public:
void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override {
Device& device = x1.device();
Expand All @@ -53,7 +54,7 @@ public:
}
};

CHAINERX_CUDA_REGISTER_OP(IfLessElseASSAOp, CudaIfLessElseASSAOp);
CHAINERX_CUDA_REGISTER_KERNEL(IfLessElseASSAKernel, CudaIfLessElseASSAKernel);

template <typename In, typename Out>
struct IfGreaterElseASSAImpl {
Expand All @@ -64,7 +65,7 @@ struct IfGreaterElseASSAImpl {
OutCudaType pos;
};

class CudaIfGreaterElseASSAOp : public IfGreaterElseASSAOp {
class CudaIfGreaterElseASSAKernel : public IfGreaterElseASSAKernel {
public:
void Call(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override {
Device& device = x1.device();
Expand All @@ -86,7 +87,7 @@ public:
}
};

CHAINERX_CUDA_REGISTER_OP(IfGreaterElseASSAOp, CudaIfGreaterElseASSAOp);
CHAINERX_CUDA_REGISTER_KERNEL(IfGreaterElseASSAKernel, CudaIfGreaterElseASSAKernel);

template <typename In, typename Out>
struct IfGreaterElseAAAAImpl {
Expand All @@ -97,7 +98,7 @@ struct IfGreaterElseAAAAImpl {
}
};

class CudaIfGreaterElseAAAAOp : public IfGreaterElseAAAAOp {
class CudaIfGreaterElseAAAAKernel : public IfGreaterElseAAAAKernel {
public:
void Call(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) override {
Device& device = x1.device();
Expand All @@ -119,15 +120,15 @@ public:
}
};

CHAINERX_CUDA_REGISTER_OP(IfGreaterElseAAAAOp, CudaIfGreaterElseAAAAOp);
CHAINERX_CUDA_REGISTER_KERNEL(IfGreaterElseAAAAKernel, CudaIfGreaterElseAAAAKernel);

template <typename T>
struct TanhImpl {
using CudaType = cuda_internal::DataType<T>;
__device__ void operator()(int64_t /*i*/, CudaType x, CudaType& out) { out = cuda::Tanh(x); }
};

class CudaTanhOp : public TanhOp {
class CudaTanhKernel : public TanhKernel {
public:
void Call(const Array& x, const Array& out) override {
Device& device = x.device();
Expand All @@ -141,7 +142,7 @@ public:
}
};

CHAINERX_CUDA_REGISTER_OP(TanhOp, CudaTanhOp);
CHAINERX_CUDA_REGISTER_KERNEL(TanhKernel, CudaTanhKernel);

} // namespace
} // namespace cuda
Expand Down

0 comments on commit 299637b

Please sign in to comment.