Skip to content

Commit

Permalink
Slight refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Aug 20, 2021
1 parent de93e88 commit 3070ae3
Showing 1 changed file with 78 additions and 67 deletions.
145 changes: 78 additions & 67 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import contextmanager
from functools import partial

import dask.dataframe as dd
import pandas as pd
import pyarrow
from dask.base import tokenize
Expand Down Expand Up @@ -48,13 +47,12 @@ def _stream_to_dfs(bqs_client, stream_name, schema, timeout):
]


def bigquery_arrow_read(
*,
def bigquery_read_partition_field(
make_create_read_session_request: callable,
partition_field: str = None,
project_id: str,
stream_name: str = None,
timeout: int,
partition_field: str,
row_filter: str,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Args:
Expand All @@ -64,6 +62,41 @@ def bigquery_arrow_read(
partition_field: BigQuery field for partitions, to be used as Dask index col for
divisions
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions.
Adapted from
https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py.
"""
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client):
session = bqs_client.create_read_session(
make_create_read_session_request(row_filter=row_filter)
)
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)

shards = [
df
for stream in session.streams
for df in _stream_to_dfs(bqs_client, stream.name, schema, timeout=timeout)
]
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
shards = [shard.set_index(partition_field, drop=True) for shard in shards]

return pd.concat(shards)


def bigquery_read(
make_create_read_session_request: callable,
project_id: str,
timeout: int,
stream_name: str,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Args:
project_id: BigQuery project
create_read_session_request: kwargs to pass to `bqs_client.create_read_session`
as `request`
stream_name: BigQuery Storage API Stream "name".
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
Expand All @@ -76,35 +109,10 @@ def bigquery_arrow_read(
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)

if (partition_field is not None) and (stream_name is not None):
raise ValueError(
"The kwargs `partition_field` and `stream_name` are mutually exclusive."
)

elif partition_field is not None:
shards = [
df
for stream in session.streams
for df in _stream_to_dfs(
bqs_client, stream.name, schema, timeout=timeout
)
]
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
shards = [shard.set_index(partition_field, drop=True) for shard in shards]

elif stream_name is not None:
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout)
# NOTE: BQ Storage API can return empty streams
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]

else:
raise NotImplementedError(
"Please specify either `partition_field` or `stream_name`."
)
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout)
# NOTE: BQ Storage API can return empty streams
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]

return pd.concat(shards)

Expand Down Expand Up @@ -184,12 +192,24 @@ def make_create_read_session_request(row_filter=""):
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)
meta = schema.empty_table().to_pandas()
delayed_kwargs = {}

label = "read-gbq-"
output_name = label + tokenize(
project_id,
dataset_id,
table_id,
partition_field,
partitions,
row_filter,
fields,
read_timeout,
)

if partition_field is not None:
if row_filter:
raise ValueError("Cannot pass both `partition_field` and `row_filter`")
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True)

meta = meta.set_index(partition_field, drop=True)

if partitions is None:
logging.info(
Expand All @@ -203,47 +223,38 @@ def make_create_read_session_request(row_filter=""):
# TODO generalize to ranges (as opposed to discrete values)

partitions = sorted(partitions)
delayed_kwargs["divisions"] = (*partitions, partitions[-1])
row_filters = [
f'{partition_field} = "{partition_value}"'
for partition_value in partitions
]
delayed_dfs = [
bigquery_arrow_read(
make_create_read_session_request=partial(
make_create_read_session_request, row_filter=row_filter
),
partition_field=partition_field,
project_id=project_id,
timeout=read_timeout,
)
for row_filter in row_filters
]
return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs)
else:
label = "read-gbq-"
output_name = label + tokenize(
project_id,
dataset_id,
table_id,
partition_field,
partitions,
row_filter,
fields,
read_timeout,
layer = DataFrameIOLayer(
output_name,
meta.columns,
row_filters,
partial(
bigquery_read_partition_field,
make_create_read_session_request,
project_id,
read_timeout,
partition_field,
),
label=label,
)
# Create Blockwise layer
divisions = (*partitions, partitions[-1])
else:
layer = DataFrameIOLayer(
output_name,
meta.columns,
[stream.name for stream in session.streams],
bigquery_arrow_read(
make_create_read_session_request=make_create_read_session_request,
project_id=project_id,
timeout=read_timeout,
partial(
bigquery_read,
make_create_read_session_request,
project_id,
read_timeout,
),
label=label,
)
divisions = tuple([None] * (len(session.streams) + 1))
graph = HighLevelGraph({output_name: layer}, {output_name: set()})
return new_dd_object(graph, output_name, meta, divisions)

graph = HighLevelGraph({output_name: layer}, {output_name: set()})
return new_dd_object(graph, output_name, meta, divisions)

0 comments on commit 3070ae3

Please sign in to comment.