Skip to content

Commit

Permalink
DbApiHook: Support kwargs in get_pandas_df (apache#9730)
Browse files Browse the repository at this point in the history
* DbApiHook: Support kwargs in get_pandas_df
* BigQueryHook: Support kwargs in get_pandas_df
* PrestoHook: Support kwargs in get_pandas_df
* HiveServer2Hook: Support kwargs in get_pandas_df

(cherry picked from commit 8f8db89)
  • Loading branch information
22quinn authored and Chris Fei committed Mar 5, 2021
1 parent 6254fab commit 30f43ba
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
7 changes: 5 additions & 2 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, **kwar
"""
raise NotImplementedError()

def get_pandas_df(self, sql, parameters=None, dialect=None):
def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
"""
Returns a Pandas DataFrame for the results produced by a BigQuery
query. The DbApiHook method must be overridden because Pandas
Expand All @@ -110,6 +110,8 @@ def get_pandas_df(self, sql, parameters=None, dialect=None):
:param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
defaults to use `self.use_legacy_sql` if not specified
:type dialect: str in {'legacy', 'standard'}
:param kwargs: (optional) passed into pandas_gbq.read_gbq method
:type kwargs: dict
"""
private_key = self._get_field('key_path', None) or self._get_field('keyfile_dict', None)

Expand All @@ -120,7 +122,8 @@ def get_pandas_df(self, sql, parameters=None, dialect=None):
project_id=self._get_field('project'),
dialect=dialect,
verbose=False,
private_key=private_key)
private_key=private_key,
**kwargs)

def table_exists(self, project_id, dataset_id, table_id):
"""
Expand Down
8 changes: 5 additions & 3 deletions airflow/hooks/dbapi_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,24 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)

def get_pandas_df(self, sql, parameters=None):
def get_pandas_df(self, sql, parameters=None, **kwargs):
"""
Executes the sql and returns a pandas dataframe
:param sql: the sql statement to be executed (str) or a list of
sql statements to execute
:type sql: str or list
:param parameters: The parameters to render the SQL query with.
:type parameters: mapping or iterable
:type parameters: dict or iterable
:param kwargs: (optional) passed into pandas.io.sql.read_sql method
:type kwargs: dict
"""
if sys.version_info[0] < 3:
sql = sql.encode('utf-8')
import pandas.io.sql as psql

with closing(self.get_conn()) as conn:
return psql.read_sql(sql, con=conn, params=parameters)
return psql.read_sql(sql, con=conn, params=parameters, **kwargs)

def get_records(self, sql, parameters=None):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/hooks/hive_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,14 +983,16 @@ def get_records(self, hql, schema='default'):
"""
return self.get_results(hql, schema=schema)['data']

def get_pandas_df(self, hql, schema='default'):
def get_pandas_df(self, hql, schema='default', **kwargs):
"""
Get a pandas dataframe from a Hive query
:param hql: hql to be executed.
:type hql: str or list
:param schema: target schema, default to 'default'.
:type schema: str
:param kwargs: (optional) passed into pandas.DataFrame constructor
:type kwargs: dict
:return: result of hql execution
:rtype: DataFrame
Expand All @@ -1004,6 +1006,6 @@ def get_pandas_df(self, hql, schema='default'):
"""
import pandas as pd
res = self.get_results(hql, schema=schema)
df = pd.DataFrame(res['data'])
df = pd.DataFrame(res['data'], **kwargs)
df.columns = [c[0] for c in res['header']]
return df
6 changes: 3 additions & 3 deletions airflow/hooks/presto_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_first(self, hql, parameters=None):
except DatabaseError as e:
raise PrestoException(self._get_pretty_exception_message(e))

def get_pandas_df(self, hql, parameters=None):
def get_pandas_df(self, hql, parameters=None, **kwargs):
"""
Get a pandas dataframe from a sql query.
"""
Expand All @@ -118,10 +118,10 @@ def get_pandas_df(self, hql, parameters=None):
raise PrestoException(self._get_pretty_exception_message(e))
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data)
df = pandas.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame()
df = pandas.DataFrame(**kwargs)
return df

def run(self, hql, parameters=None):
Expand Down

0 comments on commit 30f43ba

Please sign in to comment.