Skip to content

Commit

Permalink
Merge a13825f into b15027a
Browse files Browse the repository at this point in the history
  • Loading branch information
dbartenstein committed Aug 19, 2021
2 parents b15027a + a13825f commit babbf69
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
34 changes: 32 additions & 2 deletions cachalot/tests/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OperationalError)
from django.db.models import Case, Count, Q, Value, When
from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists
from django.db.models.functions import Now
from django.db.models.functions import Coalesce, Now
from django.db.transaction import TransactionManagementError
from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings
from pytz import UTC
Expand Down Expand Up @@ -372,7 +372,7 @@ def test_annotate_subquery(self):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])

def test_annotate_case_with_when(self):
def test_annotate_case_with_when_and_query_in_default(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
first_test=Case(
Expand All @@ -383,6 +383,36 @@ def test_annotate_case_with_when(self):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])

def test_annotate_case_with_when(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
first_test=Case(
When(Q(pk=1), then=Subquery(tests[:1])),
default=Value('noname')
)
)
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])

def test_annotate_coalesce(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
name=Coalesce(
Subquery(tests[:1]),
Value('notest')
)
)
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])

def test_annotate_raw(self):
qs = User.objects.annotate(
perm_id=RawSQL('SELECT id FROM auth_permission WHERE id = %s',
(self.t1__permission.pk,))
)
self.assert_tables(qs, User, Permission)
self.assert_query_cached(qs, [self.user, self.admin])

def test_only(self):
with self.assertNumQueries(1):
t1 = Test.objects.only('name').first()
Expand Down
35 changes: 25 additions & 10 deletions cachalot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from django.contrib.postgres.functions import TransactionNow
from django.db import connections
from django.db.models import Case, Exists, QuerySet, Subquery
from django.db.models import Exists, QuerySet, Subquery
from django.db.models.expressions import BaseExpression, RawSQL
from django.db.models.functions import Now
from django.db.models.sql import Query, AggregateQuery
from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode
Expand Down Expand Up @@ -178,18 +179,32 @@ def __update_annotated_subquery(_annotation: Subquery):
else:
tables.update(_get_tables(db_alias, _annotation.query))

def flatten(expression: BaseExpression):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
Taken from Django 3.2 as the previous Django versions don’t check
for existence of flatten.
"""
yield expression
for expr in element.get_source_expressions():
if expr:
if hasattr(expr, 'flatten'):
yield from flatten(expr)
else:
yield expr

# Gets tables in subquery annotations.
for annotation in query.annotations.values():
if isinstance(annotation, Case):
for case in annotation.cases:
for subquery in _find_subqueries_in_where(case.condition.children):
tables.update(_get_tables(db_alias, subquery))
if isinstance(annotation.default, Subquery):
__update_annotated_subquery(annotation.default)
elif isinstance(annotation, Subquery):
__update_annotated_subquery(annotation)
elif type(annotation) in UNCACHABLE_FUNCS:
if type(annotation) in UNCACHABLE_FUNCS:
raise UncachableQuery
for element in flatten(annotation):
if isinstance(element, Subquery):
__update_annotated_subquery(element)
elif isinstance(element, RawSQL):
sql = repr(element).lower()
tables.update(_get_tables_from_sql(connections[db_alias], sql))
# Gets tables in WHERE subqueries.
for subquery in _find_subqueries_in_where(query.where.children):
tables.update(_get_tables(db_alias, subquery))
Expand Down

0 comments on commit babbf69

Please sign in to comment.