Skip to content

Commit

Permalink
[4.1.x] Fixed #33829 -- Made BaseConstraint.deconstruct() and equalit…
Browse files Browse the repository at this point in the history
…y handle violation_error_message.

Regression in 6671058.

Backport of ccbf714 from main
  • Loading branch information
twidi authored and felixxm committed Jul 8, 2022
1 parent 585ed2f commit a3d35af
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 3 deletions.
1 change: 1 addition & 0 deletions django/contrib/postgres/constraints.py
Expand Up @@ -186,6 +186,7 @@ def __eq__(self, other):
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)

Expand Down
20 changes: 17 additions & 3 deletions django/db/models/constraints.py
Expand Up @@ -14,12 +14,15 @@


class BaseConstraint:
violation_error_message = _("Constraint “%(name)s” is violated.")
default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_message = None

def __init__(self, name, violation_error_message=None):
self.name = name
if violation_error_message is not None:
self.violation_error_message = violation_error_message
else:
self.violation_error_message = self.default_violation_error_message

@property
def contains_expressions(self):
Expand All @@ -43,7 +46,13 @@ def get_violation_error_message(self):
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
return (path, (), {"name": self.name})
kwargs = {"name": self.name}
if (
self.violation_error_message is not None
and self.violation_error_message != self.default_violation_error_message
):
kwargs["violation_error_message"] = self.violation_error_message
return (path, (), kwargs)

def clone(self):
_, args, kwargs = self.deconstruct()
Expand Down Expand Up @@ -94,7 +103,11 @@ def __repr__(self):

def __eq__(self, other):
if isinstance(other, CheckConstraint):
return self.name == other.name and self.check == other.check
return (
self.name == other.name
and self.check == other.check
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)

def deconstruct(self):
Expand Down Expand Up @@ -273,6 +286,7 @@ def __eq__(self, other):
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)

Expand Down
77 changes: 77 additions & 0 deletions tests/constraints/tests.py
Expand Up @@ -65,6 +65,29 @@ def test_custom_violation_error_message(self):
)
self.assertEqual(c.get_violation_error_message(), "custom base_name message")

def test_custom_violation_error_message_clone(self):
constraint = BaseConstraint(
"base_name",
violation_error_message="custom %(name)s message",
).clone()
self.assertEqual(
constraint.get_violation_error_message(),
"custom base_name message",
)

def test_deconstruction(self):
constraint = BaseConstraint(
"base_name",
violation_error_message="custom %(name)s message",
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, "django.db.models.BaseConstraint")
self.assertEqual(args, ())
self.assertEqual(
kwargs,
{"name": "base_name", "violation_error_message": "custom %(name)s message"},
)


class CheckConstraintTests(TestCase):
def test_eq(self):
Expand All @@ -84,6 +107,28 @@ def test_eq(self):
models.CheckConstraint(check=check2, name="price"),
)
self.assertNotEqual(models.CheckConstraint(check=check1, name="price"), 1)
self.assertNotEqual(
models.CheckConstraint(check=check1, name="price"),
models.CheckConstraint(
check=check1, name="price", violation_error_message="custom error"
),
)
self.assertNotEqual(
models.CheckConstraint(
check=check1, name="price", violation_error_message="custom error"
),
models.CheckConstraint(
check=check1, name="price", violation_error_message="other custom error"
),
)
self.assertEqual(
models.CheckConstraint(
check=check1, name="price", violation_error_message="custom error"
),
models.CheckConstraint(
check=check1, name="price", violation_error_message="custom error"
),
)

def test_repr(self):
constraint = models.CheckConstraint(
Expand Down Expand Up @@ -216,6 +261,38 @@ def test_eq(self):
self.assertNotEqual(
models.UniqueConstraint(fields=["foo", "bar"], name="unique"), 1
)
self.assertNotEqual(
models.UniqueConstraint(fields=["foo", "bar"], name="unique"),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_message="custom error",
),
)
self.assertNotEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_message="custom error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_message="other custom error",
),
)
self.assertEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_message="custom error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_message="custom error",
),
)

def test_eq_with_condition(self):
self.assertEqual(
Expand Down
22 changes: 22 additions & 0 deletions tests/postgres_tests/test_constraints.py
Expand Up @@ -444,17 +444,39 @@ def test_eq(self):
)
self.assertNotEqual(constraint_2, constraint_9)
self.assertNotEqual(constraint_7, constraint_8)

constraint_10 = ExclusionConstraint(
name="exclude_overlapping",
expressions=[
(F("datespan"), RangeOperators.OVERLAPS),
(F("room"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
violation_error_message="custom error",
)
constraint_11 = ExclusionConstraint(
name="exclude_overlapping",
expressions=[
(F("datespan"), RangeOperators.OVERLAPS),
(F("room"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
violation_error_message="other custom error",
)
self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_1, constraint_4)
self.assertNotEqual(constraint_1, constraint_10)
self.assertNotEqual(constraint_2, constraint_3)
self.assertNotEqual(constraint_2, constraint_4)
self.assertNotEqual(constraint_2, constraint_7)
self.assertNotEqual(constraint_4, constraint_5)
self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_1, object())
self.assertNotEqual(constraint_10, constraint_11)
self.assertEqual(constraint_10, constraint_10)

def test_deconstruct(self):
constraint = ExclusionConstraint(
Expand Down

0 comments on commit a3d35af

Please sign in to comment.