Skip to content

Commit

Permalink
Convert automatically to arrow strings (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Jul 25, 2024
1 parent 6e6bc41 commit 1215a51
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
51 changes: 50 additions & 1 deletion dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pandas as pd
import pyarrow
from dask.base import tokenize
from dask.dataframe._compat import PANDAS_GE_220
from dask.dataframe.utils import pyarrow_strings_enabled
from google.api_core import client_info as rest_client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as grpc_client_info
Expand Down Expand Up @@ -95,6 +97,7 @@ def bigquery_read(
read_kwargs: dict,
arrow_options: dict,
credentials: dict = None,
convert_string: bool = False,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Expand All @@ -114,7 +117,15 @@ def bigquery_read(
BigQuery Storage API Stream "name"
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
convert_string: bool
Whether to convert strings directly to arrow strings in the output DataFrame
"""
arrow_options = arrow_options.copy()
if convert_string:
types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get))
if types_mapper is not None:
arrow_options["types_mapper"] = types_mapper

with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
Expand All @@ -130,6 +141,37 @@ def bigquery_read(
return pd.concat(shards)


def _get_types_mapper(user_mapper):
type_mappers = []

# always use the user-defined mapper first, if available
if user_mapper is not None:
type_mappers.append(user_mapper)

type_mappers.append({pyarrow.string(): pd.StringDtype("pyarrow")}.get)
if PANDAS_GE_220:
type_mappers.append({pyarrow.large_string(): pd.StringDtype("pyarrow")}.get)
type_mappers.append({pyarrow.date32(): pd.ArrowDtype(pyarrow.date32())}.get)
type_mappers.append({pyarrow.date64(): pd.ArrowDtype(pyarrow.date64())}.get)

def _convert_decimal_type(type):
if pyarrow.types.is_decimal(type):
return pd.ArrowDtype(type)
return None

type_mappers.append(_convert_decimal_type)

def default_types_mapper(pyarrow_dtype):
"""Try all type mappers in order, starting from the user type mapper."""
for type_converter in type_mappers:
converted_type = type_converter(pyarrow_dtype)
if converted_type is not None:
return converted_type

if len(type_mappers) > 0:
return default_types_mapper


def read_gbq(
project_id: str,
dataset_id: str,
Expand Down Expand Up @@ -196,13 +238,19 @@ def make_create_read_session_request():
),
)

arrow_options_meta = arrow_options.copy()
if pyarrow_strings_enabled():
types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get))
if types_mapper is not None:
arrow_options_meta["types_mapper"] = types_mapper

# Create a read session in order to detect the schema.
# Read sessions are light weight and will be auto-deleted after 24 hours.
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)
meta = schema.empty_table().to_pandas(**arrow_options)
meta = schema.empty_table().to_pandas(**arrow_options_meta)

return dd.from_map(
partial(
Expand All @@ -212,6 +260,7 @@ def make_create_read_session_request():
read_kwargs=read_kwargs,
arrow_options=arrow_options,
credentials=credentials,
convert_string=pyarrow_strings_enabled(),
),
[stream.name for stream in session.streams],
meta=meta,
Expand Down
Empty file added dask_bigquery/tests/__init__.py
Empty file.
32 changes: 27 additions & 5 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import uuid
from datetime import datetime, timedelta, timezone

import dask
import dask.dataframe as dd
import gcsfs
import google.auth
import pandas as pd
import pyarrow as pa
import pytest
from dask.dataframe.utils import assert_eq
from dask.dataframe.utils import assert_eq, pyarrow_strings_enabled
from distributed.utils_test import cleanup # noqa: F401
from distributed.utils_test import client # noqa: F401
from distributed.utils_test import cluster_fixture # noqa: F401
Expand Down Expand Up @@ -380,11 +381,32 @@ def test_arrow_options(table):
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
arrow_options={
"types_mapper": {pa.string(): pd.StringDtype(storage="pyarrow")}.get
},
arrow_options={"types_mapper": {pa.int64(): pd.Float32Dtype()}.get},
)
assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow")
assert ddf.dtypes["number"] == pd.Float32Dtype()


@pytest.mark.parametrize("convert_string", [True, False, None])
def test_convert_string(table, convert_string, df):
project_id, dataset_id, table_id = table
config = {}
if convert_string is not None:
config = {"dataframe.convert-string": convert_string}
with dask.config.set(config):
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
)
# Roundtrip through `dd.from_pandas` to check consistent
# behavior with Dask DataFrame
result = dd.from_pandas(df, npartitions=1)
if convert_string is True or (convert_string is None and pyarrow_strings_enabled()):
assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow")
else:
assert ddf.dtypes["name"] == object

assert assert_eq(ddf.set_index("idx"), result.set_index("idx"))


@pytest.mark.skipif(sys.platform == "darwin", reason="Segfaults on macOS")
Expand Down

0 comments on commit 1215a51

Please sign in to comment.