Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- fix tests (inversion never noticed)
- allow traversing nested attributes with embed_parent
- allow prefetching across unidirectional relations (and simplify it)
- prefetches are now handled in parallel
- use run_sync for prefetches instead of the loop stuff
- move crawl_relationship for preventing import loop
- replace extra usage on the way
  • Loading branch information
devkral committed Jul 3, 2024
1 parent ef22048 commit 2ab7088
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 186 deletions.
183 changes: 66 additions & 117 deletions edgy/core/db/models/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from edgy.core.db.fields.base import RelationshipField
from edgy.core.db.models.base import EdgyBaseModel
from edgy.core.db.relationships.utils import crawl_relationship
from edgy.core.utils.sync import run_sync
from edgy.exceptions import QuerySetError

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -30,7 +32,7 @@ def from_sqla_row(
is_defer_fields: bool = False,
exclude_secrets: bool = False,
using_schema: Union[str, None] = None,
) -> Optional[Type["Model"]]:
) -> Optional["Model"]:
"""
Class method to convert a SQLAlchemy Row result into a EdgyModel row type.
Expand All @@ -52,11 +54,18 @@ def from_sqla_row(

for related in select_related:
field_name = related.split("__", 1)[0]
field = cls.meta.fields_mapping[field_name]
try:
field = cls.meta.fields_mapping[field_name]
except KeyError:
raise QuerySetError(
detail=f"Selected field \"{field_name}\" does not exist on {cls}."
) from None
if isinstance(field, RelationshipField):
model_class, _, remainder = field.traverse_field(related)
else:
raise Exception("invalid field")
raise QuerySetError(
detail=f"Selected field \"{field_name}\" is not a RelationshipField on {cls}."
) from None
if remainder:
item[field_name] = model_class.from_sqla_row(
row,
Expand Down Expand Up @@ -114,7 +123,7 @@ def from_sqla_row(

# We need to generify the model fields to make sure we can populate the
# model without mandatory fields
model = cast("Type[Model]", cls.proxy_model(**item))
model = cast("Model", cls.proxy_model(**item))

# Apply the schema to the model
model = cls.__apply_schema(model, using_schema)
Expand All @@ -133,7 +142,7 @@ def from_sqla_row(
item[column.name] = row[column]

model = (
cast("Type[Model]", cls(**item)) if not exclude_secrets else cast("Type[Model]", cls.proxy_model(**item))
cast("Model", cls(**item)) if not exclude_secrets else cast("Model", cls.proxy_model(**item))
)

# Apply the schema to the model
Expand All @@ -144,11 +153,11 @@ def from_sqla_row(
return model

@classmethod
def __apply_schema(cls, model: Type["Model"], schema: Optional[str] = None) -> Type["Model"]:
def __apply_schema(cls, model: "Model", schema: Optional[str] = None) -> "Model":
# Apply the schema to the model
if schema is not None:
model.table = model.build(schema) # type: ignore
model.proxy_model.table = model.proxy_model.build(schema) # type: ignore
model.table = model.build(schema)
model.proxy_model.table = model.proxy_model.build(schema)
return model

@classmethod
Expand All @@ -162,131 +171,71 @@ def __should_ignore_related_name(cls, related_name: str, select_related: Sequenc
return True
return False


@staticmethod
def __check_prefetch_collision(model: "Model", related: "Prefetch") -> None:
if hasattr(model, related.to_attr) or related.to_attr in model.meta.fields_mapping or related.to_attr in model.meta.managers:
raise QuerySetError(
f"Conflicting attribute to_attr='{related.related_name}' with '{related.to_attr}' in {model.__class__.__name__}"
)

@classmethod
async def __set_prefetch(cls, row: "Row", model: "Model", related: "Prefetch") -> None:
cls.__check_prefetch_collision(model, related)
clauses = []
for pkcol in cls.pkcolumns:
clauses.append(getattr(model.table.columns, pkcol) == row[pkcol])
queryset = related.queryset
crawl_result = crawl_relationship(model.__class__, related.related_name, traverse_last=True)
if queryset is None:
if crawl_result.reverse_path is False:
queryset = model.__class__.query.all()
else:
queryset = crawl_result.model_class.query.all()

if queryset.model_class == model.__class__:
# queryset is of this model
queryset = queryset.select_related(related.related_name)
queryset.embed_parent = (related.related_name, "")
elif crawl_result.reverse_path is False:
QuerySetError(
detail=(
f"Creating a reverse path is not possible, unidirectional fields used."
f"You may want to use as queryset a queryset of model class {model!r}."
)
)
else:
# queryset is of the target model
queryset = queryset.select_related(crawl_result.reverse_path)

setattr(model, related.to_attr, await queryset.filter(*clauses))

@classmethod
def __handle_prefetch_related(
cls,
row: "Row",
model: Type["Model"],
# for advancing
model: "Model",
prefetch_related: Sequence["Prefetch"],
parent_cls: Optional[Type["Model"]] = None,
# for going back
reverse_path: str = "",
is_nested: bool = False,
) -> Type["Model"]:
) -> "Model":
"""
Handles any prefetch related scenario from the model.
Loads in advance all the models needed for a specific record
Recursively checks for the related field and validates if there is any conflicting
attribute. If there is, a `QuerySetError` is raised.
"""
if not parent_cls:
parent_cls = model

for related in prefetch_related:
if not is_nested:
# Check for conflicting names
# If to_attr has the same name of any
if hasattr(parent_cls, related.to_attr):
raise QuerySetError(
f"Conflicting attribute to_attr='{related.related_name}' with '{related.to_attr}' in {parent_cls.__class__.__name__}"
)

if not is_nested:
reverse_path = ""

if "__" in related.related_name:
first_part, remainder = related.related_name.split("__", 1)

field = cls.meta.fields_mapping[first_part]
if isinstance(field, RelationshipField):
model_class, reverse_part, remainder = field.traverse_field(related.related_name)
if not reverse_part:
raise Exception("No reverse relation possible (missing related_name)")
else:
raise Exception("invalid field")

# Build the new nested Prefetch object
remainder_prefetch = related.__class__(
related_name=remainder, to_attr=related.to_attr, queryset=related.queryset
)
if reverse_path:
reverse_path = f"{reverse_part}__{reverse_path}"
else:
reverse_path = reverse_part

# Recursively continue the process of handling the
# new prefetch
model_class.__handle_prefetch_related(
row,
model,
prefetch_related=[remainder_prefetch],
reverse_path=reverse_path,
parent_cls=model,
is_nested=True,
)
queries = []

# Check for individual not nested querysets
elif related.queryset is not None and not is_nested:
extra = {}
for pkcol in cls.pkcolumns:
filter_by_pk = row[pkcol]
extra[f"{related.related_name}__{pkcol}"] = filter_by_pk
related.queryset.extra = extra

# Execute the queryset
records = asyncio.get_event_loop().run_until_complete(cls.run_query(queryset=related.queryset))
setattr(model, related.to_attr, records)
else:
records = cls.process_nested_prefetch_related(
row,
prefetch_related=related,
reverse_path=reverse_path,
parent_cls=model,
queryset=related.queryset,
)
for related in prefetch_related:

setattr(model, related.to_attr, records)
# Check for conflicting names
# If to_attr has the same name of any
cls.__check_prefetch_collision(model=model, related=related)
queries.append(cls.__set_prefetch(row=row, model=model, related=related))
run_sync(asyncio.gather(*queries))
return model

@classmethod
def process_nested_prefetch_related(
cls,
row: "Row",
prefetch_related: "Prefetch",
parent_cls: Type["Model"],
reverse_path: str,
queryset: "QuerySet",
) -> List[Type["Model"]]:
"""
Processes the nested prefetch related names.
"""
# Get the related field
field = cls.meta.fields_mapping[prefetch_related.related_name]
if isinstance(field, RelationshipField):
model_class, reverse_part, remainder = field.traverse_field(prefetch_related.related_name)
if not reverse_part:
raise Exception("No backward relation possible (missing related_name)")
else:
raise Exception("invalid field")

if reverse_path:
reverse_path = f"{reverse_part}__{reverse_path}"
else:
reverse_path = reverse_part

# TODO: related_field.clean would be better
# fix this later when allowing selecting fields for fireign keys
# Extract foreign key value
extra = {}
for pkcol in parent_cls.pkcolumns:
filter_by_pk = row[pkcol]
extra[f"{reverse_path}__{pkcol}"] = filter_by_pk

records = asyncio.get_event_loop().run_until_complete(cls.run_query(model_class, extra, queryset))
return records

@classmethod
async def run_query(
cls,
Expand All @@ -302,6 +251,6 @@ async def run_query(
return await model.query.filter(**extra) # type: ignore

if extra:
queryset.extra = extra
queryset = queryset.filter(**extra)

return await queryset
78 changes: 13 additions & 65 deletions edgy/core/db/querysets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
Dict,
Generator,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Set,
Expand All @@ -26,6 +24,7 @@
from edgy.core.db.querysets.mixins import EdgyModel, QuerySetPropsMixin, TenancyMixin
from edgy.core.db.querysets.prefetch import PrefetchMixin
from edgy.core.db.querysets.protocols import AwaitableQuery
from edgy.core.db.relationships.utils import crawl_relationship
from edgy.core.utils.models import DateParser, ModelParser
from edgy.exceptions import MultipleObjectsReturned, ObjectNotFound, QuerySetError
from edgy.protocols.queryset import QuerySetProtocol
Expand All @@ -37,67 +36,6 @@
from edgy.core.db.models import Model


class RelationshipCrawlResult(NamedTuple):
model_class: Type["Model"]
field_name: str
operator: str
forward_path: str
reverse_path: Union[str, Literal[False]]

def crawl_relationship(model_class: Type["Model"], path: str, callback_fn: Any=None) -> RelationshipCrawlResult:
field = None
forward_prefix_path = ""
reverse_path: Union[str, Literal[False]] = ""
operator: str = "exact"
field_name: str = path
while path:
splitted = path.split("__", 1)
field_name = splitted[0]
field = model_class.meta.fields_mapping.get(field_name)
if isinstance(field, RelationshipField) and len(splitted) == 2:
model_class, reverse_part, path = field.traverse_field(path)
if field.is_cross_db():
raise NotImplementedError("We cannot cross databases yet, this feature is planned")
reverse = not isinstance(field, BaseForeignKey)
if reverse_part and reverse_path is not False:
if reverse_path:
reverse_path = f"{reverse_part}__{reverse_path}"
else:
reverse_path = reverse_part
else:
reverse_path = False

if callback_fn:
callback_fn(model_class=model_class, field=field, reverse_path=reverse_path, forward_path=forward_prefix_path, reverse=reverse, operator=None)
if forward_prefix_path:
forward_prefix_path = f"{forward_prefix_path}__{field_name}"
else:
forward_prefix_path = field_name
elif len(splitted) == 2:
if "__" not in splitted[1] and splitted[1] in settings.filter_operators:
operator = splitted[1]
break
else:
raise ValueError(f"Tried to cross field: {field_name} of type {field!r}, remainder: {splitted[1]}")
else:
operator = "exact"
break

if reverse_path is not False:
if reverse_path:
reverse_path = f"{field_name}__{reverse_path}"
else:
reverse_path = field_name
if callback_fn and field is not None:
callback_fn(model_class=model_class, field=field, reverse_path=reverse_path, forward_path=forward_prefix_path, reverse=False, operator=operator)
return RelationshipCrawlResult(
model_class=model_class,
field_name=field_name,
operator=operator,
forward_path=forward_prefix_path,
reverse_path=reverse_path,
)

def clean_query_kwargs(model: Type["Model"], kwargs: Dict[str, Any]) -> Dict[str, Any]:
new_kwargs: Dict[str, Any] = {}
for key, val in kwargs.items():
Expand Down Expand Up @@ -225,11 +163,19 @@ def _build_tables_select_from_relationship(self) -> Any:
former_table = None
while select_path:
field_name = select_path.split("__", 1)[0]
try:
field = model_class.meta.fields_mapping[field_name]
except KeyError:
raise QuerySetError(
detail=f"Selected field \"{field_name}\" does not exist on {model_class}."
) from None
field = model_class.fields[field_name]
if isinstance(field, RelationshipField):
model_class, reverse_part, select_path = field.traverse_field(select_path)
else:
raise ValueError(f"{field_name}: invalid field type: {field!r}")
raise QuerySetError(
detail=f"Selected field \"{field_name}\" is not a RelationshipField on {model_class}."
)
if isinstance(field, BaseForeignKey):
foreign_key = field
reverse = False
Expand Down Expand Up @@ -808,7 +754,9 @@ async def get_or_none(self, **kwargs: Any) -> Union[EdgyModel, None]:
def embed_parent_in_result(self, result: Any) -> Any:
if not self.embed_parent:
return result
new_result = getattr(result, self.embed_parent[0])
new_result = result
for part in self.embed_parent[0].split("__"):
new_result = getattr(new_result, part)
if self.embed_parent[1]:
setattr(new_result, self.embed_parent[1], result)
return new_result
Expand Down
Loading

0 comments on commit 2ab7088

Please sign in to comment.