Skip to content

Commit

Permalink
Merge pull request #1281 from CartoDB/alasarr/bq_storage_api
Browse files Browse the repository at this point in the history
BQ Storage API for BQ download
  • Loading branch information
oleurud committed Dec 12, 2019
2 parents 303f17f + b516d66 commit dda81c1
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 181 deletions.
139 changes: 92 additions & 47 deletions cartoframes/data/clients/bigquery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import appdirs
import csv
import tqdm
import pandas as pd

from google.auth.exceptions import RefreshError
from google.cloud import bigquery, storage
from google.cloud import bigquery, storage, bigquery_storage_v1beta1 as bigquery_storage
from google.oauth2.credentials import Credentials as GoogleCredentials

from carto.exceptions import CartoException
from ...auth import get_default_credentials
from ...core.logger import log
from ...utils.utils import timelogger


_USER_CONFIG_DIR = appdirs.user_config_dir('cartoframes')
_GCS_CHUNK_SIZE = 25 * 1024 * 1024 # 25MB. This must be a multiple of 256 KB per the API specification.
Expand All @@ -36,7 +40,9 @@ def __init__(self, credentials):
self._credentials = credentials or get_default_credentials()
self.bq_client = None
self.gcs_client = None
self.bq_storage_client = None

self._gcp_execution_project = None
self.bq_public_project = None
self.bq_project = None
self.bq_dataset = None
Expand All @@ -51,22 +57,86 @@ def _init_clients(self):

self.bq_client = bigquery.Client(
project=do_credentials.gcp_execution_project,
credentials=google_credentials)
credentials=google_credentials
)

self.gcs_client = storage.Client(
project=do_credentials.bq_project,
credentials=google_credentials
)

self.bq_storage_client = bigquery_storage.BigQueryStorageClient(
credentials=google_credentials
)

self._gcp_execution_project = do_credentials.gcp_execution_project
self.bq_public_project = do_credentials.bq_public_project
self.bq_project = do_credentials.bq_project
self.bq_dataset = do_credentials.bq_dataset
self.instant_licensing = do_credentials.instant_licensing
self._gcs_bucket = do_credentials.gcs_bucket

@refresh_clients
def query(self, query, **kwargs):
return self.bq_client.query(query, **kwargs)

def upload_dataframe(self, dataframe, schema, tablename):
# Upload file to Google Cloud Storage
self._upload_dataframe_to_GCS(dataframe, tablename)
self._import_from_GCS_to_BQ(schema, tablename)

@timelogger
def download_to_file(self, job, file_path=None, fail_if_exists=False, column_names=None, progress_bar=True):
if not file_path:
file_name = '{}.csv'.format(job.job_id)
file_path = os.path.join(_USER_CONFIG_DIR, file_name)

if fail_if_exists and os.path.isfile(file_path):
raise CartoException('The file `{}` already exists.'.format(file_path))

try:
rows = self._download_by_bq_storage_api(job)
except Exception:
log.debug('Cannot download using BigQuery Storage API, fallback to standard')
rows = job.result()

_rows_to_file(rows, file_path, column_names, progress_bar)

return file_path

@timelogger
def download_to_dataframe(self, job):
try:
rows = self._download_by_bq_storage_api(job)
data = list(rows)
return pd.DataFrame(data)
except Exception:
log.debug('Cannot download using BigQuery Storage API, fallback to standard')
return job.to_dataframe()

def _download_by_bq_storage_api(self, job):
table_ref = job.destination.to_bqstorage()

parent = 'projects/{}'.format(self._gcp_execution_project)
session = self.bq_storage_client.create_read_session(
table_ref,
parent,
requested_streams=1,
format_=bigquery_storage.enums.DataFormat.AVRO,
# We use a LIQUID strategy because we only read from a
# single stream. Consider BALANCED if requested_streams > 1
sharding_strategy=(bigquery_storage.enums.ShardingStrategy.LIQUID)
)

reader = self.bq_storage_client.read_rows(
bigquery_storage.types.StreamPosition(stream=session.streams[0])
)

return reader.rows(session)

@refresh_clients
@timelogger
def _upload_dataframe_to_GCS(self, dataframe, tablename):
log.debug('Uploading to GCS')
bucket = self.gcs_client.get_bucket(self._gcs_bucket)
blob = bucket.blob(tablename, chunk_size=_GCS_CHUNK_SIZE)
dataframe.to_csv(tablename, index=False, header=False)
Expand All @@ -75,7 +145,11 @@ def upload_dataframe(self, dataframe, schema, tablename):
finally:
os.remove(tablename)

# Import from GCS To BigQuery
@refresh_clients
@timelogger
def _import_from_GCS_to_BQ(self, schema, tablename):
log.debug('Importing to BQ from GCS')

