Skip to content

Commit

Permalink
Merge pull request #8197 from emcastillo/chx-take-index
Browse files Browse the repository at this point in the history
Add mode argument to `chainerx.Take`
  • Loading branch information
mergify[bot] committed Oct 10, 2019
2 parents f98136f + 34c461f commit 4fd2210
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 68 deletions.
9 changes: 9 additions & 0 deletions chainerx/_docs/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def _docs_indexing():
The indices of the values to extract. When indices are out of bounds,
they are wrapped around.
axis (int): The axis over which to select values.
mode (str): Specifies how out-of-bounds indices will behave.
'raise' - raise an error
'wrap' - wrap around
'clip' - clip to the range
Returns:
:func:`~chainerx.ndarray`: Output array.
Expand All @@ -578,6 +582,11 @@ def _docs_indexing():
During backpropagation, this function propagates the gradient of the
output array to the input array ``a``.
Note:
The default mode for the native backend is 'raise', while for the cuda
backend is 'wrap' in order to prevent device synchronization.
'raise' mode is currently not supported in the CUDA backend.
.. seealso:: :func:`numpy.take`
""")

Expand Down
2 changes: 1 addition & 1 deletion chainerx_cc/chainerx/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ Array Array::Any(const OptionalAxes& axis, bool keepdims) const { return chainer

Array Array::Dot(const Array& b) const { return chainerx::Dot(*this, b); }

Array Array::Take(const Array& indices, int8_t axis) const { return chainerx::Take(*this, indices, axis); }
Array Array::Take(const Array& indices, int8_t axis, IndexBoundsMode mode) const { return chainerx::Take(*this, indices, axis, mode); }

Array Array::Copy() const { return chainerx::Copy(*this); }

Expand Down
4 changes: 3 additions & 1 deletion chainerx_cc/chainerx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ inline std::shared_ptr<ArrayBody>&& MoveArrayBody(Array&& array);

} // namespace internal

enum class IndexBoundsMode;

// The user interface of multi-dimensional arrays.
//
// This wraps an ArrayBody, providing accessors, an interface for graph operations and differentiable operations.
Expand Down Expand Up @@ -202,7 +204,7 @@ class Array {
// TODO(niboshi): Support Scalar and StackVector as indices.
// TODO(niboshi): Support axis=None behavior in NumPy.
// TODO(niboshi): Support indices dtype other than int64.
Array Take(const Array& indices, int8_t axis) const;
Array Take(const Array& indices, int8_t axis, IndexBoundsMode mode) const;

// Creates a copy.
// It will be connected to all the graphs.
Expand Down
3 changes: 2 additions & 1 deletion chainerx_cc/chainerx/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "chainerx/indexable_array.h"
#include "chainerx/indexer.h"
#include "chainerx/op_node.h"
#include "chainerx/routines/indexing.h"
#include "chainerx/scalar.h"
#include "chainerx/shape.h"
#include "chainerx/slice.h"
Expand Down Expand Up @@ -1168,7 +1169,7 @@ TEST_P(ArrayTest, Take) {
Shape output_shape{2, 2, 3};
Array a = testing::BuildArray(input_shape).WithLinearData<T>().WithPadding(1);
Array indices = testing::BuildArray(indices_shape).WithData<int64_t>({0, 14, 3, 1, -10, 1});
Array b = a.Take(indices, 1);
Array b = a.Take(indices, 1, IndexBoundsMode::kWrap);

EXPECT_EQ(output_shape, b.shape());
Array e = testing::BuildArray(output_shape).WithData<T>({0, 2, 3, 1, 2, 1, 4, 6, 7, 5, 6, 5});
Expand Down
153 changes: 129 additions & 24 deletions chainerx_cc/chainerx/cuda/cuda_device/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Axes MakeRollingPermutation(int8_t first_axis, int8_t last_axis, int8_t ndim) {
}

template <typename T, typename TIndex>
__global__ void TakeCudaKernel(
__global__ void TakeWrapCudaKernel(
IndexableArray<const T> a_iarray,
IndexableArray<T> out_iarray,
IndexableArray<const TIndex> indices_iarray,
Expand Down Expand Up @@ -76,7 +76,31 @@ __global__ void TakeCudaKernel(
}

template <typename T, typename TIndex>
__global__ void AddAtCudaKernel(
__global__ void TakeClipCudaKernel(
IndexableArray<const T> a_iarray,
IndexableArray<T> out_iarray,
IndexableArray<const TIndex> indices_iarray,
Indexer<> a_indexer,
Indexer<> out_indexer,
Indexer<> indices_indexer,
TIndex common_total_size,
TIndex axis_dim) {
static_assert(std::is_same<TIndex, int64_t>::value || std::is_same<TIndex, int32_t>::value, "");
for (auto it = out_indexer.It(blockIdx.x * blockDim.x + threadIdx.x, blockDim.x * gridDim.x); it; ++it) {
TIndex indices_pos = static_cast<TIndex>(it.raw_index()) / common_total_size;
TIndex common_pos = static_cast<TIndex>(it.raw_index()) % common_total_size;

TIndex index = indices_iarray[indices_indexer.It(indices_pos)];
index = max(TIndex{0}, min(index, axis_dim - 1));
CHAINERX_ASSERT(0 <= index);
CHAINERX_ASSERT(index < axis_dim);

out_iarray[it] = a_iarray[a_indexer.It(index * common_total_size + common_pos)];
}
}

template <typename T, typename TIndex>
__global__ void AddAtWrapCudaKernel(
IndexableArray<const T> a_iarray,
IndexableArray<const T> b_iarray,
IndexableArray<T> out_iarray,
Expand Down Expand Up @@ -114,8 +138,43 @@ __global__ void AddAtCudaKernel(
}
}

template <typename T, typename TIndex>
__global__ void AddAtClipCudaKernel(
IndexableArray<const T> a_iarray,
IndexableArray<const T> b_iarray,
IndexableArray<T> out_iarray,
IndexableArray<const TIndex> indices_iarray,
Indexer<> b_indexer,
Indexer<> out_indexer,
Indexer<> indices_indexer,
TIndex common_total_size,
TIndex axis_dim) {
static_assert(std::is_same<TIndex, int64_t>::value || std::is_same<TIndex, int32_t>::value, "");
for (auto it = out_indexer.It(blockIdx.x * blockDim.x + threadIdx.x, blockDim.x * gridDim.x); it; ++it) {
TIndex axis_pos = static_cast<TIndex>(it.raw_index()) / common_total_size;
TIndex common_pos = static_cast<TIndex>(it.raw_index()) % common_total_size;

cuda_internal::DataType<T> out_value = cuda_internal::StorageToDataType<const T>(a_iarray[it]);

for (auto it_indices = indices_indexer.It(0); it_indices; ++it_indices) {
TIndex index = indices_iarray[it_indices];

index = max(TIndex{0}, min(index, axis_dim - 1));
CHAINERX_ASSERT(0 <= index);
CHAINERX_ASSERT(index < axis_dim);

if (index == axis_pos) {
out_value += cuda_internal::StorageToDataType<const T>(
b_iarray[b_indexer.It(it_indices.raw_index() * common_total_size + common_pos)]);
}
}

out_iarray[it] = cuda_internal::DataToStorageType<T>(out_value);
}
}

template <typename TIndex>
void TakeImpl(Device& device, const Array& a, const Array& indices, int8_t axis, const Array& out) {
void TakeImpl(Device& device, const Array& a, const Array& indices, int8_t axis, const Array& out, IndexBoundsMode mode) {
static_assert(std::is_same<TIndex, int64_t>::value || std::is_same<TIndex, int32_t>::value, "");
CHAINERX_ASSERT(
(std::is_same<TIndex, int64_t>::value && indices.dtype() == Dtype::kInt64) ||
Expand All @@ -124,7 +183,7 @@ void TakeImpl(Device& device, const Array& a, const Array& indices, int8_t axis,

CudaSetDeviceScope scope{device.index()};

VisitDtype(out.dtype(), [&a, &indices, axis, &out](auto pt) {
VisitDtype(out.dtype(), [&a, &indices, axis, &out, mode](auto pt) {
using T = typename decltype(pt)::type;

// a and out are transposed as follows.
Expand Down Expand Up @@ -156,18 +215,33 @@ void TakeImpl(Device& device, const Array& a, const Array& indices, int8_t axis,

// TODO(niboshi): Calculate kMaxBlockSize per device
std::lock_guard<std::mutex> lock{*cuda_internal::g_mutex};
static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&TakeCudaKernel<T, TIndex>).block_size;
int64_t total_size = out_indexer.total_size();
int64_t grid_size = (total_size + kMaxBlockSize - 1) / kMaxBlockSize;
int64_t block_size = std::min<TIndex>(total_size, kMaxBlockSize);

TakeCudaKernel<<<grid_size, block_size>>>(
a_iarray, out_iarray, indices_iarray, a_indexer, out_indexer, indices_indexer, common_total_size, axis_dim);
static const int kWrapMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&TakeWrapCudaKernel<T, TIndex>).block_size;
static const int kClipMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&TakeClipCudaKernel<T, TIndex>).block_size;
int64_t grid_size;
int64_t block_size;
switch (mode) {
case IndexBoundsMode::kRaise:
throw BackendError{"Take with mode='raise' is not supported with CUDA backend"};
case IndexBoundsMode::kDefault:
case IndexBoundsMode::kWrap:
grid_size = (total_size + kWrapMaxBlockSize - 1) / kWrapMaxBlockSize;
block_size = std::min<TIndex>(total_size, kWrapMaxBlockSize);
TakeWrapCudaKernel<<<grid_size, block_size>>>(
a_iarray, out_iarray, indices_iarray, a_indexer, out_indexer, indices_indexer, common_total_size, axis_dim);
break;
case IndexBoundsMode::kClip:
grid_size = (total_size + kClipMaxBlockSize - 1) / kClipMaxBlockSize;
block_size = std::min<TIndex>(total_size, kClipMaxBlockSize);
TakeClipCudaKernel<<<grid_size, block_size>>>(
a_iarray, out_iarray, indices_iarray, a_indexer, out_indexer, indices_indexer, common_total_size, axis_dim);
break;
}
});
}

template <typename TIndex>
void AddAtImpl(Device& device, const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) {
void AddAtImpl(Device& device, const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out, IndexBoundsMode mode) {
// TODO(niboshi): Current implementation only distributes output elements in respective threads. Summation on the indices is performed
// serially in each thread. This implementation can be improved by distributing indices as well, possibly using atomicAdd.

Expand All @@ -180,7 +254,7 @@ void AddAtImpl(Device& device, const Array& a, const Array& indices, int8_t axis

CudaSetDeviceScope scope{device.index()};

VisitDtype(out.dtype(), [&a, &indices, axis, &b, &out](auto pt) {
VisitDtype(out.dtype(), [&a, &indices, axis, &b, &out, mode](auto pt) {
using T = typename decltype(pt)::type;

// b and out are transposed as follows.
Expand Down Expand Up @@ -217,30 +291,61 @@ void AddAtImpl(Device& device, const Array& a, const Array& indices, int8_t axis

TIndex axis_dim = gsl::narrow<TIndex>(a_shape[0]);

static const int kMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&AddAtCudaKernel<T, TIndex>).block_size;
int64_t total_size = out_indexer.total_size();
int64_t grid_size = (total_size + kMaxBlockSize - 1) / kMaxBlockSize;
int64_t block_size = std::min<int64_t>(total_size, kMaxBlockSize);

AddAtCudaKernel<<<grid_size, block_size>>>(
a_iarray, b_iarray, out_iarray, indices_iarray, b_indexer, out_indexer, indices_indexer, common_total_size, axis_dim);
static const int kWrapMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&AddAtWrapCudaKernel<T, TIndex>).block_size;
static const int kClipMaxBlockSize = CudaOccupancyMaxPotentialBlockSize(&AddAtClipCudaKernel<T, TIndex>).block_size;
int64_t grid_size;
int64_t block_size;
switch (mode) {
case IndexBoundsMode::kRaise:
throw BackendError{"Take with mode='raise' is not supported with CUDA backend"};
case IndexBoundsMode::kDefault:
case IndexBoundsMode::kWrap:
grid_size = (total_size + kWrapMaxBlockSize - 1) / kWrapMaxBlockSize;
block_size = std::min<int64_t>(total_size, kClipMaxBlockSize);
AddAtWrapCudaKernel<<<grid_size, block_size>>>(
a_iarray,
b_iarray,
out_iarray,
indices_iarray,
b_indexer,
out_indexer,
indices_indexer,
common_total_size,
axis_dim);
break;
case IndexBoundsMode::kClip:
grid_size = (total_size + kClipMaxBlockSize - 1) / kClipMaxBlockSize;
block_size = std::min<int64_t>(total_size, kClipMaxBlockSize);
AddAtClipCudaKernel<<<grid_size, block_size>>>(
a_iarray,
b_iarray,
out_iarray,
indices_iarray,
b_indexer,
out_indexer,
indices_indexer,
common_total_size,
axis_dim);
break;
}
});
}

class CudaTakeKernel : public TakeKernel {
public:
void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) override {
void Call(const Array& a, const Array& indices, int8_t axis, const Array& out, IndexBoundsMode mode) override {
Device& device = a.device();
CHAINERX_ASSERT(GetKind(indices.dtype()) == DtypeKind::kInt || GetKind(indices.dtype()) == DtypeKind::kUInt);
device.CheckDevicesCompatible(a, indices, out);

CudaSetDeviceScope scope{device.index()};

if (indices.dtype() == Dtype::kInt64) {
TakeImpl<int64_t>(device, a, indices, axis, out);
TakeImpl<int64_t>(device, a, indices, axis, out, mode);
} else {
const Array& indices_cast = indices.dtype() == Dtype::kInt32 ? indices : indices.AsType(Dtype::kInt32);
TakeImpl<int32_t>(device, a, indices_cast, axis, out);
TakeImpl<int32_t>(device, a, indices_cast, axis, out, mode);
}
}
};
Expand All @@ -249,18 +354,18 @@ CHAINERX_CUDA_REGISTER_KERNEL(TakeKernel, CudaTakeKernel);

class CudaAddAtKernel : public AddAtKernel {
public:
void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) override {
void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out, IndexBoundsMode mode) override {
Device& device = a.device();
CHAINERX_ASSERT(GetKind(indices.dtype()) == DtypeKind::kInt || GetKind(indices.dtype()) == DtypeKind::kUInt);
device.CheckDevicesCompatible(a, indices, out);

CudaSetDeviceScope scope{device.index()};

if (indices.dtype() == Dtype::kInt64) {
AddAtImpl<int64_t>(device, a, indices, axis, b, out);
AddAtImpl<int64_t>(device, a, indices, axis, b, out, mode);
} else {
const Array& indices_cast = indices.dtype() == Dtype::kInt32 ? indices : indices.AsType(Dtype::kInt32);
AddAtImpl<int32_t>(device, a, indices_cast, axis, b, out);
AddAtImpl<int32_t>(device, a, indices_cast, axis, b, out, mode);
}
}
};
Expand Down
5 changes: 3 additions & 2 deletions chainerx_cc/chainerx/kernels/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
#include "chainerx/array.h"
#include "chainerx/array_index.h"
#include "chainerx/kernel.h"
#include "chainerx/routines/indexing.h"

namespace chainerx {

class AddAtKernel : public Kernel {
public:
virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out) = 0;
virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& b, const Array& out, IndexBoundsMode mode) = 0;
};

class TakeKernel : public Kernel {
public:
virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& out) = 0;
virtual void Call(const Array& a, const Array& indices, int8_t axis, const Array& out, IndexBoundsMode mode) = 0;
};

class WhereKernel : public Kernel {
Expand Down
Loading

0 comments on commit 4fd2210

Please sign in to comment.