Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: extract spatial partitioning information from partitioned Parquet dataset #28

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 51 additions & 1 deletion dask_geopandas/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from functools import partial
import json
from typing import TYPE_CHECKING

import pandas as pd

import geopandas
import shapely.geometry

import dask.dataframe as dd

Expand All @@ -17,6 +19,40 @@
import pyarrow


def _get_partition_bounds(part):
"""
Based on the part information gathered by dask, get the partition bounds
if available.

"""
from pyarrow.parquet import read_metadata

# read the metadata from the actual file (this is again file IO, but
# we can't rely on the schema metadata, because this is only the
# metadata of the first piece)
pq_metadata = None
if "piece" in part:
path = part["piece"][0]
if isinstance(path, str):
pq_metadata = read_metadata(path)
if pq_metadata is None:
return None

metadata_str = pq_metadata.metadata.get(b"geo", None)
if metadata_str is None:
return None

metadata = json.loads(metadata_str.decode("utf-8"))

# for now only check the primary column (TODO generalize this to follow
# the logic of geopandas to fallback to other geometry columns)
geometry = metadata["primary_column"]
bbox = metadata["columns"][geometry].get("bbox", None)
if bbox is None:
return None
return shapely.geometry.box(*bbox)


class GeoArrowEngine(ArrowEngine):
@classmethod
def read_metadata(cls, *args, **kwargs):
Expand All @@ -27,6 +63,12 @@ def read_metadata(cls, *args, **kwargs):
# for a default "geometry" column)
meta = geopandas.GeoDataFrame(meta)

# get spatial partitions if available
regions = geopandas.GeoSeries([_get_partition_bounds(part) for part in parts])
if regions.notna().all():
# a bit hacky, but this allows us to get this passed through
meta.attrs["spatial_partitions"] = regions

return (meta, stats, parts, index)

@classmethod
Expand Down Expand Up @@ -138,5 +180,13 @@ def write_partition(
to_parquet = partial(dd.to_parquet, engine=GeoArrowEngine)
to_parquet.__doc__ = dd.to_parquet.__doc__

read_parquet = partial(dd.read_parquet, engine=GeoArrowEngine)

def read_parquet(*args, **kwargs):
result = dd.read_parquet(*args, engine=GeoArrowEngine, **kwargs)
# check if spatial partitioning information was stored
spatial_partitions = result._meta.attrs.get("spatial_partitions", None)
result.spatial_partitions = spatial_partitions
return result


read_parquet.__doc__ = dd.read_parquet.__doc__
4 changes: 3 additions & 1 deletion tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def test_parquet_roundtrip(tmp_path):

# reading back gives identical GeoDataFrame
result = dask_geopandas.read_parquet(basedir)
assert ddf.npartitions == 4
assert result.npartitions == 4
assert_geodataframe_equal(result.compute(), df)
# reading back also populates the spatial partitioning property
assert result.spatial_partitions is not None

# the written dataset is also readable by plain geopandas
result_gpd = geopandas.read_parquet(basedir)
Expand Down