Skip to content

Commit

Permalink
Fixed #30581 -- Added support for Meta.constraints validation.
Browse files Browse the repository at this point in the history
Thanks Simon Charette, Keryn Knight, and Mariusz Felisiak for reviews.
  • Loading branch information
Gagaro authored and felixxm committed May 10, 2022
1 parent 441103a commit 6671058
Show file tree
Hide file tree
Showing 17 changed files with 852 additions and 88 deletions.
50 changes: 47 additions & 3 deletions django/contrib/postgres/constraints.py
@@ -1,11 +1,13 @@
import warnings

from django.contrib.postgres.indexes import OpClass
from django.db import NotSupportedError
from django.core.exceptions import ValidationError
from django.db import DEFAULT_DB_ALIAS, NotSupportedError
from django.db.backends.ddl_references import Expressions, Statement, Table
from django.db.models import BaseConstraint, Deferrable, F, Q
from django.db.models.expressions import ExpressionList
from django.db.models.expressions import Exists, ExpressionList
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import PostgresOperatorLookup
from django.db.models.sql import Query
from django.utils.deprecation import RemovedInDjango50Warning

Expand All @@ -32,6 +34,7 @@ def __init__(
deferrable=None,
include=None,
opclasses=(),
violation_error_message=None,
):
if index_type and index_type.lower() not in {"gist", "spgist"}:
raise ValueError(
Expand Down Expand Up @@ -78,7 +81,7 @@ def __init__(
category=RemovedInDjango50Warning,
stacklevel=2,
)
super().__init__(name=name)
super().__init__(name=name, violation_error_message=violation_error_message)

def _get_expressions(self, schema_editor, query):
expressions = []
Expand Down Expand Up @@ -197,3 +200,44 @@ def __repr__(self):
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)

