Skip to content
Permalink
Browse files
feat: read_session optional to ReadRowsStream.rows() (#228)
* feat: `read_session` optional to `ReadRowsStream.rows()`

The schema from the first `ReadRowsResponse` message can be used to decode
messages, instead.

Note: `to_arrow()` and `to_dataframe()` do not work on an empty stream unless a
`read_session` has been passed in, as the schema is not available. This should
not affect `google-cloud-bigquery` and `pandas-gbq`, as those packages use the
lower-level message->dataframe/arrow methods.

* revert change to comment

* use else for empty arrow streams in try-except block

Co-authored-by: Tres Seaver <tseaver@palladion.com>

* update docstring to reflect that readsession and readrowsresponse can be used interchangeably

* update arrow deserializer, too

Co-authored-by: Tres Seaver <tseaver@palladion.com>
  • Loading branch information
tswast and tseaver committed Jul 9, 2021
1 parent a8a8c78 commit 4f5602950a0c1959e332aa2964245b9caf4828c8
Showing with 196 additions and 140 deletions.
  1. +91 −34 google/cloud/bigquery_storage_v1/reader.py
  2. +40 −37 tests/system/conftest.py
  3. +41 −44 tests/unit/test_reader_v1.py
  4. +24 −25 tests/unit/test_reader_v1_arrow.py
@@ -156,7 +156,7 @@ def _reconnect(self):
read_stream=self._name, offset=self._offset, **self._read_rows_kwargs
)

def rows(self, read_session):
def rows(self, read_session=None):
"""Iterate over all rows in the stream.
This method requires the fastavro library in order to parse row
@@ -169,19 +169,21 @@ def rows(self, read_session):
Args:
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
Returns:
Iterable[Mapping]:
A sequence of rows, represented as dictionaries.
"""
return ReadRowsIterable(self, read_session)
return ReadRowsIterable(self, read_session=read_session)

def to_arrow(self, read_session):
def to_arrow(self, read_session=None):
"""Create a :class:`pyarrow.Table` of all rows in the stream.
This method requires the pyarrow library and a stream using the Arrow
@@ -191,17 +193,19 @@ def to_arrow(self, read_session):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
return self.rows(read_session).to_arrow()
return self.rows(read_session=read_session).to_arrow()

def to_dataframe(self, read_session, dtypes=None):
def to_dataframe(self, read_session=None, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
This method requires the pandas libary to create a data frame and the
@@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
@@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None):
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

return self.rows(read_session).to_dataframe(dtypes=dtypes)
return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes)


class ReadRowsIterable(object):
@@ -242,18 +248,25 @@ class ReadRowsIterable(object):
Args:
reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream):
A read rows stream.
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
read_session ( \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
"""

# This class is modelled after the google.cloud.bigquery.table.RowIterator
# and aims to be API compatible where possible.

def __init__(self, reader, read_session):
def __init__(self, reader, read_session=None):
self._reader = reader
self._read_session = read_session
self._stream_parser = _StreamParser.from_read_session(self._read_session)
if read_session is not None:
self._stream_parser = _StreamParser.from_read_session(read_session)
else:
self._stream_parser = None

@property
def pages(self):
@@ -266,6 +279,10 @@ def pages(self):
# Each page is an iterator of rows. But also has num_items, remaining,
# and to_dataframe.
for message in self._reader:
# Only the first message contains the schema, which is needed to
# decode the messages.
if not self._stream_parser:
self._stream_parser = _StreamParser.from_read_rows_response(message)
yield ReadRowsPage(self._stream_parser, message)

def __iter__(self):
@@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None):
# pandas dataframe is about 2x faster. This is because pandas.concat is
# rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is
# usually no-copy.
schema_type = self._read_session._pb.WhichOneof("schema")

if schema_type == "arrow_schema":
try:
record_batch = self.to_arrow()
except NotImplementedError:
pass
else:
df = record_batch.to_pandas()
for column in dtypes:
df[column] = pandas.Series(df[column], dtype=dtypes[column])
@@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None):
def to_rows(self, message):
raise NotImplementedError("Not implemented.")

