Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial read_gbq implementation (WIP) #1

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Linting

on:
push:
branches: main
pull_request:
branches: main

jobs:
checks:
name: "pre-commit hooks"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: pre-commit/action@v2.0.0
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3
exclude: versioneer.py
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.8.0
hooks:
- id: isort
language_version: python3
1 change: 1 addition & 0 deletions dask_bigquery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .core import read_gbq
232 changes: 232 additions & 0 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

import logging
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial

import dask
import dask.dataframe as dd
import pandas as pd
import pyarrow
from google.cloud import bigquery, bigquery_storage


@contextmanager
def bigquery_client(project_id=None, with_storage_api=False):
"""This context manager is a temporary solution until there is an
upstream solution to handle this.
See googleapis/google-cloud-python#9457
and googleapis/gapic-generator-python#575 for reference.
"""

bq_storage_client = None
bq_client = bigquery.Client(project_id)
try:
if with_storage_api:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials
)
yield bq_client, bq_storage_client
else:
yield bq_client
finally:
bq_client.close()


def _stream_to_dfs(bqs_client, stream_name, schema, timeout):
"""Given a Storage API client and a stream name, yield all dataframes."""
return [
pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch),
schema,
).to_pandas()
for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout)
]


@dask.delayed
def _read_rows_arrow(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use delayed (see comment below) then this name will be the name that shows up in the task stream, progress bars, etc.. We may want to make it more clearly GBQ related, like bigquery_read

*,
make_create_read_session_request: callable,
partition_field: str = None,
project_id: str,
stream_name: str = None,
timeout: int,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Args:
project_id: BigQuery project
create_read_session_request: kwargs to pass to `bqs_client.create_read_session`
as `request`
partition_field: BigQuery field for partitions, to be used as Dask index col for
divisions
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions.
stream_name: BigQuery Storage API Stream "name".
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive.
Adapted from
https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py.
"""
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)

if (partition_field is not None) and (stream_name is not None):
raise ValueError(
"The kwargs `partition_field` and `stream_name` are mutually exclusive."
)

elif partition_field is not None:
shards = [
df
for stream in session.streams
for df in _stream_to_dfs(
bqs_client, stream.name, schema, timeout=timeout
)
]
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
shards = [shard.set_index(partition_field, drop=True) for shard in shards]

elif stream_name is not None:
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout)
# NOTE: BQ Storage API can return empty streams
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]

else:
raise NotImplementedError(
"Please specify either `partition_field` or `stream_name`."
)

return pd.concat(shards)


def read_gbq(
project_id: str,
dataset_id: str,
table_id: str,
partition_field: str = None,
partitions: Iterable[str] = None,
row_filter="",
fields: list[str] = (),
read_timeout: int = 3600,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
If `partition_field` and `partitions` are specified, then the resulting dask dataframe
will be partitioned along the same boundaries. Otherwise, partitions will be approximately
balanced according to BigQuery stream allocation logic.
If `partition_field` is specified but not included in `fields` (either implicitly by requesting
all fields, or explicitly by inclusion in the list `fields`), then it will still be included
in the query in order to have it available for dask dataframe indexing.
Args:
project_id: BigQuery project
dataset_id: BigQuery dataset within project
table_id: BigQuery table within dataset
partition_field: to specify filters of form "WHERE {partition_field} = ..."
partitions: all values to select of `partition_field`
fields: names of the fields (columns) to select (default None to "SELECT *")
read_timeout: # of seconds an individual read request has before timing out
Returns:
dask dataframe
See https://github.com/dask/dask/issues/3121 for additional context.
"""
if (partition_field is None) and (partitions is not None):
raise ValueError("Specified `partitions` without `partition_field`.")

# If `partition_field` is not part of the `fields` filter, fetch it anyway to be able
# to set it as dask dataframe index. We want this to be able to have consistent:
# BQ partitioning + dask divisions + pandas index values
if (partition_field is not None) and fields and (partition_field not in fields):
fields = (partition_field, *fields)

# These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this annotate is maybe a bad idea, would be nice to have @jrbourbeau or someone weigh in; note that we observed this behavior with now-fairly old dask and bigquery_storage/pyarrow versions so I have no idea if it's still relevant

# the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally
# annotating the tasks with a higher priority seems to help (but not fully solve) the issue at
# the expense of higher cluster memory usage.
with bigquery_client(project_id, with_storage_api=True) as (
bq_client,
bqs_client,
), dask.annotate(priority=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would definitely prefer to not have this annotation if possible. Data generation tasks should be *de-*prioritized if anything

table_ref = bq_client.get_table(".".join((dataset_id, table_id)))
if table_ref.table_type == "VIEW":
raise TypeError("Table type VIEW not supported")

# The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a
# generator func.
def make_create_read_session_request(row_filter=""):
return bigquery_storage.types.CreateReadSessionRequest(
max_stream_count=100, # 0 -> use as many streams as BQ Storage will provide
parent=f"projects/{project_id}",
read_session=bigquery_storage.types.ReadSession(
data_format=bigquery_storage.types.DataFormat.ARROW,
read_options=bigquery_storage.types.ReadSession.TableReadOptions(
row_restriction=row_filter,
selected_fields=fields,
),
table=table_ref.to_bqstorage(),
),
)

# Create a read session in order to detect the schema.
# Read sessions are light weight and will be auto-deleted after 24 hours.
session = bqs_client.create_read_session(
make_create_read_session_request(row_filter=row_filter)
)
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)
meta = schema.empty_table().to_pandas()
delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-")

if partition_field is not None:
if row_filter:
raise ValueError("Cannot pass both `partition_field` and `row_filter`")
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True)

