Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Negation was not being cloned properly on Q() objects. Additionally I

decided to add parens around negated Q() expressions.
  • Loading branch information...
commit cdc3fbc7f69f04f0a9f64f8662fadb11198c31dd 1 parent 50728f2
@coleifer authored
Showing with 17 additions and 6 deletions.
  1. +6 −4 peewee.py
  2. +11 −2 tests.py
View
10 peewee.py
@@ -168,11 +168,11 @@ def __invert__(self):
class Q(Leaf):
def __init__(self, lhs, op, rhs, negated=False):
+ super(Q, self).__init__()
self.lhs = lhs
self.op = op
self.rhs = rhs
self.negated = negated
- super(Q, self).__init__()
def clone(self):
return Q(self.lhs, self.op, self.rhs, self.negated)
@@ -180,8 +180,8 @@ def clone(self):
class DQ(Leaf):
def __init__(self, **query):
- self.query = query
super(DQ, self).__init__()
+ self.query = query
def clone(self):
return DQ(**self.query)
@@ -707,8 +707,10 @@ def parse_expr(self, expr, alias_map=None):
def parse_q(self, q, alias_map=None):
lhs_expr, lparams = self.parse_expr(q.lhs, alias_map)
rhs_expr, rparams = self.parse_expr(q.rhs, alias_map)
- not_expr = q.negated and 'NOT ' or ''
- return '%s%s %s %s' % (not_expr, lhs_expr, self.get_op(q.op), rhs_expr), lparams + rparams
+ parsed = '%s %s %s' % (lhs_expr, self.get_op(q.op), rhs_expr)
+ if q.negated:
+ parsed = '(NOT %s)' % parsed
+ return parsed, lparams + rparams
def parse_node(self, n, alias_map=None):
query = []
View
13 tests.py
@@ -401,7 +401,7 @@ def test_where_fk(self):
def test_where_negation(self):
sq = SelectQuery(Blog).where(~(Blog.title == 'foo'))
- self.assertWhere(sq, 'NOT blog."title" = ?', ['foo'])
+ self.assertWhere(sq, '(NOT blog."title" = ?)', ['foo'])
sq = SelectQuery(Blog).where(~((Blog.title == 'foo') | (Blog.title == 'bar')))
self.assertWhere(sq, '(NOT (blog."title" = ? OR blog."title" = ?))', ['foo', 'bar'])
@@ -429,7 +429,7 @@ def test_where_chaining_collapsing(self):
self.assertWhere(sq, '(users."id" = ?) AND (users."id" = ? OR users."id" = ?)', [1, 2, 3])
sq = SelectQuery(User).where(~(User.id == 1)).where(User.id == 2).where(~(User.id == 3))
- self.assertWhere(sq, '(users."id" = ? AND users."id" = ?) AND NOT users."id" = ?', [1, 2, 3])
+ self.assertWhere(sq, '((NOT users."id" = ?) AND users."id" = ?) AND (NOT users."id" = ?)', [1, 2, 3])
def test_grouping(self):
sq = SelectQuery(User).group_by(User.id)
@@ -1130,6 +1130,15 @@ def assertNM(self, q, exp):
query = NullModel.select().where(q).order_by(NullModel.id)
self.assertEqual([nm.char_field for nm in query], exp)
+ def test_null_query(self):
+ NullModel.delete().execute()
+ nm1 = NullModel.create(char_field='nm1')
+ nm2 = NullModel.create(char_field='nm2', int_field=1)
+ nm3 = NullModel.create(char_field='nm3', int_field=2, float_field=3.0)
+
+ q = ~(NullModel.int_field >> None)
+ self.assertNM(q, ['nm2', 'nm3'])
+
def test_field_types(self):
for field, values in self.field_data.items():
field_obj = getattr(NullModel, field)

0 comments on commit cdc3fbc

Please sign in to comment.
Something went wrong with that request. Please try again.