Skip to content

Commit

Permalink
Fixed #30446 -- Resolved Value.output_field for stdlib types.
Browse files Browse the repository at this point in the history
This required implementing a limited form of dynamic dispatch to combine
expressions with numerical output. Refs #26355 should eventually provide
a better interface for that.
  • Loading branch information
charettes authored and felixxm committed Jul 15, 2020
1 parent d08e6f5 commit 1e38f11
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 39 deletions.
11 changes: 7 additions & 4 deletions django/contrib/gis/db/models/functions.py
Expand Up @@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin:
is not acceptable by the GIS functions expecting numeric values.
"""
def as_sqlite(self, compiler, connection, **extra_context):
for expr in self.get_source_expressions():
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
expr.value = float(expr.value)
return super().as_sql(compiler, connection, **extra_context)
copy = self.copy()
copy.set_source_expressions([
Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal)
else expr
for expr in copy.get_source_expressions()
])
return copy.as_sql(compiler, connection, **extra_context)


class OracleToleranceMixin:
Expand Down
3 changes: 1 addition & 2 deletions django/contrib/postgres/fields/ranges.py
Expand Up @@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date):
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
value = models.Value(self.rhs, output_field=output_field)
value = models.Value(self.rhs)
self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection)

Expand Down
66 changes: 59 additions & 7 deletions django/db/models/expressions.py
@@ -1,7 +1,9 @@
import copy
import datetime
import functools
import inspect
from decimal import Decimal
from uuid import UUID

from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection
Expand Down Expand Up @@ -56,12 +58,7 @@ class Combinable:
def _combine(self, other, connector, reversed):
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
output_field = (
fields.DurationField()
if isinstance(other, datetime.timedelta) else
None
)
other = Value(other, output_field=output_field)
other = Value(other)

if reversed:
return CombinedExpression(other, connector, self)
Expand Down Expand Up @@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
pass


_connector_combinators = {
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
}


@functools.lru_cache(maxsize=128)
def _resolve_combined_type(connector, lhs_type, rhs_type):
combinators = _connector_combinators.get(connector, ())
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
return combined_type


class CombinedExpression(SQLiteNumericMixin, Expression):

def __init__(self, lhs, connector, rhs, output_field=None):
Expand All @@ -442,6 +458,19 @@ def get_source_expressions(self):
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs

def _resolve_output_field(self):
try:
return super()._resolve_output_field()
except FieldError:
combined_type = _resolve_combined_type(
self.connector,
type(self.lhs.output_field),
type(self.rhs.output_field),
)
if combined_type is None:
raise
return combined_type()

def as_sql(self, compiler, connection):
expressions = []
expression_params = []
Expand Down Expand Up @@ -721,6 +750,30 @@ def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize
def get_group_by_cols(self, alias=None):
return []

def _resolve_output_field(self):
if isinstance(self.value, str):
return fields.CharField()
if isinstance(self.value, bool):
return fields.BooleanField()
if isinstance(self.value, int):
return fields.IntegerField()
if isinstance(self.value, float):
return fields.FloatField()
if isinstance(self.value, datetime.datetime):
return fields.DateTimeField()
if isinstance(self.value, datetime.date):
return fields.DateField()
if isinstance(self.value, datetime.time):
return fields.TimeField()
if isinstance(self.value, datetime.timedelta):
return fields.DurationField()
if isinstance(self.value, Decimal):
return fields.DecimalField()
if isinstance(self.value, bytes):
return fields.BinaryField()
if isinstance(self.value, UUID):
return fields.UUIDField()


class RawSQL(Expression):
def __init__(self, sql, params, output_field=None):
Expand Down Expand Up @@ -1177,7 +1230,6 @@ def as_oracle(self, compiler, connection):
copy.expression = Case(
When(self.expression, then=True),
default=False,
output_field=fields.BooleanField(),
)
return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection)
Expand Down
4 changes: 2 additions & 2 deletions django/db/models/lookups.py
Expand Up @@ -6,7 +6,7 @@
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Exists, Func, Value, When
from django.db.models.fields import (
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
CharField, DateTimeField, Field, IntegerField, UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
Expand Down Expand Up @@ -123,7 +123,7 @@ def as_oracle(self, compiler, connection):
exprs = []
for expr in (self.lhs, self.rhs):
if isinstance(expr, Exists):
expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
Expand Down
10 changes: 9 additions & 1 deletion docs/ref/models/expressions.txt
Expand Up @@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like
after it's retrieved from the database. Usually no arguments are needed when
instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value.
output value. If no ``output_field`` is specified it will be tentatively
inferred from the :py:class:`type` of the provided ``value``, if possible. For
example, passing an instance of :py:class:`datetime.datetime` as ``value``
would default ``output_field`` to :class:`~django.db.models.DateTimeField`.

.. versionchanged:: 3.2

Support for inferring a default ``output_field`` from the type of ``value``
was added.

``ExpressionWrapper()`` expressions
-----------------------------------
Expand Down
9 changes: 9 additions & 0 deletions docs/releases/3.2.txt
Expand Up @@ -233,6 +233,15 @@ Models
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
on MySQL 8.0.1+.

* :class:`Value() <django.db.models.Value>` expression now
automatically resolves its ``output_field`` to the appropriate
:class:`Field <django.db.models.Field>` subclass based on the type of
it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`,
:py:class:`float`, :py:class:`int`, :py:class:`str`,
:py:class:`datetime.date`, :py:class:`datetime.datetime`,
:py:class:`datetime.time`, :py:class:`datetime.timedelta`,
:py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances.

Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 1 addition & 5 deletions tests/aggregation/tests.py
Expand Up @@ -848,10 +848,6 @@ def test_nonfield_annotation(self):
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
self.assertEqual(book.val, 2)

def test_missing_output_field_raises_error(self):
with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(2)).first()

def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
Expand Down Expand Up @@ -893,7 +889,7 @@ def test_order_of_precedence(self):

def test_combine_different_types(self):
msg = (
'Expression contains mixed types: FloatField, IntegerField. '
'Expression contains mixed types: FloatField, DecimalField. '
'You must set output_field.'
)
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
Expand Down
2 changes: 1 addition & 1 deletion tests/aggregation_regress/tests.py
Expand Up @@ -388,7 +388,7 @@ def test_sliced_conditional_aggregate(self):
)

def test_annotated_conditional_aggregate(self):
annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75)
annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75'))
self.assertAlmostEqual(
annotated_qs.aggregate(test=Avg(Case(
When(pages__lt=400, then='discount_price'),
Expand Down
38 changes: 33 additions & 5 deletions tests/expressions/tests.py
Expand Up @@ -3,15 +3,17 @@
import unittest
import uuid
from copy import deepcopy
from decimal import Decimal
from unittest import mock

from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import (
Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField,
DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F,
Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev,
Subquery, Sum, TimeField, UUIDField, Value, Variance, When,
Avg, BinaryField, BooleanField, Case, CharField, Count, DateField,
DateTimeField, DecimalField, DurationField, Exists, Expression,
ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField,
UUIDField, Value, Variance, When,
)
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
from django.db.models.functions import (
Expand Down Expand Up @@ -1711,6 +1713,30 @@ def test_compile_unresolved(self):
value = Value('foo', output_field=CharField())
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))

def test_resolve_output_field(self):
value_types = [
('str', CharField),
(True, BooleanField),
(42, IntegerField),
(3.14, FloatField),
(datetime.date(2019, 5, 15), DateField),
(datetime.datetime(2019, 5, 15), DateTimeField),
(datetime.time(3, 16), TimeField),
(datetime.timedelta(1), DurationField),
(Decimal('3.14'), DecimalField),
(b'', BinaryField),
(uuid.uuid4(), UUIDField),
]
for value, ouput_field_type in value_types:
with self.subTest(type=type(value)):
expr = Value(value)
self.assertIsInstance(expr.output_field, ouput_field_type)

def test_resolve_output_field_failure(self):
msg = 'Cannot resolve expression type, unknown output_field'
with self.assertRaisesMessage(FieldError, msg):
Value(object()).output_field


class FieldTransformTests(TestCase):

Expand Down Expand Up @@ -1848,7 +1874,9 @@ def test_empty_group_by(self):
self.assertEqual(expr.get_group_by_cols(alias=None), [])

def test_non_empty_group_by(self):
expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField())
value = Value('f')
value.output_field = None
expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
group_by_cols = expr.get_group_by_cols(alias=None)
self.assertEqual(group_by_cols, [expr.expression])
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
12 changes: 0 additions & 12 deletions tests/ordering/tests.py
@@ -1,7 +1,6 @@
from datetime import datetime
from operator import attrgetter

from django.core.exceptions import FieldError
from django.db.models import (
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
)
Expand Down Expand Up @@ -439,17 +438,6 @@ def test_order_by_constant_value(self):
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])

def test_order_by_constant_value_without_output_field(self):
msg = 'Cannot resolve expression type, unknown output_field'
qs = Article.objects.annotate(constant=Value('1')).order_by('constant')
for ordered_qs in (
qs,
qs.values('headline'),
Article.objects.order_by(Value('1')),
):
with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg):
ordered_qs.first()

def test_related_ordering_duplicate_table_reference(self):
"""
An ordering referencing a model with an ordering referencing a model
Expand Down

0 comments on commit 1e38f11

Please sign in to comment.