diff --git a/dask/dataframe/_compat.py b/dask/dataframe/_compat.py index acf625efd9a..ec307e746bd 100644 --- a/dask/dataframe/_compat.py +++ b/dask/dataframe/_compat.py @@ -13,6 +13,7 @@ PANDAS_GT_150 = PANDAS_VERSION >= Version("1.5.0") PANDAS_GT_200 = PANDAS_VERSION.major >= 2 PANDAS_GT_201 = PANDAS_VERSION.release >= (2, 0, 1) +PANDAS_GT_202 = PANDAS_VERSION.release >= (2, 0, 2) PANDAS_GT_210 = PANDAS_VERSION.release >= (2, 1, 0) import pandas.testing as tm diff --git a/dask/dataframe/io/parquet/arrow.py b/dask/dataframe/io/parquet/arrow.py index 3bfd89b2409..f50289b59f4 100644 --- a/dask/dataframe/io/parquet/arrow.py +++ b/dask/dataframe/io/parquet/arrow.py @@ -1256,7 +1256,9 @@ def _create_dd_meta(cls, dataset_info): # Make sure all categories are set to "unknown". # Cannot include index names in the `cols` argument. meta = clear_known_categories( - meta, cols=[c for c in categories if c not in meta.index.names] + meta, + cols=[c for c in categories if c not in meta.index.names], + dtype_backend=dtype_backend, ) if partition_obj: diff --git a/dask/dataframe/io/tests/test_parquet.py b/dask/dataframe/io/tests/test_parquet.py index 3e7586e2301..b70a93a9ccb 100644 --- a/dask/dataframe/io/tests/test_parquet.py +++ b/dask/dataframe/io/tests/test_parquet.py @@ -16,7 +16,7 @@ import dask.multiprocessing from dask.array.numpy_compat import _numpy_124 from dask.blockwise import Blockwise, optimize_blockwise -from dask.dataframe._compat import PANDAS_GT_150, PANDAS_GT_200 +from dask.dataframe._compat import PANDAS_GT_150, PANDAS_GT_200, PANDAS_GT_202 from dask.dataframe.io.parquet.core import get_engine from dask.dataframe.io.parquet.utils import _parse_pandas_metadata from dask.dataframe.optimize import optimize_dataframe_getitem @@ -4904,3 +4904,15 @@ def test_read_parquet_preserve_categorical_column_dtype(tmp_path): index=[0, 0], ) assert_eq(ddf, expected) + + +@PYARROW_MARK +@pytest.mark.skipif(not PANDAS_GT_200, reason="Requires pd.ArrowDtype") +def test_dtype_backend_categoricals(tmp_path): + df = pd.DataFrame({"a": pd.Series(["x", "y"], dtype="category"), "b": [1, 2]}) + outdir = tmp_path / "out.parquet" + df.to_parquet(outdir, engine="pyarrow") + ddf = dd.read_parquet(outdir, engine="pyarrow", dtype_backend="pyarrow") + pdf = pd.read_parquet(outdir, engine="pyarrow", dtype_backend="pyarrow") + # Set sort_results=False because of pandas bug up to 2.0.1 + assert_eq(ddf, pdf, sort_results=PANDAS_GT_202) diff --git a/dask/dataframe/utils.py b/dask/dataframe/utils.py index 3e5348bfe2d..2d96e112b09 100644 --- a/dask/dataframe/utils.py +++ b/dask/dataframe/utils.py @@ -261,7 +261,7 @@ def strip_unknown_categories(x, just_drop_unknown=False): return x -def clear_known_categories(x, cols=None, index=True): +def clear_known_categories(x, cols=None, index=True, dtype_backend=None): """Set categories to be unknown. Parameters @@ -273,7 +273,15 @@ def clear_known_categories(x, cols=None, index=True): index : bool, optional If True and x is a Series or DataFrame, set the clear known categories in the index as well. + dtype_backend : string, optional + If set to PyArrow, the categorical dtype is implemented as a PyArrow + dictionary """ + if dtype_backend == "pyarrow": + # Right now Categorical with PyArrow is implemented as dictionary and + # categorical accessor is not yet available + return x + if isinstance(x, (pd.Series, pd.DataFrame)): x = x.copy() if isinstance(x, pd.DataFrame):