Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize all pyarrow extension arrays efficiently #9740

Merged
merged 7 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions dask/dataframe/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

import numpy as np
import pandas as pd
from packaging.version import parse as parse_version

PANDAS_VERSION = parse_version(pd.__version__)
PANDAS_GT_104 = PANDAS_VERSION >= parse_version("1.0.4")
PANDAS_GT_110 = PANDAS_VERSION >= parse_version("1.1.0")
PANDAS_GT_120 = PANDAS_VERSION >= parse_version("1.2.0")
PANDAS_GT_121 = PANDAS_VERSION >= parse_version("1.2.1")
PANDAS_GT_130 = PANDAS_VERSION >= parse_version("1.3.0")
PANDAS_GT_131 = PANDAS_VERSION >= parse_version("1.3.1")
PANDAS_GT_133 = PANDAS_VERSION >= parse_version("1.3.3")
PANDAS_GT_140 = PANDAS_VERSION >= parse_version("1.4.0")
PANDAS_GT_150 = PANDAS_VERSION >= parse_version("1.5.0")
from packaging.version import Version

PANDAS_VERSION = Version(pd.__version__)
PANDAS_GT_104 = PANDAS_VERSION >= Version("1.0.4")
PANDAS_GT_110 = PANDAS_VERSION >= Version("1.1.0")
PANDAS_GT_120 = PANDAS_VERSION >= Version("1.2.0")
PANDAS_GT_121 = PANDAS_VERSION >= Version("1.2.1")
PANDAS_GT_130 = PANDAS_VERSION >= Version("1.3.0")
PANDAS_GT_131 = PANDAS_VERSION >= Version("1.3.1")
PANDAS_GT_133 = PANDAS_VERSION >= Version("1.3.3")
PANDAS_GT_140 = PANDAS_VERSION >= Version("1.4.0")
PANDAS_GT_150 = PANDAS_VERSION >= Version("1.5.0")
PANDAS_GT_200 = PANDAS_VERSION.major >= 2

import pandas.testing as tm

Expand Down
144 changes: 24 additions & 120 deletions dask/dataframe/_pyarrow_compat.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,47 @@
import copyreg
import math

import numpy as np
import pandas as pd

try:
import pyarrow as pa
except ImportError:
pa = None

from dask.dataframe._compat import PANDAS_GT_130, PANDAS_GT_150, PANDAS_GT_200

# Pickling of pyarrow arrays is effectively broken - pickling a slice of an
# array ends up pickling the entire backing array.
#
# See https://issues.apache.org/jira/browse/ARROW-10739
#
# This comes up when using pandas `string[pyarrow]` dtypes, which are backed by
# a `pyarrow.StringArray`. To fix this, we register a *global* override for
# pickling `pandas.core.arrays.ArrowStringArray` types. We do this at the
# pandas level rather than the pyarrow level for efficiency reasons (a pandas
# ArrowStringArray may contain many small pyarrow StringArray objects).
#
# This pickling implementation manually mucks with the backing buffers in a
# fairly efficient way:
#
# - The data buffer is never copied
# - The offsets buffer is only copied if the array is sliced with a start index
# (x[start:])
# - The mask buffer is never copied
#
# This implementation works with pickle protocol 5, allowing support for true
# zero-copy sends.
# pickling `ArrowStringArray` or `ArrowExtensionArray` types (where available).
# We do this at the pandas level rather than the pyarrow level for efficiency reasons
# (a pandas ArrowStringArray may contain many small pyarrow StringArray objects).
#
# XXX: Once pyarrow (or pandas) has fixed this bug, we should skip registering
# with copyreg for versions that lack this issue.


def pyarrow_stringarray_to_parts(array):
"""Decompose a ``pyarrow.StringArray`` into a tuple of components.

The resulting tuple can be passed to
``pyarrow_stringarray_from_parts(*components)`` to reconstruct the
``pyarrow.StringArray``.
"""
# Access the backing buffers.
#
# - mask: None, or a bitmask of length ceil(nitems / 8). 0 bits mark NULL
# elements, only present if NULL data is present, commonly None.
# - offsets: A uint32 array of nitems + 1 items marking the start/stop
# indices for the individual elements in `data`
# - data: All the utf8 string data concatenated together
#
# The structure of these buffers comes from the arrow format, documented at
# https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout.
# In particular, this is a `StringArray` (4 byte offsets), rather than a
# `LargeStringArray` (8 byte offsets).
assert pa.types.is_string(array.type)

mask, offsets, data = array.buffers()
nitems = len(array)

