Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Improved test
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 19, 2022
1 parent 99a91b2 commit 9d56c5d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 30 deletions.
26 changes: 18 additions & 8 deletions arrow-parquet-integration-testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ def get_file_path(file: str):


def _prepare(
file: str, version: str, compression: str, encoding_utf8: str, projection=None
file: str,
version: str,
compression: str,
encoding_utf8: str,
encoding_int: str,
projection=None,
):
write = f"{file}.parquet"

Expand All @@ -24,8 +29,10 @@ def _prepare(
write,
"--version",
version,
"--encoding-int",
"--encoding-utf8",
encoding_utf8,
"--encoding-int",
encoding_int,
"--compression",
compression,
]
Expand All @@ -38,7 +45,7 @@ def _prepare(
return write


def _expected(file: str):
def _expected(file: str) -> pyarrow.Table:
return pyarrow.ipc.RecordBatchFileReader(get_file_path(file)).read_all()


Expand Down Expand Up @@ -75,16 +82,19 @@ def variations():
# "generated_custom_metadata",
]:
# pyarrow does not support decoding "delta"-encoded values.
for encoding in ["plain", "delta"]:
#for encoding in ["plain"]:
for encoding_int in ["plain", "delta"]:
if encoding_int == "delta" and file in {"generated_primitive", "generated_null"}:
# see https://issues.apache.org/jira/browse/ARROW-17465
continue

for compression in ["uncompressed", "zstd", "snappy"]:
yield (version, file, compression, encoding)
yield (version, file, compression, "plain", encoding_int)


if __name__ == "__main__":
for (version, file, compression, encoding_utf8) in variations():
for (version, file, compression, encoding_utf8, encoding_int) in variations():
expected = _expected(file)
path = _prepare(file, version, compression, encoding_utf8)
path = _prepare(file, version, compression, encoding_utf8, encoding_int)

table = pq.read_table(path)
os.remove(path)
Expand Down
56 changes: 34 additions & 22 deletions arrow-parquet-integration-testing/main_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from main import _prepare, _expected


def test(file: str, version: str, column, compression: str, encoding: str):
def test(
file: str,
version: str,
column: str,
compression: str,
encoding: str,
):
"""
Tests that pyspark can read a parquet file written by arrow2.
Expand All @@ -16,13 +22,13 @@ def test(file: str, version: str, column, compression: str, encoding: str):
In pyspark: read (written) parquet to Python
assert that they are equal
"""
# write parquet
path = _prepare(file, version, compression, encoding, [column[1]])

# read IPC to Python
expected = _expected(file)
expected = next(c for i, c in enumerate(expected) if i == column[1])
expected = expected.combine_chunks().tolist()
column_index = next(i for i, c in enumerate(expected.column_names) if c == column)
expected = expected[column].combine_chunks().tolist()

# write parquet
path = _prepare(file, version, compression, encoding, encoding, [column_index])

# read parquet to Python
spark = pyspark.sql.SparkSession.builder.config(
Expand All @@ -31,28 +37,34 @@ def test(file: str, version: str, column, compression: str, encoding: str):
"false",
).getOrCreate()

result = spark.read.parquet(path).select(column[0]).collect()
result = [r[column[0]] for r in result]
result = spark.read.parquet(path).select(column).collect()
result = [r[column] for r in result]
os.remove(path)

# assert equality
assert expected == result


test("generated_primitive", "2", ("utf8_nullable", 24), "uncompressed", "delta")
test("generated_primitive", "2", ("utf8_nullable", 24), "snappy", "delta")
test("generated_null", "2", "f1", "uncompressed", "delta")

test("generated_primitive", "2", "utf8_nullable", "uncompressed", "delta")
test("generated_primitive", "2", "utf8_nullable", "snappy", "delta")
test("generated_primitive", "2", "int32_nullable", "uncompressed", "delta")
test("generated_primitive", "2", "int32_nullable", "snappy", "delta")
test("generated_primitive", "2", "int16_nullable", "uncompressed", "delta")
test("generated_primitive", "2", "int16_nullable", "snappy", "delta")

test("generated_dictionary", "1", ("dict0", 0), "uncompressed", "plain")
test("generated_dictionary", "1", ("dict0", 0), "snappy", "plain")
test("generated_dictionary", "2", ("dict0", 0), "uncompressed", "plain")
test("generated_dictionary", "2", ("dict0", 0), "snappy", "plain")
test("generated_dictionary", "1", "dict0", "uncompressed", "plain")
test("generated_dictionary", "1", "dict0", "snappy", "plain")
test("generated_dictionary", "2", "dict0", "uncompressed", "plain")
test("generated_dictionary", "2", "dict0", "snappy", "plain")

test("generated_dictionary", "1", ("dict1", 1), "uncompressed", "plain")
test("generated_dictionary", "1", ("dict1", 1), "snappy", "plain")
test("generated_dictionary", "2", ("dict1", 1), "uncompressed", "plain")
test("generated_dictionary", "2", ("dict1", 1), "snappy", "plain")
test("generated_dictionary", "1", "dict1", "uncompressed", "plain")
test("generated_dictionary", "1", "dict1", "snappy", "plain")
test("generated_dictionary", "2", "dict1", "uncompressed", "plain")
test("generated_dictionary", "2", "dict1", "snappy", "plain")

test("generated_dictionary", "1", ("dict2", 2), "uncompressed", "plain")
test("generated_dictionary", "1", ("dict2", 2), "snappy", "plain")
test("generated_dictionary", "2", ("dict2", 2), "uncompressed", "plain")
test("generated_dictionary", "2", ("dict2", 2), "snappy", "plain")
test("generated_dictionary", "1", "dict2", "uncompressed", "plain")
test("generated_dictionary", "1", "dict2", "snappy", "plain")
test("generated_dictionary", "2", "dict2", "uncompressed", "plain")
test("generated_dictionary", "2", "dict2", "snappy", "plain")

0 comments on commit 9d56c5d

Please sign in to comment.