Skip to content

Commit

Permalink
apacheGH-39163: [C++] Add missing data copy in StreamDecoder::Consume…
Browse files Browse the repository at this point in the history
…(data)

We need to copy data for metadata message. Because it may be used in
subsequent `Consume(data)` calls. We can't assume that the given
`data` is still valid in subsequent `Consume(data)`.

We also need to copy buffered `data` because it's used in subsequent
`Consume(data)` calls.
  • Loading branch information
kou committed Dec 10, 2023
1 parent e3c8187 commit 9a18aec
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 19 deletions.
37 changes: 27 additions & 10 deletions cpp/src/arrow/ipc/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,24 @@ class MessageDecoder::MessageDecoderImpl {
RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
break;
case State::METADATA: {
auto buffer = std::make_shared<Buffer>(data, next_required_size_);
// We need to copy metadata because it's used in
// ConsumeBody(). ConsumeBody() may be called from another
// ConsumeData(). We can't assume that the given data for
// the current ConsumeData() call is still valid in the
// next ConsumeData() call. So we need to copy metadata
// here.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer,
AllocateBuffer(next_required_size_, pool_));
memcpy(buffer->mutable_data(), data, next_required_size_);
RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
} break;
case State::BODY: {
// We don't need to copy the given data for body because
// we can assume that a decoded record batch should be
// valid only in a listener_->OnMessageDecoded() call. If
// the passed message is needed to be valid after the
// call, it's a listener_'s responsibility. The listener_
// may copy the data for it.
auto buffer = std::make_shared<Buffer>(data, next_required_size_);
RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
} break;
Expand All @@ -645,7 +659,12 @@ class MessageDecoder::MessageDecoderImpl {
return Status::OK();
}

chunks_.push_back(std::make_shared<Buffer>(data, size));
// We need to copy unused data because the given data for the
// current ConsumeData() call may be invalid in the next
// ConsumeData() call.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> chunk, AllocateBuffer(size, pool_));
memcpy(chunk->mutable_data(), data, size);
chunks_.push_back(std::move(chunk));
buffered_size_ += size;
return ConsumeChunks();
}
Expand Down Expand Up @@ -830,8 +849,7 @@ class MessageDecoder::MessageDecoderImpl {
}
buffered_size_ -= next_required_size_;
} else {
ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
metadata_ = std::shared_ptr<Buffer>(metadata.release());
ARROW_ASSIGN_OR_RAISE(metadata_, AllocateBuffer(next_required_size_, pool_));
RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
}
return ConsumeMetadata();
Expand All @@ -846,9 +864,8 @@ class MessageDecoder::MessageDecoderImpl {
next_required_size_ = skip_body_ ? 0 : body_length;
RETURN_NOT_OK(listener_->OnBody());
if (next_required_size_ == 0) {
ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
std::shared_ptr<Buffer> shared_body(body.release());
return ConsumeBody(&shared_body);
auto body = std::make_shared<Buffer>(nullptr, 0);
return ConsumeBody(&body);
} else {
return Status::OK();
}
Expand All @@ -872,10 +889,10 @@ class MessageDecoder::MessageDecoderImpl {
buffered_size_ -= used_size;
return Status::OK();
} else {
ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> body,
AllocateBuffer(next_required_size_, pool_));
RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
std::shared_ptr<Buffer> shared_body(body.release());
return ConsumeBody(&shared_body);
return ConsumeBody(&body);
}
}