if not array.offset:
# No leading offset, only need to slice any unnecessary data from the
# backing buffers
offsets = offsets[: 4 * (nitems + 1)]
data_stop = int.from_bytes(offsets[-4:], "little")
data = data[:data_stop]
if mask is None:
return nitems, offsets, data
else:
mask = mask[: math.ceil(nitems / 8)]
return nitems, offsets, data, mask

# There is a leading offset. This complicates things a bit.
offsets_start = array.offset * 4
offsets_stop = offsets_start + (nitems + 1) * 4
data_start = int.from_bytes(offsets[offsets_start : offsets_start + 4], "little")
data_stop = int.from_bytes(offsets[offsets_stop - 4 : offsets_stop], "little")
data = data[data_start:data_stop]

if mask is None:
npad = 0
else:
# Since the mask is a bitmask, it can only represent even units of 8
# elements. To avoid shifting any bits, we pad the array with up to 7
# elements so the mask array can always be serialized zero copy.
npad = array.offset % 8
mask_start = array.offset // 8
mask_stop = math.ceil((array.offset + nitems) / 8)
mask = mask[mask_start:mask_stop]

# Subtract the offset of the starting element from every used offset in the
# offsets array, ensuring the first element in the serialized `offsets`
# array is always 0.
offsets_array = np.frombuffer(offsets, dtype="i4")
offsets_array = (
offsets_array[array.offset : array.offset + nitems + 1]
- offsets_array[array.offset]
)
# Pad the new offsets by `npad` offsets of 0 (see the `mask` comment above). We wrap
# this in a `pyarrow.py_buffer`, since this type transparently supports pickle 5,
# avoiding an extra copy inside the pickler.
offsets = pa.py_buffer(
b"\x00" * (4 * npad) + offsets_array.data if npad else offsets_array.data
)

if mask is None:
return nitems, offsets, data
else:
return nitems, offsets, data, mask, npad


def pyarrow_stringarray_from_parts(nitems, data_offsets, data, mask=None, offset=0):
"""Reconstruct a ``pyarrow.StringArray`` from the parts returned by
``pyarrow_stringarray_to_parts``."""
return pa.StringArray.from_buffers(nitems, data_offsets, data, mask, offset=offset)
# The implementation here is based on https://github.com/pandas-dev/pandas/pull/49078
# which is included in pandas=2+. We can remove all this once Dask's minimum
# supported pandas version is at least 2.0.0.


def rebuild_arrowstringarray(*chunk_parts):
"""Rebuild a ``pandas.core.arrays.ArrowStringArray``"""
array = pa.chunked_array(
[pyarrow_stringarray_from_parts(*parts) for parts in chunk_parts],
type=pa.string(),
)
return pd.arrays.ArrowStringArray(array)
def rebuild_arrowextensionarray(type_, chunks):
array = pa.chunked_array(chunks)
return type_(array)


def reduce_arrowstringarray(x):
"""A pickle override for ``pandas.core.arrays.ArrowStringArray`` that avoids
serializing unnecessary data, while also avoiding/minimizing data copies"""
# Decompose each chunk in the backing ChunkedArray into their individual
# components for serialization. We filter out 0-length chunks, since they
# add no meaningful value to the chunked array.
chunks = tuple(
pyarrow_stringarray_to_parts(chunk)
for chunk in x._data.chunks
if len(chunk) > 0
)
return (rebuild_arrowstringarray, chunks)
def reduce_arrowextensionarray(x):
return (rebuild_arrowextensionarray, (type(x), x._data.combine_chunks()))


if hasattr(pd.arrays, "ArrowStringArray") and pa is not None:
copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowstringarray
# `pandas=2` includes efficient serialization of `pyarrow`-backed extension arrays.
# See https://github.com/pandas-dev/pandas/pull/49078 for details.
# We only need to backport efficient serialization for `pandas<2`.
if pa is not None and not PANDAS_GT_200:
if PANDAS_GT_150:
# Applies to all `pyarrow`-backed extension arrays (e.g. `string[pyarrow]`, `int64[pyarrow]`)
for type_ in [pd.arrays.ArrowExtensionArray, pd.arrays.ArrowStringArray]:
copyreg.dispatch_table[type_] = reduce_arrowextensionarray
Comment on lines +43 to +44
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When, available, we need to make sure to register copyreg entries for both pd.arrays.ArrowExtensionArray and pd.arrays.ArrowStringArray. This way, both pyarrow string implementation in pandas will pick up the serialization fixes here. I've added a test which makes sure we handle both pd.StringDtype("pyarrow") and pd.ArrowDtype(pa.string()) cases.

elif PANDAS_GT_130:
# Only `string[pyarrow]` is implemented, so just patch that
copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowextensionarray
Loading