Skip to content

Commit

Permalink
Support dtype_backend="pandas|pyarrow" configuration (#9719)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Dec 16, 2022
1 parent 936d9f7 commit 1ac0b11
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 28 deletions.
8 changes: 8 additions & 0 deletions dask/dask-schema.yaml
Expand Up @@ -72,6 +72,14 @@ properties:
task when reading a parquet dataset from a REMOTE file system.
Specifying 0 will result in serial execution on the client.
dtype_backend:
enum:
- pandas
- pyarrow
description: |
The nullable dtype implementation to use. Must be either "pandas" or
"pyarrow". Default is "pandas".
array:
type: object
properties:
Expand Down
1 change: 1 addition & 0 deletions dask/dask.yaml
Expand Up @@ -12,6 +12,7 @@ dataframe:
parquet:
metadata-task-size-local: 512 # Number of files per local metadata-processing task
metadata-task-size-remote: 16 # Number of files per remote metadata-processing task
dtype_backend: "pandas" # Dtype implementation to use

array:
backend: "numpy" # Backend array library for input IO and data creation
Expand Down
22 changes: 17 additions & 5 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -1643,19 +1643,31 @@ def _arrow_table_to_pandas(
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
# Determine is `pandas` or `pyarrow`-backed dtypes should be used
if use_nullable_dtypes == "pandas":
default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get
else:
# use_nullable_dtypes == "pyarrow"

def default_types_mapper(pyarrow_dtype): # type: ignore
# Special case pyarrow strings to use more feature complete dtype
# See https://github.com/pandas-dev/pandas/issues/50074
if pyarrow_dtype == pa.string():
return pd.StringDtype("pyarrow")
else:
return pd.ArrowDtype(pyarrow_dtype)

if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
# User-provided entries take priority over default_types_mapper
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)
return types_mapper(pa_type) or default_types_mapper(pa_type)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
_kwargs["types_mapper"] = default_types_mapper

return arrow_table.to_pandas(categories=categories, **_kwargs)

Expand Down
23 changes: 21 additions & 2 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -189,7 +189,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
use_nullable_dtypes: bool = False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -264,6 +264,22 @@ def read_parquet(
engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto'
Parquet library to use. Defaults to 'auto', which uses ``pyarrow`` if
it is installed, and falls back to ``fastparquet`` otherwise.
use_nullable_dtypes : {False, True}
Whether to use extension dtypes for the resulting ``DataFrame``.
``use_nullable_dtypes=True`` is only supported when ``engine="pyarrow"``.
.. note::
Use the ``dataframe.dtype_backend`` config option to select which
dtype implementation to use.
``dataframe.dtype_backend="pandas"`` (the default) will use
pandas' ``numpy``-backed nullable dtypes (e.g. ``Int64``,
``string[python]``, etc.) while ``dataframe.dtype_backend="pyarrow"``
will use ``pyarrow``-backed extension dtypes (e.g. ``int64[pyarrow]``,
``string[pyarrow]``, etc.). ``dataframe.dtype_backend="pyarrow"``
requires ``pandas`` 1.5+.
calculate_divisions : bool, default False
Whether to use min/max statistics from the footer metadata (or global
``_metadata`` file) to calculate divisions for the output DataFrame
Expand Down Expand Up @@ -381,6 +397,9 @@ def read_parquet(
pyarrow.parquet.ParquetDataset
"""

if use_nullable_dtypes:
use_nullable_dtypes = dask.config.get("dataframe.dtype_backend")

# "Pre-deprecation" warning for `chunksize`
if chunksize:
warnings.warn(
Expand Down Expand Up @@ -586,7 +605,7 @@ def read_parquet(
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
ctx = contextlib.nullcontext()
ctx = contextlib.nullcontext() # type: ignore

with ctx:
# Construct the output collection with from_map
Expand Down
68 changes: 48 additions & 20 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -15,7 +15,12 @@
import dask.dataframe as dd
import dask.multiprocessing
from dask.blockwise import Blockwise, optimize_blockwise
from dask.dataframe._compat import PANDAS_GT_110, PANDAS_GT_121, PANDAS_GT_130
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_121,
PANDAS_GT_130,
PANDAS_GT_150,
)
from dask.dataframe.io.parquet.core import get_engine
from dask.dataframe.io.parquet.utils import _parse_pandas_metadata
from dask.dataframe.optimize import optimize_dataframe_getitem
Expand Down Expand Up @@ -618,17 +623,37 @@ def test_roundtrip_nullable_dtypes(tmp_path, write_engine, read_engine):


@PYARROW_MARK
def test_use_nullable_dtypes(tmp_path, engine):
@pytest.mark.parametrize(
"dtype_backend",
[
"pandas",
pytest.param(
"pyarrow",
marks=pytest.mark.skipif(
not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes"
),
),
],
)
def test_use_nullable_dtypes(tmp_path, engine, dtype_backend):
"""
Test reading a parquet file without pandas metadata,
but forcing use of nullable dtypes where appropriate
"""

if dtype_backend == "pandas":
dtype_extra = ""
else:
# dtype_backend == "pyarrow"
dtype_extra = "[pyarrow]"
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype=f"Int64{dtype_extra}"),
"b": pd.Series(
[True, pd.NA, False, True, False], dtype=f"boolean{dtype_extra}"
),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype=f"Float64{dtype_extra}"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype=f"string{dtype_extra}"),
}
)
ddf = dd.from_pandas(df, npartitions=2)
Expand All @@ -644,21 +669,24 @@ def write_partition(df, i):
partitions = ddf.to_delayed()
dask.compute([write_partition(p, i) for i, p in enumerate(partitions)])

# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(ValueError, match="`use_nullable_dtypes` is not supported"):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
with dask.config.set({"dataframe.dtype_backend": dtype_backend}):
# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(
ValueError, match="`use_nullable_dtypes` is not supported"
):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)

# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)
# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)


@PYARROW_MARK
Expand Down
45 changes: 44 additions & 1 deletion dask/tests/test_spark_compat.py
@@ -1,19 +1,22 @@
import decimal
import signal
import sys
import threading

import pytest

import dask
from dask.datasets import timeseries

dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pa = pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")

import numpy as np
import pandas as pd

from dask.dataframe._compat import PANDAS_GT_150
from dask.dataframe.utils import assert_eq

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -149,3 +152,43 @@ def test_roundtrip_parquet_spark_to_dask_extension_dtypes(spark_session, tmpdir)
[pd.api.types.is_extension_array_dtype(dtype) for dtype in ddf.dtypes]
), ddf.dtypes
assert_eq(ddf, pdf, check_index=False)


@pytest.mark.skipif(not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes")
def test_read_decimal_dtype_pyarrow(spark_session, tmpdir):
tmpdir = str(tmpdir)
npartitions = 3
size = 6

decimal_data = [
decimal.Decimal("8093.234"),
decimal.Decimal("8094.234"),
decimal.Decimal("8095.234"),
decimal.Decimal("8096.234"),
decimal.Decimal("8097.234"),
decimal.Decimal("8098.234"),
]
pdf = pd.DataFrame(
{
"a": range(size),
"b": decimal_data,
}
)
sdf = spark_session.createDataFrame(pdf)
sdf = sdf.withColumn("b", sdf["b"].cast(pyspark.sql.types.DecimalType(7, 3)))
# We are not overwriting any data, but spark complains if the directory
# already exists (as tmpdir does) and we don't set overwrite
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")

with dask.config.set({"dataframe.dtype_backend": "pyarrow"}):
ddf = dd.read_parquet(tmpdir, engine="pyarrow", use_nullable_dtypes=True)
assert ddf.b.dtype.pyarrow_dtype == pa.decimal128(7, 3)
assert ddf.b.compute().dtype.pyarrow_dtype == pa.decimal128(7, 3)
expected = pdf.astype(
{
"a": "int64[pyarrow]",
"b": pd.ArrowDtype(pa.decimal128(7, 3)),
}
)

assert_eq(ddf, expected, check_index=False)

0 comments on commit 1ac0b11

Please sign in to comment.