From fc893827e6bcf07ea68bedc53f649f470202e306 Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Tue, 9 Jan 2024 07:56:50 +0000 Subject: [PATCH] GH-39527: [C++][Parquet] Validate page sizes before truncated to int32 Be defensive instead of writing invalid data. --- cpp/src/parquet/column_writer.cc | 17 ++++++++++-- cpp/src/parquet/column_writer_test.cc | 40 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index 12b2837fbfd1e..04f2f9b960111 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -271,7 +271,10 @@ class SerializedPageWriter : public PageWriter { } int64_t WriteDictionaryPage(const DictionaryPage& page) override { - int64_t uncompressed_size = page.size(); + int64_t uncompressed_size = page.buffer()->size(); + if (uncompressed_size > std::numeric_limits::max()) { + throw ParquetException("Uncompressed page size overflows to INT32_MAX."); + } std::shared_ptr compressed_data; if (has_compressor()) { auto buffer = std::static_pointer_cast( @@ -288,6 +291,9 @@ class SerializedPageWriter : public PageWriter { dict_page_header.__set_is_sorted(page.is_sorted()); const uint8_t* output_data_buffer = compressed_data->data(); + if (compressed_data->size() > std::numeric_limits::max()) { + throw ParquetException("Compressed page size overflows to INT32_MAX."); + } int32_t output_data_len = static_cast(compressed_data->size()); if (data_encryptor_.get()) { @@ -371,7 +377,7 @@ class SerializedPageWriter : public PageWriter { const int64_t uncompressed_size = page.uncompressed_size(); std::shared_ptr compressed_data = page.buffer(); const uint8_t* output_data_buffer = compressed_data->data(); - int32_t output_data_len = static_cast(compressed_data->size()); + int64_t output_data_len = compressed_data->size(); if (data_encryptor_.get()) { PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( @@ -383,7 +389,14 @@ class SerializedPageWriter : public PageWriter { } format::PageHeader page_header; + + if (uncompressed_size > std::numeric_limits::max()) { + throw ParquetException("Uncompressed page size overflows to INT32_MAX."); + } page_header.__set_uncompressed_page_size(static_cast(uncompressed_size)); + if (output_data_len > std::numeric_limits::max()) { + throw ParquetException("Compressed page size overflows to INT32_MAX."); + } page_header.__set_compressed_page_size(static_cast(output_data_len)); if (page_checksum_verification_) { diff --git a/cpp/src/parquet/column_writer_test.cc b/cpp/src/parquet/column_writer_test.cc index 59fc848d7fd57..5d8eacf58802c 100644 --- a/cpp/src/parquet/column_writer_test.cc +++ b/cpp/src/parquet/column_writer_test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include @@ -36,6 +37,7 @@ #include "parquet/test_util.h" #include "parquet/thrift_internal.h" #include "parquet/types.h" +#include "third_party/parquet_cpp/src2/parquet/column_page.h" namespace bit_util = arrow::bit_util; @@ -889,6 +891,44 @@ TEST_F(TestByteArrayValuesWriter, CheckDefaultStats) { ASSERT_TRUE(this->metadata_is_stats_set()); } +TEST(TestPageWriter, ThrowsOnPagesToLarge) { + NodePtr item = schema::Int32("item"); // optional item + NodePtr list(GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST)); + NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); // optional list + std::vector fields = {bag}; + NodePtr root = GroupNode::Make("schema", Repetition::REPEATED, fields); + + SchemaDescriptor schema; + schema.Init(root); + + auto sink = CreateOutputStream(); + auto props = WriterProperties::Builder().build(); + + auto metadata = ColumnChunkMetaDataBuilder::Make(props, schema.Column(0)); + std::unique_ptr pager = + PageWriter::Open(sink, Compression::UNCOMPRESSED, + Codec::UseDefaultCompressionLevel(), metadata.get()); + + uint8_t data; + std::shared_ptr buffer = + std::make_shared(&data, std::numeric_limits::max() + int64_t{1}); + DataPageV1 over_compressed_limit(buffer, /*num_values=*/100, Encoding::BIT_PACKED, + Encoding::BIT_PACKED, Encoding::BIT_PACKED, + /*uncompressed_size=*/100); + EXPECT_THROW(pager->WriteDataPage(over_compressed_limit), ParquetException); + DictionaryPage dictionary_over_compressed_limit(buffer, /*num_values=*/100, + Encoding::PLAIN); + EXPECT_THROW(pager->WriteDictionaryPage(dictionary_over_compressed_limit), + ParquetException); + + buffer = std::make_shared(&data, 1); + DataPageV1 over_uncompressed_limit( + buffer, /*num_values=*/100, Encoding::BIT_PACKED, Encoding::BIT_PACKED, + Encoding::BIT_PACKED, + /*uncompressed_size=*/std::numeric_limits::max() + int64_t{1}); + EXPECT_THROW(pager->WriteDataPage(over_uncompressed_limit), ParquetException); +} + TEST(TestColumnWriter, RepeatedListsUpdateSpacedBug) { // In ARROW-3930 we discovered a bug when writing from Arrow when we had data // that looks like this: