Skip to content

Commit

Permalink
Add flag to prevent operations on manytomany field w/unsaved instances.
Browse files Browse the repository at this point in the history
By default, attempting to read or write a many-to-many field on an
unsaved model instance will now raise an exception. To disable this
behavior, specify `prevent_unsaved=False` when initializing your
ManyToManyField.

Refs #2765
  • Loading branch information
coleifer committed Aug 11, 2023
1 parent f107a90 commit ac30c56
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -5574,6 +5574,9 @@ def __get__(self, instance, instance_type=None, force_query=False):
return [getattr(obj, self.dest_fk.name) for obj in backref]

src_id = getattr(instance, self.src_fk.rel_field.name)
if src_id is None and self.field._prevent_unsaved:
raise ValueError('Cannot get many-to-many "%s" for unsaved '
'instance "%s".' % (self.field, instance))
return (ManyToManyQuery(instance, self, self.rel_model)
.join(self.through_model)
.join(self.model)
Expand All @@ -5582,6 +5585,10 @@ def __get__(self, instance, instance_type=None, force_query=False):
return self.field

def __set__(self, instance, value):
src_id = getattr(instance, self.src_fk.rel_field.name)
if src_id is None and self.field._prevent_unsaved:
raise ValueError('Cannot set many-to-many "%s" for unsaved '
'instance "%s".' % (self.field, instance))
query = self.__get__(instance, force_query=True)
query.add(value, clear_existing=True)

Expand All @@ -5590,7 +5597,7 @@ class ManyToManyField(MetaField):
accessor_class = ManyToManyFieldAccessor

def __init__(self, model, backref=None, through_model=None, on_delete=None,
on_update=None, _is_backref=False):
on_update=None, prevent_unsaved=True, _is_backref=False):
if through_model is not None:
if not (isinstance(through_model, DeferredThroughModel) or
is_model(through_model)):
Expand All @@ -5604,6 +5611,7 @@ def __init__(self, model, backref=None, through_model=None, on_delete=None,
self._through_model = through_model
self._on_delete = on_delete
self._on_update = on_update
self._prevent_unsaved = prevent_unsaved
self._is_backref = _is_backref

def _get_descriptor(self):
Expand Down

0 comments on commit ac30c56

Please sign in to comment.