Skip to content

Commit

Permalink
Parameters substitution for SELECT queries
Browse files Browse the repository at this point in the history
  • Loading branch information
xzkostyan committed Oct 17, 2017
1 parent 70a68bd commit 622d9a5
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Changelog

## [Unreleased]
### Added
- Parameters substitution for SELECT queries.

### Fixed
- Columnar result returning from multiple blocks. Columns must be concatenated.

Expand Down
40 changes: 28 additions & 12 deletions src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .connection import Connection
from .protocol import ServerPacketTypes
from .util.helpers import chunks
from .util.escape import escape_params


class QueryResult(object):
Expand Down Expand Up @@ -158,7 +159,7 @@ def execute(self, query, params=None, with_column_types=False,
self.connection.force_connect()

try:
is_insert = params is not None
is_insert = isinstance(params, (list, tuple))
if is_insert:
return self.process_insert_query(
query, params, external_tables=external_tables,
Expand All @@ -167,7 +168,7 @@ def execute(self, query, params=None, with_column_types=False,
)
else:
return self.process_ordinary_query(
query, with_column_types=with_column_types,
query, params=params, with_column_types=with_column_types,
external_tables=external_tables,
query_id=query_id, settings=settings,
types_check=types_check, columnar=columnar
Expand All @@ -177,14 +178,16 @@ def execute(self, query, params=None, with_column_types=False,
self.connection.disconnect()
raise

def execute_with_progress(self, query, with_column_types=False,
external_tables=None, query_id=None,
settings=None, types_check=False):
def execute_with_progress(
self, query, params=None, with_column_types=False,
external_tables=None, query_id=None, settings=None,
types_check=False):

self.connection.force_connect()

try:
return self.process_ordinary_query_with_progress(
query, with_column_types=with_column_types,
query, params=params, with_column_types=with_column_types,
external_tables=external_tables,
query_id=query_id, settings=settings, types_check=types_check
)
Expand All @@ -194,18 +197,27 @@ def execute_with_progress(self, query, with_column_types=False,
raise

def process_ordinary_query_with_progress(
self, query, with_column_types=False, external_tables=None,
query_id=None, settings=None, types_check=False, columnar=False):
self, query, params=None, with_column_types=False,
external_tables=None, query_id=None, settings=None,
types_check=False, columnar=False):

if params:
query = self.substitute_params(query, params)

self.connection.send_query(query, query_id=query_id, settings=settings)
self.connection.send_external_tables(external_tables,
types_check=types_check)
return self.receive_result(with_column_types=with_column_types,
progress=True, columnar=columnar)

def process_ordinary_query(self, query, with_column_types=False,
external_tables=None, query_id=None,
settings=None, types_check=False,
columnar=False):
def process_ordinary_query(
self, query, params=None, with_column_types=False,
external_tables=None, query_id=None, settings=None,
types_check=False, columnar=False):

if params:
query = self.substitute_params(query, params)

self.connection.send_query(query, query_id=query_id, settings=settings)
self.connection.send_external_tables(external_tables,
types_check=types_check)
Expand Down Expand Up @@ -255,3 +267,7 @@ def cancel(self, with_column_types=False):
self.connection.send_cancel()
# Client must still read until END_OF_STREAM packet.
return self.receive_result(with_column_types=with_column_types)

def substitute_params(self, query, params):
escaped = escape_params(params)
return query % escaped
51 changes: 51 additions & 0 deletions src/util/escape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from datetime import date, datetime

from enum import Enum

from .compat import text_type, string_types


escape_chars_map = {
"\b": "\\b",
"\f": "\\f",
"\r": "\\r",
"\n": "\\n",
"\t": "\\t",
"\0": "\\0",
"\a": "\\a",
"\v": "\\v",
"\\": "\\\\",
"'": "\\'"
}


def escape_param(item):
if item is None:
return 'NULL'

elif isinstance(item, datetime):
return "'%s'" % item.strftime('%Y-%m-%d %H:%M:%S')

elif isinstance(item, date):
return "'%s'" % item.strftime('%Y-%m-%d')

elif isinstance(item, string_types):
return "'%s'" % ''.join(escape_chars_map.get(c, c) for c in item)

elif isinstance(item, (list, tuple)):
return "[%s]" % ', '.join(text_type(escape_param(x)) for x in item)

elif isinstance(item, Enum):
return item.value

else:
return item


def escape_params(params):
escaped = {}

for key, value in params.items():
escaped[key] = escape_param(value)

return escaped
148 changes: 148 additions & 0 deletions tests/test_substitution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# coding=utf-8
from __future__ import unicode_literals

from datetime import date, datetime
from decimal import Decimal
from enum import Enum

from tests.testcase import BaseTestCase


class ParametersSubstitutionTestCase(BaseTestCase):
single_tpl = 'SELECT %(x)s'
double_tpl = 'SELECT %(x)s, %(y)s'

def assert_subst(self, tpl, params, sql):
self.assertEqual(self.client.substitute_params(tpl, params), sql)

def test_int(self):
params = {'x': 123}

self.assert_subst(self.single_tpl, params, 'SELECT 123')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [(123, )])

def test_null(self):
params = {'x': None}

self.assert_subst(self.single_tpl, params, 'SELECT NULL')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [(None, )])

def test_date(self):
d = date(2017, 10, 16)
params = {'x': d}

self.assert_subst(self.single_tpl, params, "SELECT '2017-10-16'")

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [('2017-10-16', )])

tpl = 'SELECT CAST(%(x)s AS Date)'
self.assert_subst(tpl, params, "SELECT CAST('2017-10-16' AS Date)")

rv = self.client.execute(tpl, params)
self.assertEqual(rv, [(d, )])

def test_datetime(self):
dt = datetime(2017, 10, 16, 0, 18, 50)
params = {'x': dt}

self.assert_subst(self.single_tpl, params,
"SELECT '2017-10-16 00:18:50'")

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [('2017-10-16 00:18:50', )])

