diff --git a/dask/dataframe/__init__.py b/dask/dataframe/__init__.py index b247c4b763e..0f64a152d5b 100644 --- a/dask/dataframe/__init__.py +++ b/dask/dataframe/__init__.py @@ -14,67 +14,82 @@ def _dask_expr_enabled() -> bool: if _dask_expr_enabled(): - from dask_expr import ( - DataFrame, - Index, - Series, - concat, - from_array, - from_dask_array, - from_dask_dataframe, - from_delayed, - from_dict, - from_graph, - from_map, - from_pandas, - get_dummies, - isna, - map_overlap, - map_partitions, - merge, - merge_asof, - pivot_table, - read_csv, - read_hdf, - read_json, - read_orc, - read_parquet, - read_sql, - read_sql_query, - read_sql_table, - read_table, - repartition, - to_bag, - to_csv, - to_datetime, - to_hdf, - to_json, - to_numeric, - to_orc, - to_parquet, - to_records, - to_sql, - to_timedelta, - ) + try: + from dask_expr import ( + DataFrame, + Index, + Series, + concat, + from_array, + from_dask_array, + from_dask_dataframe, + from_delayed, + from_dict, + from_graph, + from_map, + from_pandas, + get_dummies, + isna, + map_overlap, + map_partitions, + merge, + merge_asof, + pivot_table, + read_csv, + read_hdf, + read_json, + read_orc, + read_parquet, + read_sql, + read_sql_query, + read_sql_table, + read_table, + repartition, + to_bag, + to_csv, + to_datetime, + to_hdf, + to_json, + to_numeric, + to_orc, + to_parquet, + to_records, + to_sql, + to_timedelta, + ) + + import dask.dataframe._pyarrow_compat + from dask.base import compute + from dask.dataframe import backends, dispatch + from dask.dataframe.io import demo + from dask.dataframe.utils import assert_eq + + def raise_not_implemented_error(attr_name): + def inner_func(*args, **kwargs): + raise NotImplementedError( + f"Function {attr_name} is not implemented for dask-expr." + ) + + return inner_func - import dask.dataframe._pyarrow_compat - from dask.base import compute - from dask.dataframe import backends, dispatch - from dask.dataframe.io import demo - from dask.dataframe.utils import assert_eq + _Frame = raise_not_implemented_error("_Frame") + Aggregation = raise_not_implemented_error("Aggregation") + read_fwf = raise_not_implemented_error("read_fwf") + melt = raise_not_implemented_error("melt") - def raise_not_implemented_error(attr_name): - def inner_func(*args, **kwargs): - raise NotImplementedError( - f"Function {attr_name} is not implemented for dask-expr." - ) + # Due to the natural circular imports caused from dask-expr + # wanting to import things from dask.dataframe, this module's init + # can be run multiple times as it walks code trying to import + # dask-expr while dask-expr is also trying to import from dask.dataframe + # Each time this happens and hits a circular import, we can reload + # dask.dataframe to update itself until dask-expr is fully initialized. + # TODO: This can go away when dask-expr is merged into dask + except ImportError: + import importlib - return inner_func + import dask.dataframe as dd - _Frame = raise_not_implemented_error("_Frame") - Aggregation = raise_not_implemented_error("Aggregation") - read_fwf = raise_not_implemented_error("read_fwf") - melt = raise_not_implemented_error("melt") + dd = importlib.reload(dd) else: try: diff --git a/dask/dataframe/io/tests/test_parquet.py b/dask/dataframe/io/tests/test_parquet.py index 6d423b1bad3..cfd3b91ae75 100644 --- a/dask/dataframe/io/tests/test_parquet.py +++ b/dask/dataframe/io/tests/test_parquet.py @@ -2267,7 +2267,6 @@ def test_writing_parquet_with_unknown_kwargs(tmpdir, engine): ddf.to_parquet(fn, engine=engine, unknown_key="unknown_value") -@pytest.mark.skipif(DASK_EXPR_ENABLED, reason="circular import") def test_to_parquet_with_get(tmpdir, engine): from dask.multiprocessing import get as mp_get