/
constraints.py
369 lines (331 loc) · 14.9 KB
/
constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from enum import Enum
from django.core.exceptions import FieldError, ValidationError
from django.db import connections
from django.db.models.expressions import Exists, ExpressionList, F
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import Exact
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
from django.db.utils import DEFAULT_DB_ALIAS
from django.utils.translation import gettext_lazy as _
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
class BaseConstraint:
default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_message = None
def __init__(self, name, violation_error_message=None):
self.name = name
if violation_error_message is not None:
self.violation_error_message = violation_error_message
else:
self.violation_error_message = self.default_violation_error_message
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def create_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")
def get_violation_error_message(self):
return self.violation_error_message % {"name": self.name}
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
kwargs = {"name": self.name}
if (
self.violation_error_message is not None
and self.violation_error_message != self.default_violation_error_message
):
kwargs["violation_error_message"] = self.violation_error_message
return (path, (), kwargs)
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name, violation_error_message=None):
self.check = check
if not getattr(check, "conditional", False):
raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression."
)
super().__init__(name, violation_error_message=violation_error_message)
def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
where = query.build_where(self.check)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._check_sql(self.name, check)
def create_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._create_check_sql(model, self.name, check)
def remove_sql(self, model, schema_editor):
return schema_editor._delete_check_sql(model, self.name)
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if not Q(self.check).check(against, using=using):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass
def __repr__(self):
return "<%s: check=%s name=%s>" % (
self.__class__.__qualname__,
self.check,
repr(self.name),
)
def __eq__(self, other):
if isinstance(other, CheckConstraint):
return (
self.name == other.name
and self.check == other.check
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["check"] = self.check
return path, args, kwargs
class Deferrable(Enum):
DEFERRED = "deferred"
IMMEDIATE = "immediate"
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class UniqueConstraint(BaseConstraint):
def __init__(
self,
*expressions,
fields=(),
name=None,
condition=None,
deferrable=None,
include=None,
opclasses=(),
violation_error_message=None,
):
if not name:
raise ValueError("A unique constraint must be named.")
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define a "
"unique constraint."
)
if expressions and fields:
raise ValueError(
"UniqueConstraint.fields and expressions are mutually exclusive."
)
if not isinstance(condition, (type(None), Q)):
raise ValueError("UniqueConstraint.condition must be a Q instance.")
if condition and deferrable:
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
if include and deferrable:
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
if opclasses and deferrable:
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
if expressions and deferrable:
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
if expressions and opclasses:
raise ValueError(
"UniqueConstraint.opclasses cannot be used with expressions. "
"Use django.contrib.postgres.indexes.OpClass() instead."
)
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
"UniqueConstraint.deferrable must be a Deferrable instance."
)
if not isinstance(include, (type(None), list, tuple)):
raise ValueError("UniqueConstraint.include must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
"have the same number of elements."
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name, violation_error_message=violation_error_message)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql(
model,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
" name=%s" % repr(self.name),
"" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
self.name == other.name
and self.fields == other.fields
and self.condition == other.condition
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
kwargs["fields"] = self.fields
if self.condition:
kwargs["condition"] = self.condition
if self.deferrable:
kwargs["deferrable"] = self.deferrable
if self.include:
kwargs["include"] = self.include
if self.opclasses:
kwargs["opclasses"] = self.opclasses
return path, self.expressions, kwargs
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
lookup_value = getattr(instance, field.attname)
if lookup_value is None or (
lookup_value == ""
and connections[using].features.interprets_empty_strings_as_nulls
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
queryset = queryset.filter(**lookup_kwargs)
else:
# Ignore constraints with excluded fields.
if exclude:
for expression in self.expressions:
if hasattr(expression, "flatten"):
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
elif isinstance(expression, F) and expression.name in exclude:
return
replacements = {
F(field): value
for field, value in instance._get_field_value_map(
meta=model._meta, exclude=exclude
).items()
}
expressions = [
Exact(expr, expr.replace_expressions(replacements))
for expr in self.expressions
]
queryset = queryset.filter(*expressions)
model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None:
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
if self.expressions:
raise ValidationError(self.get_violation_error_message())
# When fields are defined, use the unique_error_message() for
# backward compatibility.
for model, constraints in instance.get_constraints():
for constraint in constraints:
if constraint is self:
raise ValidationError(
instance.unique_error_message(model, self.fields)
)
else:
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using
):
raise ValidationError(self.get_violation_error_message())
except FieldError:
pass