Skip to content

Commit

Permalink
Test with stream release machinery from #349
Browse files Browse the repository at this point in the history
  • Loading branch information
brendan-ward committed Apr 8, 2024
1 parent fb6ee64 commit 94adbc0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 35 deletions.
63 changes: 29 additions & 34 deletions pyogrio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ from libc.string cimport strlen
from libc.math cimport isnan

cimport cython
from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer
import numpy as np

from pyogrio._ogr cimport *
Expand Down Expand Up @@ -84,6 +85,25 @@ DTYPE_OGR_FIELD_TYPES = {
}



cdef void pycapsule_array_stream_deleter(object stream_capsule) noexcept:
cdef ArrowArrayStream* stream = <ArrowArrayStream*>PyCapsule_GetPointer(
stream_capsule, 'arrow_array_stream'
)
# Do not invoke the deleter on a used/moved capsule
if stream.release != NULL:
stream.release(stream)

free(stream)


cdef object alloc_c_stream(ArrowArrayStream** c_stream):
c_stream[0] = <ArrowArrayStream*> malloc(sizeof(ArrowArrayStream))
# Ensure the capsule destructor doesn't call a random release pointer
c_stream[0].release = NULL
return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_array_stream_deleter)


cdef int start_transaction(OGRDataSourceH ogr_dataset, int force) except 1:
cdef int err = GDALDatasetStartTransaction(ogr_dataset, force)
if err == OGRERR_FAILURE:
Expand Down Expand Up @@ -1112,8 +1132,6 @@ def ogr_read(
cdef int feature_count = 0
cdef double xmin, ymin, xmax, ymax

print(f"ogr_open_arrow: {path}")

path_b = path.encode('utf-8')
path_c = path_b

Expand Down Expand Up @@ -1291,8 +1309,7 @@ def ogr_open_arrow(
cdef char **fields_c = NULL
cdef const char *field_c = NULL
cdef char **options = NULL
cdef ArrowArrayStream stream
cdef ArrowSchema schema
cdef ArrowArrayStream *stream

# this block prevents compilation of remaining code in this function, which
# fails for GDAL < 3.6.0 because OGR_L_GetArrowStream is undefined
Expand Down Expand Up @@ -1330,28 +1347,22 @@ def ogr_open_arrow(

reader = None
try:
print("set dataset options")
dataset_options = dict_to_options(dataset_kwargs)
print("open ogr dataset")
ogr_dataset = ogr_open(path_c, 0, dataset_options)

print("get ogr layer")
if sql is None:
if layer is None:
layer = get_default_layer(ogr_dataset)
ogr_layer = get_ogr_layer(ogr_dataset, layer)
else:
ogr_layer = execute_sql(ogr_dataset, sql, sql_dialect)

print("get crs")
crs = get_crs(ogr_layer)

# Encoding is derived from the user, from the dataset capabilities / type,
# or from the system locale
print("detect encoding")
encoding = encoding or detect_encoding(ogr_dataset, ogr_layer)

print("get fields")
fields = get_fields(ogr_layer, encoding, use_arrow=True)

ignored_fields = []
Expand All @@ -1361,13 +1372,10 @@ def ogr_open_arrow(
if not read_geometry:
ignored_fields.append("OGR_GEOMETRY")

print("get geometry type")
geometry_type = get_geometry_type(ogr_layer)

print("get geometry name")
geometry_name = get_string(OGR_L_GetGeometryColumn(ogr_layer))

print("get fid column")
fid_column = get_string(OGR_L_GetFIDColumn(ogr_layer))
# OGR_L_GetFIDColumn returns the column name if it is a custom column,
# or "" if not. For arrow, the default column name is "OGC_FID".
Expand All @@ -1376,24 +1384,17 @@ def ogr_open_arrow(

# Apply the attribute filter
if where is not None and where != "":
print("apply where filter")
apply_where_filter(ogr_layer, where)
print("done setting where filter")

# Apply the spatial filter
if bbox is not None:
print("apply bbox filter")
apply_bbox_filter(ogr_layer, bbox)
print("done setting bbox filter")

elif mask is not None:
print("apply mask filter")
apply_geometry_filter(ogr_layer, mask)
print("done setting mask filter")

# Limit to specified columns
if ignored_fields:
print("set ignored fields")
for field in ignored_fields:
field_b = field.encode("utf-8")
field_c = field_b
Expand All @@ -1402,11 +1403,9 @@ def ogr_open_arrow(
OGR_L_SetIgnoredFields(ogr_layer, <const char**>fields_c)

if not return_fids:
print("set no return fids")
options = CSLSetNameValue(options, "INCLUDE_FID", "NO")

if batch_size > 0:
print("set batch size")
batch_size_b = str(batch_size).encode('UTF-8')
batch_size_c = batch_size_b
options = CSLSetNameValue(
Expand All @@ -1416,35 +1415,31 @@ def ogr_open_arrow(
)

# Default to geoarrow metadata encoding (only used for GDAL >= 3.8.0)
print("set GEOMETRY_METADATA_ENCODING")
options = CSLSetNameValue(
options,
"GEOMETRY_METADATA_ENCODING",
"GEOARROW"
)

# make sure layer is read from beginning
print("reset reading")
OGR_L_ResetReading(ogr_layer)

print("get arrow stream")
if not OGR_L_GetArrowStream(ogr_layer, &stream, options):
# allocate the stream struct and wrap in capsule to ensure clean-up on error
capsule = alloc_c_stream(&stream)

if not OGR_L_GetArrowStream(ogr_layer, stream, options):
raise RuntimeError("Failed to open ArrowArrayStream from Layer")

stream_ptr = <uintptr_t> &stream
stream_ptr = <uintptr_t> stream

if skip_features:
# only supported for GDAL >= 3.8.0; have to do this after getting
# the Arrow stream
print("set skip features")
OGR_L_SetNextByIndex(ogr_layer, skip_features)

# stream has to be consumed before the Dataset is closed
print("get reader")
import pyarrow as pa
reader = pa.RecordBatchStreamReader._import_from_c(stream_ptr)
print("got reader")

meta = {
'crs': crs,
'encoding': encoding,
Expand All @@ -1457,12 +1452,11 @@ def ogr_open_arrow(
yield meta, reader

finally:
print("in ogr_open_arrow finally block")
if reader is not None:
print("closing reader")
# Mark reader as closed to prevent reading batches
reader.close()
print("closed reader")

# `stream` will be freed through `capsule` destructor

if options != NULL:
CSLDestroy(options)
Expand All @@ -1484,6 +1478,7 @@ def ogr_open_arrow(

print("done with ogr_open_arrow finally block")


def ogr_read_bounds(
str path,
object layer=None,
Expand Down
1 change: 1 addition & 0 deletions pyogrio/_ogr.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ cdef extern from "arrow_bridge.h":

struct ArrowArrayStream:
int (*get_schema)(ArrowArrayStream* stream, ArrowSchema* out)
void (*release)(ArrowArrayStream*) noexcept nogil


cdef extern from "ogr_api.h":
Expand Down
2 changes: 1 addition & 1 deletion pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_read_where_in(naturalearth_lowres_all_ext, use_arrow):

def test_read_where_range(naturalearth_lowres_all_ext, use_arrow):
if naturalearth_lowres_all_ext.suffix not in {".geojsonl", ".gpkg"}:
pytest.skip("only test gpkg")
pytest.skip("only test geojsonl or gpkg")

# should return items within range
df = read_dataframe(
Expand Down

0 comments on commit 94adbc0

Please sign in to comment.