diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py index 09ed804..fb435f8 100644 --- a/dask_bigquery/core.py +++ b/dask_bigquery/core.py @@ -5,7 +5,6 @@ from contextlib import contextmanager from functools import partial -import dask.dataframe as dd import pandas as pd import pyarrow from dask.base import tokenize @@ -48,13 +47,12 @@ def _stream_to_dfs(bqs_client, stream_name, schema, timeout): ] -def bigquery_arrow_read( - *, +def bigquery_read_partition_field( make_create_read_session_request: callable, - partition_field: str = None, project_id: str, - stream_name: str = None, timeout: int, + partition_field: str, + row_filter: str, ) -> pd.DataFrame: """Read a single batch of rows via BQ Storage API, in Arrow binary format. Args: @@ -64,6 +62,41 @@ def bigquery_arrow_read( partition_field: BigQuery field for partitions, to be used as Dask index col for divisions NOTE: Please set if specifying `row_restriction` filters in TableReadOptions. + Adapted from + https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py. + """ + with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client): + session = bqs_client.create_read_session( + make_create_read_session_request(row_filter=row_filter) + ) + schema = pyarrow.ipc.read_schema( + pyarrow.py_buffer(session.arrow_schema.serialized_schema) + ) + + shards = [ + df + for stream in session.streams + for df in _stream_to_dfs(bqs_client, stream.name, schema, timeout=timeout) + ] + # NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list + if len(shards) == 0: + shards = [schema.empty_table().to_pandas()] + shards = [shard.set_index(partition_field, drop=True) for shard in shards] + + return pd.concat(shards) + + +def bigquery_read( + make_create_read_session_request: callable, + project_id: str, + timeout: int, + stream_name: str, +) -> pd.DataFrame: + """Read a single batch of rows via BQ Storage API, in Arrow binary format. + Args: + project_id: BigQuery project + create_read_session_request: kwargs to pass to `bqs_client.create_read_session` + as `request` stream_name: BigQuery Storage API Stream "name". NOTE: Please set if reading from Storage API without any `row_restriction`. https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream @@ -76,35 +109,10 @@ def bigquery_arrow_read( schema = pyarrow.ipc.read_schema( pyarrow.py_buffer(session.arrow_schema.serialized_schema) ) - - if (partition_field is not None) and (stream_name is not None): - raise ValueError( - "The kwargs `partition_field` and `stream_name` are mutually exclusive." - ) - - elif partition_field is not None: - shards = [ - df - for stream in session.streams - for df in _stream_to_dfs( - bqs_client, stream.name, schema, timeout=timeout - ) - ] - # NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list - if len(shards) == 0: - shards = [schema.empty_table().to_pandas()] - shards = [shard.set_index(partition_field, drop=True) for shard in shards] - - elif stream_name is not None: - shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout) - # NOTE: BQ Storage API can return empty streams - if len(shards) == 0: - shards = [schema.empty_table().to_pandas()] - - else: - raise NotImplementedError( - "Please specify either `partition_field` or `stream_name`." - ) + shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout) + # NOTE: BQ Storage API can return empty streams + if len(shards) == 0: + shards = [schema.empty_table().to_pandas()] return pd.concat(shards) @@ -184,12 +192,24 @@ def make_create_read_session_request(row_filter=""): pyarrow.py_buffer(session.arrow_schema.serialized_schema) ) meta = schema.empty_table().to_pandas() - delayed_kwargs = {} + + label = "read-gbq-" + output_name = label + tokenize( + project_id, + dataset_id, + table_id, + partition_field, + partitions, + row_filter, + fields, + read_timeout, + ) if partition_field is not None: if row_filter: raise ValueError("Cannot pass both `partition_field` and `row_filter`") - delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True) + + meta = meta.set_index(partition_field, drop=True) if partitions is None: logging.info( @@ -203,47 +223,38 @@ def make_create_read_session_request(row_filter=""): # TODO generalize to ranges (as opposed to discrete values) partitions = sorted(partitions) - delayed_kwargs["divisions"] = (*partitions, partitions[-1]) row_filters = [ f'{partition_field} = "{partition_value}"' for partition_value in partitions ] - delayed_dfs = [ - bigquery_arrow_read( - make_create_read_session_request=partial( - make_create_read_session_request, row_filter=row_filter - ), - partition_field=partition_field, - project_id=project_id, - timeout=read_timeout, - ) - for row_filter in row_filters - ] - return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs) - else: - label = "read-gbq-" - output_name = label + tokenize( - project_id, - dataset_id, - table_id, - partition_field, - partitions, - row_filter, - fields, - read_timeout, + layer = DataFrameIOLayer( + output_name, + meta.columns, + row_filters, + partial( + bigquery_read_partition_field, + make_create_read_session_request, + project_id, + read_timeout, + partition_field, + ), + label=label, ) - # Create Blockwise layer + divisions = (*partitions, partitions[-1]) + else: layer = DataFrameIOLayer( output_name, meta.columns, [stream.name for stream in session.streams], - bigquery_arrow_read( - make_create_read_session_request=make_create_read_session_request, - project_id=project_id, - timeout=read_timeout, + partial( + bigquery_read, + make_create_read_session_request, + project_id, + read_timeout, ), label=label, ) divisions = tuple([None] * (len(session.streams) + 1)) - graph = HighLevelGraph({output_name: layer}, {output_name: set()}) - return new_dd_object(graph, output_name, meta, divisions) + + graph = HighLevelGraph({output_name: layer}, {output_name: set()}) + return new_dd_object(graph, output_name, meta, divisions)