From 3d8b7dba44b55e89c168bd6fe5aa8ee1b22c3650 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Mon, 18 Sep 2017 14:29:50 +0900 Subject: [PATCH] Redesign class hierarchy --- pyathena/async_cursor.py | 37 ++++--- pyathena/common.py | 93 ++++++++++++---- pyathena/cursor.py | 89 ++++++--------- pyathena/model.py | 223 +++++++------------------------------- pyathena/result_set.py | 229 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 399 insertions(+), 272 deletions(-) create mode 100644 pyathena/result_set.py diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 5a9b9ac8..ec935d8c 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -6,8 +6,10 @@ from concurrent.futures.thread import ThreadPoolExecutor +from pyathena.common import CursorIterator from pyathena.cursor import BaseCursor -from pyathena.model import AthenaResultSet +from pyathena.error import ProgrammingError, NotSupportedError +from pyathena.result_set import AthenaResultSet _logger = logging.getLogger(__name__) @@ -19,30 +21,32 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval, encryption_option, kms_key, converter, formatter, retry_exceptions, retry_attempt, retry_multiplier, retry_max_delay, retry_exponential_base, - max_workers=(os.cpu_count() or 1) * 5): + max_workers=(os.cpu_count() or 1) * 5, + arraysize=CursorIterator.DEFAULT_FETCH_SIZE): super(AsyncCursor, self).__init__(client, s3_staging_dir, schema_name, poll_interval, encryption_option, kms_key, converter, formatter, retry_exceptions, retry_attempt, retry_multiplier, retry_max_delay, retry_exponential_base) self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._arraysize = arraysize + + @property + def arraysize(self): + return self._arraysize + + @arraysize.setter + def arraysize(self, value): + if value <= 0 or value > CursorIterator.DEFAULT_FETCH_SIZE: + raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format( + CursorIterator.DEFAULT_FETCH_SIZE)) + self._arraysize = value def close(self, wait=False): self._executor.shutdown(wait=wait) def _description(self, query_id): result_set = self._collect_result_set(query_id) - return [ - ( - m.get('Name', None), - m.get('Type', None), - None, - None, - m.get('Precision', None), - m.get('Scale', None), - m.get('Nullable', None) - ) - for m in result_set.meta_data - ] + return result_set.description def description(self, query_id): return self._executor.submit(self._description, query_id) @@ -56,7 +60,7 @@ def poll(self, query_id): def _collect_result_set(self, query_id): query_execution = self._poll(query_id) return AthenaResultSet( - self._connection, self._converter, query_execution, self.arraysize, + self._connection, self._converter, query_execution, self._arraysize, self.retry_exceptions, self.retry_attempt, self.retry_multiplier, self.retry_max_delay, self.retry_exponential_base) @@ -64,5 +68,8 @@ def execute(self, operation, parameters=None): query_id = self._execute(operation, parameters) return query_id, self._executor.submit(self._collect_result_set, query_id) + def executemany(self, operation, seq_of_parameters): + raise NotSupportedError + def cancel(self, query_id): return self._executor.submit(self._cancel, query_id) diff --git a/pyathena/common.py b/pyathena/common.py index d708e8c1..deaba8e1 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -3,12 +3,11 @@ from __future__ import unicode_literals import logging import time -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from future.utils import raise_from, with_metaclass -from pyathena.error import (DatabaseError, OperationalError, - ProgrammingError, NotSupportedError) +from pyathena.error import DatabaseError, OperationalError, ProgrammingError from pyathena.model import AthenaQueryExecution from pyathena.util import retry_api_call @@ -16,14 +15,66 @@ _logger = logging.getLogger(__name__) -class BaseCursor(with_metaclass(ABCMeta, object)): +class CursorIterator(with_metaclass(ABCMeta, object)): DEFAULT_FETCH_SIZE = 1000 + def __init__(self, arraysize=None): + self.arraysize = arraysize if arraysize else self.DEFAULT_FETCH_SIZE + self._rownumber = None + + @property + def arraysize(self): + return self._arraysize + + @arraysize.setter + def arraysize(self, value): + if value <= 0 or value > self.DEFAULT_FETCH_SIZE: + raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format( + self.DEFAULT_FETCH_SIZE)) + self._arraysize = value + + @property + def rownumber(self): + return self._rownumber + + @property + def rowcount(self): + """By default, return -1 to indicate that this is not supported.""" + return -1 + + @abstractmethod + def fetchone(self): + pass + + @abstractmethod + def fetchmany(self): + pass + + @abstractmethod + def fetchall(self): + pass + + def __next__(self): + row = self.fetchone() + if row is None: + raise StopIteration + else: + return row + + next = __next__ + + def __iter__(self): + return self + + +class BaseCursor(with_metaclass(ABCMeta, object)): + def __init__(self, client, s3_staging_dir, schema_name, poll_interval, encryption_option, kms_key, converter, formatter, retry_exceptions, retry_attempt, retry_multiplier, retry_max_delay, retry_exponential_base): + super(BaseCursor, self).__init__() self._connection = client self._s3_staging_dir = s3_staging_dir self._schema_name = schema_name @@ -44,23 +95,10 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval, self.retry_max_delay = retry_max_delay self.retry_exponential_base = retry_exponential_base - self._arraysize = self.DEFAULT_FETCH_SIZE - @property def connection(self): return self._connection - @property - def arraysize(self): - return self._arraysize - - @arraysize.setter - def arraysize(self, value): - if value <= 0 or value > self.DEFAULT_FETCH_SIZE: - raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format( - self.DEFAULT_FETCH_SIZE)) - self._arraysize = value - def _query_execution(self, query_id): request = {'QueryExecutionId': query_id} try: @@ -81,7 +119,9 @@ def _query_execution(self, query_id): def _poll(self, query_id): while True: query_execution = self._query_execution(query_id) - if query_execution.state in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if query_execution.state in [AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED]: return query_execution else: time.sleep(self._poll_interval) @@ -129,8 +169,17 @@ def _execute(self, operation, parameters=None): else: return response.get('QueryExecutionId', None) + @abstractmethod + def execute(self, operation, parameters=None): + pass + + @abstractmethod def executemany(self, operation, seq_of_parameters): - raise NotSupportedError + pass + + @abstractmethod + def close(self): + pass def _cancel(self, query_id): request = {'QueryExecutionId': query_id} @@ -154,3 +203,9 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): """Does nothing by default""" pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/pyathena/cursor.py b/pyathena/cursor.py index c94acecb..9bffade5 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -3,16 +3,17 @@ from __future__ import unicode_literals import logging -from pyathena.common import BaseCursor -from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaResultSet +from pyathena.common import BaseCursor, CursorIterator +from pyathena.error import OperationalError, ProgrammingError, NotSupportedError +from pyathena.model import AthenaQueryExecution +from pyathena.result_set import AthenaResultSet from pyathena.util import synchronized _logger = logging.getLogger(__name__) -class Cursor(BaseCursor): +class Cursor(BaseCursor, CursorIterator): def __init__(self, client, s3_staging_dir, schema_name, poll_interval, encryption_option, kms_key, converter, formatter, @@ -22,77 +23,74 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval, encryption_option, kms_key, converter, formatter, retry_exceptions, retry_attempt, retry_multiplier, retry_max_delay, retry_exponential_base) - self._description = None self._query_id = None - self._meta_data = None self._result_set = None @property def rownumber(self): return self._result_set.rownumber if self._result_set else None - @property - def rowcount(self): - """By default, return -1 to indicate that this is not supported.""" - return -1 - @property def has_result_set(self): - return self._result_set and self._meta_data is not None + return self._result_set is not None @property def description(self): - if self._description or self._description == []: - return self._description if not self.has_result_set: return None - self._description = [ - ( - m.get('Name', None), - m.get('Type', None), - None, - None, - m.get('Precision', None), - m.get('Scale', None), - m.get('Nullable', None) - ) - for m in self._meta_data - ] - return self._description + return self._result_set.description @property def query_id(self): return self._query_id @property - def output_location(self): + def query(self): + if not self.has_result_set: + return None + return self._result_set.query + + @property + def state(self): + if not self.has_result_set: + return None + return self._result_set.state + + @property + def state_change_reason(self): if not self.has_result_set: return None - return self._result_set.query_execution.output_location + return self._result_set.state_change_reason @property def completion_date_time(self): if not self.has_result_set: return None - return self._result_set.query_execution.completion_date_time + return self._result_set.completion_date_time @property def submission_date_time(self): if not self.has_result_set: return None - return self._result_set.query_execution.submission_date_time + return self._result_set.submission_date_time @property def data_scanned_in_bytes(self): if not self.has_result_set: return None - return self._result_set.query_execution.data_scanned_in_bytes + return self._result_set.data_scanned_in_bytes @property def execution_time_in_millis(self): if not self.has_result_set: return None - return self._result_set.query_execution.execution_time_in_millis + return self._result_set.execution_time_in_millis + + @property + def output_location(self): + if not self.has_result_set: + return None + return self._result_set.output_location def close(self): pass @@ -101,22 +99,23 @@ def _reset_state(self): self._description = None self._query_id = None self._result_set = None - self._meta_data = None @synchronized def execute(self, operation, parameters=None): self._reset_state() self._query_id = self._execute(operation, parameters) query_execution = self._poll(self._query_id) - if query_execution.state == 'SUCCEEDED': + if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: self._result_set = AthenaResultSet( self._connection, self._converter, query_execution, self.arraysize, self.retry_exceptions, self.retry_attempt, self.retry_multiplier, self.retry_max_delay, self.retry_exponential_base) - self._meta_data = self._result_set.meta_data else: raise OperationalError(query_execution.state_change_reason) + def executemany(self, operation, seq_of_parameters): + raise NotSupportedError + @synchronized def cancel(self): if not self._query_id: @@ -140,21 +139,3 @@ def fetchall(self): if not self.has_result_set: raise ProgrammingError('No result set.') return self._result_set.fetchall() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __next__(self): - row = self.fetchone() - if row is None: - raise StopIteration - else: - return row - - next = __next__ - - def __iter__(self): - return self diff --git a/pyathena/model.py b/pyathena/model.py index aaa764d0..6d70f5df 100644 --- a/pyathena/model.py +++ b/pyathena/model.py @@ -1,14 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import unicode_literals -import collections import logging -from future.utils import raise_from -from past.builtins.misc import xrange - -from pyathena.error import DataError, OperationalError, ProgrammingError -from pyathena.util import retry_api_call +from pyathena.error import DataError _logger = logging.getLogger(__name__) @@ -16,210 +11,70 @@ class AthenaQueryExecution(object): + STATE_SUCCEEDED = 'SUCCEEDED' + STATE_FAILED = 'FAILED' + STATE_CANCELLED = 'CANCELLED' + def __init__(self, response): query_execution = response.get('QueryExecution', None) if not query_execution: raise DataError('KeyError `QueryExecution`') - self.query_id = query_execution.get('QueryExecutionId', None) - if not self.query_id: + self._query_id = query_execution.get('QueryExecutionId', None) + if not self._query_id: raise DataError('KeyError `QueryExecutionId`') - self.query = query_execution.get('Query', None) - if not self.query: + self._query = query_execution.get('Query', None) + if not self._query: raise DataError('KeyError `Query`') status = query_execution.get('Status', None) if not status: raise DataError('KeyError `Status`') - self.state = status.get('State', None) - self.state_change_reason = status.get('StateChangeReason', None) - self.completion_date_time = status.get('CompletionDateTime', None) - self.submission_date_time = status.get('SubmissionDateTime', None) + self._state = status.get('State', None) + self._state_change_reason = status.get('StateChangeReason', None) + self._completion_date_time = status.get('CompletionDateTime', None) + self._submission_date_time = status.get('SubmissionDateTime', None) statistics = query_execution.get('Statistics', {}) - self.data_scanned_in_bytes = statistics.get('DataScannedInBytes', None) - self.execution_time_in_millis = statistics.get('EngineExecutionTimeInMillis', None) + self._data_scanned_in_bytes = statistics.get('DataScannedInBytes', None) + self._execution_time_in_millis = statistics.get('EngineExecutionTimeInMillis', None) result_conf = query_execution.get('ResultConfiguration', {}) - self.output_location = result_conf.get('OutputLocation', None) + self._output_location = result_conf.get('OutputLocation', None) + @property + def query_id(self): + return self._query_id -class AthenaResultSet(object): - - def __init__(self, connection, converter, query_execution, arraysize, - retry_exceptions, retry_attempt, retry_multiplier, - retry_max_delay, retry_exponential_base): - self._connection = connection - self._converter = converter - self._query_execution = query_execution - assert self._query_execution, 'Required argument `query_execution` not found.' - self._arraysize = arraysize + @property + def query(self): + return self._query - self.retry_exceptions = retry_exceptions - self.retry_attempt = retry_attempt - self.retry_multiplier = retry_multiplier - self.retry_max_delay = retry_max_delay - self.retry_exponential_base = retry_exponential_base + @property + def state(self): + return self._state - self._meta_data = None - self._rows = collections.deque() - self._next_token = None - self._rownumber = 0 + @property + def state_change_reason(self): + return self._state_change_reason - if self._query_execution.state == 'SUCCEEDED': - self._pre_fetch() + @property + def completion_date_time(self): + return self._completion_date_time @property - def meta_data(self): - return self._meta_data + def submission_date_time(self): + return self._submission_date_time @property - def query_execution(self): - return self._query_execution + def data_scanned_in_bytes(self): + return self._data_scanned_in_bytes @property - def rownumber(self): - return self._rownumber - - def __fetch(self, next_token=None): - if self._query_execution.state != 'SUCCEEDED': - raise ProgrammingError('QueryExecutionState is not SUCCEEDED.') - if not self._query_execution.query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') - request = { - 'QueryExecutionId': self._query_execution.query_id, - 'MaxResults': self._arraysize, - } - if next_token: - request.update({'NextToken': next_token}) - try: - response = retry_api_call(self._connection.get_query_results, - exceptions=self.retry_exceptions, - attempt=self.retry_attempt, - multiplier=self.retry_multiplier, - max_delay=self.retry_max_delay, - exp_base=self.retry_exponential_base, - logger=_logger, - **request) - except Exception as e: - _logger.exception('Failed to fetch result set.') - raise_from(OperationalError(*e.args), e) - else: - return response - - def _fetch(self): - if not self._next_token: - raise ProgrammingError('NextToken is none or empty.') - response = self.__fetch(self._next_token) - self._process_rows(response) - - def _pre_fetch(self): - response = self.__fetch(None) - self._process_meta_data(response) - self._process_rows(response) - - def fetchone(self): - if not self._query_execution.query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') - if not self._rows and self._next_token: - self._fetch() - if not self._rows: - return None - else: - self._rownumber += 1 - return self._rows.popleft() - - def fetchmany(self, size=None): - if not self._query_execution.query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') - if not size or size <= 0: - size = self._arraysize - rows = [] - for _ in xrange(size): - row = self.fetchone() - if row: - rows.append(row) - else: - break - return rows - - def fetchall(self): - if not self._query_execution.query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') - rows = [] - while True: - row = self.fetchone() - if row: - rows.append(row) - else: - break - return rows - - def _process_meta_data(self, response): - result_set = response.get('ResultSet', None) - if not result_set: - raise DataError('KeyError `ResultSet`') - meta_data = result_set.get('ResultSetMetadata', None) - if not meta_data: - raise DataError('KeyError `ResultSetMetadata`') - column_info = meta_data.get('ColumnInfo', None) - if column_info is None: - raise DataError('KeyError `ColumnInfo`') - self._meta_data = tuple(column_info) - - def _process_rows(self, response): - result_set = response.get('ResultSet', None) - if not result_set: - raise DataError('KeyError `ResultSet`') - rows = result_set.get('Rows', None) - if rows is None: - raise DataError('KeyError `Rows`') - processed_rows = [] - if len(rows) > 0: - offset = 1 if not self._next_token and self._is_first_row_column_labels(rows) else 0 - processed_rows = [ - tuple([self._converter.convert(meta.get('Type', None), - row.get('VarCharValue', None)) - for meta, row in zip(self._meta_data, rows[i].get('Data', []))]) - for i in xrange(offset, len(rows)) - ] - self._rows.extend(processed_rows) - self._next_token = response.get('NextToken', None) - - def _is_first_row_column_labels(self, rows): - first_row_data = rows[0].get('Data', []) - for meta, data in zip(self._meta_data, first_row_data): - if meta.get('Name', None) != data.get('VarCharValue', None): - return False - return True + def execution_time_in_millis(self): + return self._execution_time_in_millis @property - def is_closed(self): - return self._connection is None - - def close(self): - self._connection = None - self._query_execution = None - self._meta_data = None - self._rows = None - self._next_token = None - self._rownumber = 0 - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __next__(self): - row = self.fetchone() - if row is None: - raise StopIteration - else: - return row - - next = __next__ - - def __iter__(self): - return self + def output_location(self): + return self._output_location diff --git a/pyathena/result_set.py b/pyathena/result_set.py new file mode 100644 index 00000000..26975cf9 --- /dev/null +++ b/pyathena/result_set.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +from __future__ import unicode_literals +import collections +import logging + +from future.utils import raise_from +from past.builtins.misc import xrange + +from pyathena.common import CursorIterator +from pyathena.error import DataError, OperationalError, ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.util import retry_api_call + + +_logger = logging.getLogger(__name__) + + +class AthenaResultSet(CursorIterator): + + def __init__(self, connection, converter, query_execution, arraysize, + retry_exceptions, retry_attempt, retry_multiplier, + retry_max_delay, retry_exponential_base): + super(AthenaResultSet, self).__init__(arraysize) + self._connection = connection + self._converter = converter + self._query_execution = query_execution + assert self._query_execution, 'Required argument `query_execution` not found.' + + self.retry_exceptions = retry_exceptions + self.retry_attempt = retry_attempt + self.retry_multiplier = retry_multiplier + self.retry_max_delay = retry_max_delay + self.retry_exponential_base = retry_exponential_base + + self._meta_data = None + self._rows = collections.deque() + self._next_token = None + + if self._query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: + self._rownumber = 0 + self._pre_fetch() + + @property + def meta_data(self): + return self._meta_data + + @property + def query_id(self): + return self._query_execution.query_id + + @property + def query(self): + return self._query_execution.query + + @property + def state(self): + return self._query_execution.state + + @property + def state_change_reason(self): + return self._query_execution.state_change_reason + + @property + def completion_date_time(self): + return self._query_execution.completion_date_time + + @property + def submission_date_time(self): + return self._query_execution.submission_date_time + + @property + def data_scanned_in_bytes(self): + return self._query_execution.data_scanned_in_bytes + + @property + def execution_time_in_millis(self): + return self._query_execution.execution_time_in_millis + + @property + def output_location(self): + return self._query_execution.output_location + + @property + def description(self): + if self._meta_data is None: + return None + return [ + ( + m.get('Name', None), + m.get('Type', None), + None, + None, + m.get('Precision', None), + m.get('Scale', None), + m.get('Nullable', None) + ) + for m in self._meta_data + ] + + def __fetch(self, next_token=None): + if self._query_execution.state != 'SUCCEEDED': + raise ProgrammingError('QueryExecutionState is not SUCCEEDED.') + if not self._query_execution.query_id: + raise ProgrammingError('QueryExecutionId is none or empty.') + request = { + 'QueryExecutionId': self._query_execution.query_id, + 'MaxResults': self._arraysize, + } + if next_token: + request.update({'NextToken': next_token}) + try: + response = retry_api_call(self._connection.get_query_results, + exceptions=self.retry_exceptions, + attempt=self.retry_attempt, + multiplier=self.retry_multiplier, + max_delay=self.retry_max_delay, + exp_base=self.retry_exponential_base, + logger=_logger, + **request) + except Exception as e: + _logger.exception('Failed to fetch result set.') + raise_from(OperationalError(*e.args), e) + else: + return response + + def _fetch(self): + if not self._next_token: + raise ProgrammingError('NextToken is none or empty.') + response = self.__fetch(self._next_token) + self._process_rows(response) + + def _pre_fetch(self): + response = self.__fetch() + self._process_meta_data(response) + self._process_rows(response) + + def fetchone(self): + if not self._query_execution.query_id: + raise ProgrammingError('QueryExecutionId is none or empty.') + if not self._rows and self._next_token: + self._fetch() + if not self._rows: + return None + else: + self._rownumber += 1 + return self._rows.popleft() + + def fetchmany(self, size=None): + if not self._query_execution.query_id: + raise ProgrammingError('QueryExecutionId is none or empty.') + if not size or size <= 0: + size = self._arraysize + rows = [] + for _ in xrange(size): + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def fetchall(self): + if not self._query_execution.query_id: + raise ProgrammingError('QueryExecutionId is none or empty.') + rows = [] + while True: + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def _process_meta_data(self, response): + result_set = response.get('ResultSet', None) + if not result_set: + raise DataError('KeyError `ResultSet`') + meta_data = result_set.get('ResultSetMetadata', None) + if not meta_data: + raise DataError('KeyError `ResultSetMetadata`') + column_info = meta_data.get('ColumnInfo', None) + if column_info is None: + raise DataError('KeyError `ColumnInfo`') + self._meta_data = tuple(column_info) + + def _process_rows(self, response): + result_set = response.get('ResultSet', None) + if not result_set: + raise DataError('KeyError `ResultSet`') + rows = result_set.get('Rows', None) + if rows is None: + raise DataError('KeyError `Rows`') + processed_rows = [] + if len(rows) > 0: + offset = 1 if not self._next_token and self._is_first_row_column_labels(rows) else 0 + processed_rows = [ + tuple([self._converter.convert(meta.get('Type', None), + row.get('VarCharValue', None)) + for meta, row in zip(self._meta_data, rows[i].get('Data', []))]) + for i in xrange(offset, len(rows)) + ] + self._rows.extend(processed_rows) + self._next_token = response.get('NextToken', None) + + def _is_first_row_column_labels(self, rows): + first_row_data = rows[0].get('Data', []) + for meta, data in zip(self._meta_data, first_row_data): + if meta.get('Name', None) != data.get('VarCharValue', None): + return False + return True + + @property + def is_closed(self): + return self._connection is None + + def close(self): + self._connection = None + self._query_execution = None + self._meta_data = None + self._rows = None + self._next_token = None + self._rownumber = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close()