Skip to content

Commit

Permalink
feat: prefer ANY over IN for postgres (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Oct 18, 2023
1 parent 4826525 commit 9d8cf62
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 52 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.292"
rev: "v0.1.0"
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -34,7 +34,7 @@ repos:
- id: codespell
exclude: "pdm.lock|examples/us_state_lookup.json"
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.0
hooks:
- id: black
args: [--config=./pyproject.toml]
Expand Down Expand Up @@ -70,6 +70,6 @@ repos:
"litestar[cli]",
]
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: "v0.7.0"
rev: "v0.8.1"
hooks:
- id: sphinx-lint
60 changes: 56 additions & 4 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Select,
StatementLambdaElement,
TextClause,
any_,
delete,
lambda_stmt,
over,
Expand Down Expand Up @@ -53,6 +54,9 @@ class SQLAlchemyAsyncRepository(Generic[ModelT]):
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: list[str] | str | None = None
_prefer_any: bool = False
prefer_any_dialects: tuple[str] | None = ("postgresql",)
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""

def __init__(
self,
Expand Down Expand Up @@ -93,6 +97,7 @@ def __init__(
msg = "Session improperly configure"
raise ValueError(msg)
self._dialect = self.session.bind.dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())

@classmethod
def get_id_attribute_value(cls, item: ModelT | type[ModelT], id_attribute: str | None = None) -> Any:
Expand Down Expand Up @@ -255,6 +260,8 @@ async def delete_many(
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
Expand Down Expand Up @@ -325,8 +332,8 @@ def _get_base_stmt(
return lambda_stmt(lambda: statement)
return self.statement if statement is None else statement

@staticmethod
def _get_delete_many_statement(
self,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute,
id_chunk: list[Any],
Expand All @@ -337,7 +344,10 @@ def _get_delete_many_statement(
statement = lambda_stmt(lambda: delete(model_type))
elif statement_type == "select":
statement = lambda_stmt(lambda: select(model_type))
statement += lambda s: s.where(id_attribute.in_(id_chunk))
if self._prefer_any:
statement += lambda s: s.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
else:
statement += lambda s: s.where(id_attribute.in_(id_chunk))
if supports_returning and statement_type != "select":
statement += lambda s: s.returning(model_type)
return statement
Expand Down Expand Up @@ -1073,9 +1083,26 @@ def _apply_filters(
)

elif isinstance(filter_, (NotInCollectionFilter,)):
statement = self._filter_not_in_collection(filter_.field_name, filter_.values, statement=statement)
if filter_.values is not None: # noqa: PD011
if self._prefer_any:
statement = self._filter_not_any_collection(
filter_.field_name,
filter_.values,
statement=statement,
)
else:
statement = self._filter_not_in_collection(
filter_.field_name,
filter_.values,
statement=statement,
)

elif isinstance(filter_, (CollectionFilter,)):
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
if filter_.values is not None: # noqa: PD011
if self._prefer_any:
statement = self._filter_any_collection(filter_.field_name, filter_.values, statement=statement)
else:
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, (OrderBy,)):
statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc")
elif isinstance(filter_, (SearchFilter,)):
Expand Down Expand Up @@ -1124,6 +1151,31 @@ def _filter_not_in_collection(
statement += lambda s: s.where(field.notin_(values))
return statement

def _filter_any_collection(
self,
field_name: str | InstrumentedAttribute,
values: abc.Collection[Any],
statement: StatementLambdaElement,
) -> StatementLambdaElement:
if not values:
statement += lambda s: s.where(text("1=-1"))
return statement
field = get_instrumented_attr(self.model_type, field_name)
statement += lambda s: s.where(any_(values) == field) # type: ignore[arg-type]
return statement

def _filter_not_any_collection(
self,
field_name: str | InstrumentedAttribute,
values: abc.Collection[Any],
statement: StatementLambdaElement,
) -> StatementLambdaElement:
if not values:
return statement
field = get_instrumented_attr(self.model_type, field_name)
statement += lambda s: s.where(any_(values) != field) # type: ignore[arg-type]
return statement

def _filter_on_datetime_field(
self,
field_name: str | InstrumentedAttribute,
Expand Down
60 changes: 56 additions & 4 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Select,
StatementLambdaElement,
TextClause,
any_,
delete,
lambda_stmt,
over,
Expand Down Expand Up @@ -54,6 +55,9 @@ class SQLAlchemySyncRepository(Generic[ModelT]):
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: list[str] | str | None = None
_prefer_any: bool = False
prefer_any_dialects: tuple[str] | None = ("postgresql",)
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""

def __init__(
self,
Expand Down Expand Up @@ -94,6 +98,7 @@ def __init__(
msg = "Session improperly configure"
raise ValueError(msg)
self._dialect = self.session.bind.dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())

@classmethod
def get_id_attribute_value(cls, item: ModelT | type[ModelT], id_attribute: str | None = None) -> Any:
Expand Down Expand Up @@ -256,6 +261,8 @@ def delete_many(
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
Expand Down Expand Up @@ -326,8 +333,8 @@ def _get_base_stmt(
return lambda_stmt(lambda: statement)
return self.statement if statement is None else statement

@staticmethod
def _get_delete_many_statement(
self,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute,
id_chunk: list[Any],
Expand All @@ -338,7 +345,10 @@ def _get_delete_many_statement(
statement = lambda_stmt(lambda: delete(model_type))
elif statement_type == "select":
statement = lambda_stmt(lambda: select(model_type))
statement += lambda s: s.where(id_attribute.in_(id_chunk))
if self._prefer_any:
statement += lambda s: s.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
else:
statement += lambda s: s.where(id_attribute.in_(id_chunk))
if supports_returning and statement_type != "select":
statement += lambda s: s.returning(model_type)
return statement
Expand Down Expand Up @@ -1074,9 +1084,26 @@ def _apply_filters(
)

elif isinstance(filter_, (NotInCollectionFilter,)):
statement = self._filter_not_in_collection(filter_.field_name, filter_.values, statement=statement)
if filter_.values is not None: # noqa: PD011
if self._prefer_any:
statement = self._filter_not_any_collection(
filter_.field_name,
filter_.values,
statement=statement,
)
else:
statement = self._filter_not_in_collection(
filter_.field_name,
filter_.values,
statement=statement,
)

elif isinstance(filter_, (CollectionFilter,)):
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
if filter_.values is not None: # noqa: PD011
if self._prefer_any:
statement = self._filter_any_collection(filter_.field_name, filter_.values, statement=statement)
else:
statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement)
elif isinstance(filter_, (OrderBy,)):
statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc")
elif isinstance(filter_, (SearchFilter,)):
Expand Down Expand Up @@ -1125,6 +1152,31 @@ def _filter_not_in_collection(
statement += lambda s: s.where(field.notin_(values))
return statement

def _filter_any_collection(
self,
field_name: str | InstrumentedAttribute,
values: abc.Collection[Any],
statement: StatementLambdaElement,
) -> StatementLambdaElement:
if not values:
statement += lambda s: s.where(text("1=-1"))
return statement
field = get_instrumented_attr(self.model_type, field_name)
statement += lambda s: s.where(any_(values) == field) # type: ignore[arg-type]
return statement

def _filter_not_any_collection(
self,
field_name: str | InstrumentedAttribute,
values: abc.Collection[Any],
statement: StatementLambdaElement,
) -> StatementLambdaElement:
if not values:
return statement
field = get_instrumented_attr(self.model_type, field_name)
statement += lambda s: s.where(any_(values) != field) # type: ignore[arg-type]
return statement

def _filter_on_datetime_field(
self,
field_name: str | InstrumentedAttribute,
Expand Down

0 comments on commit 9d8cf62

Please sign in to comment.