Skip to content

Commit

Permalink
Support for Presto decimals (#430)
Browse files Browse the repository at this point in the history
* Support for Presto decimals

* lower
  • Loading branch information
serenajiang committed Mar 7, 2022
1 parent 8df7254 commit 3547bd6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
18 changes: 12 additions & 6 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from __future__ import unicode_literals

from builtins import object
from decimal import Decimal

from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
Expand All @@ -34,6 +36,11 @@

_logger = logging.getLogger(__name__)

TYPES_CONVERTER = {
"decimal": Decimal,
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
"varbinary": base64.b64decode
}

class PrestoParamEscaper(common.ParamEscaper):
def escape_datetime(self, item, format):
Expand Down Expand Up @@ -307,14 +314,13 @@ def _fetch_more(self):
"""Fetch the next URI and update state"""
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))

def _decode_binary(self, rows):
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
# This function decodes base64 data in place
def _process_data(self, rows):
for i, col in enumerate(self.description):
if col[1] == 'varbinary':
col_type = col[1].split("(")[0].lower()
if col_type in TYPES_CONVERTER:
for row in rows:
if row[i] is not None:
row[i] = base64.b64decode(row[i])
row[i] = TYPES_CONVERTER[col_type](row[i])

def _process_response(self, response):
"""Given the JSON response from Presto's REST API, update the internal state with the next
Expand All @@ -341,7 +347,7 @@ def _process_response(self, response):
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
self._decode_binary(new_data)
self._process_data(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
Expand Down
4 changes: 3 additions & 1 deletion pyhive/tests/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import contextlib
import os
from decimal import Decimal

import requests

from pyhive import exc
Expand Down Expand Up @@ -93,7 +95,7 @@ def test_complex(self, cursor):
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
'0.1',
Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
Expand Down
4 changes: 3 additions & 1 deletion pyhive/tests/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import contextlib
import os
from decimal import Decimal

import requests

from pyhive import exc
Expand Down Expand Up @@ -89,7 +91,7 @@ def test_complex(self, cursor):
{"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
'0.1',
Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
Expand Down
2 changes: 1 addition & 1 deletion pyhive/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _process_response(self, response):
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
self._decode_binary(new_data)
self._process_data(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
Expand Down

0 comments on commit 3547bd6

Please sign in to comment.