Skip to content

Commit

Permalink
Resolved #30969 -- Added support for query expressions as default values
Browse files Browse the repository at this point in the history
  • Loading branch information
codingjoe committed Jun 6, 2021
1 parent ecf8af7 commit dac5fe2
Show file tree
Hide file tree
Showing 15 changed files with 307 additions and 24 deletions.
30 changes: 25 additions & 5 deletions django/db/backends/base/schema.py
Expand Up @@ -121,6 +121,15 @@ def __exit__(self, exc_type, exc_value, traceback):

# Core utility functions

def prepare_param(self, node):
sql, params = None, (node,)
if hasattr(node, 'as_sql'):
compiler = self.connection.ops.compiler('SQLCompiler')(
query=None, connection=self.connection, using=None
)
sql, params = compiler.compile(node)
return sql, tuple(params)

def execute(self, sql, params=()):
"""Execute the given SQL statement, with optional parameters."""
# Don't perform the transactional DDL check if SQL is being collected
Expand All @@ -132,17 +141,25 @@ def execute(self, sql, params=()):
)
# Account for non-string statement objects.
sql = str(sql)
prepared_params = None
if params is not None:
prepared_params = [self.prepare_param(p) for p in params]
params_sql = tuple(filter(None, (p[0] for p in prepared_params)))
if params_sql:
sql %= params_sql
prepared_params = tuple(q for p in prepared_params for q in p[1])

# Log the command we're running, then run it
logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql})
logger.debug("%s; (params %r)", sql, prepared_params, extra={'params': prepared_params, 'sql': sql})
if self.collect_sql:
ending = "" if sql.rstrip().endswith(";") else ";"
if params is not None:
self.collected_sql.append((sql % tuple(map(self.quote_value, params))) + ending)
if prepared_params is not None:
self.collected_sql.append((sql % tuple(map(self.quote_value, prepared_params))) + ending)
else:
self.collected_sql.append(sql + ending)
else:
with self.connection.cursor() as cursor:
cursor.execute(sql, params)
cursor.execute(sql, prepared_params)

def quote_name(self, name):
return self.connection.ops.quote_name(name)
Expand Down Expand Up @@ -314,7 +331,10 @@ def _effective_default(field):

def effective_default(self, field):
"""Return a field's effective database default value."""
return field.get_db_prep_save(self._effective_default(field), self.connection)
default = self._effective_default(field)
if hasattr(default, 'as_sql'):
return default
return field.get_db_prep_save(default, self.connection)

