|
| 1 | +from enum import Enum |
| 2 | +from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| 3 | + |
1 | 4 | from django.db.backends.ddl_references import Statement, Table
|
2 |
| -from django.db.models import Deferrable, F, Q |
3 |
| -from django.db.models.constraints import BaseConstraint |
| 5 | +from django.db.models import Deferrable, F, Func, Q |
| 6 | +from django.db.models.constraints import BaseConstraint, Deferrable |
| 7 | +from django.db.models.expressions import BaseExpression |
4 | 8 | from django.db.models.sql import Query
|
5 | 9 |
|
6 |
| -__all__ = ['ExclusionConstraint'] |
| 10 | +__all__ = ['ExclusionConstraint', 'ConstraintTrigger', 'TriggerEvent'] |
| 11 | + |
| 12 | + |
| 13 | +class TriggerEvent(Enum): |
| 14 | + INSERT = "INSERT" |
| 15 | + UPDATE = "UPDATE" |
| 16 | + DELETE = "DELETE" |
| 17 | + |
| 18 | + |
| 19 | +_TriggerEventLike = Union[Literal["INSERT", "UPDATE", "DELETE"], TriggerEvent] |
| 20 | + |
| 21 | + |
| 22 | +class ConstraintTrigger(BaseConstraint): |
| 23 | + template = """ |
| 24 | + CREATE CONSTRAINT TRIGGER %(name)s |
| 25 | + AFTER %(events)s ON %(table)s %(deferrable)s |
| 26 | + FOR EACH ROW %(condition)s |
| 27 | + EXECUTE PROCEDURE %(procedure)s |
| 28 | + """.strip() |
| 29 | + delete_template = "DROP TRIGGER %(name)s ON %(table)s" |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + *, |
| 34 | + name: str, |
| 35 | + events: Union[List[_TriggerEventLike], Tuple[_TriggerEventLike, ...]], |
| 36 | + function: Func, |
| 37 | + condition: Optional[BaseExpression] = None, |
| 38 | + deferrable: Optional[Deferrable] = None, |
| 39 | + ): |
| 40 | + if not events: |
| 41 | + raise ValueError( |
| 42 | + "ConstraintTrigger events must be a list of at least one TriggerEvent" |
| 43 | + ) |
| 44 | + self.events = tuple( |
| 45 | + e.value if isinstance(e, TriggerEvent) else str(e).upper() for e in events |
| 46 | + ) |
| 47 | + self.function = function |
| 48 | + self.condition = condition |
| 49 | + self.deferrable = deferrable |
| 50 | + super().__init__(name) |
| 51 | + |
| 52 | + def __eq__(self, other): |
| 53 | + if isinstance(other, self.__class__): |
| 54 | + return ( |
| 55 | + self.name == other.name |
| 56 | + and set(self.events) == set(other.events) |
| 57 | + and self.function == other.function |
| 58 | + and self.condition == other.condition |
| 59 | + and self.deferrable == other.deferrable |
| 60 | + ) |
| 61 | + return super().__eq__(other) |
| 62 | + |
| 63 | + def _get_condition_sql(self, compiler, schema_editor, query) -> str: |
| 64 | + if self.condition is None: |
| 65 | + return "" |
| 66 | + sql, params = self.condition.as_sql(compiler, schema_editor.connection) |
| 67 | + condition_sql = sql % tuple(schema_editor.quote_value(p) for p in params) |
| 68 | + return "WHEN %s" % (condition_sql) |
| 69 | + |
| 70 | + def _get_procedure_sql(self, compiler, schema_editor) -> str: |
| 71 | + sql, params = self.function.as_sql(compiler, schema_editor.connection) |
| 72 | + return sql % tuple(schema_editor.quote_value(p) for p in params) |
| 73 | + |
| 74 | + def create_sql(self, model, schema_editor) -> Statement: |
| 75 | + table = Table(model._meta.db_table, schema_editor.quote_name) |
| 76 | + query = Query(model, alias_cols=False) |
| 77 | + compiler = query.get_compiler(connection=schema_editor.connection) |
| 78 | + condition = self._get_condition_sql(compiler, schema_editor, query) |
| 79 | + return Statement( |
| 80 | + self.template, |
| 81 | + name=schema_editor.quote_name(self.name), |
| 82 | + events=" OR ".join(self.events), |
| 83 | + table=table, |
| 84 | + condition=condition, |
| 85 | + deferrable=schema_editor._deferrable_constraint_sql(self.deferrable), |
| 86 | + procedure=self._get_procedure_sql(compiler, schema_editor), |
| 87 | + ) |
| 88 | + |
| 89 | + def remove_sql(self, model, schema_editor) -> Statement: |
| 90 | + return Statement( |
| 91 | + self.delete_template, |
| 92 | + table=Table(model._meta.db_table, schema_editor.quote_name), |
| 93 | + name=schema_editor.quote_name(self.name), |
| 94 | + ) |
| 95 | + |
| 96 | + def deconstruct(self) -> Tuple[str, Tuple[Any, ...], Dict[str, Any]]: |
| 97 | + path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) |
| 98 | + kwargs = { |
| 99 | + "name": self.name, |
| 100 | + "events": self.events, |
| 101 | + "function": self.function, |
| 102 | + } |
| 103 | + if self.condition: |
| 104 | + kwargs["condition"] = self.condition |
| 105 | + if self.deferrable is not None: |
| 106 | + kwargs["deferrable"] = self.deferrable |
| 107 | + return path, (), kwargs |
7 | 108 |
|
8 | 109 |
|
9 | 110 | class ExclusionConstraint(BaseConstraint):
|
|
0 commit comments