Skip to content

Commit

Permalink
feat(bigquery): add client and storage_client params to connect
Browse files Browse the repository at this point in the history
This allows for fine-grained customization, such as connecting to regional
endpoints. Adds test to make sure regional endpoints are used when set.
  • Loading branch information
tswast authored and cpcloud committed May 30, 2023
1 parent a286217 commit 4cf1354
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 14 deletions.
63 changes: 49 additions & 14 deletions ibis/backends/bigquery/__init__.py
Expand Up @@ -8,6 +8,7 @@

import google.auth.credentials
import google.cloud.bigquery as bq
import google.cloud.bigquery_storage_v1 as bqstorage
import pandas as pd
import pydata_google_auth
from pydata_google_auth import cache
Expand Down Expand Up @@ -45,9 +46,7 @@
CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt"


def _create_client_info(application_name):
from google.api_core.client_info import ClientInfo

def _create_user_agent(application_name: str) -> str:
user_agent = []

if application_name:
Expand All @@ -56,7 +55,19 @@ def _create_client_info(application_name):
user_agent_default_template = f"ibis/{ibis.__version__}"
user_agent.append(user_agent_default_template)

return ClientInfo(user_agent=" ".join(user_agent))
return " ".join(user_agent)


def _create_client_info(application_name):
from google.api_core.client_info import ClientInfo

return ClientInfo(user_agent=_create_user_agent(application_name))


def _create_client_info_gapic(application_name):
from google.api_core.gapic_v1.client_info import ClientInfo

return ClientInfo(user_agent=_create_user_agent(application_name))


class Backend(BaseSQLBackend):
Expand All @@ -82,6 +93,8 @@ def do_connect(
auth_external_data: bool = False,
auth_cache: str = "default",
partition_column: str | None = "PARTITIONTIME",
client: bq.Client | None = None,
storage_client: bqstorage.BigQueryReadClient | None = None,
):
"""Create a `Backend` for use with Ibis.
Expand Down Expand Up @@ -125,15 +138,24 @@ def do_connect(
partition_column
Identifier to use instead of default ``_PARTITIONTIME`` partition
column. Defaults to ``'PARTITIONTIME'``.
client
A ``Client`` from the ``google.cloud.bigquery`` package. If not
set, one is created using the ``project_id`` and ``credentials``.
storage_client
A ``BigQueryReadClient`` from the
``google.cloud.bigquery_storage_v1`` package. If not set, one is
created using the ``project_id`` and ``credentials``.
Returns
-------
Backend
An instance of the BigQuery backend.
"""
default_project_id = ""
default_project_id = client.project if client is not None else project_id

if credentials is None:
# Only need `credentials` to create a `client` and
# `storage_client`, so only one or the other needs to be set.
if (client is None or storage_client is None) and credentials is None:
scopes = SCOPES
if auth_external_data:
scopes = EXTERNAL_DATA_SCOPES
Expand Down Expand Up @@ -170,11 +192,23 @@ def do_connect(
self.dataset,
) = parse_project_and_dataset(project_id, dataset_id)

self.client = bq.Client(
project=self.billing_project,
credentials=credentials,
client_info=_create_client_info(application_name),
)
if client is not None:
self.client = client
else:
self.client = bq.Client(
project=self.billing_project,
credentials=credentials,
client_info=_create_client_info(application_name),
)

if storage_client is not None:
self.storage_client = storage_client
else:
self.storage_client = bqstorage.BigQueryReadClient(
credentials=credentials,
client_info=_create_client_info_gapic(application_name),
)

self.partition_column = partition_column

