From e942272809c54dff4073f3598aa9fc230d3b83e1 Mon Sep 17 00:00:00 2001 From: Sebastien Corbin Date: Fri, 17 Apr 2020 15:47:31 +0200 Subject: [PATCH 1/2] Add test for OuterRef usage in CTE --- tests/test_cte.py | 51 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/test_cte.py b/tests/test_cte.py index 6310ae3..6cceb4d 100644 --- a/tests/test_cte.py +++ b/tests/test_cte.py @@ -5,8 +5,10 @@ from unittest import SkipTest from django.db.models import IntegerField, TextField -from django.db.models.aggregates import Count, Sum -from django.db.models.expressions import Exists, F, OuterRef, Subquery, Value +from django.db.models.aggregates import Count, Max, Min, Sum +from django.db.models.expressions import ( + Exists, ExpressionWrapper, F, OuterRef, Subquery, Value, +) from django.db.models.functions import Concat from django.test import TestCase @@ -323,3 +325,48 @@ def test_delete_cte_query(self): ('earth', 33), ('proxima centauri', 2000), ]) + + def test_outerref_in_cte_query(self): + # This query is meant to return the difference between min and max + # order of each region, through a subquery + min_and_max = With( + Order.objects + .filter(region=OuterRef("pk")) + .values('region') # This is to force group by region_id + .annotate( + amount_min=Min("amount"), + amount_max=Max("amount"), + ) + .values('amount_min', 'amount_max') + ) + regions = ( + Region.objects + .annotate( + difference=Subquery( + min_and_max.queryset().with_cte(min_and_max).annotate( + difference=ExpressionWrapper( + F('amount_max') - F('amount_min'), + output_field=int_field, + ), + ).values('difference')[:1], + output_field=IntegerField() + ) + ) + .order_by("name") + ) + print(regions.query) + + data = [(r.name, r.difference) for r in regions] + self.assertEqual(data, [ + ("bernard's star", None), + ('deimos', None), + ('earth', 3), + ('mars', 2), + ('mercury', 2), + ('moon', 2), + ('phobos', None), + ('proxima centauri', 0), + ('proxima centauri b', 2), + ('sun', 0), + ('venus', 3) + ]) From 31754cd5a8caa554c92ce262d85737dd40f41aa6 Mon Sep 17 00:00:00 2001 From: Daniel Miller Date: Fri, 24 Apr 2020 16:48:07 -0400 Subject: [PATCH 2/2] Fix #15 Resolve OuterRef for CTE subqueries --- django_cte/expressions.py | 50 +++++++++++++++++++++++++++++++++++++++ django_cte/query.py | 6 +++++ 2 files changed, 56 insertions(+) create mode 100644 django_cte/expressions.py diff --git a/django_cte/expressions.py b/django_cte/expressions.py new file mode 100644 index 0000000..64e20ec --- /dev/null +++ b/django_cte/expressions.py @@ -0,0 +1,50 @@ +import django +from django.db.models import Subquery + + +class CTESubqueryResolver(object): + + def __init__(self, annotation): + self.annotation = annotation + + def resolve_expression(self, *args, **kw): + # source: django.db.models.expressions.Subquery.resolve_expression + # --- begin copied code (lightly adapted) --- # + + # Need to recursively resolve these. + def resolve_all(child): + if hasattr(child, 'children'): + [resolve_all(_child) for _child in child.children] + if hasattr(child, 'rhs'): + child.rhs = resolve(child.rhs) + + def resolve(child): + if hasattr(child, 'resolve_expression'): + resolved = child.resolve_expression(*args, **kw) + # Add table alias to the parent query's aliases to prevent + # quoting. + if hasattr(resolved, 'alias') and \ + resolved.alias != resolved.target.model._meta.db_table: + get_query(clone).external_aliases.add(resolved.alias) + return resolved + return child + + # --- end copied code --- # + + if django.VERSION < (3, 0): + def get_query(clone): + return clone.queryset.query + else: + def get_query(clone): + return clone.query + + # NOTE this uses the old (pre-Django 3) way of resolving. + # Should a different technique should be used on Django 3+? + clone = self.annotation.resolve_expression(*args, **kw) + if isinstance(self.annotation, Subquery): + for cte in get_query(clone)._with_ctes: + resolve_all(cte.query.where) + for key, value in cte.query.annotations.items(): + if isinstance(value, Subquery): + cte.query.annotations[key] = resolve(value) + return clone diff --git a/django_cte/query.py b/django_cte/query.py index a8d7ff3..756bdab 100644 --- a/django_cte/query.py +++ b/django_cte/query.py @@ -10,6 +10,8 @@ SQLUpdateCompiler, ) +from .expressions import CTESubqueryResolver + class CTEQuery(Query): """A Query which processes SQL compilation through the CTE compiler""" @@ -42,6 +44,10 @@ def get_compiler(self, using=None, connection=None): klass = COMPILER_TYPES.get(self.__class__, CTEQueryCompiler) return klass(self, connection, using) + def add_annotation(self, annotation, *args, **kw): + annotation = CTESubqueryResolver(annotation) + super(CTEQuery, self).add_annotation(annotation, *args, **kw) + def __chain(self, _name, klass=None, *args, **kwargs): klass = QUERY_TYPES.get(klass, self.__class__) clone = getattr(super(CTEQuery, self), _name)(klass, *args, **kwargs)