diff --git a/pyathena/connection.py b/pyathena/connection.py index b63b4acd..6ef75b9f 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -9,10 +9,12 @@ from boto3.session import Session from future.utils import iteritems -from pyathena.converter import TypeConverter +from pyathena.async_pandas_cursor import AsyncPandasCursor +from pyathena.converter import DefaultTypeConverter, DefaultPandasTypeConverter from pyathena.cursor import Cursor from pyathena.error import NotSupportedError -from pyathena.formatter import ParameterFormatter +from pyathena.formatter import DefaultParameterFormatter +from pyathena.pandas_cursor import PandasCursor from pyathena.util import RetryConfig _logger = logging.getLogger(__name__) @@ -65,8 +67,8 @@ def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', **self._session_kwargs) self._client = self._session.client('athena', region_name=region_name, **self._client_kwargs) - self._converter = converter if converter else TypeConverter() - self._formatter = formatter if formatter else ParameterFormatter() + self._converter = converter + self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class @@ -120,6 +122,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): def cursor(self, cursor=None, **kwargs): if not cursor: cursor = self.cursor_class + converter = kwargs.pop('converter', self._converter) + if not converter: + if cursor is PandasCursor or cursor is AsyncPandasCursor: + converter = DefaultPandasTypeConverter() + else: + converter = DefaultTypeConverter() return cursor(connection=self, s3_staging_dir=kwargs.pop('s3_staging_dir', self.s3_staging_dir), schema_name=kwargs.pop('schema_name', self.schema_name), @@ -127,7 +135,7 @@ def cursor(self, cursor=None, **kwargs): poll_interval=kwargs.pop('poll_interval', self.poll_interval), encryption_option=kwargs.pop('encryption_option', self.encryption_option), kms_key=kwargs.pop('kms_key', self.kms_key), - converter=kwargs.pop('converter', self._converter), + converter=converter, formatter=kwargs.pop('formatter', self._formatter), retry_config=kwargs.pop('retry_config', self._retry_config), **kwargs) diff --git a/pyathena/converter.py b/pyathena/converter.py index 222996ce..a3a1eb7c 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -5,9 +5,13 @@ import binascii import json import logging +from abc import ABCMeta, abstractmethod +from copy import deepcopy from datetime import datetime from decimal import Decimal +from future.utils import with_metaclass + _logger = logging.getLogger(__name__) @@ -100,14 +104,72 @@ def _to_default(varchar_value): } -class TypeConverter(object): +class Converter(with_metaclass(ABCMeta, object)): + + def __init__(self, mappings, default=None): + self._mappings = deepcopy(mappings) + self._default = default + + @property + def mappings(self): + return self._mappings + + def get(self, type_): + return self.mappings.get(type_, self._default) + + def set(self, type_, converter): + self.mappings[type_] = converter + + def remove(self, type_): + self.mappings.pop(type_) + + def update(self, mappings): + self.mappings.update(mappings) + + @abstractmethod + def convert(self, type_, value): + raise NotImplemented # pragma: nocover + + +class DefaultTypeConverter(Converter): def __init__(self): - self._mappings = _DEFAULT_CONVERTERS + super(DefaultTypeConverter, self).__init__( + mappings=_DEFAULT_CONVERTERS, default=_to_default) - def convert(self, type_, varchar_value): - converter = self._mappings.get(type_, _to_default) - return converter(varchar_value) + def convert(self, type_, value): + converter = self.get(type_) + return converter(value) - def register_converter(self, type_, converter): - self._mappings[type_] = converter + +class DefaultPandasTypeConverter(Converter): + + def __init__(self): + super(DefaultPandasTypeConverter, self).__init__(mappings=self._dtypes) + + @property + def _dtypes(self): + if not hasattr(self, '__dtypes'): + import pandas as pd + self.__dtypes = { + 'boolean': object, + 'tinyint': pd.Int64Dtype(), + 'smallint': pd.Int64Dtype(), + 'integer': pd.Int64Dtype(), + 'bigint': pd.Int64Dtype(), + 'float': float, + 'real': float, + 'double': float, + 'char': str, + 'varchar': str, + 'array': str, + 'map': str, + 'row': str, + 'decimal': object, + 'json': object, + 'varbinary': object, + } + return self.__dtypes + + def convert(self, type_, value): + pass # pragma: nocover diff --git a/pyathena/formatter.py b/pyathena/formatter.py index 66b94f5a..103d366b 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -78,7 +78,24 @@ def _format_seq(formatter, escaper, val): return '({0})'.format(','.join(results)) -class ParameterFormatter(object): +_DEFAULT_FORMATTERS = { + type(None): _format_none, + date: _format_date, + datetime: _format_datetime, + int: _format_default, + float: _format_default, + long: _format_default, + Decimal: _format_default, + bool: _format_bool, + str: _format_str, + unicode: _format_str, + list: _format_seq, + set: _format_seq, + tuple: _format_seq, +} + + +class DefaultParameterFormatter(object): def __init__(self): self.mappings = _DEFAULT_FORMATTERS @@ -113,20 +130,3 @@ def format(self, operation, parameters=None): def register_formatter(self, type_, formatter): self.mappings[type_] = formatter - - -_DEFAULT_FORMATTERS = { - type(None): _format_none, - date: _format_date, - datetime: _format_datetime, - int: _format_default, - float: _format_default, - long: _format_default, - Decimal: _format_default, - bool: _format_bool, - str: _format_str, - unicode: _format_str, - list: _format_seq, - set: _format_seq, - tuple: _format_seq, -} diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 0ac8874d..32a67da4 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -9,6 +9,7 @@ import logging import re from decimal import Decimal +from distutils.util import strtobool from future.utils import raise_from from past.builtins.misc import xrange @@ -293,9 +294,10 @@ class AthenaPandasResultSet(AthenaResultSet): _pattern_output_location = re.compile(r'^s3://(?P[a-zA-Z0-9.\-_]+)/(?P.+)$') _converters = { - 'decimal': Decimal, - 'varbinary': lambda b: binascii.a2b_hex(''.join(b.split(' '))), - 'json': json.loads, + 'boolean': lambda b: bool(strtobool(b)) if b else None, + 'decimal': lambda d: Decimal(d) if d else None, + 'varbinary': lambda b: binascii.a2b_hex(''.join(b.split(' '))) if b else None, + 'json': lambda j: json.loads(j) if j else None, } _parse_dates = [ 'date', @@ -330,31 +332,11 @@ def _parse_output_location(cls, output_location): else: raise DataError('Unknown `output_location` format.') - @property - def _dtypes(self): - if not hasattr(self, '__dtypes'): - import pandas as pd - self.__dtypes = { - 'boolean': bool, - 'tinyint': pd.Int64Dtype(), - 'smallint': pd.Int64Dtype(), - 'integer': pd.Int64Dtype(), - 'bigint': pd.Int64Dtype(), - 'float': float, - 'real': float, - 'double': float, - 'char': str, - 'varchar': str, - 'array': str, - 'map': str, - 'row': str, - } - return self.__dtypes - @property def dtypes(self): return { - d[0]: self._dtypes[d[1]] for d in self.description if d[1] in self._dtypes + d[0]: self._converter.mappings[d[1]] for d in self.description + if d[1] in self._converter.mappings } @property diff --git a/scripts/test_data/boolean_na_values.tsv b/scripts/test_data/boolean_na_values.tsv new file mode 100644 index 00000000..385f321b --- /dev/null +++ b/scripts/test_data/boolean_na_values.tsv @@ -0,0 +1,3 @@ +true false +false + diff --git a/scripts/test_data/delete_test_data.sh b/scripts/test_data/delete_test_data.sh index 809dd2c4..006e5516 100755 --- a/scripts/test_data/delete_test_data.sh +++ b/scripts/test_data/delete_test_data.sh @@ -4,3 +4,4 @@ aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row/one_row.tsv aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv +aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv diff --git a/scripts/test_data/upload_test_data.sh b/scripts/test_data/upload_test_data.sh index f27a001c..e022e3f7 100755 --- a/scripts/test_data/upload_test_data.sh +++ b/scripts/test_data/upload_test_data.sh @@ -4,3 +4,4 @@ aws s3 cp $(dirname $0)/one_row.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/on aws s3 cp $(dirname $0)/one_row_complex.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv aws s3 cp $(dirname $0)/many_rows.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv aws s3 cp $(dirname $0)/integer_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv +aws s3 cp $(dirname $0)/boolean_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv diff --git a/tests/conftest.py b/tests/conftest.py index 9603989f..a227cb8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,6 +56,8 @@ def _create_table(cursor): ENV.s3_staging_dir, S3_PREFIX, 'partition_table') location_integer_na_values = '{0}{1}/{2}/'.format( ENV.s3_staging_dir, S3_PREFIX, 'integer_na_values') + location_boolean_na_values = '{0}{1}/{2}/'.format( + ENV.s3_staging_dir, S3_PREFIX, 'boolean_na_values') for q in read_query( os.path.join(BASE_PATH, 'sql', 'create_table.sql')): cursor.execute(q.format(schema=SCHEMA, @@ -63,4 +65,5 @@ def _create_table(cursor): location_many_rows=location_many_rows, location_one_row_complex=location_one_row_complex, location_partition_table=location_partition_table, - location_integer_na_values=location_integer_na_values)) + location_integer_na_values=location_integer_na_values, + location_boolean_na_values=location_boolean_na_values)) diff --git a/tests/sql/create_table.sql b/tests/sql/create_table.sql index d1852753..0208f6bf 100644 --- a/tests/sql/create_table.sql +++ b/tests/sql/create_table.sql @@ -45,4 +45,12 @@ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.integer_na_values ( b INT ) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE -LOCATION '{location_integer_na_values}' +LOCATION '{location_integer_na_values}'; + +DROP TABLE IF EXISTS {schema}.boolean_na_values; +CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.boolean_na_values ( + a BOOLEAN, + b BOOLEAN +) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE +LOCATION '{location_boolean_na_values}'; diff --git a/tests/test_formatter.py b/tests/test_formatter.py index aa1cbd1e..7214bd41 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -7,14 +7,14 @@ from decimal import Decimal from pyathena.error import ProgrammingError -from pyathena.formatter import ParameterFormatter +from pyathena.formatter import DefaultParameterFormatter -class TestParameterFormatter(unittest.TestCase): +class TestDefaultParameterFormatter(unittest.TestCase): # TODO More DDL statement test case & Complex parameter format test case - FORMATTER = ParameterFormatter() + FORMATTER = DefaultParameterFormatter() def format(self, operation, parameters=None): return self.FORMATTER.format(operation, parameters) diff --git a/tests/test_pandas_cursor.py b/tests/test_pandas_cursor.py index 86df2fca..5273ddaa 100644 --- a/tests/test_pandas_cursor.py +++ b/tests/test_pandas_cursor.py @@ -221,7 +221,7 @@ def test_complex_as_pandas(self, cursor): df['col_decimal'].dtype.type, ]) self.assertEqual(dtypes, tuple([ - np.bool_, + bool, np.int64, np.int64, np.int64, @@ -342,3 +342,18 @@ def test_integer_na_values(self, cursor): (1, np.nan), (np.nan, np.nan), ]) + + @with_pandas_cursor + def test_boolean_na_values(self, cursor): + df = cursor.execute(""" + SELECT * FROM boolean_na_values + """).as_pandas() + rows = [tuple([ + row['a'], + row['b'], + ]) for _, row in df.iterrows()] + self.assertEqual(rows, [ + (True, False), + (False, None), + (None, None), + ])