def _parse_project_and_dataset(self, dataset) -> tuple[str, str]:
Expand Down Expand Up @@ -315,8 +349,7 @@ def _cursor_to_arrow(
if method is None:
method = lambda result: result.to_arrow(
progress_bar_type=None,
bqstorage_client=None,
create_bqstorage_client=True,
bqstorage_client=self.storage_client,
)
query = cursor.query
query_result = query.result(page_size=chunk_size)
Expand Down Expand Up @@ -370,7 +403,9 @@ def to_pyarrow_batches(
cursor = self.raw_sql(sql, params=params, **kwargs)
batch_iter = self._cursor_to_arrow(
cursor,
method=lambda result: result.to_arrow_iterable(),
method=lambda result: result.to_arrow_iterable(
bqstorage_client=self.storage_client
),
chunk_size=chunk_size,
)
return pa.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter)
Expand Down
29 changes: 29 additions & 0 deletions ibis/backends/bigquery/tests/conftest.py
Expand Up @@ -27,6 +27,8 @@
import ibis.expr.types as ir

DATASET_ID = "ibis_gbq_testing"
DATASET_ID_TOKYO = "ibis_gbq_testing_tokyo"
REGION_TOKYO = "asia-northeast1"
DEFAULT_PROJECT_ID = "ibis-gbq"
PROJECT_ID_ENV_VAR = "GOOGLE_BIGQUERY_PROJECT_ID"

Expand Down Expand Up @@ -109,6 +111,14 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
with contextlib.suppress(gexc.NotFound):
client.create_dataset(testing_dataset, exists_ok=True)

testing_dataset_tokyo = bq.Dataset(
bq.DatasetReference(project_id, DATASET_ID_TOKYO)
)
testing_dataset_tokyo.location = REGION_TOKYO

with contextlib.suppress(gexc.NotFound):
client.create_dataset(testing_dataset_tokyo, exists_ok=True)

# day partitioning
functional_alltypes_parted = bq.Table(
bq.TableReference(testing_dataset, "functional_alltypes_parted")
Expand Down Expand Up @@ -268,6 +278,25 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
for table, schema in TEST_TABLES.items()
)

# Test regional endpoints with non-US data.

futures.extend(
e.submit(
make_job,
client.load_table_from_file,
io.BytesIO(
data_dir.joinpath("parquet", f"{table}.parquet").read_bytes()
),
bq.TableReference(testing_dataset_tokyo, table),
job_config=bq.LoadJobConfig(
schema=ibis_schema_to_bq_schema(schema),
write_disposition=write_disposition,
source_format=bq.SourceFormat.PARQUET,
),
)
for table, schema in TEST_TABLES.items()
)

for fut in concurrent.futures.as_completed(futures):
fut.result()

Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/bigquery/tests/system/conftest.py
Expand Up @@ -12,6 +12,8 @@
DEFAULT_PROJECT_ID = "ibis-gbq"
PROJECT_ID_ENV_VAR = "GOOGLE_BIGQUERY_PROJECT_ID"
DATASET_ID = "ibis_gbq_testing"
DATASET_ID_TOKYO = "ibis_gbq_testing_tokyo"
REGION_TOKYO = "asia-northeast1"


def pytest_addoption(parser):
Expand All @@ -34,6 +36,16 @@ def dataset_id() -> str:
return DATASET_ID


@pytest.fixture(scope="session")
def dataset_id_tokyo() -> str:
return DATASET_ID_TOKYO


@pytest.fixture(scope="session")
def region_tokyo() -> str:
return REGION_TOKYO


@pytest.fixture(scope="session")
def default_credentials():
try:
Expand Down
34 changes: 34 additions & 0 deletions ibis/backends/bigquery/tests/system/test_connect.py
@@ -1,10 +1,12 @@
from unittest import mock

import google.api_core.client_options
import google.api_core.exceptions as gexc
import pydata_google_auth
import pytest
from google.auth import credentials as auth
from google.cloud import bigquery as bq
from google.cloud import bigquery_storage_v1 as bqstorage

import ibis

Expand Down Expand Up @@ -178,3 +180,35 @@ def test_auth_cache_unknown(project_id):
dataset_id="bigquery-public-data.stackoverflow",
auth_cache="not_a_real_cache",
)


def test_client_with_regional_endpoints(
project_id, credentials, dataset_id, dataset_id_tokyo, region_tokyo
):
bq_options = google.api_core.client_options.ClientOptions(
api_endpoint=f"https://{region_tokyo}-bigquery.googleapis.com"
)
bq_client = bq.Client(
client_options=bq_options, project=project_id, credentials=credentials
)

# Note there is no protocol specifier for gRPC APIs.
bqstorage_options = google.api_core.client_options.ClientOptions(
api_endpoint=f"{region_tokyo}-bigquerystorage.googleapis.com"
)
bqstorage_client = bqstorage.BigQueryReadClient(
client_options=bqstorage_options, credentials=credentials
)

con = ibis.bigquery.connect(
client=bq_client, storage_client=bqstorage_client, project_id=project_id
)

# Fails because dataset not in Tokyo.
with pytest.raises(gexc.NotFound, match=dataset_id):
con.table(f"{dataset_id}.functional_alltypes")

# Succeeds because dataset is in Tokyo.
alltypes = con.table(f"{dataset_id_tokyo}.functional_alltypes")
df = alltypes.limit(2).execute()
assert len(df.index) == 2

0 comments on commit 4cf1354

Please sign in to comment.