Skip to content

Commit

Permalink
Add to_bigquery() function to BigQueryRetrievalJob (#1634)
Browse files Browse the repository at this point in the history
* Add to_bigquery() function for bq retrieval job

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* Using tenacity for retries

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* Refactoring to_biquery function

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* Adding tenacity dependency and changing temp table prefix to historical

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* Use self.client instead of creating a new client

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* pin tenacity to major version

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>

* Tenacity dependency range

Signed-off-by: Vivian Tao <vivian.tao@shopify.com>
  • Loading branch information
vtao2 authored and Tsotne Tabidze committed Jun 17, 2021
1 parent f682e7e commit d1c9e2b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
34 changes: 31 additions & 3 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import time
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from typing import List, Optional, Set, Union

import pandas
import pyarrow
from jinja2 import BaseLoader, Environment
from tenacity import retry, stop_after_delay, wait_fixed

from feast import errors
from feast.data_source import BigQuerySource, DataSource
Expand Down Expand Up @@ -118,7 +120,7 @@ def get_historical_features(
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
)

job = BigQueryRetrievalJob(query=query, client=client)
job = BigQueryRetrievalJob(query=query, client=client, config=config)
return job


Expand Down Expand Up @@ -206,15 +208,41 @@ def _infer_event_timestamp_from_dataframe(entity_df: pandas.DataFrame) -> str:


class BigQueryRetrievalJob(RetrievalJob):
def __init__(self, query, client):
def __init__(self, query, client, config):
self.query = query
self.client = client
self.config = config

def to_df(self):
# TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df()
df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True)
return df

def to_bigquery(self, dry_run=False) -> Optional[str]:
@retry(wait=wait_fixed(10), stop=stop_after_delay(1800), reraise=True)
def _block_until_done():
return self.client.get_job(bq_job.job_id).state in ["PENDING", "RUNNING"]

today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path, dry_run=dry_run)
bq_job = self.client.query(self.query, job_config=job_config)

if dry_run:
print(
"This query will process {} bytes.".format(bq_job.total_bytes_processed)
)
return None

_block_until_done()

if bq_job.exception():
raise bq_job.exception()

print(f"Done writing to '{path}'.")
return path


@dataclass(frozen=True)
class FeatureViewQueryContext:
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"pydantic>=1.0.0",
"PyYAML==5.3.*",
"tabulate==0.8.*",
"tenacity>=7.*",
"toml==0.10.*",
"tqdm==4.*",
]
Expand Down Expand Up @@ -88,7 +89,6 @@
"pytest-mock==1.10.4",
"Sphinx!=4.0.0",
"sphinx-rtd-theme",
"tenacity",
"adlfs==0.5.9",
"firebase-admin==4.5.2",
"pre-commit",
Expand Down
16 changes: 16 additions & 0 deletions sdk/python/tests/test_historical_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,22 @@ def test_historical_features_from_bigquery_sources(
],
)

# Just a dry run, should not create table
bq_dry_run = job_from_sql.to_bigquery(dry_run=True)
assert bq_dry_run is None

bq_temp_table_path = job_from_sql.to_bigquery()
assert bq_temp_table_path.split(".")[0] == gcp_project

if provider_type == "gcp_custom_offline_config":
assert bq_temp_table_path.split(".")[1] == "foo"
else:
assert bq_temp_table_path.split(".")[1] == bigquery_dataset

# Check that this table actually exists
actual_bq_temp_table = bigquery.Client().get_table(bq_temp_table_path)
assert actual_bq_temp_table.table_id == bq_temp_table_path.split(".")[-1]

start_time = datetime.utcnow()
actual_df_from_sql_entities = job_from_sql.to_df()
end_time = datetime.utcnow()
Expand Down

0 comments on commit d1c9e2b

Please sign in to comment.