def quote_value(self, value):
"""
Expand Down
3 changes: 3 additions & 0 deletions django/db/backends/oracle/schema.py
Expand Up @@ -152,6 +152,9 @@ def _generate_temp_name(self, for_name):
return self.normalize_name(for_name + "_" + suffix)

def prepare_default(self, value):
if hasattr(value, 'as_sql'):
sql, params = self.prepare_param(value)
return sql % tuple(self.quote_value(p) for p in params)
return self.quote_value(value)

def _field_should_be_indexed(self, model, field):
Expand Down
2 changes: 1 addition & 1 deletion django/db/models/base.py
Expand Up @@ -872,7 +872,7 @@ def _save_table(self, raw=False, cls=None, force_insert=False,
results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw)
if results:
for value, field in zip(results[0], returning_fields):
setattr(self, field.attname, value)
setattr(self, field.attname, field.to_python(value))
return updated

def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update):
Expand Down
2 changes: 2 additions & 0 deletions django/db/models/expressions.py
Expand Up @@ -581,6 +581,8 @@ def as_sql(self, compiler, connection):
class F(Combinable):
"""An object capable of resolving references to existing query objects."""

contains_column_references = True

def __init__(self, name):
"""
Arguments:
Expand Down
14 changes: 13 additions & 1 deletion django/db/models/fields/__init__.py
Expand Up @@ -197,6 +197,7 @@ def check(self, **kwargs):
return [
*self._check_field_name(),
*self._check_choices(),
*self._check_default_expressions(),
*self._check_db_index(),
*self._check_null_allowed_for_primary_keys(),
*self._check_backend_specific_checks(**kwargs),
Expand Down Expand Up @@ -236,6 +237,15 @@ def _check_field_name(self):
else:
return []

def _check_default_expressions(self):
if hasattr(self.default, 'contains_column_references') and self.default.contains_column_references:
return [checks.Error(
"'default' expressions cannot reference other fields (e.g. not contain a Q, or F expression).",
obj=self,
id='fields.E011',
)]
return []

@classmethod
def _choices_is_value(cls, value):
return isinstance(value, (str, Promise)) or not is_iterable(value)
Expand Down Expand Up @@ -762,7 +772,7 @@ def db_returning(self):
Private API intended only to be used by Django itself. Currently only
the PostgreSQL backend supports returning multiple fields on a model.
"""
return False
return self.has_default() and hasattr(self.default, 'as_sql')

def set_attributes_from_name(self, name):
self.name = self.name or name
Expand Down Expand Up @@ -2420,6 +2430,8 @@ def get_prep_value(self, value):
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
if hasattr(value, 'as_sql'):
return value
if not isinstance(value, uuid.UUID):
value = self.to_python(value)

Expand Down
1 change: 1 addition & 0 deletions django/db/models/query_utils.py
Expand Up @@ -36,6 +36,7 @@ class Q(tree.Node):
OR = 'OR'
default = AND
conditional = True
contains_column_references = True

def __init__(self, *args, _connector=None, _negated=False, **kwargs):
super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated)
Expand Down
2 changes: 2 additions & 0 deletions docs/ref/checks.txt
Expand Up @@ -185,6 +185,8 @@ Model fields
``choices`` (``<count>`` characters).
* **fields.E010**: ``<field>`` default should be a callable instead of an
instance so that it's not shared between all field instances.
* **fields.E011**: ``<field>`` default expression may not reference database
columns (e.g. contain F or Q expressions).
* **fields.E100**: ``AutoField``\s must set primary_key=True.
* **fields.E110**: ``BooleanField``\s do not accept null values. *This check
appeared before support for null values was added in Django 2.1.*
Expand Down
22 changes: 22 additions & 0 deletions docs/ref/models/fields.txt
Expand Up @@ -368,6 +368,28 @@ The default value is used when new model instances are created and a value
isn't provided for the field. When the field is a primary key, the default is
also used when the field is set to ``None``.

.. versionadded:: 4.0

The MariaDB (10.5+), Oracle and PostgreSQL database backends support
:doc:`Query Expressions </ref/models/expressions>` as ``default`` values::

from django.db import models
from django.db.models.functions import ExtractYear, Now, Pi

class ReturningModel(models.Model)
created = models.DateTimeField(editable=False, default=Now())
pi = models.FloatField(default=Pi())
year = models.PositiveSmallIntegerField(default=ExtractYear(Now()))

.. note::
The query expressions are not stored in the database as column default
values, but expressions are passed directly to the database during inserts
and updates. Therefore, the defaults will not be available should you
perform any raw SQL ``INSERT`` or ``UPDATE``; you must explicitly provide
an equivalent SQL expression. You will not be able to use the ``DEFAULT` or
``DEFAULT VALUES`` keywords to populate columns for fields using query
expressions.

``editable``
------------

Expand Down
3 changes: 3 additions & 0 deletions docs/releases/4.0.txt
Expand Up @@ -248,6 +248,9 @@ Models

* :class:`~django.db.models.DurationField` now supports multiplying and
dividing by scalar values on SQLite.
* The MariaDB, Oracle and PostgreSQL backends support
:doc:`Query Expressions </ref/models/expressions>` as
:attr:`~django.db.models.Field.default` values for model fields.

Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
30 changes: 30 additions & 0 deletions tests/invalid_models_tests/test_ordinary_fields.py
Expand Up @@ -3,6 +3,7 @@

from django.core.checks import Error, Warning as DjangoWarning
from django.db import connection, models
from django.db.models.functions import Now
from django.test import (
SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature,
)
Expand Down Expand Up @@ -334,6 +335,35 @@ class Model(models.Model):
),
])

def test_default_expressions(self):
class Model(models.Model):
created = models.DateField(default=Now())

field = Model._meta.get_field('created')
self.assertEqual(field.check(), [])

def test_default_expressions_failure(self):
class Model(models.Model):
q = models.BooleanField(default=models.Q(pk__gte=100))
wrapped_q = models.BooleanField(
default=models.ExpressionWrapper(models.Q(pk__gte=100), output_field=models.BooleanField())
)
max_pk = models.BooleanField(
default=models.Max('pk')
)

for name in ('q', 'wrapped_q', 'max_pk'):
with self.subTest(name):
field = Model._meta.get_field(name)
self.assertEqual(field.check(), [
Error(
"'default' expressions cannot reference other fields"
" (e.g. not contain a Q, or F expression).",
obj=field,
id='fields.E011',
),
])

def test_bad_db_index_value(self):
class Model(models.Model):
field = models.CharField(max_length=10, db_index='bad')
Expand Down
44 changes: 36 additions & 8 deletions tests/queries/models.py
Expand Up @@ -2,7 +2,11 @@
Various complex queries that have been problematic in the past.
"""
from django.db import models
from django.db.models.functions import Now
from django.db.models import ExpressionWrapper, Value
from django.db.models.expressions import Func, RawSQL
from django.db.models.functions import (
Cast, Coalesce, ExtractYear, Now, Pi, TruncDay,
)


class DumbCategory(models.Model):
Expand Down Expand Up @@ -727,20 +731,44 @@ class CustomDbColumn(models.Model):
ip_address = models.GenericIPAddressField(null=True)


class CreatedField(models.DateTimeField):
db_returning = True
class NativeUUID4:
"""UUID as implemented by the database."""
def as_sql(self, compiler, connection):
raise NotImplementedError

def __init__(self, *args, **kwargs):
kwargs.setdefault('default', Now)
super().__init__(*args, **kwargs)
def as_postgresql(self, compiler, connection):
return "uuid_generate_v4()", ()

def as_oracle(self, compiler, connection):
return "SYS_GUID()", ()

def as_mysql(self, compiler, connection):
return "UUID()", ()


class Mod(Func):
function = 'MOD'


class ReturningModel(models.Model):
created = CreatedField(editable=False)
created = models.DateTimeField(editable=False, default=Now())
created_date = models.DateField(default=TruncDay(Now()))
year = models.PositiveSmallIntegerField(default=ExtractYear(Now()))
pi = models.FloatField(default=Pi())
expr_wrapper = models.PositiveBigIntegerField(default=ExpressionWrapper(
Value(1) + Value(2),
output_field=models.PositiveBigIntegerField(),
))
coalesce_val = models.PositiveBigIntegerField(
default=Coalesce(Value(None), Value(1337), output_field=models.PositiveBigIntegerField()),
)
raw_sql = models.IntegerField(default=RawSQL('5 * %s', (15,), output_field=models.IntegerField()))
uuid = models.UUIDField(default=NativeUUID4())
is_odd = models.BooleanField(default=Cast(Mod(Value(10), Value(2)), output_field=models.BooleanField()))


class NonIntegerPKReturningModel(models.Model):
created = CreatedField(editable=False, primary_key=True)
created = models.DateTimeField(editable=False, primary_key=True, default=Now())


class JSONFieldNullable(models.Model):
Expand Down
47 changes: 46 additions & 1 deletion tests/queries/test_db_returning.py
@@ -1,4 +1,5 @@
import datetime
import uuid

from django.db import connection
from django.test import TestCase, skipUnlessDBFeature
Expand All @@ -7,8 +8,16 @@
from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel


class UUIDExtensionTestCase(TestCase):
def setUp(self):
super().setUp()
if connection.vendor == 'postgresql':
with connection.cursor() as cursor:
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')


@skipUnlessDBFeature('can_return_columns_from_insert')
class ReturningValuesTests(TestCase):
class ReturningValuesTests(UUIDExtensionTestCase):
def test_insert_returning(self):
with CaptureQueriesContext(connection) as captured_queries:
DumbCategory.objects.create()
Expand Down Expand Up @@ -49,3 +58,39 @@ def test_bulk_insert(self):
with self.subTest(obj=obj):
self.assertTrue(obj.pk)
self.assertIsInstance(obj.created, datetime.datetime)


@skipUnlessDBFeature('can_return_columns_from_insert')
class DatabaseDefaultsTests(UUIDExtensionTestCase):
def test_now(self):
obj = ReturningModel.objects.create()
self.assertIsInstance(obj.created, datetime.datetime)

def test_truncate(self):
obj = ReturningModel.objects.create()
self.assertIsInstance(obj.year, int)

def test_pi(self):
obj = ReturningModel.objects.create()
self.assertAlmostEqual(obj.pi, 3.1415926535897) # decimal precision varies

def test_expression_wrapper(self):
obj = ReturningModel.objects.create()
self.assertEqual(obj.expr_wrapper, 3)

def test_coalesce_value(self):
obj = ReturningModel.objects.create()
self.assertEqual(obj.coalesce_val, 1337)

def test_raw_sql(self):
obj = ReturningModel.objects.create()
self.assertEqual(obj.raw_sql, 75)

def test_custom_uuid(self):
obj = ReturningModel.objects.create()
self.assertIsInstance(obj.uuid, uuid.UUID)

def test_boolean_function(self):
obj = ReturningModel.objects.create()
self.assertIsInstance(obj.is_odd, bool)
self.assertFalse(obj.is_odd)
2 changes: 2 additions & 0 deletions tests/schema/models.py
@@ -1,5 +1,6 @@
from django.apps.registry import Apps
from django.db import models
from django.db.models.functions import Now

# Because we want to test creation and deletion of these as separate things,
# these models are all inserted into a separate Apps so the main test
Expand All @@ -13,6 +14,7 @@ class Author(models.Model):
height = models.PositiveIntegerField(null=True, blank=True)
weight = models.IntegerField(null=True, blank=True)
uuid = models.UUIDField(null=True)
date_of_birth = models.DateTimeField(default=Now)

class Meta:
apps = new_apps
Expand Down
4 changes: 2 additions & 2 deletions tests/schema/test_logging.py
Expand Up @@ -7,12 +7,12 @@ class SchemaLoggerTests(TestCase):
def test_extra_args(self):
editor = connection.schema_editor(collect_sql=True)
sql = 'SELECT * FROM foo WHERE id in (%s, %s)'
params = [42, 1337]
params = (42, 1337)
with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm:
editor.execute(sql, params)
self.assertEqual(cm.records[0].sql, sql)
self.assertEqual(cm.records[0].params, params)
self.assertEqual(
cm.records[0].getMessage(),
'SELECT * FROM foo WHERE id in (%s, %s); (params [42, 1337])',
'SELECT * FROM foo WHERE id in (%s, %s); (params (42, 1337))',
)

0 comments on commit dac5fe2

Please sign in to comment.