Skip to content
This repository has been archived by the owner on Jul 8, 2023. It is now read-only.

Commit

Permalink
Draft: support annotate parameter in field to allow ORM annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Jun 30, 2023
1 parent 16922ab commit 54db7ec
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 8 deletions.
15 changes: 14 additions & 1 deletion strawberry_django_plus/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .descriptors import ModelProperty
from .permissions import filter_with_perms
from .utils import resolvers
from .utils.typing import TypeOrSequence
from .utils.typing import TypeOrMapping, TypeOrSequence

if TYPE_CHECKING:
from strawberry_django_plus.type import StrawberryDjangoType
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
**kwargs,
):
Expand All @@ -95,6 +96,7 @@ def __init__(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
)
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -486,6 +488,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> _T:
Expand Down Expand Up @@ -513,6 +516,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> Any:
Expand Down Expand Up @@ -540,6 +544,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> StrawberryDjangoField:
Expand All @@ -566,6 +571,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
Expand Down Expand Up @@ -606,6 +612,7 @@ def field(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand All @@ -631,6 +638,7 @@ def node(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
Expand Down Expand Up @@ -675,6 +683,7 @@ def node(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand All @@ -699,6 +708,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
) -> Any:
...
Expand All @@ -725,6 +735,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
) -> Any:
...
Expand All @@ -749,6 +760,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
Expand Down Expand Up @@ -823,6 +835,7 @@ def connection(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand Down
66 changes: 60 additions & 6 deletions strawberry_django_plus/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from django.db import models
from django.db.models import Prefetch
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import BaseExpression
from django.db.models.fields.reverse_related import (
ManyToManyRel,
ManyToOneRel,
Expand Down Expand Up @@ -53,7 +54,7 @@
get_possible_type_definitions,
get_selections,
)
from .utils.typing import TypeOrSequence
from .utils.typing import TypeOrMapping, TypeOrSequence

try:
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
Expand All @@ -79,6 +80,7 @@
else:
_relation_fields = (models.ManyToManyField, ManyToManyRel, ManyToOneRel)
_sentinel = object()
_annotate_placeholder = "______annotate_placeholder______"
_interfaces: """
defaultdict[
Schema,
Expand All @@ -90,6 +92,8 @@

PrefetchCallable: TypeAlias = Callable[[GraphQLResolveInfo], Prefetch]
PrefetchType: TypeAlias = Union[str, Prefetch, PrefetchCallable]
AnnotateCallable: TypeAlias = Callable[[GraphQLResolveInfo], BaseExpression]
AnnotateType: TypeAlias = Union[BaseExpression, AnnotateCallable]


def _get_prefetch_queryset(
Expand Down Expand Up @@ -207,6 +211,15 @@ def _get_model_hints(
# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
if field_store is not None:
if len(field_store.annotate) == 1 and _annotate_placeholder in field_store.annotate:
# This is a special case where we need to update the field name,
# because when field_store was created on __init__, the field name wasn't available.
# This allows for annotate expressions to be declared as:
# total: int = gql.django.field(annotate=Sum("price")) # noqa: ERA001
# Instead of the more redundant:
# total: int = gql.django.field(annotate={"total": Sum("price")}) # noqa: ERA001
field_store.annotate = {field.name: field_store.annotate[_annotate_placeholder]}

Check warning on line 221 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L221

Added line #L221 was not covered by tests

store |= field_store.with_prefix(prefix, info=info) if prefix else field_store

# Then from the model property if one is defined
Expand Down Expand Up @@ -446,6 +459,8 @@ class OptimizerConfig:
Enable `QuerySet.select_related` optimizations
enable_prefetch_related:
Enable `QuerySet.prefetch_related` optimizations
enable_annotate:
Enable `QuerySet.annotate` optimizations
prefetch_custom_queryset:
Use custom instead of _base_manager for prefetch querysets
Expand All @@ -454,6 +469,7 @@ class OptimizerConfig:
enable_only: bool = dataclasses.field(default=True)
enable_select_related: bool = dataclasses.field(default=True)
enable_prefetch_related: bool = dataclasses.field(default=True)
enable_annotate: bool = dataclasses.field(default=True)
prefetch_custom_queryset: bool = dataclasses.field(default=False)


Expand All @@ -468,20 +484,23 @@ class OptimizerStore:
Set of values to optimize using `QuerySet.select_related`
prefetch_related:
Set of values to optimize using `QuerySet.prefetch_related`
annotate:
Dict to use on `QuerySet.annotate`
"""

only: List[str] = dataclasses.field(default_factory=list)
select_related: List[str] = dataclasses.field(default_factory=list)
prefetch_related: List[PrefetchType] = dataclasses.field(default_factory=list)
annotate: Dict[str, AnnotateType] = dataclasses.field(default_factory=dict)

def __bool__(self):
return any([self.only, self.select_related, self.prefetch_related])
return any([self.only, self.select_related, self.prefetch_related, self.annotate])

def __ior__(self, other: "OptimizerStore"):
self.only.extend(other.only)
self.select_related.extend(other.select_related)
self.prefetch_related.extend(other.prefetch_related)
self.annotate.update(other.annotate)
return self

def __or__(self, other: "OptimizerStore"):
Expand All @@ -492,6 +511,7 @@ def copy(self):
only=self.only[:],
select_related=self.select_related[:],
prefetch_related=self.prefetch_related[:],
annotate=self.annotate.copy(),
)

@classmethod
Expand All @@ -501,6 +521,7 @@ def with_hints(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[PrefetchType]] = None,
annotate: Optional[TypeOrMapping[AnnotateType]] = None,
):
return cls(
only=[only] if isinstance(only, str) else list(only or []),
Expand All @@ -512,6 +533,12 @@ def with_hints(
if isinstance(prefetch_related, (str, Prefetch, Callable))
else list(prefetch_related or [])
),
annotate=(
# placeholder here, because field name is evaluated later on .annotate call:
{_annotate_placeholder: annotate}
if isinstance(annotate, (BaseExpression, Callable))
else dict(annotate or {})
),
)

