Skip to content

Commit

Permalink
apacheGH-33984: [C++][Python] DLPack implementation for Arrow Arrays …
Browse files Browse the repository at this point in the history
…(producer) (apache#38472)

### Rationale for this change

DLPack is selected for Array API protocol so it is important to have it implemented for Arrow/PyArrow Arrays also. This is possible for primitive type arrays (int, uint and float) with no validity buffer. Device support is not in scope of this PR (CPU only). 

### What changes are included in this PR?

- `ExportArray` and `ExportDevice` methods on Arrow C++ Arrays
- `__dlpack__` method on the base PyArrow Array class exposing `ExportArray` method
-  `__dlpack_device__` method on the base PyArrow Array class exposing `ExportDevice` method

### Are these changes tested?

Yes, tests are added to `dlpack_test.cc` and `test_array.py`.

### Are there any user-facing changes?

No.

* Closes: apache#33984

Lead-authored-by: AlenkaF <frim.alenka@gmail.com>
Co-authored-by: Alenka Frim <AlenkaF@users.noreply.github.com>
Co-authored-by: Antoine Pitrou <antoine@python.org>
Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
3 people committed Dec 19, 2023
1 parent 3e182f2 commit 6c326db
Show file tree
Hide file tree
Showing 15 changed files with 982 additions and 3 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ set(ARROW_SRCS
type_traits.cc
visitor.cc
c/bridge.cc
c/dlpack.cc
io/buffered.cc
io/caching.cc
io/compressed.cc
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

add_arrow_test(bridge_test PREFIX "arrow-c")
add_arrow_test(dlpack_test)

add_arrow_benchmark(bridge_benchmark)

Expand Down
133 changes: 133 additions & 0 deletions cpp/src/arrow/c/dlpack.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/c/dlpack.h"

#include "arrow/array/array_base.h"
#include "arrow/c/dlpack_abi.h"
#include "arrow/device.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"

namespace arrow::dlpack {

namespace {

Result<DLDataType> GetDLDataType(const DataType& type) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = type.bit_width();
switch (type.id()) {
case Type::INT8:
case Type::INT16:
case Type::INT32:
case Type::INT64:
dtype.code = DLDataTypeCode::kDLInt;
return dtype;
case Type::UINT8:
case Type::UINT16:
case Type::UINT32:
case Type::UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
return dtype;
case Type::HALF_FLOAT:
case Type::FLOAT:
case Type::DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
return dtype;
case Type::BOOL:
// DLPack supports byte-packed boolean values
return Status::TypeError("Bit-packed boolean data type not supported by DLPack.");
default:
return Status::TypeError("DataType is not compatible with DLPack spec: ",
type.ToString());
}
}

struct ManagerCtx {
std::shared_ptr<ArrayData> array;
DLManagedTensor tensor;
};

} // namespace

Result<DLManagedTensor*> ExportArray(const std::shared_ptr<Array>& arr) {
// Define DLDevice struct nad check if array type is supported
// by the DLPack protocol at the same time. Raise TypeError if not.
// Supported data types: int, uint, float with no validity buffer.
ARROW_ASSIGN_OR_RAISE(auto device, ExportDevice(arr))

// Define the DLDataType struct
const DataType& type = *arr->type();
std::shared_ptr<ArrayData> data = arr->data();
ARROW_ASSIGN_OR_RAISE(auto dlpack_type, GetDLDataType(type));

// Create ManagerCtx that will serve as the owner of the DLManagedTensor
std::unique_ptr<ManagerCtx> ctx(new ManagerCtx);

// Define the data pointer to the DLTensor
// If array is of length 0, data pointer should be NULL
if (arr->length() == 0) {
ctx->tensor.dl_tensor.data = NULL;
} else {
const auto data_offset = data->offset * type.byte_width();
ctx->tensor.dl_tensor.data =
const_cast<uint8_t*>(data->buffers[1]->data() + data_offset);
}

ctx->tensor.dl_tensor.device = device;
ctx->tensor.dl_tensor.ndim = 1;
ctx->tensor.dl_tensor.dtype = dlpack_type;
ctx->tensor.dl_tensor.shape = const_cast<int64_t*>(&data->length);
ctx->tensor.dl_tensor.strides = NULL;
ctx->tensor.dl_tensor.byte_offset = 0;

ctx->array = std::move(data);
ctx->tensor.manager_ctx = ctx.get();
ctx->tensor.deleter = [](struct DLManagedTensor* self) {
delete reinterpret_cast<ManagerCtx*>(self->manager_ctx);
};
return &ctx.release()->tensor;
}

Result<DLDevice> ExportDevice(const std::shared_ptr<Array>& arr) {
// Check if array is supported by the DLPack protocol.
if (arr->null_count() > 0) {
return Status::TypeError("Can only use DLPack on arrays with no nulls.");
}
const DataType& type = *arr->type();
if (type.id() == Type::BOOL) {
return Status::TypeError("Bit-packed boolean data type not supported by DLPack.");
}
if (!is_integer(type.id()) && !is_floating(type.id())) {
return Status::TypeError("DataType is not compatible with DLPack spec: ",
type.ToString());
}

// Define DLDevice struct
DLDevice device;
if (arr->data()->buffers[1]->device_type() == DeviceAllocationType::kCPU) {
device.device_id = 0;
device.device_type = DLDeviceType::kDLCPU;
return device;
} else {
return Status::NotImplemented(
"DLPack support is implemented only for buffers on CPU device.");
}
}

} // namespace arrow::dlpack
51 changes: 51 additions & 0 deletions cpp/src/arrow/c/dlpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/array/array_base.h"
#include "arrow/c/dlpack_abi.h"

namespace arrow::dlpack {

/// \brief Export Arrow array as DLPack tensor.
///
/// DLMangedTensor is produced as defined by the DLPack protocol,
/// see https://dmlc.github.io/dlpack/latest/.
///
/// Data types for which the protocol is supported are
/// integer and floating-point data types.
///
/// DLPack protocol only supports arrays with one contiguous
/// memory region which means Arrow Arrays with validity buffers
/// are not supported.
///
/// \param[in] arr Arrow array
/// \return DLManagedTensor struct
ARROW_EXPORT
Result<DLManagedTensor*> ExportArray(const std::shared_ptr<Array>& arr);

/// \brief Get DLDevice with enumerator specifying the
/// type of the device data is stored on and index of the
/// device which is 0 by default for CPU.
///
/// \param[in] arr Arrow array
/// \return DLDevice struct
ARROW_EXPORT
Result<DLDevice> ExportDevice(const std::shared_ptr<Array>& arr);

} // namespace arrow::dlpack
Loading

0 comments on commit 6c326db

Please sign in to comment.