Skip to content

Commit

Permalink
Fix ArrowEngine bug in use of clear_known_categories (#6887)
Browse files Browse the repository at this point in the history
* fix bug calling clear_known_categories

* handle index in general - not just unnamed

* remove unused NONE_LABEL

* no need to target index explicitly
  • Loading branch information
rjzamora committed Nov 25, 2020
1 parent 25bdd20 commit 2bfc965
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
8 changes: 6 additions & 2 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -23,7 +23,6 @@
preserve_ind_supported = pa.__version__ >= LooseVersion("0.15.0")
schema_field_supported = pa.__version__ >= LooseVersion("0.15.0")


#
# Private Helper Functions
#
Expand Down Expand Up @@ -314,7 +313,12 @@ def _generate_dd_meta(schema, index, categories, partition_info):

index_cols = index or ()
meta = _meta_from_dtypes(all_columns, dtypes, index_cols, column_index_names)
meta = clear_known_categories(meta, cols=categories)
if categories:
# Make sure all categories are set to "unknown".
# Cannot include index names in the `cols` argument.
meta = clear_known_categories(
meta, cols=[c for c in categories if c not in meta.index.names]
)

if partition_obj:
for partition in partition_obj:
Expand Down
21 changes: 21 additions & 0 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -914,6 +914,27 @@ def test_categories(tmpdir, engine):
ddf2 = dd.read_parquet(fn, categories=["foo"], engine=engine)


def test_categories_unnamed_index(tmpdir, engine):
# Check that we can handle an unnamed categorical index
# https://github.com/dask/dask/issues/6885

if engine == "pyarrow" and pa.__version__ < LooseVersion("0.15.0"):
pytest.skip("PyArrow>=0.15 Required.")

tmpdir = str(tmpdir)

df = pd.DataFrame(
data={"A": [1, 2, 3], "B": ["a", "a", "b"]}, index=["x", "y", "y"]
)
ddf = dd.from_pandas(df, npartitions=1)
ddf = ddf.categorize(columns=["B"])

ddf.to_parquet(tmpdir, engine=engine)
ddf2 = dd.read_parquet(tmpdir, engine=engine)

assert_eq(ddf.index, ddf2.index, check_divisions=False)


def test_empty_partition(tmpdir, engine):
fn = str(tmpdir)
df = pd.DataFrame({"a": range(10), "b": range(10)})
Expand Down

0 comments on commit 2bfc965

Please sign in to comment.