dataset_ref = self.bq_client.dataset(self.bq_dataset, project=self.bq_project)
table_ref = dataset_ref.table(tablename)
schema_wrapped = [bigquery.SchemaField(column, dtype) for column, dtype in schema.items()]
Expand All @@ -91,56 +165,27 @@ def upload_dataframe(self, dataframe, schema, tablename):

job.result() # Waits for table load to complete.

@refresh_clients
def query(self, query, **kwargs):
return self.bq_client.query(query, **kwargs)
def get_table_column_names(self, project, dataset, table):
table_info = self._get_table(project, dataset, table)
return [field.name for field in table_info.schema]

@refresh_clients
def get_table(self, project, dataset, table):
def _get_table(self, project, dataset, table):
full_table_name = '{}.{}.{}'.format(project, dataset, table)
return self.bq_client.get_table(full_table_name)

def get_table_column_names(self, project, dataset, table):
table_info = self.get_table(project, dataset, table)
return [field.name for field in table_info.schema]

def download_to_file(self, project, dataset, table, file_path=None, limit=None, offset=None,
fail_if_exists=False, progress_bar=True):
if not file_path:
file_name = '{}.{}.{}.csv'.format(project, dataset, table)
file_path = os.path.join(_USER_CONFIG_DIR, file_name)
def _rows_to_file(rows, file_path, column_names=None, progress_bar=True):
if progress_bar:
pb = tqdm.tqdm_notebook(total=rows.total_rows)

if fail_if_exists and os.path.isfile(file_path):
raise CartoException('The file `{}` already exists.'.format(file_path))

column_names = self.get_table_column_names(project, dataset, table)

query = _download_query(project, dataset, table, limit, offset)
rows_iter = self.query(query).result()

if progress_bar:
pb = tqdm.tqdm_notebook(total=rows_iter.total_rows)

with open(file_path, 'w') as csvfile:
csvwriter = csv.writer(csvfile)
with open(file_path, 'w') as csvfile:
csvwriter = csv.writer(csvfile)

if column_names:
csvwriter.writerow(column_names)

for row in rows_iter:
csvwriter.writerow(row.values())
if progress_bar:
pb.update(1)

return file_path


def _download_query(project, dataset, table, limit=None, offset=None):
full_table_name = '`{}.{}.{}`'.format(project, dataset, table)
query = 'SELECT * FROM {}'.format(full_table_name)

if limit:
query += ' LIMIT {}'.format(limit)
if offset:
query += ' OFFSET {}'.format(offset)

return query
for row in rows:
csvwriter.writerow(row.values())
if progress_bar:
pb.update(1)
6 changes: 5 additions & 1 deletion cartoframes/data/observatory/catalog/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def _download(self, credentials=None, file_path=None):
project, dataset, table = full_remote_table_name.split('.')

try:
file_path = bq_client.download_to_file(project, dataset, table, file_path)
column_names = bq_client.get_table_column_names(project, dataset, table)
query = 'SELECT * FROM `{}`'.format(full_remote_table_name)
job = bq_client.query(query)

file_path = bq_client.download_to_file(job, column_names=column_names)
except NotFound:
raise CartoException('You have not purchased the dataset `{}` yet'.format(self.id))

Expand Down
1 change: 1 addition & 0 deletions cartoframes/data/observatory/catalog/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def subscription_info(self, credentials=None):
subscription_info.fetch_subscription_info(self.id, GEOGRAPHY_TYPE, _credentials))

def _is_subscribed(self, credentials=None):

if self.is_public_data:
return True

Expand Down
9 changes: 5 additions & 4 deletions cartoframes/data/observatory/enrichment/enrichment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .enrichment_service import EnrichmentService, prepare_variables, AGGREGATION_DEFAULT, AGGREGATION_NONE
from ....core.logger import log
from ....utils.utils import timelogger


class Enrichment(EnrichmentService):
Expand Down Expand Up @@ -128,6 +128,7 @@ def enrich_points(self, dataframe, variables, geom_col=None, filters=[]):
AGGREGATION_NONE = AGGREGATION_NONE
"""Do not aggregate data in polygons enrichment. More info in :py:attr:`Enrichment.enrich_polygons`"""

@timelogger
def enrich_polygons(self, dataframe, variables, geom_col=None, filters=[], aggregation=AGGREGATION_DEFAULT):
"""Enrich your polygons `DataFrame` with columns (:obj:`Variable`) from one or more :obj:`Dataset` in
the Data Observatory by intersecting the polygons in the source `DataFrame` with geographies in the
Expand Down Expand Up @@ -315,12 +316,12 @@ def enrich_polygons(self, dataframe, variables, geom_col=None, filters=[], aggre
enrichment = Enrichment()
cdf_enrich = enrichment.enrich_polygons(df, variables, aggregation=aggregation)
"""
log.debug('Preparing')
variables = prepare_variables(variables, self.credentials, aggregation)
cartodataframe = self._prepare_data(dataframe, geom_col)

