Permalink
Browse files

Fixed query argument encoding in py3k for bytestrings on non-latin1 c…

…onnections
  • Loading branch information...
1 parent 0554168 commit 8940883b692b9b8c725ff18fd794e3ca42dc912d @johnsoft committed Feb 6, 2012
Showing with 43 additions and 45 deletions.
  1. +2 −2 pymysql/connections.py
  2. +29 −36 pymysql/converters.py
  3. +12 −7 pymysql/cursors.py
View
@@ -614,7 +614,7 @@ def close(self):
def autocommit(self, value):
''' Set whether or not to commit after every execute() '''
try:
- self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
+ self._execute_command(COM_QUERY, b"SET AUTOCOMMIT = " +
self.escape(value))
self.read_packet()
except:
@@ -709,7 +709,7 @@ def ping(self, reconnect=True):
def set_charset(self, charset):
try:
if charset:
- self._execute_command(COM_QUERY, "SET NAMES %s" %
+ self._execute_command(COM_QUERY, b"SET NAMES " +
self.escape(charset))
self.read_packet()
self.charset = charset
View
@@ -21,19 +21,8 @@
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset):
- if type(val) in [tuple, list, set]:
- return escape_sequence(val, charset)
- if type(val) is dict:
- return escape_dict(val, charset)
- if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode):
- # deal with py3k bytes
- val = val.decode(charset)
encoder = encoders[type(val)]
- val = encoder(val)
- if type(val) is str:
- return val
- val = val.encode(charset)
- return val
+ return encoder(val, charset)
def escape_dict(val, charset):
n = {}
@@ -47,55 +36,58 @@ def escape_sequence(val, charset):
for item in val:
quoted = escape_item(item, charset)
n.append(quoted)
- return "(" + ",".join(n) + ")"
+ return b"(" + b",".join(n) + b")"
def escape_set(val, charset):
val = map(lambda x: escape_item(x, charset), val)
- return ','.join(val)
+ return b','.join(val)
-def escape_bool(value):
- return str(int(value))
+def escape_bool(value, charset):
+ return str(int(value)).encode(charset)
-def escape_object(value):
- return str(value)
+def escape_via_str(value, charset):
+ return str(value).encode(charset)
-escape_int = escape_long = escape_object
+escape_int = escape_long = escape_via_str
-def escape_float(value):
- return ('%.15g' % value)
+def escape_float(value, charset):
+ return escape_string('%.15g' % value, charset)
-def escape_string(value):
- return ("'%s'" % ESCAPE_REGEX.sub(
+def escape_bytes(value, charset):
+ return escape_string(value.decode('latin-1'), 'latin-1') # latin-1 is a 1:1 mapping from chars to bytes
+
+def escape_string(value, charset):
+ ret = ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value))
+ return ret.encode(charset)
-def escape_unicode(value):
- return escape_string(value)
+escape_unicode = escape_string
def escape_None(value):
- return 'NULL'
+ return b'NULL'
-def escape_timedelta(obj):
+def escape_timedelta(obj, charset):
seconds = int(obj.seconds) % 60
minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
- return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
+ return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
-def escape_time(obj):
+def escape_time(obj, charset):
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
int(obj.second))
if obj.microsecond:
s += ".%f" % obj.microsecond
- return escape_string(s)
+ return escape_string(s, charset)
-def escape_datetime(obj):
- return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
+def escape_datetime(obj, charset):
+ return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
-def escape_date(obj):
- return escape_string(obj.strftime("%Y-%m-%d"))
+def escape_date(obj, charset):
+ return escape_string(obj.strftime("%Y-%m-%d"), charset)
-def escape_struct_time(obj):
- return escape_datetime(datetime.datetime(*obj[:6]))
+def escape_struct_time(obj, charset):
+ return escape_datetime(datetime.datetime(*obj[:6]), charset)
def convert_datetime(connection, field, obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
@@ -291,6 +283,7 @@ def convert_float(connection, field, data):
int: escape_int,
long: escape_long,
float: escape_float,
+ bytes: escape_bytes,
str: escape_string,
unicode: escape_unicode,
tuple: escape_sequence,
View
@@ -90,23 +90,28 @@ def execute(self, query, args=None):
charset = conn.charset
del self.messages[:]
- # TODO: make sure that conn.escape is correct
+ # Ideally we would use %-formatting on byte strings here, however py3k
+ # did away with this functionality. So instead, we represent the query
+ # and escaped args as strs, which contain not actual characters, but
+ # instead raw bytes encoded in the connection charset.
+
+ if isinstance(query, str) and charset != 'latin-1':
+ query = query.encode(charset).decode('latin-1')
if args is not None:
if isinstance(args, tuple) or isinstance(args, list):
- escaped_args = tuple(conn.escape(arg) for arg in args)
+ escaped_args = tuple(conn.escape(arg).decode('latin-1') for arg in args)
elif isinstance(args, dict):
- escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items())
+ escaped_args = dict((key, conn.escape(val).decode('latin-1'))
+ for (key, val) in args.items())
else:
#If it's not a dictionary let's try escaping it anyways.
#Worst case it will throw a Value error
- escaped_args = conn.escape(args)
+ escaped_args = conn.escape(args).decode('latin-1')
query = query % escaped_args
- if isinstance(query, unicode):
- query = query.encode(charset)
-
+ query = query.encode('latin-1')
result = 0
try:
result = self._query(query)

0 comments on commit 8940883

Please sign in to comment.