Skip to content

Commit

Permalink
Redesign class hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Sep 18, 2017
1 parent 8bc7ffe commit 45c6142
Show file tree
Hide file tree
Showing 5 changed files with 397 additions and 272 deletions.
37 changes: 22 additions & 15 deletions pyathena/async_cursor.py
Expand Up @@ -6,8 +6,10 @@

from concurrent.futures.thread import ThreadPoolExecutor

from pyathena.common import CursorIterator
from pyathena.cursor import BaseCursor
from pyathena.model import AthenaResultSet
from pyathena.error import ProgrammingError, NotSupportedError
from pyathena.result_set import AthenaResultSet


_logger = logging.getLogger(__name__)
Expand All @@ -19,30 +21,32 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval,
encryption_option, kms_key, converter, formatter,
retry_exceptions, retry_attempt, retry_multiplier,
retry_max_delay, retry_exponential_base,
max_workers=(os.cpu_count() or 1) * 5):
max_workers=(os.cpu_count() or 1) * 5,
arraysize=CursorIterator.DEFAULT_FETCH_SIZE):
super(AsyncCursor, self).__init__(client, s3_staging_dir, schema_name, poll_interval,
encryption_option, kms_key, converter, formatter,
retry_exceptions, retry_attempt, retry_multiplier,
retry_max_delay, retry_exponential_base)
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._arraysize = arraysize

@property
def arraysize(self):
return self._arraysize

@arraysize.setter
def arraysize(self, value):
if value <= 0 or value > CursorIterator.DEFAULT_FETCH_SIZE:
raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format(
CursorIterator.DEFAULT_FETCH_SIZE))
self._arraysize = value

def close(self, wait=False):
self._executor.shutdown(wait=wait)

def _description(self, query_id):
result_set = self._collect_result_set(query_id)
return [
(
m.get('Name', None),
m.get('Type', None),
None,
None,
m.get('Precision', None),
m.get('Scale', None),
m.get('Nullable', None)
)
for m in result_set.meta_data
]
return result_set.description

def description(self, query_id):
return self._executor.submit(self._description, query_id)
Expand All @@ -56,13 +60,16 @@ def poll(self, query_id):
def _collect_result_set(self, query_id):
query_execution = self._poll(query_id)
return AthenaResultSet(
self._connection, self._converter, query_execution, self.arraysize,
self._connection, self._converter, query_execution, self._arraysize,
self.retry_exceptions, self.retry_attempt, self.retry_multiplier,
self.retry_max_delay, self.retry_exponential_base)

def execute(self, operation, parameters=None):
query_id = self._execute(operation, parameters)
return query_id, self._executor.submit(self._collect_result_set, query_id)

def executemany(self, operation, seq_of_parameters):
raise NotSupportedError

def cancel(self, query_id):
return self._executor.submit(self._cancel, query_id)
93 changes: 74 additions & 19 deletions pyathena/common.py
Expand Up @@ -3,27 +3,78 @@
from __future__ import unicode_literals
import logging
import time
from abc import ABCMeta
from abc import ABCMeta, abstractmethod

from future.utils import raise_from, with_metaclass

from pyathena.error import (DatabaseError, OperationalError,
ProgrammingError, NotSupportedError)
from pyathena.error import DatabaseError, OperationalError, ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.util import retry_api_call


_logger = logging.getLogger(__name__)


class BaseCursor(with_metaclass(ABCMeta, object)):
class CursorIterator(with_metaclass(ABCMeta, object)):

DEFAULT_FETCH_SIZE = 1000

def __init__(self, arraysize=None):
self.arraysize = arraysize if arraysize else self.DEFAULT_FETCH_SIZE
self._rownumber = None

@property
def arraysize(self):
return self._arraysize

@arraysize.setter
def arraysize(self, value):
if value <= 0 or value > self.DEFAULT_FETCH_SIZE:
raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format(
self.DEFAULT_FETCH_SIZE))
self._arraysize = value

@property
def rownumber(self):
return self._rownumber

@property
def rowcount(self):
"""By default, return -1 to indicate that this is not supported."""
return -1

@abstractmethod
def fetchone(self):
pass

@abstractmethod
def fetchmany(self):
pass

@abstractmethod
def fetchall(self):
pass

def __next__(self):
row = self.fetchone()
if row is None:
raise StopIteration
else:
return row

next = __next__

def __iter__(self):
return self


class BaseCursor(with_metaclass(ABCMeta, object)):

def __init__(self, client, s3_staging_dir, schema_name, poll_interval,
encryption_option, kms_key, converter, formatter,
retry_exceptions, retry_attempt, retry_multiplier,
retry_max_delay, retry_exponential_base):
super(BaseCursor, self).__init__()
self._connection = client
self._s3_staging_dir = s3_staging_dir
self._schema_name = schema_name
Expand All @@ -44,23 +95,10 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval,
self.retry_max_delay = retry_max_delay
self.retry_exponential_base = retry_exponential_base

self._arraysize = self.DEFAULT_FETCH_SIZE

@property
def connection(self):
return self._connection

@property
def arraysize(self):
return self._arraysize

@arraysize.setter
def arraysize(self, value):
if value <= 0 or value > self.DEFAULT_FETCH_SIZE:
raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format(
self.DEFAULT_FETCH_SIZE))
self._arraysize = value