tpl = 'SELECT CAST(%(x)s AS DateTime)'
self.assert_subst(tpl, params,
"SELECT CAST('2017-10-16 00:18:50' AS DateTime)")

rv = self.client.execute(tpl, params)
self.assertEqual(rv, [(dt, )])

def test_string(self):
params = {'x': 'test\t\n\x16', 'y': 'тест\t\n\x16'}

self.assert_subst(self.double_tpl, params,
"SELECT 'test\\t\\n\x16', 'тест\\t\\n\x16'")

rv = self.client.execute(self.double_tpl, params)
self.assertEqual(rv, [('test\t\n\x16', 'тест\t\n\x16')])

params = {'x': "'"}

self.assert_subst(self.single_tpl, params, "SELECT '\\''")

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [("'", )])

params = {'x': "\\"}

self.assert_subst(self.single_tpl, params, "SELECT '\\\\'")

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [("\\", )])

def test_array(self):
params = {'x': [1, None, 2]}

self.assert_subst(self.single_tpl, params, 'SELECT [1, NULL, 2]')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [((1, None, 2), )])

params = {'x': [[1, 2, 3], [4, 5], [6, 7]]}

self.assert_subst(self.single_tpl, params,
'SELECT [[1, 2, 3], [4, 5], [6, 7]]')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [(((1, 2, 3), (4, 5), (6, 7)), )])

def test_tuple(self):
params = {'x': (1, None, 2)}

self.assert_subst(self.single_tpl, params, 'SELECT [1, NULL, 2]')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [((1, None, 2), )])

params = {'x': ((1, 2, 3), (4, 5), (6, 7))}

self.assert_subst(self.single_tpl, params,
'SELECT [[1, 2, 3], [4, 5], [6, 7]]')

rv = self.client.execute(self.single_tpl, params)
self.assertEqual(rv, [(((1, 2, 3), (4, 5), (6, 7)), )])

def test_enum(self):

class A(Enum):
hello = -1
world = 2

params = {'x': A.hello, 'y': A.world}

self.assert_subst(self.double_tpl, params, 'SELECT -1, 2')

rv = self.client.execute(self.double_tpl, params)
self.assertEqual(rv, [(-1, 2)])

def test_float(self):
params = {'x': 1e-12, 'y': 123.45}

self.assert_subst(self.double_tpl, params, 'SELECT 1e-12, 123.45')

rv = self.client.execute(self.double_tpl, params)
self.assertEqual(rv, [(params['x'], params['y'])])

def test_decimal(self):
params = {'x': Decimal('1e-2'), 'y': Decimal('123.45')}

self.assert_subst(self.double_tpl, params, 'SELECT 0.01, 123.45')

rv = self.client.execute(self.double_tpl, params)
self.assertEqual(rv, [(0.01, 123.45)])

0 comments on commit 622d9a5

Please sign in to comment.