Skip to content
This repository has been archived by the owner on Oct 13, 2023. It is now read-only.

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hongqn committed Feb 24, 2012
1 parent 76706aa commit 874fb74
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
7 changes: 0 additions & 7 deletions oursql/connections.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ def _connect(self):
self.password, self.db, False, self.charset) self.password, self.db, False, self.charset)


def query(self, sql, args): def query(self, sql, args):
result = self._query(sql, args)
if isinstance(result, tuple):
return result[0] # affected rows
else:
return result

def _query(self, sql, args):
try: try:
return self._umysql_conn.query(sql, args) return self._umysql_conn.query(sql, args)
except umysql.Error, exc: except umysql.Error, exc:
Expand Down
9 changes: 7 additions & 2 deletions oursql/cursors.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from .utils import setdocstring from .utils import setdocstring


class Cursor(pymysql.cursors.Cursor): class Cursor(pymysql.cursors.Cursor):
setdocstring(pymysql.cursors.Cursor) @setdocstring(pymysql.cursors.Cursor)
def execute(self, query, args=()): def execute(self, query, args=None):
conn = self._get_db() conn = self._get_db()

if args is None:
args = ()
elif not isinstance(args, (tuple, list, dict)):
args = (args,)
return conn.query(query, args) return conn.query(query, args)
10 changes: 5 additions & 5 deletions tests/test_oursql_dbapi20.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import oursql import oursql


def setup_module(): def setup_module():
@apply
def print_sqls(): def print_sqls():
print "patch" print "patch"
import oursql.connections import oursql.connections
orig_query = oursql.connections.Connection._query orig_query = oursql.connections.Connection.query
def _query(self, *a, **kw): def query(self, *a, **kw):
print a, kw print "QUERY:", a, kw
return orig_query(self, *a, **kw) return orig_query(self, *a, **kw)
oursql.connections.Connection._query = _query oursql.connections.Connection.query = query
print_sqls()


class test_oursql(dbapi20.DatabaseAPI20Test): class test_oursql(dbapi20.DatabaseAPI20Test):
driver = oursql driver = oursql
Expand Down

0 comments on commit 874fb74

Please sign in to comment.