def _query_execution(self, query_id):
request = {'QueryExecutionId': query_id}
try:
Expand All @@ -81,7 +119,9 @@ def _query_execution(self, query_id):
def _poll(self, query_id):
while True:
query_execution = self._query_execution(query_id)
if query_execution.state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
if query_execution.state in [AthenaQueryExecution.STATE_SUCCEEDED,
AthenaQueryExecution.STATE_FAILED,
AthenaQueryExecution.STATE_CANCELLED]:
return query_execution
else:
time.sleep(self._poll_interval)
Expand Down Expand Up @@ -129,8 +169,17 @@ def _execute(self, operation, parameters=None):
else:
return response.get('QueryExecutionId', None)

@abstractmethod
def execute(self, operation, parameters=None):
pass

@abstractmethod
def executemany(self, operation, seq_of_parameters):
raise NotSupportedError
pass

@abstractmethod
def close(self):
pass

def _cancel(self, query_id):
request = {'QueryExecutionId': query_id}
Expand All @@ -154,3 +203,9 @@ def setinputsizes(self, sizes):
def setoutputsize(self, size, column=None):
"""Does nothing by default"""
pass

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
89 changes: 35 additions & 54 deletions pyathena/cursor.py
Expand Up @@ -3,16 +3,17 @@
from __future__ import unicode_literals
import logging

from pyathena.common import BaseCursor
from pyathena.error import OperationalError, ProgrammingError
from pyathena.model import AthenaResultSet
from pyathena.common import BaseCursor, CursorIterator
from pyathena.error import OperationalError, ProgrammingError, NotSupportedError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaResultSet
from pyathena.util import synchronized


_logger = logging.getLogger(__name__)


class Cursor(BaseCursor):
class Cursor(BaseCursor, CursorIterator):

def __init__(self, client, s3_staging_dir, schema_name, poll_interval,
encryption_option, kms_key, converter, formatter,
Expand All @@ -22,77 +23,74 @@ def __init__(self, client, s3_staging_dir, schema_name, poll_interval,
encryption_option, kms_key, converter, formatter,
retry_exceptions, retry_attempt, retry_multiplier,
retry_max_delay, retry_exponential_base)
self._description = None
self._query_id = None
self._meta_data = None
self._result_set = None

@property
def rownumber(self):
return self._result_set.rownumber if self._result_set else None

@property
def rowcount(self):
"""By default, return -1 to indicate that this is not supported."""
return -1

@property
def has_result_set(self):
return self._result_set and self._meta_data is not None
return self._result_set is not None

@property
def description(self):
if self._description or self._description == []:
return self._description
if not self.has_result_set:
return None
self._description = [
(
m.get('Name', None),
m.get('Type', None),
None,
None,
m.get('Precision', None),
m.get('Scale', None),
m.get('Nullable', None)
)
for m in self._meta_data
]
return self._description
return self._result_set.description

@property
def query_id(self):
return self._query_id

@property
def output_location(self):
def query(self):
if not self.has_result_set:
return None
return self._result_set.query

@property
def state(self):
if not self.has_result_set:
return None
return self._result_set.state

@property
def state_change_reason(self):
if not self.has_result_set:
return None
return self._result_set.query_execution.output_location
return self._result_set.state_change_reason

@property
def completion_date_time(self):
if not self.has_result_set:
return None
return self._result_set.query_execution.completion_date_time
return self._result_set.completion_date_time

@property
def submission_date_time(self):
if not self.has_result_set:
return None
return self._result_set.query_execution.submission_date_time
return self._result_set.submission_date_time

@property
def data_scanned_in_bytes(self):
if not self.has_result_set:
return None
return self._result_set.query_execution.data_scanned_in_bytes
return self._result_set.data_scanned_in_bytes

@property
def execution_time_in_millis(self):
if not self.has_result_set:
return None
return self._result_set.query_execution.execution_time_in_millis
return self._result_set.execution_time_in_millis

@property
def output_location(self):
if not self.has_result_set:
return None
return self._result_set.output_location

def close(self):
pass
Expand All @@ -101,22 +99,23 @@ def _reset_state(self):
self._description = None
self._query_id = None
self._result_set = None
self._meta_data = None

@synchronized
def execute(self, operation, parameters=None):
self._reset_state()
self._query_id = self._execute(operation, parameters)
query_execution = self._poll(self._query_id)
if query_execution.state == 'SUCCEEDED':
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
self._result_set = AthenaResultSet(
self._connection, self._converter, query_execution, self.arraysize,
self.retry_exceptions, self.retry_attempt, self.retry_multiplier,
self.retry_max_delay, self.retry_exponential_base)
self._meta_data = self._result_set.meta_data
else:
raise OperationalError(query_execution.state_change_reason)

def executemany(self, operation, seq_of_parameters):
raise NotSupportedError

@synchronized
def cancel(self):
if not self._query_id:
Expand All @@ -140,21 +139,3 @@ def fetchall(self):
if not self.has_result_set:
raise ProgrammingError('No result set.')
return self._result_set.fetchall()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def __next__(self):
row = self.fetchone()
if row is None:
raise StopIteration
else:
return row

next = __next__

def __iter__(self):
return self

0 comments on commit 45c6142

Please sign in to comment.