Expand Down
11 changes: 8 additions & 3 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper {
Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) override {
auto listener = std::make_shared<CollectListener>();
auto listener = std::make_shared<CollectListener>(true);
StreamDecoder decoder(listener, options);
RETURN_NOT_OK(DoConsume(&decoder));
*out_batches = listener->record_batches();
Expand All @@ -1358,7 +1358,10 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper {

struct StreamDecoderDataWriterHelper : public StreamDecoderWriterHelper {
Status DoConsume(StreamDecoder* decoder) override {
return decoder->Consume(buffer_->data(), buffer_->size());
// This data is valid only in this function.
ARROW_ASSIGN_OR_RAISE(auto temporary_buffer,
Buffer::Copy(buffer_, arrow::default_cpu_memory_manager()));
return decoder->Consume(temporary_buffer->data(), temporary_buffer->size());
}
};

Expand All @@ -1369,7 +1372,9 @@ struct StreamDecoderBufferWriterHelper : public StreamDecoderWriterHelper {
struct StreamDecoderSmallChunksWriterHelper : public StreamDecoderWriterHelper {
Status DoConsume(StreamDecoder* decoder) override {
for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) {
RETURN_NOT_OK(decoder->Consume(buffer_->data() + offset, 1));
// This data is valid only in this block.
ARROW_ASSIGN_OR_RAISE(auto temporary_buffer, buffer_->CopySlice(offset, 1));
RETURN_NOT_OK(decoder->Consume(temporary_buffer->data(), temporary_buffer->size()));
}
return Status::OK();
}
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,39 @@ Status Listener::OnRecordBatchWithMetadataDecoded(
return OnRecordBatchDecoded(std::move(record_batch_with_metadata.batch));
}

namespace {
Status CopyArrayData(std::shared_ptr<ArrayData> data) {
auto& buffers = data->buffers;
for (size_t i = 0; i < buffers.size(); ++i) {
auto& buffer = buffers[i];
if (!buffer) {
continue;
}
ARROW_ASSIGN_OR_RAISE(buffers[i], Buffer::Copy(buffer, buffer->memory_manager()));
}
for (auto child_data : data->child_data) {
ARROW_RETURN_NOT_OK(CopyArrayData(child_data));
}
if (data->dictionary) {
ARROW_RETURN_NOT_OK(CopyArrayData(data->dictionary));
}
return Status::OK();
}
}; // namespace

Status CollectListener::OnRecordBatchWithMetadataDecoded(
RecordBatchWithMetadata record_batch_with_metadata) {
auto record_batch = std::move(record_batch_with_metadata.batch);
if (copy_record_batch_) {
for (auto column_data : record_batch->column_data()) {
ARROW_RETURN_NOT_OK(CopyArrayData(column_data));
}
}
record_batches_.push_back(std::move(record_batch));
metadatas_.push_back(std::move(record_batch_with_metadata.custom_metadata));
return Status::OK();
}

class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal {
public:
explicit StreamDecoderImpl(std::shared_ptr<Listener> listener, IpcReadOptions options)
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,12 @@ class ARROW_EXPORT Listener {
/// \since 0.17.0
class ARROW_EXPORT CollectListener : public Listener {
public:
CollectListener() : schema_(), filtered_schema_(), record_batches_(), metadatas_() {}
CollectListener(bool copy_record_batch = false)
: copy_record_batch_(copy_record_batch),
schema_(),
filtered_schema_(),
record_batches_(),
metadatas_() {}
virtual ~CollectListener() = default;

Status OnSchemaDecoded(std::shared_ptr<Schema> schema,
Expand All @@ -328,11 +333,7 @@ class ARROW_EXPORT CollectListener : public Listener {
}

Status OnRecordBatchWithMetadataDecoded(
RecordBatchWithMetadata record_batch_with_metadata) override {
record_batches_.push_back(std::move(record_batch_with_metadata.batch));
metadatas_.push_back(std::move(record_batch_with_metadata.custom_metadata));
return Status::OK();
}
RecordBatchWithMetadata record_batch_with_metadata) override;

/// \return the decoded schema
std::shared_ptr<Schema> schema() const { return schema_; }
Expand Down Expand Up @@ -375,6 +376,7 @@ class ARROW_EXPORT CollectListener : public Listener {
}

private:
bool copy_record_batch_;
std::shared_ptr<Schema> schema_;
std::shared_ptr<Schema> filtered_schema_;
std::vector<std::shared_ptr<RecordBatch>> record_batches_;
Expand Down

0 comments on commit 9a18aec

Please sign in to comment.