Skip to content

Commit

Permalink
Refactored the empty/full result logic in WhereNode.as_sql()
Browse files Browse the repository at this point in the history
Made sure the WhereNode.as_sql() handles various EmptyResultSet and
FullResultSet conditions correctly. Also, got rid of the FullResultSet
exception class. It is now represented by '', [] return value in the
as_sql() methods.
  • Loading branch information
akaariai committed Jul 1, 2012
1 parent 2b9fb2e commit bd283aa
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 38 deletions.
3 changes: 0 additions & 3 deletions django/db/models/sql/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
class EmptyResultSet(Exception):
pass

class FullResultSet(Exception):
pass

class MultiJoin(Exception):
"""
Used by join construction code to indicate the point at which a
Expand Down
83 changes: 48 additions & 35 deletions django/db/models/sql/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from django.utils import tree
from django.db.models.fields import Field
from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.aggregates import Aggregate

# Connection types
Expand Down Expand Up @@ -75,57 +75,70 @@ def add(self, data, connector):
def as_sql(self, qn, connection):
"""
Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty.
If 'node' is provided, that is the root of the SQL generation
(generally not needed except by the internal implementation for
recursion).
substituted in. Returns '', [] if this node matches everything,
None, [] if this node is empty, and raises EmptyResultSet if this
node can't match anything.
"""
if not self.children:
return None, []
# Note that the logic here is made slightly more complex than
# necessary because there are two kind of empty nodes: Nodes
# containing 0 children, and nodes that are known to match everything.
# A match-everything node is different than empty node (which also
# technically matches everything) for backwards compatibility reasons.
# Refs #5261.
result = []
result_params = []
empty = True
everything_childs, nothing_childs = 0, 0
non_empty_childs = len(self.children)

for child in self.children:
try:
if hasattr(child, 'as_sql'):
sql, params = child.as_sql(qn=qn, connection=connection)
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection)

except EmptyResultSet:
if self.connector == AND and not self.negated:
# We can bail out early in this particular case (only).
raise
elif self.negated:
empty = False
continue
except FullResultSet:
if self.connector == OR:
if self.negated:
empty = True
break
# We match everything. No need for any constraints.
nothing_childs += 1
else:
if sql:
result.append(sql)
result_params.extend(params)
else:
if sql is None:
# Skip empty childs totally.
non_empty_childs -= 1
continue
everything_childs += 1
# Check if this node matches nothing or everything.
# First check the amount of full nodes and empty nodes
# to make this node empty/full.
if self.connector == AND:
full_needed, empty_needed = non_empty_childs, 1
else:
full_needed, empty_needed = 1, non_empty_childs
# Now, check if this node is full/empty using the
# counts.
if empty_needed - nothing_childs <= 0:
if self.negated:
return '', []
else:
raise EmptyResultSet
if full_needed - everything_childs <= 0:
if self.negated:
empty = True
continue

empty = False
if sql:
result.append(sql)
result_params.extend(params)
if empty:
raise EmptyResultSet
raise EmptyResultSet
else:
return '', []

if non_empty_childs == 0:
# All the child nodes were empty, so this one is empty, too.
return None, []
conn = ' %s ' % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
sql_string = 'NOT (%s)' % sql_string
elif len(self.children) != 1:
if len(result) > 1:
sql_string = '(%s)' % sql_string
if self.negated:
sql_string = 'NOT %s' % sql_string
return sql_string, result_params

def make_atom(self, child, qn, connection):
Expand Down Expand Up @@ -261,7 +274,7 @@ class EverythingNode(object):
"""

def as_sql(self, qn=None, connection=None):
raise FullResultSet
return '', []

def relabel_aliases(self, change_map, node=None):
return
Expand Down
86 changes: 86 additions & 0 deletions tests/regressiontests/queries/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS
from django.db.models import Count
from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet
from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode
from django.db.models.sql.datastructures import EmptyResultSet
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import str_prefix
from django.utils import unittest
Expand Down Expand Up @@ -1316,10 +1318,23 @@ def test_ticket9848(self):
)

def test_ticket5261(self):
# Test different empty excludes.
self.assertQuerysetEqual(
Note.objects.exclude(Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.filter(~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.filter(~Q()|~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.exclude(~Q()&~Q()),
['<Note: n1>', '<Note: n2>']
)


class SelectRelatedTests(TestCase):
Expand Down Expand Up @@ -2020,3 +2035,74 @@ def test_evaluated_proxy_count(self):
self.assertEqual(qs.count(), 1)
str(qs.query)
self.assertEqual(qs.count(), 1)

class WhereNodeTest(TestCase):
class DummyNode(object):
def as_sql(self, qn, connection):
return 'dummy', []

def test_empty_full_handling_conjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('(dummy AND dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))

def test_empty_full_handling_disjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('(dummy OR dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT dummy', []))

def test_empty_nodes(self):
qn = connection.ops.quote_name
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.connector = 'OR'
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w = WhereNode(children=[empty_w, NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)

0 comments on commit bd283aa

Please sign in to comment.