def _parse_avro_schema(self):
raise NotImplementedError("Not implemented.")

def _parse_arrow_schema(self):
raise NotImplementedError("Not implemented.")

@staticmethod
def from_read_session(read_session):
schema_type = read_session._pb.WhichOneof("schema")
@@ -503,22 +527,38 @@ def from_read_session(read_session):
"Unsupported schema type in read_session: {0}".format(schema_type)
)

@staticmethod
def from_read_rows_response(message):
schema_type = message._pb.WhichOneof("schema")
if schema_type == "avro_schema":
return _AvroStreamParser(message)
elif schema_type == "arrow_schema":
return _ArrowStreamParser(message)
else:
raise TypeError(
"Unsupported schema type in message: {0}".format(schema_type)
)


class _AvroStreamParser(_StreamParser):
"""Helper to parse Avro messages into useful representations."""

def __init__(self, read_session):
def __init__(self, message):
"""Construct an _AvroStreamParser.
Args:
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

self._read_session = read_session
self._first_message = message
self._avro_schema_json = None
self._fastavro_schema = None
self._column_names = None
@@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None):
strings in the fastavro library.
Args:
message ( \
~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \
):
A message containing Avro bytes to parse into a pandas DataFrame.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
@@ -578,10 +622,11 @@ def _parse_avro_schema(self):
if self._avro_schema_json:
return

self._avro_schema_json = json.loads(self._read_session.avro_schema.schema)
self._avro_schema_json = json.loads(self._first_message.avro_schema.schema)
self._column_names = tuple(
(field["name"] for field in self._avro_schema_json["fields"])
)
self._first_message = None

def _parse_fastavro(self):
"""Convert parsed Avro schema to fastavro format."""
@@ -615,11 +660,22 @@ def to_rows(self, message):


class _ArrowStreamParser(_StreamParser):
def __init__(self, read_session):
def __init__(self, message):
"""Construct an _ArrowStreamParser.
Args:
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if pyarrow is None:
raise ImportError(_PYARROW_REQUIRED)

self._read_session = read_session
self._first_message = message
self._schema = None

def to_arrow(self, message):
@@ -659,6 +715,7 @@ def _parse_arrow_schema(self):
return

self._schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema)
)
self._column_names = [field.name for field in self._schema]
self._first_message = None
@@ -18,13 +18,41 @@
import os
import uuid

import google.auth
from google.cloud import bigquery
import pytest
import test_utils.prefixer

from . import helpers


prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system")


_TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}"
_ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets")
_ALL_TYPES_SCHEMA = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]


@pytest.fixture(scope="session")
@@ -38,18 +66,9 @@ def use_mtls():


@pytest.fixture(scope="session")
def credentials(use_mtls):
import google.auth
from google.oauth2 import service_account

if use_mtls:
# mTLS test uses user credentials instead of service account credentials
creds, _ = google.auth.default()
return creds

# NOTE: the test config in noxfile checks that the env variable is indeed set
filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
return service_account.Credentials.from_service_account_file(filename)
def credentials():
creds, _ = google.auth.default()
return creds


@pytest.fixture()
@@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls):
def dataset(project_id, bq_client):
from google.cloud import bigquery

unique_suffix = str(uuid.uuid4()).replace("-", "_")
dataset_name = "bq_storage_system_tests_" + unique_suffix
dataset_name = prefixer.create_prefix()

dataset_id = "{}.{}".format(project_id, dataset_name)
dataset = bigquery.Dataset(dataset_id)
@@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls):
return bigquery.Client(credentials=credentials)


@pytest.fixture(scope="session", autouse=True)
def cleanup_datasets(bq_client: bigquery.Client):
for dataset in bq_client.list_datasets():
if prefixer.should_cleanup(dataset.dataset_id):
bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True)


@pytest.fixture
def all_types_table_ref(project_id, dataset, bq_client):
from google.cloud import bigquery

schema = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]
bq_table = bigquery.table.Table(
table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id),
schema=schema,
schema=_ALL_TYPES_SCHEMA,
)

created_table = bq_client.create_table(bq_table)

0 comments on commit 4f56029

Please sign in to comment.