From dac5fe28a675ddf22da263a15938d6c3963d8d63 Mon Sep 17 00:00:00 2001 From: Johannes Hoppe Date: Sat, 14 Sep 2019 19:56:28 -0700 Subject: [PATCH] Resolved #30969 -- Added support for query expressions as default values Ref ticket-31206 --- django/db/backends/base/schema.py | 30 ++++- django/db/backends/oracle/schema.py | 3 + django/db/models/base.py | 2 +- django/db/models/expressions.py | 2 + django/db/models/fields/__init__.py | 14 +- django/db/models/query_utils.py | 1 + docs/ref/checks.txt | 2 + docs/ref/models/fields.txt | 22 +++ docs/releases/4.0.txt | 3 + .../test_ordinary_fields.py | 30 +++++ tests/queries/models.py | 44 ++++-- tests/queries/test_db_returning.py | 47 ++++++- tests/schema/models.py | 2 + tests/schema/test_logging.py | 4 +- tests/schema/tests.py | 125 +++++++++++++++++- 15 files changed, 307 insertions(+), 24 deletions(-) diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index c409464ecaf7e..5b2b81a599374 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -121,6 +121,15 @@ def __exit__(self, exc_type, exc_value, traceback): # Core utility functions + def prepare_param(self, node): + sql, params = None, (node,) + if hasattr(node, 'as_sql'): + compiler = self.connection.ops.compiler('SQLCompiler')( + query=None, connection=self.connection, using=None + ) + sql, params = compiler.compile(node) + return sql, tuple(params) + def execute(self, sql, params=()): """Execute the given SQL statement, with optional parameters.""" # Don't perform the transactional DDL check if SQL is being collected @@ -132,17 +141,25 @@ def execute(self, sql, params=()): ) # Account for non-string statement objects. sql = str(sql) + prepared_params = None + if params is not None: + prepared_params = [self.prepare_param(p) for p in params] + params_sql = tuple(filter(None, (p[0] for p in prepared_params))) + if params_sql: + sql %= params_sql + prepared_params = tuple(q for p in prepared_params for q in p[1]) + # Log the command we're running, then run it - logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql}) + logger.debug("%s; (params %r)", sql, prepared_params, extra={'params': prepared_params, 'sql': sql}) if self.collect_sql: ending = "" if sql.rstrip().endswith(";") else ";" - if params is not None: - self.collected_sql.append((sql % tuple(map(self.quote_value, params))) + ending) + if prepared_params is not None: + self.collected_sql.append((sql % tuple(map(self.quote_value, prepared_params))) + ending) else: self.collected_sql.append(sql + ending) else: with self.connection.cursor() as cursor: - cursor.execute(sql, params) + cursor.execute(sql, prepared_params) def quote_name(self, name): return self.connection.ops.quote_name(name) @@ -314,7 +331,10 @@ def _effective_default(field): def effective_default(self, field): """Return a field's effective database default value.""" - return field.get_db_prep_save(self._effective_default(field), self.connection) + default = self._effective_default(field) + if hasattr(default, 'as_sql'): + return default + return field.get_db_prep_save(default, self.connection) def quote_value(self, value): """ diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py index cf875fc5e3ea6..dc102a2ae3039 100644 --- a/django/db/backends/oracle/schema.py +++ b/django/db/backends/oracle/schema.py @@ -152,6 +152,9 @@ def _generate_temp_name(self, for_name): return self.normalize_name(for_name + "_" + suffix) def prepare_default(self, value): + if hasattr(value, 'as_sql'): + sql, params = self.prepare_param(value) + return sql % tuple(self.quote_value(p) for p in params) return self.quote_value(value) def _field_should_be_indexed(self, model, field): diff --git a/django/db/models/base.py b/django/db/models/base.py index 0f8af9f9204e3..11b7f04dd34ef 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -872,7 +872,7 @@ def _save_table(self, raw=False, cls=None, force_insert=False, results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw) if results: for value, field in zip(results[0], returning_fields): - setattr(self, field.attname, value) + setattr(self, field.attname, field.to_python(value)) return updated def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index feb04d4585bff..15b9a700c707a 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -581,6 +581,8 @@ def as_sql(self, compiler, connection): class F(Combinable): """An object capable of resolving references to existing query objects.""" + contains_column_references = True + def __init__(self, name): """ Arguments: diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 51155f57769f4..6a7a131da9dcd 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -197,6 +197,7 @@ def check(self, **kwargs): return [ *self._check_field_name(), *self._check_choices(), + *self._check_default_expressions(), *self._check_db_index(), *self._check_null_allowed_for_primary_keys(), *self._check_backend_specific_checks(**kwargs), @@ -236,6 +237,15 @@ def _check_field_name(self): else: return [] + def _check_default_expressions(self): + if hasattr(self.default, 'contains_column_references') and self.default.contains_column_references: + return [checks.Error( + "'default' expressions cannot reference other fields (e.g. not contain a Q, or F expression).", + obj=self, + id='fields.E011', + )] + return [] + @classmethod def _choices_is_value(cls, value): return isinstance(value, (str, Promise)) or not is_iterable(value) @@ -762,7 +772,7 @@ def db_returning(self): Private API intended only to be used by Django itself. Currently only the PostgreSQL backend supports returning multiple fields on a model. """ - return False + return self.has_default() and hasattr(self.default, 'as_sql') def set_attributes_from_name(self, name): self.name = self.name or name @@ -2420,6 +2430,8 @@ def get_prep_value(self, value): def get_db_prep_value(self, value, connection, prepared=False): if value is None: return None + if hasattr(value, 'as_sql'): + return value if not isinstance(value, uuid.UUID): value = self.to_python(value) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 188b6408507e8..d4a7afae5220c 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -36,6 +36,7 @@ class Q(tree.Node): OR = 'OR' default = AND conditional = True + contains_column_references = True def __init__(self, *args, _connector=None, _negated=False, **kwargs): super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated) diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index f304da7e1168c..f33d3be8e1905 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -185,6 +185,8 @@ Model fields ``choices`` (```` characters). * **fields.E010**: ```` default should be a callable instead of an instance so that it's not shared between all field instances. +* **fields.E011**: ```` default expression may not reference database + columns (e.g. contain F or Q expressions). * **fields.E100**: ``AutoField``\s must set primary_key=True. * **fields.E110**: ``BooleanField``\s do not accept null values. *This check appeared before support for null values was added in Django 2.1.* diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 98475605041b7..4c7455f3c7ae2 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -368,6 +368,28 @@ The default value is used when new model instances are created and a value isn't provided for the field. When the field is a primary key, the default is also used when the field is set to ``None``. +.. versionadded:: 4.0 + +The MariaDB (10.5+), Oracle and PostgreSQL database backends support +:doc:`Query Expressions ` as ``default`` values:: + + from django.db import models + from django.db.models.functions import ExtractYear, Now, Pi + + class ReturningModel(models.Model) + created = models.DateTimeField(editable=False, default=Now()) + pi = models.FloatField(default=Pi()) + year = models.PositiveSmallIntegerField(default=ExtractYear(Now())) + +.. note:: + The query expressions are not stored in the database as column default + values, but expressions are passed directly to the database during inserts + and updates. Therefore, the defaults will not be available should you + perform any raw SQL ``INSERT`` or ``UPDATE``; you must explicitly provide + an equivalent SQL expression. You will not be able to use the ``DEFAULT` or + ``DEFAULT VALUES`` keywords to populate columns for fields using query + expressions. + ``editable`` ------------ diff --git a/docs/releases/4.0.txt b/docs/releases/4.0.txt index 60c449bc88218..03a51ba16136c 100644 --- a/docs/releases/4.0.txt +++ b/docs/releases/4.0.txt @@ -248,6 +248,9 @@ Models * :class:`~django.db.models.DurationField` now supports multiplying and dividing by scalar values on SQLite. +* The MariaDB, Oracle and PostgreSQL backends support + :doc:`Query Expressions ` as + :attr:`~django.db.models.Field.default` values for model fields. Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 6eddd853af56f..011471898105c 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -3,6 +3,7 @@ from django.core.checks import Error, Warning as DjangoWarning from django.db import connection, models +from django.db.models.functions import Now from django.test import ( SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature, ) @@ -334,6 +335,35 @@ class Model(models.Model): ), ]) + def test_default_expressions(self): + class Model(models.Model): + created = models.DateField(default=Now()) + + field = Model._meta.get_field('created') + self.assertEqual(field.check(), []) + + def test_default_expressions_failure(self): + class Model(models.Model): + q = models.BooleanField(default=models.Q(pk__gte=100)) + wrapped_q = models.BooleanField( + default=models.ExpressionWrapper(models.Q(pk__gte=100), output_field=models.BooleanField()) + ) + max_pk = models.BooleanField( + default=models.Max('pk') + ) + + for name in ('q', 'wrapped_q', 'max_pk'): + with self.subTest(name): + field = Model._meta.get_field(name) + self.assertEqual(field.check(), [ + Error( + "'default' expressions cannot reference other fields" + " (e.g. not contain a Q, or F expression).", + obj=field, + id='fields.E011', + ), + ]) + def test_bad_db_index_value(self): class Model(models.Model): field = models.CharField(max_length=10, db_index='bad') diff --git a/tests/queries/models.py b/tests/queries/models.py index 383f633be9a7c..c6da70f43cad6 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -2,7 +2,11 @@ Various complex queries that have been problematic in the past. """ from django.db import models -from django.db.models.functions import Now +from django.db.models import ExpressionWrapper, Value +from django.db.models.expressions import Func, RawSQL +from django.db.models.functions import ( + Cast, Coalesce, ExtractYear, Now, Pi, TruncDay, +) class DumbCategory(models.Model): @@ -727,20 +731,44 @@ class CustomDbColumn(models.Model): ip_address = models.GenericIPAddressField(null=True) -class CreatedField(models.DateTimeField): - db_returning = True +class NativeUUID4: + """UUID as implemented by the database.""" + def as_sql(self, compiler, connection): + raise NotImplementedError - def __init__(self, *args, **kwargs): - kwargs.setdefault('default', Now) - super().__init__(*args, **kwargs) + def as_postgresql(self, compiler, connection): + return "uuid_generate_v4()", () + + def as_oracle(self, compiler, connection): + return "SYS_GUID()", () + + def as_mysql(self, compiler, connection): + return "UUID()", () + + +class Mod(Func): + function = 'MOD' class ReturningModel(models.Model): - created = CreatedField(editable=False) + created = models.DateTimeField(editable=False, default=Now()) + created_date = models.DateField(default=TruncDay(Now())) + year = models.PositiveSmallIntegerField(default=ExtractYear(Now())) + pi = models.FloatField(default=Pi()) + expr_wrapper = models.PositiveBigIntegerField(default=ExpressionWrapper( + Value(1) + Value(2), + output_field=models.PositiveBigIntegerField(), + )) + coalesce_val = models.PositiveBigIntegerField( + default=Coalesce(Value(None), Value(1337), output_field=models.PositiveBigIntegerField()), + ) + raw_sql = models.IntegerField(default=RawSQL('5 * %s', (15,), output_field=models.IntegerField())) + uuid = models.UUIDField(default=NativeUUID4()) + is_odd = models.BooleanField(default=Cast(Mod(Value(10), Value(2)), output_field=models.BooleanField())) class NonIntegerPKReturningModel(models.Model): - created = CreatedField(editable=False, primary_key=True) + created = models.DateTimeField(editable=False, primary_key=True, default=Now()) class JSONFieldNullable(models.Model): diff --git a/tests/queries/test_db_returning.py b/tests/queries/test_db_returning.py index 9ba352a7ab7f6..eb7e112751ca5 100644 --- a/tests/queries/test_db_returning.py +++ b/tests/queries/test_db_returning.py @@ -1,4 +1,5 @@ import datetime +import uuid from django.db import connection from django.test import TestCase, skipUnlessDBFeature @@ -7,8 +8,16 @@ from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel +class UUIDExtensionTestCase(TestCase): + def setUp(self): + super().setUp() + if connection.vendor == 'postgresql': + with connection.cursor() as cursor: + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + + @skipUnlessDBFeature('can_return_columns_from_insert') -class ReturningValuesTests(TestCase): +class ReturningValuesTests(UUIDExtensionTestCase): def test_insert_returning(self): with CaptureQueriesContext(connection) as captured_queries: DumbCategory.objects.create() @@ -49,3 +58,39 @@ def test_bulk_insert(self): with self.subTest(obj=obj): self.assertTrue(obj.pk) self.assertIsInstance(obj.created, datetime.datetime) + + +@skipUnlessDBFeature('can_return_columns_from_insert') +class DatabaseDefaultsTests(UUIDExtensionTestCase): + def test_now(self): + obj = ReturningModel.objects.create() + self.assertIsInstance(obj.created, datetime.datetime) + + def test_truncate(self): + obj = ReturningModel.objects.create() + self.assertIsInstance(obj.year, int) + + def test_pi(self): + obj = ReturningModel.objects.create() + self.assertAlmostEqual(obj.pi, 3.1415926535897) # decimal precision varies + + def test_expression_wrapper(self): + obj = ReturningModel.objects.create() + self.assertEqual(obj.expr_wrapper, 3) + + def test_coalesce_value(self): + obj = ReturningModel.objects.create() + self.assertEqual(obj.coalesce_val, 1337) + + def test_raw_sql(self): + obj = ReturningModel.objects.create() + self.assertEqual(obj.raw_sql, 75) + + def test_custom_uuid(self): + obj = ReturningModel.objects.create() + self.assertIsInstance(obj.uuid, uuid.UUID) + + def test_boolean_function(self): + obj = ReturningModel.objects.create() + self.assertIsInstance(obj.is_odd, bool) + self.assertFalse(obj.is_odd) diff --git a/tests/schema/models.py b/tests/schema/models.py index 75e4de0874973..0aa3b3f16a885 100644 --- a/tests/schema/models.py +++ b/tests/schema/models.py @@ -1,5 +1,6 @@ from django.apps.registry import Apps from django.db import models +from django.db.models.functions import Now # Because we want to test creation and deletion of these as separate things, # these models are all inserted into a separate Apps so the main test @@ -13,6 +14,7 @@ class Author(models.Model): height = models.PositiveIntegerField(null=True, blank=True) weight = models.IntegerField(null=True, blank=True) uuid = models.UUIDField(null=True) + date_of_birth = models.DateTimeField(default=Now) class Meta: apps = new_apps diff --git a/tests/schema/test_logging.py b/tests/schema/test_logging.py index 453bdd798e38d..5e68951321fc6 100644 --- a/tests/schema/test_logging.py +++ b/tests/schema/test_logging.py @@ -7,12 +7,12 @@ class SchemaLoggerTests(TestCase): def test_extra_args(self): editor = connection.schema_editor(collect_sql=True) sql = 'SELECT * FROM foo WHERE id in (%s, %s)' - params = [42, 1337] + params = (42, 1337) with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm: editor.execute(sql, params) self.assertEqual(cm.records[0].sql, sql) self.assertEqual(cm.records[0].params, params) self.assertEqual( cm.records[0].getMessage(), - 'SELECT * FROM foo WHERE id in (%s, %s); (params [42, 1337])', + 'SELECT * FROM foo WHERE id in (%s, %s); (params (42, 1337))', ) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index eb9be178db2af..ace176ac0d502 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -12,13 +12,16 @@ from django.db.models import ( CASCADE, PROTECT, AutoField, BigAutoField, BigIntegerField, BinaryField, BooleanField, CharField, CheckConstraint, DateField, DateTimeField, - DecimalField, F, FloatField, ForeignKey, ForeignObject, Index, - IntegerField, JSONField, ManyToManyField, Model, OneToOneField, OrderBy, - PositiveIntegerField, Q, SlugField, SmallAutoField, SmallIntegerField, - TextField, TimeField, UniqueConstraint, UUIDField, Value, + DecimalField, ExpressionWrapper, F, FloatField, ForeignKey, ForeignObject, + Index, IntegerField, JSONField, ManyToManyField, Model, OneToOneField, + OrderBy, PositiveIntegerField, Q, SlugField, SmallAutoField, + SmallIntegerField, TextField, TimeField, UniqueConstraint, UUIDField, + Value, ) from django.db.models.fields.json import KeyTextTransform -from django.db.models.functions import Abs, Cast, Collate, Lower, Random, Upper +from django.db.models.functions import ( + Abs, Cast, Collate, Lower, Now, Random, TruncDay, Upper, +) from django.db.models.indexes import IndexExpression from django.db.transaction import TransactionManagementError, atomic from django.test import ( @@ -3098,7 +3101,7 @@ def test_func_index_nonexistent_field(self): index = Index(Lower('nonexistent'), name='func_nonexistent_idx') msg = ( "Cannot resolve keyword 'nonexistent' into field. Choices are: " - "height, id, name, uuid, weight" + "date_of_birth, height, id, name, uuid, weight" ) with self.assertRaisesMessage(FieldError, msg): with connection.schema_editor() as editor: @@ -3333,6 +3336,116 @@ def test_add_field_use_effective_default(self): item = cursor.fetchall()[0] self.assertEqual(item[0], None if connection.features.interprets_empty_strings_as_nulls else '') + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_add_field_default_db_returning(self): + # Create the table. + with connection.schema_editor() as editor: + editor.create_model(Author) + # Add new field with database default. + Author.objects.create(name='author 1') + new_field = DateTimeField(default=Now()) + new_field.set_attributes_from_name('db_returning') + with connection.schema_editor() as editor: + editor.add_field(Author, new_field) + # Field was added with the right default. + with connection.cursor() as cursor: + cursor.execute("SELECT db_returning FROM schema_author;") + item = cursor.fetchall()[0] + self.assertIsNotNone(item[0]) + + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_alter_field_default_db_returning(self): + # Create the table. + with connection.schema_editor() as editor: + editor.create_model(BookWithoutAuthor) + BookWithoutAuthor.objects.create(title='book 1', pub_date=datetime.datetime.now()) + # Alter to add field with database default. + old_field = BookWithoutAuthor._meta.get_field('pub_date') + new_field = DateTimeField(default=Now) + new_field.set_attributes_from_name('pub_date') + with connection.schema_editor() as editor: + editor.alter_field(BookWithoutAuthor, old_field, new_field, strict=True) + + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_create_model_default_db_returning(self): + # Create the table. + with connection.schema_editor() as editor: + with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm: + editor.create_model(Author) + + if connection.vendor == "oracle": + self.assertIn( + '"DATE_OF_BIRTH" TIMESTAMP NOT NULL', + cm.records[0].sql + ) + else: + self.assertIn( + '"date_of_birth" timestamp with time zone NOT NULL', + cm.records[0].sql + ) + self.assertEqual(cm.records[0].params, None) + + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_add_field_default_db_returning_expression(self): + # Create the table. + with connection.schema_editor() as editor: + editor.create_model(Author) + # Add database default to a nullable field. + new_field = IntegerField(default=ExpressionWrapper(Value(1) * Value(3), output_field=IntegerField())) + new_field.set_attributes_from_name('place_of_birth') + with connection.schema_editor() as editor: + with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm: + editor.add_field(Author, new_field) + + if connection.vendor == 'oracle': + self.assertEqual( + cm.records[0].sql, + 'ALTER TABLE "SCHEMA_AUTHOR" ADD "PLACE_OF_BIRTH" NUMBER(11) DEFAULT (1 * 3) NOT NULL' + ) + else: + self.assertEqual( + cm.records[0].sql, + 'ALTER TABLE "schema_author" ADD COLUMN "place_of_birth" integer DEFAULT (%s * %s) NOT NULL' + ) + self.assertEqual(cm.records[0].params, (1, 3)) + + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_alter_nullable_field_default_db_returning(self): + # Create the table. + with connection.schema_editor() as editor: + editor.create_model(Author) + # Add database default to a nullable field. + old_field = Author._meta.get_field('weight') + new_field = IntegerField(default=ExpressionWrapper(Value(1) * Value(3), output_field=IntegerField())) + new_field.set_attributes_from_name('weight') + with connection.schema_editor() as editor: + with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm: + editor.alter_field(Author, old_field, new_field, strict=True) + + if connection.vendor == 'oracle': + self.assertEqual( + cm.records[0].sql, + 'ALTER TABLE "SCHEMA_AUTHOR" MODIFY "WEIGHT" DEFAULT (1 * 3)' + ) + else: + self.assertEqual( + cm.records[0].sql, + 'ALTER TABLE "schema_author" ALTER COLUMN "weight" SET DEFAULT (%s * %s)' + ) + self.assertEqual(cm.records[0].params, (1, 3)) + + @skipUnlessDBFeature('can_return_columns_from_insert') + def test_effective_default_db_default(self): + """#31206 - effective_default() should be handle database defaults""" + new_field = DateTimeField(default=TruncDay(Now())) + new_field.set_attributes_from_name('date_of_birth') + with connection.schema_editor() as editor: + editor.create_model(Author) + # unit test + editor.effective_default(DateTimeField(default=Now())) + # integration test + editor.alter_field(Author, Author._meta.get_field('date_of_birth'), new_field) + def test_add_field_default_dropped(self): # Create the table with connection.schema_editor() as editor: