From 9e69c38b0dc62cbcee3b5d8172b137c1ed6af639 Mon Sep 17 00:00:00 2001 From: Chris Jordan-Squire Date: Mon, 18 Sep 2023 16:08:35 -0400 Subject: [PATCH] GH-35095: Prevent write after close This addresses GH-35095 by adding a flag to IpcFormatWriter to track when a writer has been closed, and check this flag before writes. --- cpp/src/arrow/ipc/read_write_test.cc | 20 ++++++++++++++++++++ cpp/src/arrow/ipc/writer.cc | 8 +++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 69b827b8fe78d..e90af8b789af9 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -1519,6 +1519,23 @@ class ReaderWriterMixin : public ExtensionTypesMixin { } } + void TestWriteAfterClose() { + // Part of GH-35095. + std::shared_ptr batch_ints; + ASSERT_OK(MakeIntRecordBatch(&batch_ints)); + + std::shared_ptr schema = batch_ints->schema(); + + WriterHelper writer_helper; + ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults())); + ASSERT_OK(writer_helper.WriteBatch(batch_ints)); + ASSERT_OK(writer_helper.Finish()); + + // Write after close raises status + auto foo = writer_helper.WriteBatch(batch_ints); + // ASSERT_RAISES(Invalid, writer_helper.WriteBatch(batch_ints)); + } + void TestWriteDifferentSchema() { // Test writing batches with a different schema than the RecordBatchWriter // was initialized with. @@ -1991,6 +2008,9 @@ TEST_F(TestFileFormatGenerator, DictionaryRoundTrip) { TestDictionaryRoundtrip() TEST_F(TestFileFormatGeneratorCoalesced, DictionaryRoundTrip) { TestDictionaryRoundtrip(); } +TEST_F(TestFileFormat, WriteAfterClose) { TestWriteAfterClose(); } + +TEST_F(TestStreamFormat, WriteAfterClose) { TestWriteAfterClose(); } TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); } diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 1d230601566a0..e4b49ed56464e 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -1070,6 +1070,9 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { Status WriteRecordBatch( const RecordBatch& batch, const std::shared_ptr& custom_metadata) override { + if (closed_) { + return Status::Invalid("Destination already closed"); + } if (!batch.schema()->Equals(schema_, false /* check_metadata */)) { return Status::Invalid("Tried to write record batch with different schema"); } @@ -1101,7 +1104,9 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { Status Close() override { RETURN_NOT_OK(CheckStarted()); - return payload_writer_->Close(); + RETURN_NOT_OK(payload_writer_->Close()); + closed_ = true; + return Status::OK(); } Status Start() { @@ -1213,6 +1218,7 @@ class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { std::unordered_map> last_dictionaries_; bool started_ = false; + bool closed_ = false; IpcWriteOptions options_; WriteStats stats_; };