Skip to content

Commit

Permalink
Addresses art049#162
Browse files Browse the repository at this point in the history
Allow passing exclude keyword arguments to engine.save.
  • Loading branch information
Nicolas Martinez committed Sep 7, 2021
1 parent f20f08f commit 9d3643c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
32 changes: 28 additions & 4 deletions odmantic/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,13 @@ async def find_one(
return results[0]

async def _save(
self, instance: ModelType, session: AsyncIOMotorClientSession
self,
instance: ModelType,
session: AsyncIOMotorClientSession,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> ModelType:
"""Perform an atomic save operation in the specified session"""
save_tasks = []
Expand All @@ -310,7 +316,12 @@ async def _save(
instance.__fields_modified__ | instance.__mutable_fields__
) - set([instance.__primary_field__])
if len(fields_to_update) > 0:
doc = instance.doc(include=fields_to_update)
doc = instance.doc(
include=fields_to_update,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
exclude_unset=exclude_unset,
)
collection = self.get_collection(type(instance))
await collection.update_one(
{"_id": getattr(instance, instance.__primary_field__)},
Expand All @@ -320,7 +331,14 @@ async def _save(
object.__setattr__(instance, "__fields_modified__", set())
return instance

async def save(self, instance: ModelType) -> ModelType:
async def save(
self,
instance: ModelType,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> ModelType:
"""Persist an instance to the database
This method behaves as an 'upsert' operation. If a document already exists
Expand All @@ -347,7 +365,13 @@ async def save(self, instance: ModelType) -> ModelType:

async with await self.client.start_session() as s:
async with s.start_transaction():
await self._save(instance, s)
await self._save(
instance,
s,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
return instance

async def save_all(self, instances: Sequence[ModelType]) -> List[ModelType]:
Expand Down
15 changes: 13 additions & 2 deletions odmantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,14 @@ def dict( # type: ignore # Missing deprecated/ unsupported parameters
exclude_none=exclude_none,
)

def doc(self, include: Optional["AbstractSetIntStr"] = None) -> Dict[str, Any]:
def doc(
self,
include: Optional["AbstractSetIntStr"] = None,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Dict[str, Any]:
"""Generate a document representation of the instance (as a dictionary).
Args:
Expand All @@ -648,7 +655,11 @@ def doc(self, include: Optional["AbstractSetIntStr"] = None) -> Dict[str, Any]:
Returns:
the document associated to the instance
"""
raw_doc = self.dict()
raw_doc = self.dict(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
doc: Dict[str, Any] = {}
for field_name, field in self.__odm_fields__.items():
if include is not None and field_name not in include:
Expand Down

0 comments on commit 9d3643c

Please sign in to comment.