Skip to content

Commit

Permalink
Support NA values with boolean column (fix #100, fix #102, fix #103)
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Nov 23, 2019
1 parent d202135 commit 23a6df9
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 73 deletions.
18 changes: 13 additions & 5 deletions pyathena/connection.py
Expand Up @@ -9,10 +9,12 @@
from boto3.session import Session
from future.utils import iteritems

from pyathena.converter import TypeConverter
from pyathena.async_pandas_cursor import AsyncPandasCursor
from pyathena.converter import DefaultTypeConverter, DefaultPandasTypeConverter
from pyathena.cursor import Cursor
from pyathena.error import NotSupportedError
from pyathena.formatter import ParameterFormatter
from pyathena.formatter import DefaultParameterFormatter
from pyathena.pandas_cursor import PandasCursor
from pyathena.util import RetryConfig

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,8 +67,8 @@ def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default',
**self._session_kwargs)
self._client = self._session.client('athena', region_name=region_name,
**self._client_kwargs)
self._converter = converter if converter else TypeConverter()
self._formatter = formatter if formatter else ParameterFormatter()
self._converter = converter
self._formatter = formatter if formatter else DefaultParameterFormatter()
self._retry_config = retry_config if retry_config else RetryConfig()
self.cursor_class = cursor_class

Expand Down Expand Up @@ -120,14 +122,20 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def cursor(self, cursor=None, **kwargs):
if not cursor:
cursor = self.cursor_class
converter = kwargs.pop('converter', self._converter)
if not converter:
if cursor is PandasCursor or cursor is AsyncPandasCursor:
converter = DefaultPandasTypeConverter()
else:
converter = DefaultTypeConverter()
return cursor(connection=self,
s3_staging_dir=kwargs.pop('s3_staging_dir', self.s3_staging_dir),
schema_name=kwargs.pop('schema_name', self.schema_name),
work_group=kwargs.pop('work_group', self.work_group),
poll_interval=kwargs.pop('poll_interval', self.poll_interval),
encryption_option=kwargs.pop('encryption_option', self.encryption_option),
kms_key=kwargs.pop('kms_key', self.kms_key),
converter=kwargs.pop('converter', self._converter),
converter=converter,
formatter=kwargs.pop('formatter', self._formatter),
retry_config=kwargs.pop('retry_config', self._retry_config),
**kwargs)
Expand Down
84 changes: 70 additions & 14 deletions pyathena/converter.py
Expand Up @@ -5,8 +5,10 @@
import binascii
import json
import logging
from copy import deepcopy
from datetime import datetime
from decimal import Decimal
from distutils.util import strtobool

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,14 +50,9 @@ def _to_decimal(varchar_value):


def _to_boolean(varchar_value):
if varchar_value is None:
return None
elif varchar_value.lower() == 'true':
return True
elif varchar_value.lower() == 'false':
return False
else:
if varchar_value is None or varchar_value == '':
return None
return bool(strtobool(varchar_value))


def _to_binary(varchar_value):
Expand Down Expand Up @@ -98,16 +95,75 @@ def _to_default(varchar_value):
'decimal': _to_decimal,
'json': _to_json,
}
_DEFAULT_PANDAS_CONVERTERS = {
'boolean': _to_boolean,
'decimal': _to_decimal,
'varbinary': _to_binary,
'json': _to_json,
}


class Converter(object):

def __init__(self, mappings, default=None, types=None):
self._mappings = mappings
self._default = default
self._types = types

class TypeConverter(object):
@property
def mappings(self):
return self._mappings

@property
def types(self):
return self._types

def get(self, type_):
return self.mappings.get(type_, self._default)

def set(self, type_, converter):
self.mappings[type_] = converter

def remove(self, type_):
self.mappings.pop(type_, None)

def update(self, mappings):
self.mappings.update(mappings)

def convert(self, type_, value):
converter = self.get(type_)
return converter(value)


class DefaultTypeConverter(Converter):

def __init__(self):
self._mappings = _DEFAULT_CONVERTERS
super(DefaultTypeConverter, self).__init__(
mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default)

def convert(self, type_, varchar_value):
converter = self._mappings.get(type_, _to_default)
return converter(varchar_value)

def register_converter(self, type_, converter):
self._mappings[type_] = converter
class DefaultPandasTypeConverter(Converter):

def __init__(self):
super(DefaultPandasTypeConverter, self).__init__(
mappings=deepcopy(_DEFAULT_PANDAS_CONVERTERS), default=_to_default, types=self._dtypes)

@property
def _dtypes(self):
if not hasattr(self, '__dtypes'):
import pandas as pd
self.__dtypes = {
'tinyint': pd.Int64Dtype(),
'smallint': pd.Int64Dtype(),
'integer': pd.Int64Dtype(),
'bigint': pd.Int64Dtype(),
'float': float,
'real': float,
'double': float,
'char': str,
'varchar': str,
'array': str,
'map': str,
'row': str,
}
return self.__dtypes
36 changes: 18 additions & 18 deletions pyathena/formatter.py
Expand Up @@ -78,7 +78,24 @@ def _format_seq(formatter, escaper, val):
return '({0})'.format(','.join(results))


class ParameterFormatter(object):
_DEFAULT_FORMATTERS = {
type(None): _format_none,
date: _format_date,
datetime: _format_datetime,
int: _format_default,
float: _format_default,
long: _format_default,
Decimal: _format_default,
bool: _format_bool,
str: _format_str,
unicode: _format_str,
list: _format_seq,
set: _format_seq,
tuple: _format_seq,
}


class DefaultParameterFormatter(object):