def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
queryset = model._default_manager.using(using)
replacement_map = instance._get_field_value_map(
meta=model._meta, exclude=exclude
)
lookups = []
for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str):
expression = F(expression)
if isinstance(expression, F):
if exclude and expression.name in exclude:
return
rhs_expression = replacement_map.get(expression.name, expression)
else:
rhs_expression = expression.replace_references(replacement_map)
if exclude:
for expr in rhs_expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
# Remove OpClass because it only has sense during the constraint
# creation.
if isinstance(expression, OpClass):
expression = expression.get_source_expressions()[0]
if isinstance(rhs_expression, OpClass):
rhs_expression = rhs_expression.get_source_expressions()[0]
lookup = PostgresOperatorLookup(lhs=expression, rhs=rhs_expression)
lookup.postgres_operator = operator
lookups.append(lookup)
queryset = queryset.filter(*lookups)
model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None:
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
raise ValidationError(self.get_violation_error_message())
else:
if (self.condition & Exists(queryset.filter(self.condition))).check(
replacement_map, using=using
):
raise ValidationError(self.get_violation_error_message())
93 changes: 81 additions & 12 deletions django/db/models/base.py
Expand Up @@ -28,6 +28,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constraints import CheckConstraint, UniqueConstraint
from django.db.models.deletion import CASCADE, Collector
from django.db.models.expressions import RawSQL
from django.db.models.fields.related import (
ForeignObjectRel,
OneToOneField,
Expand Down Expand Up @@ -1189,6 +1190,16 @@ def _get_next_or_previous_in_order(self, is_next):
setattr(self, cachename, obj)
return getattr(self, cachename)

def _get_field_value_map(self, meta, exclude=None):
if exclude is None:
exclude = set()
meta = meta or self._meta
return {
field.name: Value(getattr(self, field.attname), field)
for field in meta.local_concrete_fields
if field.name not in exclude
}

def prepare_database_save(self, field):
if self.pk is None:
raise ValueError(
Expand Down Expand Up @@ -1221,7 +1232,7 @@ def validate_unique(self, exclude=None):
if errors:
raise ValidationError(errors)

def _get_unique_checks(self, exclude=None):
def _get_unique_checks(self, exclude=None, include_meta_constraints=False):
"""
Return a list of checks to perform. Since validate_unique() could be
called from a ModelForm, some fields may have been excluded; we can't
Expand All @@ -1234,13 +1245,15 @@ def _get_unique_checks(self, exclude=None):
unique_checks = []

unique_togethers = [(self.__class__, self._meta.unique_together)]
constraints = [(self.__class__, self._meta.total_unique_constraints)]
constraints = []
if include_meta_constraints:
constraints = [(self.__class__, self._meta.total_unique_constraints)]
for parent_class in self._meta.get_parent_list():
if parent_class._meta.unique_together:
unique_togethers.append(
(parent_class, parent_class._meta.unique_together)
)
if parent_class._meta.total_unique_constraints:
if include_meta_constraints and parent_class._meta.total_unique_constraints:
constraints.append(
(parent_class, parent_class._meta.total_unique_constraints)
)
Expand All @@ -1251,10 +1264,11 @@ def _get_unique_checks(self, exclude=None):
# Add the check if the field isn't excluded.
unique_checks.append((model_class, tuple(check)))

for model_class, model_constraints in constraints:
for constraint in model_constraints:
if not any(name in exclude for name in constraint.fields):
unique_checks.append((model_class, constraint.fields))
if include_meta_constraints:
for model_class, model_constraints in constraints:
for constraint in model_constraints:
if not any(name in exclude for name in constraint.fields):
unique_checks.append((model_class, constraint.fields))

# These are checks for the unique_for_<date/year/month>.
date_checks = []
Expand Down Expand Up @@ -1410,10 +1424,35 @@ def unique_error_message(self, model_class, unique_check):
params=params,
)

def full_clean(self, exclude=None, validate_unique=True):
def get_constraints(self):
constraints = [(self.__class__, self._meta.constraints)]
for parent_class in self._meta.get_parent_list():
if parent_class._meta.constraints:
constraints.append((parent_class, parent_class._meta.constraints))
return constraints

def validate_constraints(self, exclude=None):
constraints = self.get_constraints()
using = router.db_for_write(self.__class__, instance=self)

errors = {}
for model_class, model_constraints in constraints:
for constraint in model_constraints:
try:
constraint.validate(model_class, self, exclude=exclude, using=using)
except ValidationError as e:
if e.code == "unique" and len(constraint.fields) == 1:
errors.setdefault(constraint.fields[0], []).append(e)
else:
errors = e.update_error_dict(errors)
if errors:
raise ValidationError(errors)

def full_clean(self, exclude=None, validate_unique=True, validate_constraints=True):
"""
Call clean_fields(), clean(), and validate_unique() on the model.
Raise a ValidationError for any errors that occur.
Call clean_fields(), clean(), validate_unique(), and
validate_constraints() on the model. Raise a ValidationError for any
errors that occur.
"""
errors = {}
if exclude is None:
Expand Down Expand Up @@ -1443,6 +1482,16 @@ def full_clean(self, exclude=None, validate_unique=True):
except ValidationError as e:
errors = e.update_error_dict(errors)

# Run constraints checks, but only for fields that passed validation.
if validate_constraints:
for name in errors:
if name != NON_FIELD_ERRORS and name not in exclude:
exclude.add(name)
try:
self.validate_constraints(exclude=exclude)
except ValidationError as e:
errors = e.update_error_dict(errors)

if errors:
raise ValidationError(errors)

Expand Down Expand Up @@ -2339,8 +2388,28 @@ def _check_constraints(cls, databases):
connection.features.supports_table_check_constraints
or "supports_table_check_constraints"
not in cls._meta.required_db_features
) and isinstance(constraint.check, Q):
references.update(cls._get_expr_references(constraint.check))
):
if isinstance(constraint.check, Q):
references.update(
cls._get_expr_references(constraint.check)
)
if any(
isinstance(expr, RawSQL)
for expr in constraint.check.flatten()
):
errors.append(
checks.Warning(
f"Check constraint {constraint.name!r} contains "
f"RawSQL() expression and won't be validated "
f"during the model full_clean().",
hint=(
"Silence this warning if you don't care about "
"it."
),
obj=cls,
id="models.W045",
),
)
for field_name, *lookups in references:
# pk is an alias that won't be found by opts.get_field.
if field_name != "pk":
Expand Down
92 changes: 87 additions & 5 deletions django/db/models/constraints.py
@@ -1,16 +1,25 @@
from enum import Enum

from django.db.models.expressions import ExpressionList, F
from django.core.exceptions import FieldError, ValidationError
from django.db import connections
from django.db.models.expressions import Exists, ExpressionList, F
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import Exact
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
from django.db.utils import DEFAULT_DB_ALIAS
from django.utils.translation import gettext_lazy as _

__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]


class BaseConstraint:
def __init__(self, name):
violation_error_message = _("Constraint “%(name)s” is violated.")

def __init__(self, name, violation_error_message=None):
self.name = name
if violation_error_message is not None:
self.violation_error_message = violation_error_message

@property
def contains_expressions(self):
Expand All @@ -25,6 +34,12 @@ def create_sql(self, model, schema_editor):
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")

def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")

def get_violation_error_message(self):
return self.violation_error_message % {"name": self.name}

def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
Expand All @@ -36,13 +51,13 @@ def clone(self):


class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name):
def __init__(self, *, check, name, violation_error_message=None):
self.check = check
if not getattr(check, "conditional", False):
raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression."
)
super().__init__(name)
super().__init__(name, violation_error_message=violation_error_message)

def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
Expand All @@ -62,6 +77,14 @@ def create_sql(self, model, schema_editor):
def remove_sql(self, model, schema_editor):
return schema_editor._delete_check_sql(model, self.name)

def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if not Q(self.check).check(against, using=using):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass

def __repr__(self):
return "<%s: check=%s name=%s>" % (
self.__class__.__qualname__,
Expand Down Expand Up @@ -99,6 +122,7 @@ def __init__(
deferrable=None,
include=None,
opclasses=(),
violation_error_message=None,
):
if not name:
raise ValueError("A unique constraint must be named.")
Expand Down Expand Up @@ -148,7 +172,7 @@ def __init__(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name)
super().__init__(name, violation_error_message=violation_error_message)

@property
def contains_expressions(self):
Expand Down Expand Up @@ -265,3 +289,61 @@ def deconstruct(self):
if self.opclasses:
kwargs["opclasses"] = self.opclasses
return path, self.expressions, kwargs

def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
lookup_value = getattr(instance, field.attname)
if lookup_value is None or (
lookup_value == ""
and connections[using].features.interprets_empty_strings_as_nulls
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
queryset = queryset.filter(**lookup_kwargs)
else:
# Ignore constraints with excluded fields.
if exclude:
for expression in self.expressions:
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
replacement_map = instance._get_field_value_map(
meta=model._meta, exclude=exclude
)
expressions = [
Exact(expr, expr.replace_references(replacement_map))
for expr in self.expressions
]
queryset = queryset.filter(*expressions)
model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None:
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
if self.expressions:
raise ValidationError(self.get_violation_error_message())
# When fields are defined, use the unique_error_message() for
# backward compatibility.
for model, constraints in instance.get_constraints():
for constraint in constraints:
if constraint is self:
raise ValidationError(
instance.unique_error_message(model, self.fields)
)
else:
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using
):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass

0 comments on commit 6671058

Please sign in to comment.