def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
Expand All @@ -529,10 +556,19 @@ def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
else: # pragma:nocover
assert_never(p)

annotate = {}
for k, v in self.annotate.items():
if isinstance(v, Callable):
assert_type(v, AnnotateCallable)
v = v(info) # noqa: PLW2901

Check warning on line 563 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L561-L563

Added lines #L561 - L563 were not covered by tests

annotate[f"{prefix}{LOOKUP_SEP}{k}"] = v

Check warning on line 565 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L565

Added line #L565 was not covered by tests

return self.__class__(
only=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.only],
select_related=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.select_related],
prefetch_related=prefetch_related,
annotate=annotate,
)

def apply(
Expand Down Expand Up @@ -601,6 +637,17 @@ def apply(
if config.enable_only and self.only:
qs = qs.only(*self.only)

if config.enable_annotate and self.annotate:
to_annotate = {}
for k, v in self.annotate.items():
if isinstance(v, Callable):
assert_type(v, AnnotateCallable)
v = v(info) # noqa: PLW2901

Check warning on line 645 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L641-L645

Added lines #L641 - L645 were not covered by tests

to_annotate[k] = v

Check warning on line 647 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L647

Added line #L647 was not covered by tests

qs = qs.annotate(**to_annotate)

Check warning on line 649 in strawberry_django_plus/optimizer.py

View check run for this annotation

Codecov / codecov/patch

strawberry_django_plus/optimizer.py#L649

Added line #L649 was not covered by tests

return qs


Expand All @@ -620,6 +667,9 @@ class DjangoOptimizerExtension(SchemaExtension):
Enable `QuerySet.select_related` optimizations
enable_prefetch_related_optimization:
Enable `QuerySet.prefetch_related` optimizations
enable_annotate_optimization:
Enable `QuerySet.annotate` optimizations
Examples:
Add the following to your schema configuration.
Expand Down Expand Up @@ -647,13 +697,15 @@ def __init__(
enable_only_optimization: bool = True,
enable_select_related_optimization: bool = True,
enable_prefetch_related_optimization: bool = True,
enable_annotate_optimization: bool = True,
execution_context: Optional[ExecutionContext] = None,
prefetch_custom_queryset: bool = False,
):
super().__init__(execution_context=execution_context) # type: ignore
self._enable_ony = enable_only_optimization
self._enable_only = enable_only_optimization
self._enable_select_related = enable_select_related_optimization
self._enable_prefetch_related = enable_prefetch_related_optimization
self._enable_annotate = enable_annotate_optimization
self._prefetch_custom_queryset = prefetch_custom_queryset

def on_execute(self) -> Generator[None, None, None]:
Expand Down Expand Up @@ -684,10 +736,11 @@ def resolve(
if ret._result_cache is None: # type: ignore
config = OptimizerConfig(
enable_only=(
self._enable_ony and info.operation.operation == OperationType.QUERY
self._enable_only and info.operation.operation == OperationType.QUERY
),
enable_select_related=self._enable_select_related,
enable_prefetch_related=self._enable_prefetch_related,
enable_annotate=self._enable_annotate,
prefetch_custom_queryset=self._prefetch_custom_queryset,
)
return resolvers.resolve_qs(optimize(qs=ret, info=info, config=config))
Expand All @@ -705,9 +758,10 @@ def optimize(
return qs

config = OptimizerConfig(
enable_only=self._enable_ony and info.operation.operation == OperationType.QUERY,
enable_only=self._enable_only and info.operation.operation == OperationType.QUERY,
enable_select_related=self._enable_select_related,
enable_prefetch_related=self._enable_prefetch_related,
enable_annotate=self._enable_annotate,
prefetch_custom_queryset=self._prefetch_custom_queryset,
)
return optimize(qs, info, config=config, store=store)
3 changes: 2 additions & 1 deletion strawberry_django_plus/utils/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, Sequence, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Sequence, TypeVar, Union

from django.contrib.auth.base_user import AbstractBaseUser
from graphql.type.definition import GraphQLResolveInfo
Expand All @@ -14,6 +14,7 @@

DictTree: TypeAlias = Dict[str, "DictTree"]
TypeOrSequence: TypeAlias = Union[_T, Sequence[_T]]
TypeOrMapping: TypeAlias = Union[_T, Mapping[str, _T]]
TypeOrIterable: TypeAlias = Union[_T, Iterable[_T]]
UserType: TypeAlias = Union[AbstractBaseUser, "AnonymousUser"]
ResolverInfo: TypeAlias = Union[Info[StrawberryDjangoContext, Any], GraphQLResolveInfo]
Expand Down

0 comments on commit 54db7ec

Please sign in to comment.