if partitions is None:
logging.info(
"Specified `partition_field` without `partitions`; reading full table."
)
partitions = [
p
for p in bq_client.list_partitions(f"{dataset_id}.{table_id}")
if p != "__NULL__"
]
# TODO generalize to ranges (as opposed to discrete values)

partitions = sorted(partitions)
delayed_kwargs["divisions"] = (*partitions, partitions[-1])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnaul I noticed in the example I run, that this line causes to have the last partition to contain only 1 element, but that element could have fit into the previous to last partition. What is the reason you separate the last partition?

Copy link
Contributor

@bnaul bnaul Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why it's not working correctly for you but the idea is that you need n+1 divisions for n partitions. seems to work OK here

import dask.dataframe as dd
from dask import delayed

@delayed
def make_df(d):
    return pd.DataFrame({"date": d, "x": np.random.random(10)}).set_index("date")

dates = pd.date_range("2020-01-01", "2020-01-08")
ddf = dd.from_delayed([make_df(d) for d in dates], divisions=[*dates, dates[-1]])

ddf
Out[61]:
Dask DataFrame Structure:
                     x
npartitions=8
2020-01-01     float64
2020-01-02         ...
...                ...
2020-01-08         ...
2020-01-08         ...
Dask Name: from-delayed, 16 tasks


ddf.map_partitions(len).compute()
Out[62]:
0    10
1    10
2    10
3    10
4    10
5    10
6    10
7    10
dtype: int64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's related to how the data is originally partitioned. For example when I read one of the tables of the covid public data set that I copied on "my_project" I see this

from dask_bigquery import read_gbq

ddf= read_gbq(
                project_id="my_project",
                dataset_id="covid19_public_forecasts",
                table_id="county_14d",)

ddf.map_partitions(len).compute()

Notice the last two partitions...

0     3164
1     3164
2     3164
3     3164
4     3164
5     3164
6     3164
7     3164
8     3164
9     3164
10    3164
11    3164
12    3164
13    3164
14    3164
15    3164
16    3164
17    3164
18    3164
19    3164
20    3164
21    3164
22    3164
23    3164
24    3164
25    3164
26    3164
27    3164
28    3164
29    3164
30    3164
31    3164
32    3164
33    3164
34    3164
35    3164
36    3164
37    3164
38    3164
39    3164
40    3164
41    3163
42       1
dtype: int64

