Skip to content

Commit

Permalink
fix: adds MISSING placeholder (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Dec 7, 2023
1 parent 68c8501 commit 2f76af0
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 33 deletions.
10 changes: 5 additions & 5 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from advanced_alchemy.operations import Merge
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

if TYPE_CHECKING:
Expand Down Expand Up @@ -547,8 +547,8 @@ async def get_or_upsert(
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
Expand Down Expand Up @@ -609,8 +609,8 @@ async def get_and_update(
existing = await self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
Expand Down
10 changes: 5 additions & 5 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from advanced_alchemy.operations import Merge
from advanced_alchemy.repository._util import get_instrumented_attr, wrap_sqlalchemy_exception
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

if TYPE_CHECKING:
Expand Down Expand Up @@ -548,8 +548,8 @@ def get_or_upsert(
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
Expand Down Expand Up @@ -610,8 +610,8 @@ def get_and_update(
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
Expand Down
10 changes: 5 additions & 5 deletions advanced_alchemy/repository/memory/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
OrderBy,
SearchFilter,
)
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

from .base import AnyObject, CollectionT, InMemoryStore, SQLAlchemyInMemoryStore, SQLAlchemyMultiStore
Expand Down Expand Up @@ -343,8 +343,8 @@ async def get_or_upsert(
return (await self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
return existing, False
Expand All @@ -363,8 +363,8 @@ async def get_and_update(self, match_fields: list[str] | str | None = None, **kw
existing = await self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self.update(existing)
Expand Down
10 changes: 5 additions & 5 deletions advanced_alchemy/repository/memory/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
OrderBy,
SearchFilter,
)
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import MISSING, ModelT
from advanced_alchemy.utils.deprecation import deprecated

from .base import AnyObject, CollectionT, InMemoryStore, SQLAlchemyInMemoryStore, SQLAlchemyMultiStore
Expand Down Expand Up @@ -346,8 +346,8 @@ def get_or_upsert(
return (self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, False
Expand All @@ -366,8 +366,8 @@ def get_and_update(self, match_fields: list[str] | str | None = None, **kwargs:
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
field = getattr(existing, field_name, MISSING)
if field is not MISSING and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
Expand Down
11 changes: 2 additions & 9 deletions advanced_alchemy/repository/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.orm import RelationshipProperty, Session, class_mapper, object_mapper

from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import _MISSING, MISSING, ModelT

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -29,13 +29,6 @@ class _NotSet:
pass


class _MISSING:
pass


MISSING = _MISSING()


class InMemoryStore(Generic[T]):
def __init__(self) -> None:
self._store: dict[Any, T] = {}
Expand Down Expand Up @@ -310,7 +303,7 @@ class Child(Base):
for relationship in obj_mapper.relationships:
for column in relationship.local_columns:
column_relationships[column] = relationship

# sourcery skip: assign-if-exp
if state := inspect(data):
new_attrs: dict[str, Any] = state.dict
else:
Expand Down
8 changes: 8 additions & 0 deletions advanced_alchemy/repository/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"RowT",
"SQLAlchemySyncRepositoryT",
"SQLAlchemyAsyncRepositoryT",
"MISSING",
)

T = TypeVar("T")
Expand All @@ -26,3 +27,10 @@

SQLAlchemySyncRepositoryT = TypeVar("SQLAlchemySyncRepositoryT", bound="SQLAlchemySyncRepository")
SQLAlchemyAsyncRepositoryT = TypeVar("SQLAlchemyAsyncRepositoryT", bound="SQLAlchemyAsyncRepository")


class _MISSING:
pass


MISSING = _MISSING()
8 changes: 4 additions & 4 deletions tests/unit/test_extensions/test_litestar/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
name="a",
),
model_name=ANY,
default_factory=Empty,
default_factory=Empty, # type: ignore[arg-type]
dto_field=DTOField(),
),
replace(
Expand All @@ -75,7 +75,7 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
name="b",
),
model_name=ANY,
default_factory=Empty,
default_factory=Empty, # type: ignore[arg-type]
dto_field=DTOField(mark=Mark.READ_ONLY),
),
metadata=ANY,
Expand All @@ -90,7 +90,7 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
name="c",
),
model_name=ANY,
default_factory=Empty,
default_factory=Empty, # type: ignore[arg-type]
dto_field=DTOField(),
),
metadata=ANY,
Expand All @@ -106,7 +106,7 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
default=1,
),
model_name=ANY,
default_factory=Empty,
default_factory=Empty, # type: ignore[arg-type]
dto_field=DTOField(),
),
metadata=ANY,
Expand Down

0 comments on commit 2f76af0

Please sign in to comment.