Skip to content

Commit

Permalink
Support UniqueConstraint
Browse files Browse the repository at this point in the history
  • Loading branch information
kalekseev committed Jul 29, 2020
1 parent 599e2b1 commit 9a23276
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 36 deletions.
78 changes: 42 additions & 36 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,17 @@ def get_extra_kwargs(self):

return extra_kwargs

def get_unique_together_constraints(self, model):
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint):
yield (
constraint.fields,
model._default_manager.filter(constraint.condition) if constraint.condition else model._default_manager
)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Return any additional field options that need to be included as a
Expand Down Expand Up @@ -1401,12 +1412,11 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

unique_constraint_names -= {None}

# Include each of the `unique_together` field names,
# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for parent_class in [model] + list(model._meta.parents):
for unique_together_list in parent_class._meta.unique_together:
if set(field_names).issuperset(set(unique_together_list)):
unique_constraint_names |= set(unique_together_list)
for unique_together_list, queryset in self.get_unique_together_constraints(model):
if set(field_names).issuperset(set(unique_together_list)):
unique_constraint_names |= set(unique_together_list)

# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
Expand Down Expand Up @@ -1503,11 +1513,6 @@ def get_unique_together_validators(self):
"""
Determine a default set of validators for any unique_together constraints.
"""
model_class_inheritance_tree = (
[self.Meta.model] +
list(self.Meta.model._meta.parents)
)

# The field names we're passing though here only include fields
# which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not
Expand All @@ -1533,34 +1538,35 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for parent_class in model_class_inheritance_tree:
for unique_together in parent_class._meta.unique_together:
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(set(unique_together)):
continue

for source in unique_together:
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
"fields ({fields}) that map to this model field. "
"Either remove the extra fields, or override "
"`Meta.validators` with a `UniqueTogetherValidator` "
"using the desired field names."
.format(
model=self.Meta.model.__name__,
serializer=self.__class__.__name__,
field=source,
fields=', '.join(source_map[source]),
)
)
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
if len(unique_together) < 2:
continue
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(set(unique_together)):
continue

field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator(
queryset=parent_class._default_manager,
fields=field_names
for source in unique_together:
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
"fields ({fields}) that map to this model field. "
"Either remove the extra fields, or override "
"`Meta.validators` with a `UniqueTogetherValidator` "
"using the desired field names."
.format(
model=self.Meta.model.__name__,
serializer=self.__class__.__name__,
field=source,
fields=', '.join(source_map[source]),
)
)
validators.append(validator)

field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names
)
validators.append(validator)
return validators

def get_unique_for_date_validators(self):
Expand Down
70 changes: 70 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,76 @@ def filter(self, **kwargs):
assert queryset.called_with == {'race_name': 'bar', 'position': 1}


class UniqueConstraintModel(models.Model):
race_name = models.CharField(max_length=100)
position = models.IntegerField()
global_id = models.IntegerField()

class Meta:
constraints = [
models.UniqueConstraint(
name="unique_constraint_model_global_id_uniq",
fields=('global_id',),
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
)
]


class UniqueConstraintSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = '__all__'


class TestUniqueConstraintValidation(TestCase):
def setUp(self):
self.instance = UniqueConstraintModel.objects.create(
race_name='example',
position=1,
global_id=1
)
UniqueConstraintModel.objects.create(
race_name='example',
position=2,
global_id=2
)
UniqueConstraintModel.objects.create(
race_name='other',
position=1,
global_id=3
)

def test_repr(self):
serializer = UniqueConstraintSerializer()
expected = dedent("""
UniqueConstraintSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100, required=True)
position = IntegerField(required=True)
global_id = IntegerField(required=True)
class Meta:
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
""")
assert repr(serializer) == expected

def test_fields_and_queryset(self):
"""
UniqueConstraint fields and condition attributes must be passed
to UniqueTogetherValidator as fields and queryset
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
)


# Tests for `UniqueForDateValidator`
# ----------------------------------

Expand Down

0 comments on commit 9a23276

Please sign in to comment.