Skip to content

Commit

Permalink
Support NA values with boolean column (fix #100, fix #102, fix #103)
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Nov 23, 2019
1 parent d202135 commit cea7d46
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 35 deletions.
11 changes: 9 additions & 2 deletions pyathena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from boto3.session import Session
from future.utils import iteritems

from pyathena.converter import TypeConverter
from pyathena.async_pandas_cursor import AsyncPandasCursor
from pyathena.converter import DefaultTypeConverter, PandasTypeConverter
from pyathena.cursor import Cursor
from pyathena.error import NotSupportedError
from pyathena.formatter import ParameterFormatter
from pyathena.pandas_cursor import PandasCursor
from pyathena.util import RetryConfig

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,7 +67,12 @@ def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default',
**self._session_kwargs)
self._client = self._session.client('athena', region_name=region_name,
**self._client_kwargs)
self._converter = converter if converter else TypeConverter()
if converter:
self._converter = converter
elif cursor_class is PandasCursor or cursor_class is AsyncPandasCursor:
self._converter = PandasTypeConverter()
else:
self._converter = DefaultTypeConverter()
self._formatter = formatter if formatter else ParameterFormatter()
self._retry_config = retry_config if retry_config else RetryConfig()
self.cursor_class = cursor_class
Expand Down
75 changes: 67 additions & 8 deletions pyathena/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import binascii
import json
import logging
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import datetime
from decimal import Decimal

from future.utils import with_metaclass

_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,7 +81,7 @@ def _to_default(varchar_value):
return varchar_value


_DEFAULT_CONVERTERS = {
_DEFAULT_MAPPINGS = {
'boolean': _to_boolean,
'tinyint': _to_int,
'smallint': _to_int,
Expand All @@ -100,14 +104,69 @@ def _to_default(varchar_value):
}


class TypeConverter(object):
class Converter(with_metaclass(ABCMeta, object)):

def __init__(self, mappings, default=None):
self._mappings = deepcopy(mappings)
self._default = default

@property
def mappings(self):
return self._mappings

def get(self, type_):
return self.mappings.get(type_, self._default)

def set(self, type_, converter):
self.mappings[type_] = converter

def remove(self, type_):
self.mappings.pop(type_)

def update(self, mappings):
self.mappings.update(mappings)

@abstractmethod
def convert(self, type_, value):
raise NotImplemented # pragma: nocover


class DefaultTypeConverter(Converter):

def __init__(self):
self._mappings = _DEFAULT_CONVERTERS
super(DefaultTypeConverter, self).__init__(
mappings=_DEFAULT_MAPPINGS, default=_to_default)

def convert(self, type_, varchar_value):
converter = self._mappings.get(type_, _to_default)
return converter(varchar_value)
def convert(self, type_, value):
converter = self.get(type_)
return converter(value)

def register_converter(self, type_, converter):
self._mappings[type_] = converter

class PandasTypeConverter(Converter):

def __init__(self):
super(PandasTypeConverter, self).__init__(mappings=self._dtypes)

@property
def _dtypes(self):
if not hasattr(self, '__dtypes'):
import pandas as pd
self.__dtypes = {
'boolean': object,
'tinyint': pd.Int64Dtype(),
'smallint': pd.Int64Dtype(),
'integer': pd.Int64Dtype(),
'bigint': pd.Int64Dtype(),
'float': float,
'real': float,
'double': float,
'char': str,
'varchar': str,
'array': str,
'map': str,
'row': str,
}
return self.__dtypes

def convert(self, type_, value):
raise NotImplemented # pragma: nocover
23 changes: 1 addition & 22 deletions pyathena/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,31 +330,10 @@ def _parse_output_location(cls, output_location):
else:
raise DataError('Unknown `output_location` format.')

@property
def _dtypes(self):
if not hasattr(self, '__dtypes'):
import pandas as pd
self.__dtypes = {
'boolean': bool,
'tinyint': pd.Int64Dtype(),
'smallint': pd.Int64Dtype(),
'integer': pd.Int64Dtype(),
'bigint': pd.Int64Dtype(),
'float': float,
'real': float,
'double': float,
'char': str,
'varchar': str,
'array': str,
'map': str,
'row': str,
}
return self.__dtypes

@property
def dtypes(self):
return {
d[0]: self._dtypes[d[1]] for d in self.description if d[1] in self._dtypes
d[0]: self._converter.mappings[d[1]] for d in self.description if d[1] in self._converter.mappings
}

@property
Expand Down
3 changes: 3 additions & 0 deletions scripts/test_data/boolean_na_values.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
true false
false

1 change: 1 addition & 0 deletions scripts/test_data/delete_test_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row/one_row.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv
aws s3 rm ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv
1 change: 1 addition & 0 deletions scripts/test_data/upload_test_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ aws s3 cp $(dirname $0)/one_row.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/on
aws s3 cp $(dirname $0)/one_row_complex.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/one_row_complex/one_row_complex.tsv
aws s3 cp $(dirname $0)/many_rows.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/many_rows/many_rows.tsv
aws s3 cp $(dirname $0)/integer_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/integer_na_values/integer_na_values.tsv
aws s3 cp $(dirname $0)/boolean_na_values.tsv ${AWS_ATHENA_S3_STAGING_DIR}test_pyathena/boolean_na_values/boolean_na_values.tsv
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ def _create_table(cursor):
ENV.s3_staging_dir, S3_PREFIX, 'partition_table')
location_integer_na_values = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'integer_na_values')
location_boolean_na_values = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'boolean_na_values')
for q in read_query(
os.path.join(BASE_PATH, 'sql', 'create_table.sql')):
cursor.execute(q.format(schema=SCHEMA,
location_one_row=location_one_row,
location_many_rows=location_many_rows,
location_one_row_complex=location_one_row_complex,
location_partition_table=location_partition_table,
location_integer_na_values=location_integer_na_values))
location_integer_na_values=location_integer_na_values,
location_boolean_na_values=location_boolean_na_values))
10 changes: 9 additions & 1 deletion tests/sql/create_table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.integer_na_values (
b INT
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_integer_na_values}'
LOCATION '{location_integer_na_values}';

DROP TABLE IF EXISTS {schema}.boolean_na_values;
CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.boolean_na_values (
a BOOLEAN,
b BOOLEAN
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_boolean_na_values}';
17 changes: 16 additions & 1 deletion tests/test_pandas_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_complex_as_pandas(self, cursor):
df['col_decimal'].dtype.type,
])
self.assertEqual(dtypes, tuple([
np.bool_,
bool,
np.int64,
np.int64,
np.int64,
Expand Down Expand Up @@ -342,3 +342,18 @@ def test_integer_na_values(self, cursor):
(1, np.nan),
(np.nan, np.nan),
])

@with_pandas_cursor
def test_boolean_na_values(self, cursor):
df = cursor.execute("""
SELECT * FROM boolean_na_values
""").as_pandas()
rows = [tuple([
row['a'],
row['b'],
]) for _, row in df.iterrows()]
self.assertEqual(rows, [
(True, False),
(False, np.nan),
(np.nan, np.nan),
])

0 comments on commit cea7d46

Please sign in to comment.