Skip to content

Commit

Permalink
Change FieldExpr to TypeVar to work around Mappings being invariant i…
Browse files Browse the repository at this point in the history
…n the key type

See:
- python/typing#445
- python/typing#273
  • Loading branch information
gsakkis committed Sep 14, 2023
1 parent dde2ed0 commit c333423
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 31 deletions.
9 changes: 4 additions & 5 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
wrap_with_actions,
)
from beanie.odm.bulk import BulkWriter, Operation
from beanie.odm.fields import FieldExpr, IndexModel, PydanticObjectId
from beanie.odm.fields import IndexModel, PydanticObjectId
from beanie.odm.interfaces.find import BaseSettings, FindInterface
from beanie.odm.interfaces.update import UpdateMethods
from beanie.odm.links import Link, LinkedModelMixin, LinkInfo, LinkTypes
Expand Down Expand Up @@ -417,8 +417,7 @@ async def replace(
)

use_revision_id = self._settings.use_revision
find_query: Dict[FieldExpr, Any] = {"_id": self.id}

find_query = {"_id": self.id}
if use_revision_id and not ignore_revision:
find_query["revision_id"] = self._previous_revision_id
try:
Expand Down Expand Up @@ -548,7 +547,7 @@ async def replace_many(
:return: None
"""
ids_list = [document.id for document in documents]
if await cls.find(In(cls.id, ids_list)).count() != len(ids_list): # type: ignore[arg-type]
if await cls.find(In("_id", ids_list)).count() != len(ids_list):
raise ReplaceError(
"Some of the documents are not exist in the collection"
)
Expand Down Expand Up @@ -582,7 +581,7 @@ async def update(
arguments = list(args)
use_revision_id = self._settings.use_revision

find_query: Dict[FieldExpr, Any] = {
find_query = {
"_id": self.id if self.id is not None else PydanticObjectId()
}
if use_revision_id and not ignore_revision:
Expand Down
4 changes: 2 additions & 2 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from functools import cached_property
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Tuple, TypeVar, Union

import bson
import pymongo
Expand Down Expand Up @@ -97,7 +97,7 @@ def __deepcopy__(self, memo: dict) -> Self:
return self


FieldExpr = Union[ExpressionField, str]
FieldExpr = TypeVar("FieldExpr", bound=Union[ExpressionField, str])


def convert_field_exprs_to_str(expression: Any) -> Any:
Expand Down
8 changes: 4 additions & 4 deletions beanie/odm/interfaces/update.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Mapping, Optional, cast
from typing import Any, Mapping, Optional

from pymongo.client_session import ClientSession

Expand Down Expand Up @@ -51,7 +51,7 @@ class Sample(Document):
:return: self
"""
return self.update(
cast(Mapping[FieldExpr, Any], Set(expression)),
Set(expression),
session=session,
bulk_writer=bulk_writer,
**pymongo_kwargs,
Expand All @@ -75,7 +75,7 @@ def current_date(
:return: self
"""
return self.update(
cast(Mapping[FieldExpr, Any], CurrentDate(expression)),
CurrentDate(expression),
session=session,
bulk_writer=bulk_writer,
**pymongo_kwargs,
Expand Down Expand Up @@ -110,7 +110,7 @@ class Sample(Document):
:return: self
"""
return self.update(
cast(Mapping[FieldExpr, Any], Inc(expression)),
Inc(expression),
session=session,
bulk_writer=bulk_writer,
**pymongo_kwargs,
Expand Down
11 changes: 7 additions & 4 deletions beanie/odm/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,17 @@ def eval_type(cls, t: Any) -> Type["Document"]:

async def fetch_link(self, field: FieldExpr) -> None:
if isinstance(field, ExpressionField):
field = str(field)
ref_obj = getattr(self, field, None)
attr = str(field)
else:
assert isinstance(field, str)
attr = field
ref_obj = getattr(self, attr, None)
if isinstance(ref_obj, Link):
value = await ref_obj.fetch(fetch_links=True)
setattr(self, field, value)
setattr(self, attr, value)
elif isinstance(ref_obj, list) and ref_obj:
values = await Link.fetch_list(ref_obj, fetch_links=True)
setattr(self, field, values)
setattr(self, attr, values)

async def fetch_all_links(self) -> None:
await asyncio.gather(
Expand Down
19 changes: 8 additions & 11 deletions beanie/odm/queries/find/many.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import beanie
from beanie.odm.bulk import BulkWriter
from beanie.odm.fields import (
ExpressionField,
FieldExpr,
SortDirection,
convert_field_exprs_to_str,
Expand Down Expand Up @@ -145,20 +144,18 @@ def sort(
pass
elif isinstance(arg, list):
self.sort(*arg)
elif isinstance(arg, tuple):
self._add_sort(*arg)
else:
self._add_sort(arg)
if isinstance(arg, tuple):
key, direction = arg
else:
key = arg
direction = None
self._add_sort(convert_field_exprs_to_str(key), direction)
return self

def _add_sort(
self, key: FieldExpr, direction: Optional[SortDirection] = None
):
if isinstance(key, ExpressionField):
key = str(key)
elif not isinstance(key, str):
def _add_sort(self, key: str, direction: Optional[SortDirection]) -> None:
if not isinstance(key, str):
raise TypeError(f"Sort key must be a string, not {type(key)}")

if direction is None:
if key.startswith("-"):
direction = SortDirection.DESCENDING
Expand Down
7 changes: 2 additions & 5 deletions beanie/odm/queries/find/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Any,
Generator,
Generic,
List,
Mapping,
Optional,
Type,
Expand Down Expand Up @@ -202,9 +201,8 @@ async def count(self) -> int:
:return: int
"""
if self.fetch_links:
args = cast(List[Mapping[FieldExpr, Any]], self.find_expressions)
return await self.document_model.find_many(
*args,
*self.find_expressions,
session=self.session,
fetch_links=self.fetch_links,
**self.pymongo_kwargs,
Expand All @@ -226,9 +224,8 @@ async def _find(self, use_cache: bool, parse: bool) -> Optional[ModelT]:
doc = await self._find(use_cache=False, parse=False)
cache.set(cache_key, doc)
elif self.fetch_links:
args = cast(List[Mapping[FieldExpr, Any]], self.find_expressions)
doc = await self.document_model.find_many(
*args,
*self.find_expressions,
session=self.session,
fetch_links=self.fetch_links,
projection_model=self.projection_model,
Expand Down

0 comments on commit c333423

Please sign in to comment.