Skip to content

Commit

Permalink
test(bigquery): add tests for concatenating categorical columns (#10180)
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Jan 31, 2020
1 parent 193a1dd commit 77dd923
Showing 1 changed file with 168 additions and 0 deletions.
168 changes: 168 additions & 0 deletions bigquery/tests/unit/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3242,6 +3242,174 @@ def test_to_dataframe_w_bqstorage_snapshot(self):
with pytest.raises(ValueError):
row_iterator.to_dataframe(bqstorage_client)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(
bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`"
)
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_to_dataframe_concat_categorical_dtype_w_pyarrow(self):
from google.cloud.bigquery import schema
from google.cloud.bigquery import table as mut
from google.cloud.bigquery_storage_v1beta1 import reader

arrow_fields = [
# Not alphabetical to test column order.
pyarrow.field("col_str", pyarrow.utf8()),
# The backend returns strings, and without other info, pyarrow contains
# string data in categorical columns, too (and not maybe the Dictionary
# type that corresponds to pandas.Categorical).
pyarrow.field("col_category", pyarrow.utf8()),
]
arrow_schema = pyarrow.schema(arrow_fields)

# create a mock BQ storage client
bqstorage_client = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient
)
bqstorage_client.transport = mock.create_autospec(
big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
)
session = bigquery_storage_v1beta1.types.ReadSession(
streams=[{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"}],
arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()},
)
bqstorage_client.create_read_session.return_value = session

mock_rowstream = mock.create_autospec(reader.ReadRowsStream)
bqstorage_client.read_rows.return_value = mock_rowstream

# prepare the iterator over mocked rows
mock_rows = mock.create_autospec(reader.ReadRowsIterable)
mock_rowstream.rows.return_value = mock_rows
page_items = [
[
pyarrow.array(["foo", "bar", "baz"]), # col_str
pyarrow.array(["low", "medium", "low"]), # col_category
],
[
pyarrow.array(["foo_page2", "bar_page2", "baz_page2"]), # col_str
pyarrow.array(["medium", "high", "low"]), # col_category
],
]

mock_pages = []

for record_list in page_items:
page_record_batch = pyarrow.RecordBatch.from_arrays(
record_list, schema=arrow_schema
)
mock_page = mock.create_autospec(reader.ReadRowsPage)
mock_page.to_arrow.return_value = page_record_batch
mock_pages.append(mock_page)

type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages)

schema = [
schema.SchemaField("col_str", "IGNORED"),
schema.SchemaField("col_category", "IGNORED"),
]

row_iterator = mut.RowIterator(
_mock_client(),
None, # api_request: ignored
None, # path: ignored
schema,
table=mut.TableReference.from_string("proj.dset.tbl"),
selected_fields=schema,
)

# run the method under test
got = row_iterator.to_dataframe(
bqstorage_client=bqstorage_client,
dtypes={
"col_category": pandas.core.dtypes.dtypes.CategoricalDtype(
categories=["low", "medium", "high"], ordered=False,
),
},
)

# Are the columns in the expected order?
column_names = ["col_str", "col_category"]
self.assertEqual(list(got), column_names)

# Have expected number of rows?
total_pages = len(mock_pages) # we have a single stream, thus these two equal
total_rows = len(page_items[0][0]) * total_pages
self.assertEqual(len(got.index), total_rows)

# Are column types correct?
expected_dtypes = [
pandas.core.dtypes.dtypes.np.dtype("O"), # the default for string data
pandas.core.dtypes.dtypes.CategoricalDtype(
categories=["low", "medium", "high"], ordered=False,
),
]
self.assertEqual(list(got.dtypes), expected_dtypes)

# And the data in the categorical column?
self.assertEqual(
list(got["col_category"]),
["low", "medium", "low", "medium", "high", "low"],
)

# Don't close the client if it was passed in.
bqstorage_client.transport.channel.close.assert_not_called()

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_to_dataframe_concat_categorical_dtype_wo_pyarrow(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
SchemaField("col_str", "STRING"),
SchemaField("col_category", "STRING"),
]
row_data = [
[u"foo", u"low"],
[u"bar", u"medium"],
[u"baz", u"low"],
[u"foo_page2", u"medium"],
[u"bar_page2", u"high"],
[u"baz_page2", u"low"],
]
path = "/foo"

rows = [{"f": [{"v": field} for field in row]} for row in row_data[:3]]
rows_page2 = [{"f": [{"v": field} for field in row]} for row in row_data[3:]]
api_request = mock.Mock(
side_effect=[{"rows": rows, "pageToken": "NEXTPAGE"}, {"rows": rows_page2}]
)

row_iterator = self._make_one(_mock_client(), api_request, path, schema)

with mock.patch("google.cloud.bigquery.table.pyarrow", None):
got = row_iterator.to_dataframe(
dtypes={
"col_category": pandas.core.dtypes.dtypes.CategoricalDtype(
categories=["low", "medium", "high"], ordered=False,
),
},
)

self.assertIsInstance(got, pandas.DataFrame)
self.assertEqual(len(got), 6) # verify the number of rows
expected_columns = [field.name for field in schema]
self.assertEqual(list(got), expected_columns) # verify the column names

# Are column types correct?
expected_dtypes = [
pandas.core.dtypes.dtypes.np.dtype("O"), # the default for string data
pandas.core.dtypes.dtypes.CategoricalDtype(
categories=["low", "medium", "high"], ordered=False,
),
]
self.assertEqual(list(got.dtypes), expected_dtypes)

# And the data in the categorical column?
self.assertEqual(
list(got["col_category"]),
["low", "medium", "low", "medium", "high", "low"],
)


class TestPartitionRange(unittest.TestCase):
def _get_target_class(self):
Expand Down

0 comments on commit 77dd923

Please sign in to comment.