Skip to content

Commit

Permalink
Add credentials parameter to read_gbq (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Sep 6, 2023
1 parent 5f084d5 commit e28e815
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
25 changes: 21 additions & 4 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@contextmanager
def bigquery_clients(project_id):
def bigquery_clients(project_id, credentials: dict = None):
"""This context manager is a temporary solution until there is an
upstream solution to handle this.
See googleapis/google-cloud-python#9457
Expand All @@ -35,7 +35,15 @@ def bigquery_clients(project_id):
user_agent=f"dask-bigquery/{dask_bigquery.__version__}"
)

with bigquery.Client(project_id, client_info=bq_client_info) as bq_client:
# Google library client needs an instance of google.auth.credentials.Credentials
if isinstance(credentials, dict):
credentials = service_account.Credentials.from_service_account_info(
info=credentials
)

with bigquery.Client(
project_id, credentials=credentials, client_info=bq_client_info
) as bq_client:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials,
client_info=bqstorage_client_info,
Expand Down Expand Up @@ -88,6 +96,7 @@ def bigquery_read(
project_id: str,
read_kwargs: dict,
arrow_options: dict,
credentials: dict = None,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Expand All @@ -108,7 +117,7 @@ def bigquery_read(
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
"""
with bigquery_clients(project_id) as (_, bqs_client):
with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
Expand All @@ -132,6 +141,7 @@ def read_gbq(
max_stream_count: int = 0,
read_kwargs: dict = None,
arrow_options: dict = None,
credentials: dict = None,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
Partitions will be approximately balanced according to BigQuery stream allocation logic.
Expand All @@ -157,14 +167,20 @@ def read_gbq(
kwargs to pass to record_batch.to_pandas() when converting from pyarrow to pandas. See
https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatch.html#pyarrow.RecordBatch.to_pandas
for possible values
credentials : dict, optional
Credentials for accessing Google APIs. Use this parameter to override
default credentials. The dict should contain service account credentials in JSON format.
Returns
-------
Dask DataFrame
"""
read_kwargs = read_kwargs or {}
arrow_options = arrow_options or {}
with bigquery_clients(project_id) as (bq_client, bqs_client):
with bigquery_clients(project_id, credentials=credentials) as (
bq_client,
bqs_client,
):
table_ref = bq_client.get_table(f"{dataset_id}.{table_id}")
if table_ref.table_type == "VIEW":
raise TypeError("Table type VIEW not supported")
Expand Down Expand Up @@ -209,6 +225,7 @@ def make_create_read_session_request():
project_id=project_id,
read_kwargs=read_kwargs,
arrow_options=arrow_options,
credentials=credentials,
),
label=label,
)
Expand Down
29 changes: 29 additions & 0 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,35 @@ def test_read_columns(df, table, client):
assert list(ddf.columns) == columns


@pytest.mark.parametrize("dataset_fixture", ["write_dataset", "write_existing_dataset"])
def test_read_gbq_credentials(df, dataset_fixture, request, monkeypatch):
dataset = request.getfixturevalue(dataset_fixture)
credentials, project_id, dataset_id, table_id = dataset
ddf = dd.from_pandas(df, npartitions=2)

monkeypatch.delenv("GOOGLE_DEFAULT_CREDENTIALS", raising=False)
# with explicit credentials
result = to_gbq(
ddf,
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id or "table_to_write",
credentials=credentials,
)
assert result.state == "DONE"

# with explicit credentials
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id or "table_to_write",
credentials=credentials,
)

assert list(ddf.columns) == ["name", "number", "timestamp", "idx"]
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


def test_max_streams(df, table, client):
project_id, dataset_id, table_id = table
ddf = read_gbq(
Expand Down

0 comments on commit e28e815

Please sign in to comment.