diff --git a/dask/dataframe/io/parquet/arrow.py b/dask/dataframe/io/parquet/arrow.py index d87bae2928e..12256e97f05 100644 --- a/dask/dataframe/io/parquet/arrow.py +++ b/dask/dataframe/io/parquet/arrow.py @@ -72,6 +72,7 @@ def _write_partitioned( pandas_to_arrow_table, preserve_index, index_cols=(), + return_metadata=True, **kwargs, ): """Write table to a partitioned dataset with pyarrow. @@ -119,8 +120,14 @@ def _write_partitioned( fs.mkdirs(prefix, exist_ok=True) full_path = fs.sep.join([prefix, filename]) with fs.open(full_path, "wb") as f: - pq.write_table(subtable, f, metadata_collector=md_list, **kwargs) - md_list[-1].set_file_path(fs.sep.join([subdir, filename])) + pq.write_table( + subtable, + f, + metadata_collector=md_list if return_metadata else None, + **kwargs, + ) + if return_metadata: + md_list[-1].set_file_path(fs.sep.join([subdir, filename])) return md_list @@ -680,6 +687,7 @@ def write_partition( preserve_index, index_cols=index_cols, compression=compression, + return_metadata=return_metadata, **kwargs, ) if md_list: @@ -693,7 +701,7 @@ def write_partition( t, fil, compression=compression, - metadata_collector=md_list, + metadata_collector=md_list if return_metadata else None, **kwargs, ) if md_list: diff --git a/dask/dataframe/io/parquet/core.py b/dask/dataframe/io/parquet/core.py index 04837494e12..a82245b4790 100644 --- a/dask/dataframe/io/parquet/core.py +++ b/dask/dataframe/io/parquet/core.py @@ -808,10 +808,10 @@ def to_parquet( # Collect metadata and write _metadata. # TODO: Use tree-reduction layer (when available) - meta_name = "metadata-" + data_write._name if write_metadata_file: + final_name = "metadata-" + data_write._name dsk = { - (meta_name, 0): ( + (final_name, 0): ( apply, engine.write_metadata, [ @@ -824,16 +824,22 @@ def to_parquet( ) } else: - dsk = {(meta_name, 0): (lambda x: None, data_write.__dask_keys__())} + # NOTE: We still define a single task to tie everything together + # when we are not writing a _metadata file. We do not want to + # return `data_write` (or a `data_write.to_bag()`), because calling + # `compute()` on a multi-partition collection requires the overhead + # of trying to concatenate results on the client. + final_name = "store-" + data_write._name + dsk = {(final_name, 0): (lambda x: None, data_write.__dask_keys__())} # Convert data_write + dsk to computable collection - graph = HighLevelGraph.from_collections(meta_name, dsk, dependencies=(data_write,)) + graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=(data_write,)) if compute: return compute_as_if_collection( - Scalar, graph, [(meta_name, 0)], **compute_kwargs + Scalar, graph, [(final_name, 0)], **compute_kwargs ) else: - return Scalar(graph, meta_name, "") + return Scalar(graph, final_name, "") def create_metadata_file(