Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQuery: Determine the schema in load_table_from_dataframe based on dtypes. #9049

Merged
merged 5 commits into from Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -49,6 +49,21 @@

_PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds.

_PANDAS_DTYPE_TO_BQ = {
"bool": "BOOLEAN",
"datetime64[ns, UTC]": "TIMESTAMP",
"datetime64[ns]": "DATETIME",
"float32": "FLOAT",
"float64": "FLOAT",
"int8": "INTEGER",
"int16": "INTEGER",
"int32": "INTEGER",
"int64": "INTEGER",
"uint8": "INTEGER",
"uint16": "INTEGER",
"uint32": "INTEGER",
}


class _DownloadState(object):
"""Flag to indicate that a thread should exit early."""
Expand Down Expand Up @@ -172,6 +187,31 @@ def bq_to_arrow_array(series, bq_field):
return pyarrow.array(series, type=arrow_type)


def dataframe_to_bq_schema(dataframe):
"""Convert a pandas DataFrame schema to a BigQuery schema.

TODO(GH#8140): Add bq_schema argument to allow overriding autodetected
schema for a subset of columns.

Args:
dataframe (pandas.DataFrame):
DataFrame to convert to convert to Parquet file.

Returns:
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]:
The automatically determined schema. Returns None if the type of
any column cannot be determined.
"""
bq_schema = []
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema.append(bq_field)
return tuple(bq_schema)


def dataframe_to_arrow(dataframe, bq_schema):
"""Convert pandas dataframe to Arrow table, using BigQuery schema.

Expand Down
15 changes: 15 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Expand Up @@ -21,6 +21,7 @@
except ImportError: # Python 2.7
import collections as collections_abc

import copy
import functools
import gzip
import io
Expand Down Expand Up @@ -1521,11 +1522,25 @@ def load_table_from_dataframe(

if job_config is None:
job_config = job.LoadJobConfig()
else:
# Make a copy so that the job config isn't modified in-place.
job_config_properties = copy.deepcopy(job_config._properties)
job_config = job.LoadJobConfig()
job_config._properties = job_config_properties
job_config.source_format = job.SourceFormat.PARQUET

if location is None:
location = self.location

if not job_config.schema:
autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(dataframe)

# Only use an explicit schema if we were able to determine one
# matching the dataframe. If not, fallback to the pandas to_parquet
# method.
if autodetected_schema:
job_config.schema = autodetected_schema

tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
os.close(tmpfd)

Expand Down
76 changes: 76 additions & 0 deletions bigquery/tests/system.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import base64
import collections
import concurrent.futures
import csv
import datetime
Expand Down Expand Up @@ -634,6 +635,81 @@ def test_load_table_from_local_avro_file_then_dump_table(self):
sorted(row_tuples, key=by_wavelength), sorted(ROWS, key=by_wavelength)
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_automatic_schema(self):
"""Test that a DataFrame with dtypes that map well to BigQuery types
can be uploaded without specifying a schema.

https://github.com/googleapis/google-cloud-python/issues/9044
"""
df_data = collections.OrderedDict(
[
("bool_col", pandas.Series([True, False, True], dtype="bool")),
(
"ts_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
).dt.tz_localize(pytz.utc),
),
(
"dt_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
),
),
("float32_col", pandas.Series([1.0, 2.0, 3.0], dtype="float32")),
("float64_col", pandas.Series([4.0, 5.0, 6.0], dtype="float64")),
("int8_col", pandas.Series([-12, -11, -10], dtype="int8")),
("int16_col", pandas.Series([-9, -8, -7], dtype="int16")),
("int32_col", pandas.Series([-6, -5, -4], dtype="int32")),
("int64_col", pandas.Series([-3, -2, -1], dtype="int64")),
("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")),
("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")),
("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")),
]
)
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())

dataset_id = _make_dataset_id("bq_load_test")
self.temp_dataset(dataset_id)
table_id = "{}.{}.load_table_from_dataframe_w_automatic_schema".format(
Config.CLIENT.project, dataset_id
)

load_job = Config.CLIENT.load_table_from_dataframe(dataframe, table_id)
load_job.result()

table = Config.CLIENT.get_table(table_id)
self.assertEqual(
tuple(table.schema),
(
bigquery.SchemaField("bool_col", "BOOLEAN"),
bigquery.SchemaField("ts_col", "TIMESTAMP"),
bigquery.SchemaField("dt_col", "DATETIME"),
bigquery.SchemaField("float32_col", "FLOAT"),
bigquery.SchemaField("float64_col", "FLOAT"),
bigquery.SchemaField("int8_col", "INTEGER"),
bigquery.SchemaField("int16_col", "INTEGER"),
bigquery.SchemaField("int32_col", "INTEGER"),
bigquery.SchemaField("int64_col", "INTEGER"),
bigquery.SchemaField("uint8_col", "INTEGER"),
bigquery.SchemaField("uint16_col", "INTEGER"),
bigquery.SchemaField("uint32_col", "INTEGER"),
),
)
self.assertEqual(table.num_rows, 3)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_nulls(self):
Expand Down
74 changes: 72 additions & 2 deletions bigquery/tests/unit/test_client.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import collections
import datetime
import decimal
import email
Expand Down Expand Up @@ -5325,9 +5326,78 @@ def test_load_table_from_dataframe_w_custom_job_config(self):
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config is job_config
assert sent_config.source_format == job.SourceFormat.PARQUET

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_automatic_schema(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
df_data = collections.OrderedDict(
[
("int_col", [1, 2, 3]),
("float_col", [1.0, 2.0, 3.0]),
("bool_col", [True, False, True]),
(
"dt_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
),
),
(
"ts_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
).dt.tz_localize(pytz.utc),
),
]
)
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

with load_patch as load_table_from_file:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=self.LOCATION,
project=None,
job_config=mock.ANY,
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET
assert tuple(sent_config.schema) == (
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("bool_col", "BOOLEAN"),
SchemaField("dt_col", "DATETIME"),
SchemaField("ts_col", "TIMESTAMP"),
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_struct_fields_error(self):
Expand Down Expand Up @@ -5509,7 +5579,7 @@ def test_load_table_from_dataframe_w_nulls(self):
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config is job_config
assert sent_config.schema == schema
assert sent_config.source_format == job.SourceFormat.PARQUET

# Low-level tests
Expand Down