Skip to content
This repository has been archived by the owner on Oct 13, 2023. It is now read-only.

Commit

Permalink
copy regexp from MySQL-python 1.2.4
Browse files Browse the repository at this point in the history
Fix #7
  • Loading branch information
hongqn committed Dec 6, 2013
1 parent ac04b8a commit d3973b3
Show file tree
Hide file tree
Showing 19 changed files with 76 additions and 26 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ language: python
python:
- "2.6"
- "2.7"
install: "pip install ."
install:
- "pip install ."
- "pip install -r requirements.txt"
script: nosetests
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
PyMySQL==0.5
gevent==1.0
greenlet==0.4.1
mock==1.0.1
nose==1.3.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
"Topic :: Software Development :: Libraries :: Python Modules",
],
test_suite = 'nose.collector',
tests_require = ['nose'],
tests_require = ['nose', 'mock'],
)
17 changes: 17 additions & 0 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from mock import patch, Mock
import umysqldb.cursors as M


def test_executemany_should_handle_on_duplicate_key_update_clause():
with patch.object(M.Cursor, '_query') as mock_query:
sql = 'INSERT INTO _test(id,val) VALUES(%s,%s) ' \
'ON DUPLICATE KEY UPDATE id=VALUES(id),val=VALUES(val)'
args = [(44495, 1), (44495, 2)]
mock_conn = Mock()
mock_conn.literal.side_effect = repr
cursor = M.Cursor(mock_conn)
cursor.executemany(sql, args)
mock_query.assert_called_once_with(
'INSERT INTO _test(id,val) VALUES\n(%s,%s),(%s,%s)\n '
'ON DUPLICATE KEY UPDATE id=VALUES(id),val=VALUES(val)',
(44495, 1, 44495, 2))
1 change: 1 addition & 0 deletions tests/test_umysqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import umysqldb
import umysqldb.err


@raises(umysqldb.err.OperationalError)
def test_access_denied_should_raise_OperationalError():
umysqldb.connect(host='127.0.0.1', user='asdf', passwd='fdsa')
1 change: 1 addition & 0 deletions umysqldb/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pymysql.constants import *
72 changes: 48 additions & 24 deletions umysqldb/cursors.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
import sys
import re
import pymysql.cursors

from .util import setdocstring

#Thank you MySQLdb for the kind regex
INSERT_VALUES = re.compile(
r"(?P<start>.+values\s*)"
r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))"
r"(?P<end>.*)",
re.I)
# Thank you MySQLdb for the kind regex
restr = r"""
\s
values
\s*
(
\(
[^()']*
(?:
(?:
(?:\(
# ( - editor hightlighting helper
[^)]*
\))
|
'
[^\\']*
(?:\\.[^\\']*)*
'
)
[^()']*
)*
\)
)
"""

insert_values = re.compile(restr, re.S | re.I | re.X)


def _flatten(alist):
result = []
map(result.extend, alist)
return tuple(result)


class Cursor(pymysql.cursors.Cursor):

@setdocstring(pymysql.cursors.Cursor.execute)
def execute(self, query, args=None):
if args is None:
Expand All @@ -39,23 +62,22 @@ def execute(self, query, args=None):
def executemany(self, query, args):
if not args:
return
conn = self._get_db()
charset = conn.charset
db = self._get_db()
charset = db.charset
if isinstance(query, unicode):
query = query.encode(charset)
matched = INSERT_VALUES.match(query)
if not matched:

m = insert_values.search(query)
if not m:
self.rowcount = sum([self.execute(query, arg) for arg in args])
return self.rowcount

#Speed up a bulk insert MySQLdb style
start = matched.group('start')
values = matched.group('values')
end = matched.group('end')

sql_params = (values for i in range(len(args)))
multirow_query = '\n'.join([start, ','.join(sql_params), end])
# Speed up a bulk insert MySQLdb style
p = m.start(1)
e = m.end(1)
qv = m.group(1)
sql_params = (qv for i in range(len(args)))
multirow_query = '\n'.join([query[:p], ','.join(sql_params), query[e:]])
return self.execute(multirow_query, _flatten(args))

def _query(self, query, args=()):
Expand All @@ -71,12 +93,13 @@ def _query(self, query, args=()):


class DictCursor(Cursor):

"""A cursor which returns results as a dictionary"""

def execute(self, query, args=None):
result = super(DictCursor, self).execute(query, args)
if self.description:
self._fields = [ field[0] for field in self.description ]
self._fields = [field[0] for field in self.description]
return result

def fetchone(self):
Expand All @@ -94,7 +117,8 @@ def fetchmany(self, size=None):
if self._rows is None:
return None
end = self.rownumber + (size or self.arraysize)
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
result = [dict(zip(self._fields, r))
for r in self._rows[self.rownumber:end]]
self.rownumber = min(end, len(self._rows))
return tuple(result)

Expand All @@ -104,9 +128,9 @@ def fetchall(self):
if self._rows is None:
return None
if self.rownumber:
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
result = [dict(zip(self._fields, r))
for r in self._rows[self.rownumber:]]
else:
result = [ dict(zip(self._fields, r)) for r in self._rows ]
result = [dict(zip(self._fields, r)) for r in self._rows]
self.rownumber = len(self._rows)
return tuple(result)

0 comments on commit d3973b3

Please sign in to comment.