Skip to content

Commit

Permalink
Add option to avoid copying into the chunk cache when writing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635977610
Change-Id: I09cb0b54dd7bb5ad4080ca43feddac5ad4c8b67c
  • Loading branch information
jbms authored and copybara-github committed May 22, 2024
1 parent d77c943 commit 7570693
Show file tree
Hide file tree
Showing 24 changed files with 1,247 additions and 384 deletions.
5 changes: 4 additions & 1 deletion python/tensorstore/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,10 @@ pybind11_cc_library(
":write_futures",
"//tensorstore",
"//tensorstore:array",
"//tensorstore:array_storage_statistics",
"//tensorstore:batch",
"//tensorstore:cast",
"//tensorstore:codec_spec",
"//tensorstore:context",
"//tensorstore:contiguous_layout",
"//tensorstore:data_type",
Expand All @@ -514,6 +516,7 @@ pybind11_cc_library(
"//tensorstore:open_options",
"//tensorstore:progress",
"//tensorstore:rank",
"//tensorstore:read_write_options",
"//tensorstore:resize_options",
"//tensorstore:schema",
"//tensorstore:spec",
Expand All @@ -522,6 +525,7 @@ pybind11_cc_library(
"//tensorstore/driver/array",
"//tensorstore/index_space:index_transform",
"//tensorstore/internal:global_initializer",
"//tensorstore/internal:intrusive_ptr",
"//tensorstore/internal/json:pprint_python",
"//tensorstore/internal/json_binding",
"//tensorstore/kvstore",
Expand All @@ -530,7 +534,6 @@ pybind11_cc_library(
"//tensorstore/util:future",
"//tensorstore/util:unit",
"@com_github_pybind_pybind11//:pybind11",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = True,
Expand Down
61 changes: 44 additions & 17 deletions python/tensorstore/tensorstore_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "python/tensorstore/array_type_caster.h"
#include "python/tensorstore/batch.h"
Expand All @@ -45,11 +44,12 @@
#include "python/tensorstore/spec.h"
#include "python/tensorstore/tensorstore_class.h"
#include "python/tensorstore/tensorstore_module_components.h"
#include "python/tensorstore/transaction.h"
#include "python/tensorstore/write_futures.h"
#include "tensorstore/array.h"
#include "tensorstore/array_storage_statistics.h"
#include "tensorstore/batch.h"
#include "tensorstore/cast.h"
#include "tensorstore/codec_spec.h"
#include "tensorstore/context.h"
#include "tensorstore/contiguous_layout.h"
#include "tensorstore/data_type.h"
Expand All @@ -58,6 +58,7 @@
#include "tensorstore/index_space/index_domain.h"
#include "tensorstore/index_space/index_transform.h"
#include "tensorstore/internal/global_initializer.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/internal/json/pprint_python.h"
#include "tensorstore/internal/json_binding/json_binding.h"
#include "tensorstore/internal/json_binding/std_optional.h"
Expand All @@ -67,6 +68,7 @@
#include "tensorstore/open_options.h"
#include "tensorstore/progress.h"
#include "tensorstore/rank.h"
#include "tensorstore/read_write_options.h"
#include "tensorstore/resize_options.h"
#include "tensorstore/schema.h"
#include "tensorstore/serialization/std_optional.h"
Expand All @@ -92,17 +94,24 @@ namespace py = ::pybind11;

namespace {

template <typename... ParamDef>
WriteFutures IssueCopyOrWrite(
const TensorStore<>& self,
std::variant<PythonTensorStoreObject*, ArrayArgumentPlaceholder> source) {
std::variant<PythonTensorStoreObject*, ArrayArgumentPlaceholder> source,
KeywordArgument<ParamDef>&... arg) {
if (auto* store = std::get_if<PythonTensorStoreObject*>(&source)) {
return tensorstore::Copy((**store).value, self);
CopyOptions options;
ApplyKeywordArguments<ParamDef...>(options, arg...);
return tensorstore::Copy((**store).value, self, std::move(options));
} else {
WriteOptions options;
ApplyKeywordArguments<ParamDef...>(options, arg...);
auto& source_obj = std::get_if<ArrayArgumentPlaceholder>(&source)->value;
SharedArray<const void> source_array;
ConvertToArray<const void, dynamic_rank, /*nothrow=*/false>(
source_obj, &source_array, self.dtype(), 0, self.rank());
return tensorstore::Write(std::move(source_array), self);
return tensorstore::Write(std::move(source_array), self,
std::move(options));
}
}

Expand All @@ -127,6 +136,15 @@ constexpr auto ForwardSpecRequestSetters = [](auto callback,
spec_setters::SetUnbindContext{});
};

constexpr auto ForwardWriteSetters = [](auto callback, auto... other_param) {
callback(other_param...,
// TODO(jbms): Add this option once it is supported.
#if 0
write_setters::SetCanReferenceSourceDataUntilCommit{},
#endif
write_setters::SetCanReferenceSourceDataIndefinitely{});
};

using TensorStoreCls = py::class_<PythonTensorStoreObject>;

TensorStoreCls MakeTensorStoreClass(py::module m) {
Expand Down Expand Up @@ -530,16 +548,8 @@ See also:
)",
py::kw_only(), py::arg("order") = "C", py::arg("batch") = std::nullopt);

cls.def(
"write",
[](Self& self,
std::variant<PythonTensorStoreObject*, ArrayArgumentPlaceholder>
source) {
return PythonWriteFutures(
IssueCopyOrWrite(self.value, std::move(source)),
self.reference_manager());
},
R"(
ForwardWriteSetters([&](auto... param_def) {
std::string doc = R"(
Writes to the current domain.
Example:
Expand Down Expand Up @@ -581,6 +591,10 @@ Writes to the current domain.
:python:`self.dtype`. May be an existing :py:obj:`TensorStore` or any
:py:obj:`~numpy.typing.ArrayLike`, including a scalar.
)";
AppendKeywordArgumentDocs(doc, param_def...);
doc += R"(
Returns:
Future representing the asynchronous result of the write operation.
Expand Down Expand Up @@ -723,8 +737,21 @@ writes to be read:
in any subsequent reads *using the same transaction*. The write is only
durably committed once the *transaction* is committed successfully.
)",
py::arg("source"));
)";
cls.def(
"write",
[](Self& self,
std::variant<PythonTensorStoreObject*, ArrayArgumentPlaceholder>
source,
KeywordArgument<decltype(param_def)>... kwarg) {
return PythonWriteFutures(
IssueCopyOrWrite<decltype(param_def)...>(
self.value, std::move(source), kwarg...),
self.reference_manager());
},
doc.c_str(), py::arg("source"), py::kw_only(),
MakeKeywordArgumentPyArg(param_def)...);
});

