diff --git a/docs/changelog.rst b/docs/changelog.rst index d048bbac..629268d7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,21 @@ Changelog ######### +**2.5.1** +********* + +Fix custom sql filtering, bring back backward compatibility +=========================================================== + +* Fix custom sql filtering support: bring back backward compatibility by `@mahenzon`_ in `#74 `_ +* Read version from file by `@mahenzon`_ in `#74 `_ + +Authors +""""""" + +* `@mahenzon`_ + + **2.5.0** ********* diff --git a/docs/conf.py b/docs/conf.py index 9bce1286..f08ddf4a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,9 +19,14 @@ import os import sys from datetime import datetime +from pathlib import Path sys.path.insert(0, os.path.abspath("..")) +BASE_DIR = Path(__file__).resolve().parent.parent +VERSION_FILEPATH = BASE_DIR / "fastapi_jsonapi" / "VERSION" +RELEASE_VERSION = VERSION_FILEPATH.read_text().strip() + # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. @@ -64,9 +69,9 @@ # built documents. # # The short X.Y version. -version = "2.5" +version = ".".join(RELEASE_VERSION.split(".", maxsplit=2)[:2]) # The full version, including alpha/beta/rc tags. -release = "2.5.2" +release = RELEASE_VERSION # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/examples/custom_filter_example.py b/examples/custom_filter_example.py index ce529974..4277ae5c 100644 --- a/examples/custom_filter_example.py +++ b/examples/custom_filter_example.py @@ -1,7 +1,8 @@ -from typing import Any +from typing import Any, Union from pydantic.fields import Field, ModelField from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList from fastapi_jsonapi.schema_base import BaseModel @@ -11,7 +12,7 @@ def jsonb_contains_sql_filter( model_column: InstrumentedAttribute, value: dict[Any, Any], operator: str, -) -> tuple[Any, list[Any]]: +) -> Union[BinaryExpression, BooleanClauseList]: """ Any SQLA (or Tortoise) magic here @@ -19,10 +20,9 @@ def jsonb_contains_sql_filter( :param model_column: :param value: any dict :param operator: value 'jsonb_contains' - :return: one sqla filter and list of joins + :return: one sqla filter expression """ - filter_sqla = model_column.op("@>")(value) - return filter_sqla, [] + return model_column.op("@>")(value) class PictureSchema(BaseModel): diff --git a/fastapi_jsonapi/VERSION b/fastapi_jsonapi/VERSION new file mode 100644 index 00000000..73462a5a --- /dev/null +++ b/fastapi_jsonapi/VERSION @@ -0,0 +1 @@ +2.5.1 diff --git a/fastapi_jsonapi/__init__.py b/fastapi_jsonapi/__init__.py index 5d198f45..a9d73a18 100644 --- a/fastapi_jsonapi/__init__.py +++ b/fastapi_jsonapi/__init__.py @@ -1,4 +1,5 @@ """JSON API utils package.""" +from pathlib import Path from fastapi import FastAPI @@ -8,7 +9,7 @@ from fastapi_jsonapi.exceptions.json_api import HTTPException from fastapi_jsonapi.querystring import QueryStringManager -__version__ = "2.5.0" +__version__ = Path(__file__).parent.joinpath("VERSION").read_text().strip() __all__ = [ "init", diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index 3cdfe097..27aea356 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -1,6 +1,7 @@ """Helper to create sqlalchemy filters according to filter querystring parameter""" import inspect import logging +from collections.abc import Sequence from typing import ( Any, Callable, @@ -16,7 +17,7 @@ from pydantic import BaseConfig, BaseModel from pydantic.fields import ModelField from pydantic.validators import _VALIDATORS, find_validators -from sqlalchemy import and_, not_, or_ +from sqlalchemy import and_, false, not_, or_ from sqlalchemy.orm import aliased from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.util import AliasedClass @@ -396,11 +397,83 @@ def prepare_relationships_info( ) +def build_terminal_node_filter_expressions( + filter_item: Dict, + target_schema: Type[TypeSchema], + target_model: Type[TypeModel], + relationships_info: Dict[RelationshipPath, RelationshipFilteringInfo], +): + name: str = filter_item["name"] + if is_relationship_filter(name): + *relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER) + relationship_info: RelationshipFilteringInfo = relationships_info[ + RELATIONSHIP_SPLITTER.join(relationship_path) + ] + model_column = get_model_column( + model=relationship_info.aliased_model, + schema=relationship_info.target_schema, + field_name=field_name, + ) + target_schema = relationship_info.target_schema + else: + field_name = name + model_column = get_model_column( + model=target_model, + schema=target_schema, + field_name=field_name, + ) + + schema_field = target_schema.__fields__[field_name] + + filter_operator = filter_item["op"] + custom_filter_expression: Callable = get_custom_filter_expression_callable( + schema_field=schema_field, + operator=filter_operator, + ) + if custom_filter_expression is None: + return build_filter_expression( + schema_field=schema_field, + model_column=model_column, + operator=get_operator( + model_column=model_column, + operator_name=filter_operator, + ), + value=filter_item["val"], + ) + + custom_call_result = custom_filter_expression( + schema_field=schema_field, + model_column=model_column, + value=filter_item["val"], + operator=filter_operator, + ) + if isinstance(custom_call_result, Sequence): + expected_len = 2 + if len(custom_call_result) != expected_len: + log.error( + "Invalid filter, returned sequence length is not %s: %s, len=%s", + expected_len, + custom_call_result, + len(custom_call_result), + ) + raise InvalidFilters(detail="Custom sql filter backend error.") + log.warning( + "Custom filter result of `[expr, [joins]]` is deprecated." + " Please return only filter expression from now on. " + "(triggered on schema field %s for filter operator %s on column %s)", + schema_field, + filter_operator, + model_column, + ) + custom_call_result = custom_call_result[0] + return custom_call_result + + def build_filter_expressions( - filter_item: Union[dict, list], + filter_item: Dict, target_schema: Type[TypeSchema], target_model: Type[TypeModel], - relationships_info: dict[RelationshipPath, RelationshipFilteringInfo], + relationships_info: Dict[RelationshipPath, RelationshipFilteringInfo], ) -> Union[BinaryExpression, BooleanClauseList]: """ Return sqla expressions. @@ -409,93 +482,59 @@ def build_filter_expressions( in where condition: query(Model).where(build_filter_expressions(...)) """ if is_terminal_node(filter_item): - name = filter_item["name"] + return build_terminal_node_filter_expressions( + filter_item=filter_item, + target_schema=target_schema, + target_model=target_model, + relationships_info=relationships_info, + ) - if is_relationship_filter(name): - *relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER) - relationship_info: RelationshipFilteringInfo = relationships_info[ - RELATIONSHIP_SPLITTER.join(relationship_path) - ] - model_column = get_model_column( - model=relationship_info.aliased_model, - schema=relationship_info.target_schema, - field_name=field_name, - ) - target_schema = relationship_info.target_schema - else: - field_name = name - model_column = get_model_column( - model=target_model, - schema=target_schema, - field_name=field_name, - ) + if not isinstance(filter_item, dict): + log.warning("Could not build filtering expressions %s", locals()) + # dirty. refactor. + return not_(false()) - schema_field = target_schema.__fields__[field_name] + sqla_logic_operators = { + "or": or_, + "and": and_, + "not": not_, + } - custom_filter_expression = get_custom_filter_expression_callable( - schema_field=schema_field, - operator=filter_item["op"], + if len(logic_operators := set(filter_item.keys())) > 1: + msg = ( + f"In each logic node expected one of operators: {set(sqla_logic_operators.keys())} " + f"but got {len(logic_operators)}: {logic_operators}" ) - if custom_filter_expression: - return custom_filter_expression( - schema_field=schema_field, - model_column=model_column, - value=filter_item["val"], - operator=filter_item["op"], - ) - else: - return build_filter_expression( - schema_field=schema_field, - model_column=model_column, - operator=get_operator( - model_column=model_column, - operator_name=filter_item["op"], - ), - value=filter_item["val"], - ) + raise InvalidFilters(msg) - if isinstance(filter_item, dict): - sqla_logic_operators = { - "or": or_, - "and": and_, - "not": not_, - } - - if len(logic_operators := set(filter_item.keys())) > 1: - msg = ( - f"In each logic node expected one of operators: {set(sqla_logic_operators.keys())} " - f"but got {len(logic_operators)}: {logic_operators}" - ) - raise InvalidFilters(msg) - - if (logic_operator := logic_operators.pop()) not in set(sqla_logic_operators.keys()): - msg = f"Not found logic operator {logic_operator} expected one of {set(sqla_logic_operators.keys())}" - raise InvalidFilters(msg) - - op = sqla_logic_operators[logic_operator] - - if logic_operator == "not": - return op( - build_filter_expressions( - filter_item=filter_item[logic_operator], - target_schema=target_schema, - target_model=target_model, - relationships_info=relationships_info, - ), - ) + if (logic_operator := logic_operators.pop()) not in set(sqla_logic_operators.keys()): + msg = f"Not found logic operator {logic_operator} expected one of {set(sqla_logic_operators.keys())}" + raise InvalidFilters(msg) - expressions = [] - for filter_sub_item in filter_item[logic_operator]: - expressions.append( - build_filter_expressions( - filter_item=filter_sub_item, - target_schema=target_schema, - target_model=target_model, - relationships_info=relationships_info, - ), - ) + op = sqla_logic_operators[logic_operator] + + if logic_operator == "not": + return op( + build_filter_expressions( + filter_item=filter_item[logic_operator], + target_schema=target_schema, + target_model=target_model, + relationships_info=relationships_info, + ), + ) + + expressions = [] + for filter_sub_item in filter_item[logic_operator]: + expressions.append( + build_filter_expressions( + filter_item=filter_sub_item, + target_schema=target_schema, + target_model=target_model, + relationships_info=relationships_info, + ), + ) - return op(*expressions) + return op(*expressions) def create_filters_and_joins( diff --git a/pyproject.toml b/pyproject.toml index 1d59ef00..b9d4ab38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ packages = [ [tool.poetry] name = "fastapi-jsonapi" -version = "2.5.0" +version = "2.5.1" description = "FastAPI extension to create REST web api according to JSON:API specification" authors = [ "Aleksei Nekrasov ", diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index 9968e6d8..eadb8fe7 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -243,4 +243,5 @@ def build_app_custom( atomic = AtomicOperations() app.include_router(atomic.router, prefix="") + init(app) return app diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 9b2029df..ec182ff4 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -11,9 +11,11 @@ from fastapi import FastAPI, status from httpx import AsyncClient from pydantic import BaseModel, Field +from pydantic.fields import ModelField from pytest import fixture, mark, param, raises # noqa PT013 -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import InstrumentedAttribute from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests @@ -2112,16 +2114,186 @@ class UserWithNotNullableEmailSchema(UserSchema): response = await client.get(url, params=params) assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text assert response.json() == { - "detail": { - "errors": [ - { - "detail": "The field `email` can't be null", - "source": {"parameter": "filters"}, - "status_code": status.HTTP_400_BAD_REQUEST, - "title": "Invalid filters querystring parameter.", - }, - ], - }, + "errors": [ + { + "detail": "The field `email` can't be null", + "source": {"parameter": "filters"}, + "status_code": status.HTTP_400_BAD_REQUEST, + "title": "Invalid filters querystring parameter.", + }, + ], + } + + async def test_custom_sql_filter_lower_string( + self, + async_session: AsyncSession, + user_1: User, + user_2: User, + ): + resource_type = "user_with_custom_lower_filter" + + assert user_1.id != user_2.id + + def lower_equals_sql_filter( + schema_field: ModelField, + model_column: InstrumentedAttribute, + value: str, + operator: str, + ): + return func.lower(model_column) == func.lower(value) + + class UserWithEmailFieldSchema(UserAttributesBaseSchema): + email: str = Field( + _lower_equals_sql_filter_=lower_equals_sql_filter, + ) + + app = build_app_custom( + model=User, + schema=UserWithEmailFieldSchema, + resource_type=resource_type, + ) + + name, _, domain = user_1.email.partition("@") + user_1.email = f"{name.upper()}@{domain}" + await async_session.commit() + params = { + "filter": dumps( + [ + { + "name": "email", + "op": "lower_equals", + "val": f"{name}@{domain.upper()}", + }, + ], + ), + } + url = app.url_path_for(f"get_{resource_type}_list") + async with AsyncClient(app=app, base_url="http://test") as client: + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json()["data"] + + assert len(response_data) == 1 + assert response_data[0] == { + "id": str(user_1.id), + "type": resource_type, + "attributes": UserWithEmailFieldSchema.from_orm(user_1).dict(), + } + + async def test_custom_sql_filter_lower_string_old_style_with_joins( + self, + caplog, + async_session: AsyncSession, + user_1: User, + user_2: User, + ): + resource_type = "user_with_custom_lower_filter_old_style_joins" + + assert user_1.id != user_2.id + + def lower_equals_sql_filter( + schema_field: ModelField, + model_column: InstrumentedAttribute, + value: str, + operator: str, + ): + return func.lower(model_column) == func.lower(value), [] + + class UserWithEmailFieldFilterSchema(UserAttributesBaseSchema): + email: str = Field( + _lower_equals_sql_filter_=lower_equals_sql_filter, + ) + + app = build_app_custom( + model=User, + schema=UserWithEmailFieldFilterSchema, + resource_type=resource_type, + ) + + name, _, domain = user_1.email.partition("@") + user_1.email = f"{name.upper()}@{domain}" + await async_session.commit() + params = { + "filter": dumps( + [ + { + "name": "email", + "op": "lower_equals", + "val": f"{name}@{domain.upper()}", + }, + ], + ), + } + url = app.url_path_for(f"get_{resource_type}_list") + async with AsyncClient(app=app, base_url="http://test") as client: + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json()["data"] + + assert len(response_data) == 1 + assert response_data[0] == { + "id": str(user_1.id), + "type": resource_type, + "attributes": UserWithEmailFieldFilterSchema.from_orm(user_1).dict(), + } + assert any( + # str from logs + "Please return only filter expression from now on" in record.msg + # check all records + for record in caplog.records + ) + + async def test_custom_sql_filter_invalid_result( + self, + caplog, + async_session: AsyncSession, + user_1: User, + ): + resource_type = "user_with_custom_invalid_sql_filter" + + def returns_invalid_number_of_params_filter( + schema_field: ModelField, + model_column: InstrumentedAttribute, + value: str, + operator: str, + ): + return 1, 2, 3 + + class UserWithInvalidEmailFieldFilterSchema(UserAttributesBaseSchema): + email: str = Field( + _custom_broken_filter_sql_filter_=returns_invalid_number_of_params_filter, + ) + + app = build_app_custom( + model=User, + schema=UserWithInvalidEmailFieldFilterSchema, + resource_type=resource_type, + ) + + params = { + "filter": dumps( + [ + { + "name": "email", + "op": "custom_broken_filter", + "val": "qwerty", + }, + ], + ), + } + url = app.url_path_for(f"get_{resource_type}_list") + async with AsyncClient(app=app, base_url="http://test") as client: + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text + assert response.json() == { + "errors": [ + { + "detail": "Custom sql filter backend error.", + "source": {"parameter": "filters"}, + "status_code": status.HTTP_400_BAD_REQUEST, + "title": "Invalid filters querystring parameter.", + }, + ], } async def test_composite_filter_by_one_field( diff --git a/tests/test_api/test_validators.py b/tests/test_api/test_validators.py index dae0ed10..ec0bba65 100644 --- a/tests/test_api/test_validators.py +++ b/tests/test_api/test_validators.py @@ -190,16 +190,14 @@ async def execute_request_and_check_response( res = await client.post(url, json=body) assert res.status_code == status.HTTP_400_BAD_REQUEST, res.text assert res.json() == { - "detail": { - "errors": [ - { - "detail": expected_detail, - "source": {"pointer": ""}, - "status_code": status.HTTP_400_BAD_REQUEST, - "title": "Bad Request", - }, - ], - }, + "errors": [ + { + "detail": expected_detail, + "source": {"pointer": ""}, + "status_code": status.HTTP_400_BAD_REQUEST, + "title": "Bad Request", + }, + ], } async def execute_request_twice_and_check_response(