Skip to content

Commit 11cf91f

Browse files
committed
Refs #31622 -- Implement CREATE CONSTRAINT TRIGGER on Postgres
1 parent 221feb6 commit 11cf91f

File tree

1 file changed

+104
-3
lines changed

1 file changed

+104
-3
lines changed

django/contrib/postgres/constraints.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,110 @@
1+
from enum import Enum
2+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3+
14
from django.db.backends.ddl_references import Statement, Table
2-
from django.db.models import Deferrable, F, Q
3-
from django.db.models.constraints import BaseConstraint
5+
from django.db.models import Deferrable, F, Func, Q
6+
from django.db.models.constraints import BaseConstraint, Deferrable
7+
from django.db.models.expressions import BaseExpression
48
from django.db.models.sql import Query
59

6-
__all__ = ['ExclusionConstraint']
10+
__all__ = ['ExclusionConstraint', 'ConstraintTrigger', 'TriggerEvent']
11+
12+
13+
class TriggerEvent(Enum):
14+
INSERT = "INSERT"
15+
UPDATE = "UPDATE"
16+
DELETE = "DELETE"
17+
18+
19+
_TriggerEventLike = Union[Literal["INSERT", "UPDATE", "DELETE"], TriggerEvent]
20+
21+
22+
class ConstraintTrigger(BaseConstraint):
23+
template = """
24+
CREATE CONSTRAINT TRIGGER %(name)s
25+
AFTER %(events)s ON %(table)s %(deferrable)s
26+
FOR EACH ROW %(condition)s
27+
EXECUTE PROCEDURE %(procedure)s
28+
""".strip()
29+
delete_template = "DROP TRIGGER %(name)s ON %(table)s"
30+
31+
def __init__(
32+
self,
33+
*,
34+
name: str,
35+
events: Union[List[_TriggerEventLike], Tuple[_TriggerEventLike, ...]],
36+
function: Func,
37+
condition: Optional[BaseExpression] = None,
38+
deferrable: Optional[Deferrable] = None,
39+
):
40+
if not events:
41+
raise ValueError(
42+
"ConstraintTrigger events must be a list of at least one TriggerEvent"
43+
)
44+
self.events = tuple(
45+
e.value if isinstance(e, TriggerEvent) else str(e).upper() for e in events
46+
)
47+
self.function = function
48+
self.condition = condition
49+
self.deferrable = deferrable
50+
super().__init__(name)
51+
52+
def __eq__(self, other):
53+
if isinstance(other, self.__class__):
54+
return (
55+
self.name == other.name
56+
and set(self.events) == set(other.events)
57+
and self.function == other.function
58+
and self.condition == other.condition
59+
and self.deferrable == other.deferrable
60+
)
61+
return super().__eq__(other)
62+
63+
def _get_condition_sql(self, compiler, schema_editor, query) -> str:
64+
if self.condition is None:
65+
return ""
66+
sql, params = self.condition.as_sql(compiler, schema_editor.connection)
67+
condition_sql = sql % tuple(schema_editor.quote_value(p) for p in params)
68+
return "WHEN %s" % (condition_sql)
69+
70+
def _get_procedure_sql(self, compiler, schema_editor) -> str:
71+
sql, params = self.function.as_sql(compiler, schema_editor.connection)
72+
return sql % tuple(schema_editor.quote_value(p) for p in params)
73+
74+
def create_sql(self, model, schema_editor) -> Statement:
75+
table = Table(model._meta.db_table, schema_editor.quote_name)
76+
query = Query(model, alias_cols=False)
77+
compiler = query.get_compiler(connection=schema_editor.connection)
78+
condition = self._get_condition_sql(compiler, schema_editor, query)
79+
return Statement(
80+
self.template,
81+
name=schema_editor.quote_name(self.name),
82+
events=" OR ".join(self.events),
83+
table=table,
84+
condition=condition,
85+
deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
86+
procedure=self._get_procedure_sql(compiler, schema_editor),
87+
)
88+
89+
def remove_sql(self, model, schema_editor) -> Statement:
90+
return Statement(
91+
self.delete_template,
92+
table=Table(model._meta.db_table, schema_editor.quote_name),
93+
name=schema_editor.quote_name(self.name),
94+
)
95+
96+
def deconstruct(self) -> Tuple[str, Tuple[Any, ...], Dict[str, Any]]:
97+
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
98+
kwargs = {
99+
"name": self.name,
100+
"events": self.events,
101+
"function": self.function,
102+
}
103+
if self.condition:
104+
kwargs["condition"] = self.condition
105+
if self.deferrable is not None:
106+
kwargs["deferrable"] = self.deferrable
107+
return path, (), kwargs
7108

8109

9110
class ExclusionConstraint(BaseConstraint):

0 commit comments

Comments
 (0)