Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement __eq__ for validators #8925

Merged
merged 2 commits into from Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api-guide/validators.md
Expand Up @@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now

The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field.

Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below.
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality.

---

Expand Down
37 changes: 37 additions & 0 deletions rest_framework/validators.py
Expand Up @@ -79,6 +79,15 @@ def __repr__(self):
smart_repr(self.queryset)
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.lookup == other.lookup
)


class UniqueTogetherValidator:
"""
Expand Down Expand Up @@ -166,6 +175,16 @@ def __repr__(self):
smart_repr(self.fields)
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
)


class ProhibitSurrogateCharactersValidator:
message = _('Surrogate characters are not allowed: U+{code_point:X}.')
Expand All @@ -177,6 +196,13 @@ def __call__(self, value):
message = self.message.format(code_point=ord(surrogate_character))
raise ValidationError(message, code=self.code)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.code == other.code
)


class BaseUniqueForValidator:
message = None
Expand Down Expand Up @@ -230,6 +256,17 @@ def __call__(self, attrs, serializer):
self.field: message
}, code='unique')

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.missing_message == other.missing_message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.field == other.field
and self.date_field == other.date_field
)

def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_validators.py
@@ -1,4 +1,5 @@
import datetime
from unittest.mock import MagicMock

import pytest
from django.db import DataError, models
Expand Down Expand Up @@ -787,3 +788,13 @@ def test_validator_raises_error_when_abstract_method_called(self):
validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name=''
)

def test_equality_operator(self):
mock_queryset = MagicMock()
validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
assert validator == validator2
validator2.date_field = "bar2"
assert validator != validator2