Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primary Support of UniqueConstraint #7438

Merged
merged 1 commit into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 46 additions & 36 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,23 @@ def get_extra_kwargs(self):

return extra_kwargs

def get_unique_together_constraints(self, model):
"""
Returns iterator of (fields, queryset), each entry describes an unique together
constraint on `fields` in `queryset`.
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
yield (
constraint.fields,
model._default_manager
if constraint.condition is None
else model._default_manager.filter(constraint.condition)
)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Return any additional field options that need to be included as a
Expand Down Expand Up @@ -1426,12 +1443,11 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

unique_constraint_names -= {None}

# Include each of the `unique_together` field names,
# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for parent_class in [model] + list(model._meta.parents):
for unique_together_list in parent_class._meta.unique_together:
if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list)
for unique_together_list, queryset in self.get_unique_together_constraints(model):
if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list)

# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
Expand Down Expand Up @@ -1526,11 +1542,6 @@ def get_unique_together_validators(self):
"""
Determine a default set of validators for any unique_together constraints.
"""
model_class_inheritance_tree = (
[self.Meta.model] +
list(self.Meta.model._meta.parents)
)

# The field names we're passing though here only include fields
# which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not
Expand All @@ -1556,34 +1567,33 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for parent_class in model_class_inheritance_tree:
for unique_together in parent_class._meta.unique_together:
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together):
continue

for source in unique_together:
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
"fields ({fields}) that map to this model field. "
"Either remove the extra fields, or override "
"`Meta.validators` with a `UniqueTogetherValidator` "
"using the desired field names."
.format(
model=self.Meta.model.__name__,
serializer=self.__class__.__name__,
field=source,
fields=', '.join(source_map[source]),
)
)
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together):
continue

field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator(
queryset=parent_class._default_manager,
fields=field_names
for source in unique_together:
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
"fields ({fields}) that map to this model field. "
"Either remove the extra fields, or override "
"`Meta.validators` with a `UniqueTogetherValidator` "
"using the desired field names."
.format(
model=self.Meta.model.__name__,
serializer=self.__class__.__name__,
field=source,
fields=', '.join(source_map[source]),
)
)
validators.append(validator)

field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names
)
validators.append(validator)
return validators

def get_unique_for_date_validators(self):
Expand Down
29 changes: 24 additions & 5 deletions rest_framework/utils/field_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ def get_detail_view_name(model):
}


def get_unique_validators(field_name, model_field):
"""
Returns a list of UniqueValidators that should be applied to the field.
"""
field_set = set([field_name])
conditions = {
c.condition
for c in model_field.model._meta.constraints
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
}
if getattr(model_field, 'unique', False):
conditions.add(None)
if not conditions:
return
unique_error_message = get_unique_error_message(model_field)
queryset = model_field.model._default_manager
for condition in conditions:
yield UniqueValidator(
queryset=queryset if condition is None else queryset.filter(condition),
message=unique_error_message
)


def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
Expand Down Expand Up @@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field):
if not isinstance(validator, validators.MinLengthValidator)
]

if getattr(model_field, 'unique', False):
validator = UniqueValidator(
queryset=model_field.model._default_manager,
message=get_unique_error_message(model_field))
validator_kwarg.append(validator)
validator_kwarg += get_unique_validators(field_name, model_field)

if validator_kwarg:
kwargs['validators'] = validator_kwarg
Expand Down
100 changes: 100 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,106 @@ def filter(self, **kwargs):
assert queryset.called_with == {'race_name': 'bar', 'position': 1}


class UniqueConstraintModel(models.Model):
race_name = models.CharField(max_length=100)
position = models.IntegerField()
global_id = models.IntegerField()
fancy_conditions = models.IntegerField(null=True)

class Meta:
constraints = [
models.UniqueConstraint(
name="unique_constraint_model_global_id_uniq",
fields=('global_id',),
),
models.UniqueConstraint(
name="unique_constraint_model_fancy_1_uniq",
fields=('fancy_conditions',),
condition=models.Q(global_id__lte=1)
),
models.UniqueConstraint(
name="unique_constraint_model_fancy_3_uniq",
fields=('fancy_conditions',),
condition=models.Q(global_id__gte=3)
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
)
]


class UniqueConstraintSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = '__all__'


class TestUniqueConstraintValidation(TestCase):
def setUp(self):
self.instance = UniqueConstraintModel.objects.create(
race_name='example',
position=1,
global_id=1
)
UniqueConstraintModel.objects.create(
race_name='example',
position=2,
global_id=2
)
UniqueConstraintModel.objects.create(
race_name='other',
position=1,
global_id=3
)

def test_repr(self):
serializer = UniqueConstraintSerializer()
# the order of validators isn't deterministic so delete
# fancy_conditions field that has two of them
del serializer.fields['fancy_conditions']
expected = dedent("""
UniqueConstraintSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100, required=True)
position = IntegerField(required=True)
global_id = IntegerField(validators=[<UniqueValidator(queryset=UniqueConstraintModel.objects.all())>])
class Meta:
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
""")
assert repr(serializer) == expected

def test_unique_together_field(self):
"""
UniqueConstraint fields and condition attributes must be passed
to UniqueTogetherValidator as fields and queryset
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
)

def test_single_field_uniq_validators(self):
"""
UniqueConstraint with single field must be transformed into
field's UniqueValidator
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validators = serializer.fields['global_id'].validators
assert len(validators) == 1
assert validators[0].queryset == UniqueConstraintModel.objects

validators = serializer.fields['fancy_conditions'].validators
assert len(validators) == 2
ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators}
assert ids_in_qs == {frozenset([1]), frozenset([3])}


# Tests for `UniqueForDateValidator`
# ----------------------------------

Expand Down