cls.def(
"resize",
Expand Down
43 changes: 43 additions & 0 deletions python/tensorstore/tensorstore_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,49 @@ Batch to use for reading any metadata required for opening.

} // namespace open_setters

namespace write_setters {

#if 0
// TODO(jbms): Add this option once it is supported.
struct SetCanReferenceSourceDataUntilCommit {
using type = bool;
static constexpr const char* name = "can_reference_source_data_until_commit";
template <typename Self>
static absl::Status Apply(Self& self, type value) {
if (value) {
self.Set(can_reference_source_data_until_commit);
}
return absl::OkStatus();
}
static constexpr const char* doc = R"(

References to the source data may be retained until the write is committed. The
source data must not be modified until the write is committed.

)";
};
#endif

struct SetCanReferenceSourceDataIndefinitely {
using type = bool;
static constexpr const char* name = "can_reference_source_data_indefinitely";
template <typename Self>
static absl::Status Apply(Self& self, type value) {
if (value) {
self.Set(can_reference_source_data_indefinitely);
}
return absl::OkStatus();
}
static constexpr const char* doc = R"(
References to the source data may be retained indefinitely, even after the write
is committed. The source data must not be modified until all references are
released.
)";
};
} // namespace write_setters

} // namespace internal_python
} // namespace tensorstore

Expand Down
17 changes: 17 additions & 0 deletions python/tensorstore/tests/tensorstore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,20 @@ async def test_spec_open_mode():
store = await ts.open(spec, context=context, open_mode=open_mode)
requested_spec = store.spec(**open_mode_kwargs)
assert requested_spec.open_mode == open_mode


async def test_zero_copy():
store = await ts.open(
{"driver": "zarr3", "kvstore": "memory://"},
dtype=ts.uint32,
shape=[64],
create=True,
)
arr = np.full(shape=[64], fill_value=42, dtype=np.uint32)
await store.write(arr, can_reference_source_data_indefinitely=True)
np.testing.assert_equal(42, await store.read())
# Modify arr. This violates the guarantee indicated by
# `can_reference_source_data_indefinitely=True` but is done here for testing
# purposes.
arr[...] = 43
np.testing.assert_equal(43, await store.read())
3 changes: 3 additions & 0 deletions tensorstore/driver/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ tensorstore_cc_library(
hdrs = ["chunk.h"],
deps = [
"//tensorstore:index",
"//tensorstore:read_write_options",
"//tensorstore/index_space:index_transform",
"//tensorstore/index_space:transformed_array",
"//tensorstore/internal:arena",
"//tensorstore/internal:lock_collection",
"//tensorstore/internal:nditerable",
"//tensorstore/internal/poly",
"//tensorstore/util:future",
"//tensorstore/util:result",
"//tensorstore/util:span",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
],
)
Expand Down
16 changes: 13 additions & 3 deletions tensorstore/driver/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,21 @@ void ArrayDriver::Write(
return GetTransformedArrayNDIterable(self->data_, chunk_transform, arena);
}