row_filters = [
f'{partition_field} = "{partition_value}"'
for partition_value in partitions
]
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=partial(
make_create_read_session_request, row_filter=row_filter
),
partition_field=partition_field,
project_id=project_id,
timeout=read_timeout,
)
for row_filter in row_filters
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great for now, but at some point we may want to use raw task graphs. They're a bit cleaner in a few ways. Delayed is more designed for user code. If we have the time we prefer to use raw graphs in dev code.

For example, in some cases I wouldn't be surprised if each Delayed task produces a single TaskGroup, rather than having all of the tasks in a single TaskGroup. Sure, this will compute just fine, but other features (like the task group visualization, or coiled telemetry) may be sad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jrbourbeau and I gave it a try to use HighLevelGraphs and we realized that this will require modifying the structure of the function _read_rows_arrow since as is now, the inputs don't match the required format asked in DataFrameIOLayer
https://github.com/dask/dask/blob/95fb60a31a87c6b94b01ed75ab6533fa04d51f19/dask/layers.py#L1159-L1166

We might want to move this to a separate PR.

else:
delayed_kwargs["meta"] = meta
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=make_create_read_session_request,
project_id=project_id,
stream_name=stream.name,
timeout=read_timeout,
)
for stream in session.streams
]

return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs)
98 changes: 98 additions & 0 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import random

import pandas as pd
import pytest
from dask.dataframe.utils import assert_eq
from distributed.utils_test import cluster_fixture # noqa: F401
from distributed.utils_test import client, loop # noqa: F401
from google.cloud import bigquery

from dask_bigquery import read_gbq

# These tests are run locally and assume the user is already athenticated.
# It also assumes that the user has created a project called dask-bigquery.


@pytest.fixture
def df():
records = [
{
"name": random.choice(["fred", "wilma", "barney", "betty"]),
"number": random.randint(0, 100),
"idx": i,
}
for i in range(10)
]

yield pd.DataFrame(records)


@pytest.fixture
def dataset(df):
"Push some data to BigQuery using pandas gbq"

with bigquery.Client() as bq_client:
try:
bq_client.delete_dataset(
dataset="dask-bigquery.dataset_test",
delete_contents=True,
)
except: # noqa: E722
pass

# push data to gbq
pd.DataFrame.to_gbq(
df,
destination_table="dataset_test.table_test",
project_id="dask-bigquery",
chunksize=5,
if_exists="append",
)
yield "dask-bigquery.dataset_test.table_test"


# test simple read
def test_read_gbq(df, dataset, client):
"""Test simple read of data pushed to BigQuery using pandas-gbq"""
project_id, dataset_id, table_id = dataset.split(".")

ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)

assert ddf.columns.tolist() == ["name", "number", "idx"]
assert len(ddf) == 10
assert ddf.npartitions == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we verify that the data is actually the same as the data created in push_data instead? Pushing the fixture idea a little further

@pytest.fixture
def df():
    ...

@pytest.fixture
def dataset(df):
    ...

def test_read_gbq(client, dataset, df):
    ddf = read_gbq(...)

    assert_eq(ddf, df)

Maybe there are sorting things that get in the way (is GBQ ordered?) If so then, as you did before

assert_eq(ddf.set_index("idx"), df.set_index("idx"))

In general we want to use assert_eq if possible. It runs lots of cleanliness checks on the Dask collection, graph, metadata, and so on.

Copy link
Contributor Author

@ncclementi ncclementi Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I get some order issues when reading back from gbq, mainly because when I read back the default index goes from 0 to chunksize-1, where chunksize was chosen when I pushed the pandas dataframe. This was part of the reason I had as an extra column "idx".

But thanks for pointing out the assert_eq I forgot we had that. Although, in order for that line to work I had to do a compute on the dask dataframe. I'm assuming this is because I'm comparing a dask data frame with a pandas one.?

assert_eq(ddf.set_index("idx").compute(), df.set_index("idx"))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_eq should handle comparing Dask and pandas DataFrames. What error do you get without the compute()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this:

