diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..229c576 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,16 @@ +name: Linting + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + checks: + name: "pre-commit hooks" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: pre-commit/action@v2.0.0 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ee7164d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + language_version: python3 + exclude: versioneer.py +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + language_version: python3 +- repo: https://github.com/pycqa/isort + rev: 5.8.0 + hooks: + - id: isort + language_version: python3 \ No newline at end of file diff --git a/dask_bigquery/__init__.py b/dask_bigquery/__init__.py new file mode 100644 index 0000000..80a7c9c --- /dev/null +++ b/dask_bigquery/__init__.py @@ -0,0 +1 @@ +from .core import read_gbq diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py new file mode 100644 index 0000000..244ed02 --- /dev/null +++ b/dask_bigquery/core.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable +from contextlib import contextmanager +from functools import partial + +import dask +import dask.dataframe as dd +import pandas as pd +import pyarrow +from google.cloud import bigquery, bigquery_storage + + +@contextmanager +def bigquery_client(project_id=None, with_storage_api=False): + """This context manager is a temporary solution until there is an + upstream solution to handle this. + See googleapis/google-cloud-python#9457 + and googleapis/gapic-generator-python#575 for reference. + """ + + bq_storage_client = None + bq_client = bigquery.Client(project_id) + try: + if with_storage_api: + bq_storage_client = bigquery_storage.BigQueryReadClient( + credentials=bq_client._credentials + ) + yield bq_client, bq_storage_client + else: + yield bq_client + finally: + bq_client.close() + + +def _stream_to_dfs(bqs_client, stream_name, schema, timeout): + """Given a Storage API client and a stream name, yield all dataframes.""" + return [ + pyarrow.ipc.read_record_batch( + pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch), + schema, + ).to_pandas() + for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout) + ] + + +@dask.delayed +def bigquery_arrow_read( + *, + make_create_read_session_request: callable, + partition_field: str = None, + project_id: str, + stream_name: str = None, + timeout: int, +) -> pd.DataFrame: + """Read a single batch of rows via BQ Storage API, in Arrow binary format. + Args: + project_id: BigQuery project + create_read_session_request: kwargs to pass to `bqs_client.create_read_session` + as `request` + partition_field: BigQuery field for partitions, to be used as Dask index col for + divisions + NOTE: Please set if specifying `row_restriction` filters in TableReadOptions. + stream_name: BigQuery Storage API Stream "name". + NOTE: Please set if reading from Storage API without any `row_restriction`. + https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream + NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive. + Adapted from + https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py. + """ + with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client): + session = bqs_client.create_read_session(make_create_read_session_request()) + schema = pyarrow.ipc.read_schema( + pyarrow.py_buffer(session.arrow_schema.serialized_schema) + ) + + if (partition_field is not None) and (stream_name is not None): + raise ValueError( + "The kwargs `partition_field` and `stream_name` are mutually exclusive." + ) + + elif partition_field is not None: + shards = [ + df + for stream in session.streams + for df in _stream_to_dfs( + bqs_client, stream.name, schema, timeout=timeout + ) + ] + # NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list + if len(shards) == 0: + shards = [schema.empty_table().to_pandas()] + shards = [shard.set_index(partition_field, drop=True) for shard in shards] + + elif stream_name is not None: + shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout) + # NOTE: BQ Storage API can return empty streams + if len(shards) == 0: + shards = [schema.empty_table().to_pandas()] + + else: + raise NotImplementedError( + "Please specify either `partition_field` or `stream_name`." + ) + + return pd.concat(shards) + + +def read_gbq( + project_id: str, + dataset_id: str, + table_id: str, + partition_field: str = None, + partitions: Iterable[str] = None, + row_filter="", + fields: list[str] = (), + read_timeout: int = 3600, +): + """Read table as dask dataframe using BigQuery Storage API via Arrow format. + If `partition_field` and `partitions` are specified, then the resulting dask dataframe + will be partitioned along the same boundaries. Otherwise, partitions will be approximately + balanced according to BigQuery stream allocation logic. + If `partition_field` is specified but not included in `fields` (either implicitly by requesting + all fields, or explicitly by inclusion in the list `fields`), then it will still be included + in the query in order to have it available for dask dataframe indexing. + Args: + project_id: BigQuery project + dataset_id: BigQuery dataset within project + table_id: BigQuery table within dataset + partition_field: to specify filters of form "WHERE {partition_field} = ..." + partitions: all values to select of `partition_field` + fields: names of the fields (columns) to select (default None to "SELECT *") + read_timeout: # of seconds an individual read request has before timing out + Returns: + dask dataframe + See https://github.com/dask/dask/issues/3121 for additional context. + """ + if (partition_field is None) and (partitions is not None): + raise ValueError("Specified `partitions` without `partition_field`.") + + # If `partition_field` is not part of the `fields` filter, fetch it anyway to be able + # to set it as dask dataframe index. We want this to be able to have consistent: + # BQ partitioning + dask divisions + pandas index values + if (partition_field is not None) and fields and (partition_field not in fields): + fields = (partition_field, *fields) + + # These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with + # the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally + # annotating the tasks with a higher priority seems to help (but not fully solve) the issue at + # the expense of higher cluster memory usage. + with bigquery_client(project_id, with_storage_api=True) as ( + bq_client, + bqs_client, + ): + table_ref = bq_client.get_table(".".join((dataset_id, table_id))) + if table_ref.table_type == "VIEW": + raise TypeError("Table type VIEW not supported") + + # The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a + # generator func. + def make_create_read_session_request(row_filter=""): + return bigquery_storage.types.CreateReadSessionRequest( + max_stream_count=100, # 0 -> use as many streams as BQ Storage will provide + parent=f"projects/{project_id}", + read_session=bigquery_storage.types.ReadSession( + data_format=bigquery_storage.types.DataFormat.ARROW, + read_options=bigquery_storage.types.ReadSession.TableReadOptions( + row_restriction=row_filter, + selected_fields=fields, + ), + table=table_ref.to_bqstorage(), + ), + ) + + # Create a read session in order to detect the schema. + # Read sessions are light weight and will be auto-deleted after 24 hours. + session = bqs_client.create_read_session( + make_create_read_session_request(row_filter=row_filter) + ) + schema = pyarrow.ipc.read_schema( + pyarrow.py_buffer(session.arrow_schema.serialized_schema) + ) + meta = schema.empty_table().to_pandas() + delayed_kwargs = {} + + if partition_field is not None: + if row_filter: + raise ValueError("Cannot pass both `partition_field` and `row_filter`") + delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True) + + if partitions is None: + logging.info( + "Specified `partition_field` without `partitions`; reading full table." + ) + partitions = [ + p + for p in bq_client.list_partitions(f"{dataset_id}.{table_id}") + if p != "__NULL__" + ] + # TODO generalize to ranges (as opposed to discrete values) + + partitions = sorted(partitions) + delayed_kwargs["divisions"] = (*partitions, partitions[-1]) + row_filters = [ + f'{partition_field} = "{partition_value}"' + for partition_value in partitions + ] + delayed_dfs = [ + bigquery_arrow_read( + make_create_read_session_request=partial( + make_create_read_session_request, row_filter=row_filter + ), + partition_field=partition_field, + project_id=project_id, + timeout=read_timeout, + ) + for row_filter in row_filters + ] + else: + delayed_kwargs["meta"] = meta + delayed_dfs = [ + bigquery_arrow_read( + make_create_read_session_request=make_create_read_session_request, + project_id=project_id, + stream_name=stream.name, + timeout=read_timeout, + ) + for stream in session.streams + ] + + return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs) diff --git a/dask_bigquery/tests/test_core.py b/dask_bigquery/tests/test_core.py new file mode 100644 index 0000000..cfce96a --- /dev/null +++ b/dask_bigquery/tests/test_core.py @@ -0,0 +1,97 @@ +import random +import uuid + +import pandas as pd +import pytest +from dask.dataframe.utils import assert_eq +from distributed.utils_test import cluster_fixture # noqa: F401 +from distributed.utils_test import client, loop # noqa: F401 +from google.cloud import bigquery + +from dask_bigquery import read_gbq + +# These tests are run locally and assume the user is already athenticated. +# It also assumes that the user has created a project called dask-bigquery. + + +@pytest.fixture +def df(): + records = [ + { + "name": random.choice(["fred", "wilma", "barney", "betty"]), + "number": random.randint(0, 100), + "idx": i, + } + for i in range(10) + ] + + yield pd.DataFrame(records) + + +@pytest.fixture +def dataset(df): + "Push some data to BigQuery using pandas gbq" + dataset_id = uuid.uuid4().hex + project_id = "dask-bigquery" + # push data to gbq + pd.DataFrame.to_gbq( + df, + destination_table=f"{dataset_id}.table_test", + project_id=project_id, + chunksize=5, + if_exists="append", + ) + yield f"{project_id}.{dataset_id}.table_test" + + with bigquery.Client() as bq_client: + bq_client.delete_dataset( + dataset=f"{project_id}.{dataset_id}", + delete_contents=True, + ) + + +# test simple read +def test_read_gbq(df, dataset, client): + """Test simple read of data pushed to BigQuery using pandas-gbq""" + project_id, dataset_id, table_id = dataset.split(".") + + ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id) + + assert ddf.columns.tolist() == ["name", "number", "idx"] + assert len(ddf) == 10 + assert ddf.npartitions == 2 + + assert assert_eq(ddf.set_index("idx"), df.set_index("idx")) + + +# test partitioned data: this test requires a copy of the public dataset +# bigquery-public-data.covid19_public_forecasts.county_14d into a the +# project dask-bigquery + + +@pytest.mark.parametrize( + "fields", + ([], ["county_name"], ["county_name", "county_fips_code"]), + ids=["no_fields", "missing_partition_field", "fields"], +) +def test_read_gbq_partitioning(fields, client): + partitions = ["Teton", "Loudoun"] + ddf = read_gbq( + project_id="dask-bigquery", + dataset_id="covid19_public_forecasts", + table_id="county_14d", + partition_field="county_name", + partitions=partitions, + fields=fields, + ) + + assert len(ddf) # check it's not empty + loaded = set(ddf.columns) | {ddf.index.name} + + if fields: + assert loaded == set(fields) | {"county_name"} + else: # all columns loaded + assert loaded >= set(["county_name", "county_fips_code"]) + + assert ddf.npartitions == len(partitions) + assert list(ddf.divisions) == sorted(ddf.divisions) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..30f84f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +dask +distributed +google-cloud-bigquery +google-cloud-bigquery-storage +pandas +pandas-gbq +pyarrow diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..4921ff4 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,4 @@ +[flake8] +exclude = __init__.py +max-line-length = 120 +ignore = F811 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..48525a0 --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +from setuptools import setup + +with open("README.md", "r", encoding="utf-8") as f: + long_description = f.read() + +setup( + name="dask-bigquery", + version="0.0.1", + description="Dask + BigQuery intergration", + license="BSD", + packages=["dask_bigquery"], + long_description=long_description, + long_description_content_type="text/markdown", + python_requires=">=3.7", + install_requires=open("requirements.txt").read().strip().split("\n"), + extras_require={"test": ["pytest"]}, + include_package_data=True, + zip_safe=False, +)