From 9d56c5d187541810d5fc1e176e67006feb888fed Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 19 Aug 2022 05:55:36 +0000 Subject: [PATCH] Improved test --- arrow-parquet-integration-testing/main.py | 26 ++++++--- .../main_spark.py | 56 +++++++++++-------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/arrow-parquet-integration-testing/main.py b/arrow-parquet-integration-testing/main.py index 2ca76d99b1f..a880af617d8 100644 --- a/arrow-parquet-integration-testing/main.py +++ b/arrow-parquet-integration-testing/main.py @@ -10,7 +10,12 @@ def get_file_path(file: str): def _prepare( - file: str, version: str, compression: str, encoding_utf8: str, projection=None + file: str, + version: str, + compression: str, + encoding_utf8: str, + encoding_int: str, + projection=None, ): write = f"{file}.parquet" @@ -24,8 +29,10 @@ def _prepare( write, "--version", version, - "--encoding-int", + "--encoding-utf8", encoding_utf8, + "--encoding-int", + encoding_int, "--compression", compression, ] @@ -38,7 +45,7 @@ def _prepare( return write -def _expected(file: str): +def _expected(file: str) -> pyarrow.Table: return pyarrow.ipc.RecordBatchFileReader(get_file_path(file)).read_all() @@ -75,16 +82,19 @@ def variations(): # "generated_custom_metadata", ]: # pyarrow does not support decoding "delta"-encoded values. - for encoding in ["plain", "delta"]: - #for encoding in ["plain"]: + for encoding_int in ["plain", "delta"]: + if encoding_int == "delta" and file in {"generated_primitive", "generated_null"}: + # see https://issues.apache.org/jira/browse/ARROW-17465 + continue + for compression in ["uncompressed", "zstd", "snappy"]: - yield (version, file, compression, encoding) + yield (version, file, compression, "plain", encoding_int) if __name__ == "__main__": - for (version, file, compression, encoding_utf8) in variations(): + for (version, file, compression, encoding_utf8, encoding_int) in variations(): expected = _expected(file) - path = _prepare(file, version, compression, encoding_utf8) + path = _prepare(file, version, compression, encoding_utf8, encoding_int) table = pq.read_table(path) os.remove(path) diff --git a/arrow-parquet-integration-testing/main_spark.py b/arrow-parquet-integration-testing/main_spark.py index c33a1aa63ad..e29655fb46b 100644 --- a/arrow-parquet-integration-testing/main_spark.py +++ b/arrow-parquet-integration-testing/main_spark.py @@ -7,7 +7,13 @@ from main import _prepare, _expected -def test(file: str, version: str, column, compression: str, encoding: str): +def test( + file: str, + version: str, + column: str, + compression: str, + encoding: str, +): """ Tests that pyspark can read a parquet file written by arrow2. @@ -16,13 +22,13 @@ def test(file: str, version: str, column, compression: str, encoding: str): In pyspark: read (written) parquet to Python assert that they are equal """ - # write parquet - path = _prepare(file, version, compression, encoding, [column[1]]) - # read IPC to Python expected = _expected(file) - expected = next(c for i, c in enumerate(expected) if i == column[1]) - expected = expected.combine_chunks().tolist() + column_index = next(i for i, c in enumerate(expected.column_names) if c == column) + expected = expected[column].combine_chunks().tolist() + + # write parquet + path = _prepare(file, version, compression, encoding, encoding, [column_index]) # read parquet to Python spark = pyspark.sql.SparkSession.builder.config( @@ -31,28 +37,34 @@ def test(file: str, version: str, column, compression: str, encoding: str): "false", ).getOrCreate() - result = spark.read.parquet(path).select(column[0]).collect() - result = [r[column[0]] for r in result] + result = spark.read.parquet(path).select(column).collect() + result = [r[column] for r in result] os.remove(path) # assert equality assert expected == result -test("generated_primitive", "2", ("utf8_nullable", 24), "uncompressed", "delta") -test("generated_primitive", "2", ("utf8_nullable", 24), "snappy", "delta") +test("generated_null", "2", "f1", "uncompressed", "delta") + +test("generated_primitive", "2", "utf8_nullable", "uncompressed", "delta") +test("generated_primitive", "2", "utf8_nullable", "snappy", "delta") +test("generated_primitive", "2", "int32_nullable", "uncompressed", "delta") +test("generated_primitive", "2", "int32_nullable", "snappy", "delta") +test("generated_primitive", "2", "int16_nullable", "uncompressed", "delta") +test("generated_primitive", "2", "int16_nullable", "snappy", "delta") -test("generated_dictionary", "1", ("dict0", 0), "uncompressed", "plain") -test("generated_dictionary", "1", ("dict0", 0), "snappy", "plain") -test("generated_dictionary", "2", ("dict0", 0), "uncompressed", "plain") -test("generated_dictionary", "2", ("dict0", 0), "snappy", "plain") +test("generated_dictionary", "1", "dict0", "uncompressed", "plain") +test("generated_dictionary", "1", "dict0", "snappy", "plain") +test("generated_dictionary", "2", "dict0", "uncompressed", "plain") +test("generated_dictionary", "2", "dict0", "snappy", "plain") -test("generated_dictionary", "1", ("dict1", 1), "uncompressed", "plain") -test("generated_dictionary", "1", ("dict1", 1), "snappy", "plain") -test("generated_dictionary", "2", ("dict1", 1), "uncompressed", "plain") -test("generated_dictionary", "2", ("dict1", 1), "snappy", "plain") +test("generated_dictionary", "1", "dict1", "uncompressed", "plain") +test("generated_dictionary", "1", "dict1", "snappy", "plain") +test("generated_dictionary", "2", "dict1", "uncompressed", "plain") +test("generated_dictionary", "2", "dict1", "snappy", "plain") -test("generated_dictionary", "1", ("dict2", 2), "uncompressed", "plain") -test("generated_dictionary", "1", ("dict2", 2), "snappy", "plain") -test("generated_dictionary", "2", ("dict2", 2), "uncompressed", "plain") -test("generated_dictionary", "2", ("dict2", 2), "snappy", "plain") +test("generated_dictionary", "1", "dict2", "uncompressed", "plain") +test("generated_dictionary", "1", "dict2", "snappy", "plain") +test("generated_dictionary", "2", "dict2", "uncompressed", "plain") +test("generated_dictionary", "2", "dict2", "snappy", "plain")