cartodataframe = self._prepare_data(dataframe, geom_col)
temp_table_name = self._get_temp_table_name()
log.debug('Uploading')

self._upload_data(temp_table_name, cartodataframe)

queries = self._get_polygon_enrichment_sql(temp_table_name, variables, filters, aggregation)
return self._execute_enrichment(queries, cartodataframe)
34 changes: 29 additions & 5 deletions cartoframes/data/observatory/enrichment/enrichment_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
import time

from collections import defaultdict

Expand All @@ -11,6 +12,7 @@
from ....core.cartodataframe import CartoDataFrame
from ....core.logger import log
from ....utils.geom_utils import to_geojson
from ....utils.utils import timelogger


_ENRICHMENT_ID = 'enrichment_id'
Expand Down Expand Up @@ -71,12 +73,31 @@ def __init__(self, credentials=None):
self.enrichment_id = _ENRICHMENT_ID
self.geojson_column = _GEOJSON_COLUMN

@timelogger
def _execute_enrichment(self, queries, cartodataframe):

dfs_enriched = list()
awaiting_jobs = set()
errors = list()

def callback(job):
if not job.errors:
dfs_enriched.append(self.bq_client.download_to_dataframe(job))
else:
errors.extend(job.errors)

awaiting_jobs.discard(job.job_id)

for query in queries:
df_enriched = self.bq_client.query(query).to_dataframe()
dfs_enriched.append(df_enriched)
job = self.bq_client.query(query)
awaiting_jobs.add(job.job_id)
job.add_done_callback(callback)

while awaiting_jobs:
time.sleep(0.5)

if len(errors) > 0:
raise Exception(errors)

for df in dfs_enriched:
cartodataframe = cartodataframe.merge(df, on=self.enrichment_id, how='left')
Expand All @@ -87,6 +108,7 @@ def _execute_enrichment(self, queries, cartodataframe):

return cartodataframe

@timelogger
def _prepare_data(self, dataframe, geom_col):
cartodataframe = CartoDataFrame(dataframe, copy=True)

Expand Down Expand Up @@ -156,9 +178,10 @@ def __get_dataset(self, variable, table_name):

def __get_geo_table(self, variable):
geography_id = Dataset.get(variable.dataset).geography
geography = Geography.get(geography_id)
_, dataset_geo_table, geo_table = geography_id.split('.')

if variable.project_name != self.bq_public_project:
if not geography.is_public_data:
return '{project}.{dataset}.view_{dataset_geo_table}_{geo_table}'.format(
project=self.bq_project,
dataset=self.bq_dataset,
Expand Down Expand Up @@ -308,6 +331,7 @@ def _build_where_clausule(self, filters):
return where


@timelogger
def prepare_variables(variables, credentials, aggregation=None):
if isinstance(variables, list):
variables = [_prepare_variable(var, aggregation) for var in variables]
Expand Down Expand Up @@ -367,12 +391,12 @@ def _is_subscribed(dataset, geography, credentials):
if not dataset._is_subscribed(credentials):
raise EnrichmentException("""
You are not subscribed to the Dataset '{}' yet. Please, use the subscribe method first.
""".format(dataset))
""".format(dataset.id))

if not geography._is_subscribed(credentials):
raise EnrichmentException("""
You are not subscribed to the Geography '{}' yet. Please, use the subscribe method first.
""".format(geography))
""".format(geography.id))


def _get_aggregation(variable, aggregation):
Expand Down
12 changes: 12 additions & 0 deletions cartoframes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import requests
import geopandas
import numpy as np
import time

from functools import wraps
from warnings import catch_warnings, filterwarnings

from ..auth.credentials import Credentials
from ..core.logger import log

try:
basestring
Expand Down Expand Up @@ -391,3 +393,13 @@ def replacer(match):
re.DOTALL | re.MULTILINE
)
return re.sub(pattern, replacer, text).strip()


def timelogger(method):
def fn(*args, **kw):
start = time.time()
result = method(*args, **kw)
log.debug('%s in %s s', method.__name__, round(time.time() - start, 2))
return result

return fn
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_version():
'pyarrow>=0.14.1,<1.0',
'google-cloud-storage>=1.23.0,<2.0',
'google-cloud-bigquery>=1.22.0,<2.0',
'google-cloud-bigquery-storage>=0.7.0,<1.0',
'fastavro>=0.22.7,<1.0',
'mercantile>=1.1.2,<2.0'
# 'Rtree>=0.8.3,<1.0'
]
Expand Down

0 comments on commit dda81c1

Please sign in to comment.