WriteChunk::EndWriteResult operator()(WriteChunk::EndWrite,
IndexTransformView<> chunk_transform,
bool success, Arena* arena) {
WriteChunk::EndWriteResult operator()(
WriteChunk::EndWrite, IndexTransformView<> /*chunk_transform*/,
bool /*success*/, Arena* /*arena*/) {
return {};
}

bool operator()(
WriteChunk::WriteArray, IndexTransformView<> /*chunk_transform*/,
WriteChunk::GetWriteSourceArrayFunction /*get_source_array*/,
Arena* /*arena*/, WriteChunk::EndWriteResult& /*end_write_result*/) {
// Note: Since the backing array for this driver must remain fixed, we
// can't support zero-copy writes, and therefore this method offers no
// advantage over the generic `NDIterable` code path.
return false;
}
};
// Cancellation does not make sense since there is only a single call to
// `set_value` which occurs immediately after `set_starting`.
Expand Down
11 changes: 11 additions & 0 deletions tensorstore/driver/cast/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ struct WriteChunkImpl {
bool success, Arena* arena) {
return base(WriteChunk::EndWrite{}, chunk_transform, success, arena);
}

bool operator()(WriteChunk::WriteArray, IndexTransformView<> chunk_transform,
WriteChunk::GetWriteSourceArrayFunction get_source_array,
Arena* arena, WriteChunk::EndWriteResult& end_write_result) {
if (!(self->output_conversion_.flags &
DataTypeConversionFlags::kCanReinterpretCast)) {
return false;
}
return base(WriteChunk::WriteArray{}, chunk_transform, get_source_array,
arena, end_write_result);
}
};

template <typename Chunk, typename ChunkImpl>
Expand Down
29 changes: 28 additions & 1 deletion tensorstore/driver/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
/// provided along with the ReadChunk/WriteChunk object.

#include <mutex>
#include <utility>

#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "tensorstore/index.h"
#include "tensorstore/index_space/index_transform.h"
#include "tensorstore/index_space/transformed_array.h"
#include "tensorstore/internal/arena.h"
#include "tensorstore/internal/lock_collection.h"
#include "tensorstore/internal/nditerable.h"
#include "tensorstore/internal/poly/poly.h"
#include "tensorstore/read_write_options.h"
#include "tensorstore/util/future.h"
#include "tensorstore/util/result.h"
#include "tensorstore/util/span.h"
Expand Down Expand Up @@ -81,6 +85,7 @@ struct ReadChunk {
struct WriteChunk {
struct BeginWrite {};
struct EndWrite {};
struct WriteArray {};

struct [[nodiscard]] EndWriteResult {
/// Indicates an error recording write operation in memory. Errors
Expand All @@ -103,6 +108,13 @@ struct WriteChunk {
Future<const void> commit_future;
};

using TransformedArrayWithReferenceRestriction =
std::pair<TransformedSharedArray<const void>,
SourceDataReferenceRestriction>;

using GetWriteSourceArrayFunction =
absl::FunctionRef<Result<TransformedArrayWithReferenceRestriction>()>;

using Impl = poly::Poly<
sizeof(void*) * 2,
/*Copyable=*/true, //
Expand Down Expand Up @@ -154,7 +166,22 @@ struct WriteChunk {
/// \param arena Non-null pointer to allocation arena that may be
/// used for allocating memory.
EndWriteResult(EndWrite, IndexTransformView<> chunk_transform,
bool success, Arena* arena)>;
bool success, Arena* arena),

/// Writes a transformed array directly, if supported.
///
/// \param chunk_transform Transform with a range that is a subset of the
/// associated `WriteChunk::transform`.
/// \param get_source_array Function that may be called to obtain the
/// source array with the same input domain as `chunk_transform`.
/// \param arena Non-null pointer to allocation arena that may be
/// used for allocating memory.
/// \param end_write_result[out] Set on success.
/// \returns `true` if `WriteArray` is supported and the result has been
/// set in `end_write_result`, `false` otherwise.
bool(WriteArray, IndexTransformView<> chunk_transform,
GetWriteSourceArrayFunction get_source_array, Arena* arena,
EndWriteResult& end_write_result)>;

/// Type-erased chunk implementation. In the case of the chunks produced by
/// `ChunkCache::Write`, for example, the contained object holds a
Expand Down
6 changes: 6 additions & 0 deletions tensorstore/driver/json/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,12 @@ struct WriteChunkImpl {
return {/*.copy_status=*/{},
/*.commit_future=*/node->transaction()->future()};
}

bool operator()(WriteChunk::WriteArray, IndexTransformView<> chunk_transform,
WriteChunk::GetWriteSourceArrayFunction get_source_array,
Arena* arena, WriteChunk::EndWriteResult& end_write_result) {
return false;
}
};

void JsonDriver::Write(
Expand Down
Loading

0 comments on commit 7570693

Please sign in to comment.