____________________________________________________________________ test_read_gbq _____________________________________________________________________

df =      name  number  idx
0   betty      71    0
1    fred      36    1
2   wilma      75    2
3   betty      13    3
4  ...   4
5    fred      74    5
6   wilma      69    6
7    fred      31    7
8  barney      31    8
9   betty      97    9
dataset = 'dask-bigquery.dataset_test.table_test', client = <Client: 'tcp://127.0.0.1:55212' processes=2 threads=2, memory=32.00 GiB>

    def test_read_gbq(df, dataset, client):
        """Test simple read of data pushed to BigQuery using pandas-gbq"""
        project_id, dataset_id, table_id = dataset.split(".")
    
        ddf = read_gbq(
            project_id=project_id, dataset_id=dataset_id, table_id=table_id
        )
    
        assert ddf.columns.tolist() == ["name", "number", "idx"]
        assert len(ddf) == 10
        assert ddf.npartitions == 2
        #breakpoint()
>       assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))

test_core.py:67: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../../../mambaforge/envs/test_gbq/lib/python3.8/site-packages/dask/dataframe/utils.py:541: in assert_eq
    assert_sane_keynames(a)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

ddf = Dask DataFrame Structure:
                 name number
npartitions=2               
0              object  int64
4                 ...    ...
9                 ...    ...
Dask Name: sort_index, 22 tasks

    def assert_sane_keynames(ddf):
        if not hasattr(ddf, "dask"):
            return
        for k in ddf.dask.keys():
            while isinstance(k, tuple):
                k = k[0]
            assert isinstance(k, (str, bytes))
            assert len(k) < 100
            assert " " not in k
>           assert k.split("-")[0].isidentifier()
E           AssertionError

../../../../../mambaforge/envs/test_gbq/lib/python3.8/site-packages/dask/dataframe/utils.py:621: AssertionError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting! Do you have any additional information in the traceback? I'm wondering what k is? You could try adding --pdb to the end of your pytest command, which will drop you into a pdb session at the point the test raises an error. You can then do pp k to see what k is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also xref dask/dask#8061

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting! Do you have any additional information in the traceback? I'm wondering what k is? You could try adding --pdb to the end of your pytest command, which will drop you into a pdb session at the point the test raises an error. You can then do pp k to see what k is

There is no additional in the traceback but pp k returns
'dataset_test.table_test--46e9ff0148164adf1b543e44137043cd'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is coming from the

delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-")

line earlier in this function. Ultimately we'll want to move away from delayed and constructing the Dask graph ourselves, so for now I think it's okay to drop the prefix= here and use

delayed_kwargs = {}

instead. That should allow you to also drop the compute() call in assert_eq


assert assert_eq(ddf.set_index("idx").compute(), df.set_index("idx"))


# test partitioned data: this test requires a copy of the public dataset
# bigquery-public-data.covid19_public_forecasts.county_14d into a the
# project dask-bigquery


@pytest.mark.parametrize(
"fields",
([], ["county_name"], ["county_name", "county_fips_code"]),
ids=["no_fields", "missing_partition_field", "fields"],
)
def test_read_gbq_partitioning(fields, client):
partitions = ["Teton", "Loudoun"]
ddf = read_gbq(
project_id="dask-bigquery",
dataset_id="covid19_public_forecasts",
table_id="county_14d",
partition_field="county_name",
partitions=partitions,
fields=fields,
)

assert len(ddf) # check it's not empty
loaded = set(ddf.columns) | {ddf.index.name}

if fields:
assert loaded == set(fields) | {"county_name"}
else: # all columns loaded
assert loaded >= set(["county_name", "county_fips_code"])

assert ddf.npartitions == len(partitions)
assert list(ddf.divisions) == sorted(ddf.divisions)
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dask
distributed
google-cloud-bigquery
google-cloud-bigquery-storage
pandas
pandas-gbq
pyarrow
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
exclude = __init__.py
max-line-length = 120
ignore = F811
Loading