def __init__(self):
self.mappings = _DEFAULT_FORMATTERS
Expand Down Expand Up @@ -113,20 +130,3 @@ def format(self, operation, parameters=None):

def register_formatter(self, type_, formatter):
self.mappings[type_] = formatter


_DEFAULT_FORMATTERS = {
type(None): _format_none,
date: _format_date,
datetime: _format_datetime,
int: _format_default,
float: _format_default,
long: _format_default,
Decimal: _format_default,
bool: _format_bool,
str: _format_str,
unicode: _format_str,
list: _format_seq,
set: _format_seq,
tuple: _format_seq,
}
35 changes: 4 additions & 31 deletions pyathena/result_set.py
Expand Up @@ -2,13 +2,10 @@
from __future__ import absolute_import
from __future__ import unicode_literals

import binascii
import collections
import io
import json
import logging
import re
from decimal import Decimal

from future.utils import raise_from
from past.builtins.misc import xrange
Expand Down Expand Up @@ -292,11 +289,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class AthenaPandasResultSet(AthenaResultSet):

_pattern_output_location = re.compile(r'^s3://(?P<bucket>[a-zA-Z0-9.\-_]+)/(?P<key>.+)$')
_converters = {
'decimal': Decimal,
'varbinary': lambda b: binascii.a2b_hex(''.join(b.split(' '))),
'json': json.loads,
}
_parse_dates = [
'date',
'time',
Expand Down Expand Up @@ -330,37 +322,18 @@ def _parse_output_location(cls, output_location):
else:
raise DataError('Unknown `output_location` format.')

@property
def _dtypes(self):
if not hasattr(self, '__dtypes'):
import pandas as pd
self.__dtypes = {
'boolean': bool,
'tinyint': pd.Int64Dtype(),
'smallint': pd.Int64Dtype(),
'integer': pd.Int64Dtype(),
'bigint': pd.Int64Dtype(),
'float': float,
'real': float,
'double': float,
'char': str,
'varchar': str,
'array': str,
'map': str,
'row': str,
}
return self.__dtypes

@property
def dtypes(self):
return {
d[0]: self._dtypes[d[1]] for d in self.description if d[1] in self._dtypes
d[0]: self._converter.types[d[1]] for d in self.description
if d[1] in self._converter.types
}

@property
def converters(self):
return {
d[0]: self._converters[d[1]] for d in self.description if d[1] in self._converters
d[0]: self._converter.mappings[d[1]] for d in self.description
if d[1] in self._converter.mappings
}

@property
Expand Down
3 changes: 3 additions & 0 deletions scripts/test_data/boolean_na_values.tsv
@@ -0,0 +1,3 @@
true false
false

1 change: 1 addition & 0 deletions scripts/test_data/delete_test_data.sh
Expand Up @@ -4,3 +4,4 @@ aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row/one_row.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv
1 change: 1 addition & 0 deletions scripts/test_data/upload_test_data.sh
Expand Up @@ -4,3 +4,4 @@ aws s3 cp $(dirname $0)/one_row.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/on
aws s3 cp $(dirname $0)/one_row_complex.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv
aws s3 cp $(dirname $0)/many_rows.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv
aws s3 cp $(dirname $0)/integer_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv
aws s3 cp $(dirname $0)/boolean_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv
5 changes: 4 additions & 1 deletion tests/conftest.py
Expand Up @@ -56,11 +56,14 @@ def _create_table(cursor):
ENV.s3_staging_dir, S3_PREFIX, 'partition_table')
location_integer_na_values = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'integer_na_values')
location_boolean_na_values = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'boolean_na_values')
for q in read_query(
os.path.join(BASE_PATH, 'sql', 'create_table.sql')):
cursor.execute(q.format(schema=SCHEMA,
location_one_row=location_one_row,
location_many_rows=location_many_rows,
location_one_row_complex=location_one_row_complex,
location_partition_table=location_partition_table,
location_integer_na_values=location_integer_na_values))
location_integer_na_values=location_integer_na_values,
location_boolean_na_values=location_boolean_na_values))
10 changes: 9 additions & 1 deletion tests/sql/create_table.sql
Expand Up @@ -45,4 +45,12 @@ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.integer_na_values (
b INT
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_integer_na_values}'
LOCATION '{location_integer_na_values}';

DROP TABLE IF EXISTS {schema}.boolean_na_values;
CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.boolean_na_values (
a BOOLEAN,
b BOOLEAN
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_boolean_na_values}';
6 changes: 3 additions & 3 deletions tests/test_formatter.py
Expand Up @@ -7,14 +7,14 @@
from decimal import Decimal

from pyathena.error import ProgrammingError
from pyathena.formatter import ParameterFormatter
from pyathena.formatter import DefaultParameterFormatter


class TestParameterFormatter(unittest.TestCase):
class TestDefaultParameterFormatter(unittest.TestCase):

# TODO More DDL statement test case & Complex parameter format test case

FORMATTER = ParameterFormatter()
FORMATTER = DefaultParameterFormatter()

def format(self, operation, parameters=None):
return self.FORMATTER.format(operation, parameters)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_pandas_cursor.py
Expand Up @@ -342,3 +342,18 @@ def test_integer_na_values(self, cursor):
(1, np.nan),
(np.nan, np.nan),
])

@with_pandas_cursor
def test_boolean_na_values(self, cursor):
df = cursor.execute("""
SELECT * FROM boolean_na_values
""").as_pandas()
rows = [tuple([
row['a'],
row['b'],
]) for _, row in df.iterrows()]
self.assertEqual(rows, [
(True, False),
(False, None),
(None, None),
])

0 comments on commit 23a6df9

Please sign in to comment.