diff --git a/fastapi_jsonapi/api.py b/fastapi_jsonapi/api.py index 93ea4606..9fbdfd40 100644 --- a/fastapi_jsonapi/api.py +++ b/fastapi_jsonapi/api.py @@ -78,6 +78,7 @@ def __init__( pagination_default_offset: Optional[int] = None, pagination_default_limit: Optional[int] = None, methods: Iterable[str] = (), + max_cache_size: int = 0, ) -> None: """ Initialize router items. @@ -127,7 +128,7 @@ def __init__( self.pagination_default_number: Optional[int] = pagination_default_number self.pagination_default_offset: Optional[int] = pagination_default_offset self.pagination_default_limit: Optional[int] = pagination_default_limit - self.schema_builder = SchemaBuilder(resource_type=resource_type) + self.schema_builder = SchemaBuilder(resource_type=resource_type, max_cache_size=max_cache_size) dto = self.schema_builder.create_schemas( schema=schema, diff --git a/fastapi_jsonapi/schema_builder.py b/fastapi_jsonapi/schema_builder.py index 3db08eeb..7601e2d1 100644 --- a/fastapi_jsonapi/schema_builder.py +++ b/fastapi_jsonapi/schema_builder.py @@ -1,5 +1,6 @@ """JSON API schemas builder class.""" from dataclasses import dataclass +from functools import lru_cache from typing import ( Any, Callable, @@ -122,8 +123,16 @@ class SchemaBuilder: def __init__( self, resource_type: str, + max_cache_size: int = 0, ): self._resource_type = resource_type + self._init_cache(max_cache_size) + + def _init_cache(self, max_cache_size: int): + # TODO: remove crutch + self._get_info_from_schema_for_building_cached = lru_cache(maxsize=max_cache_size)( + self._get_info_from_schema_for_building_cached, + ) def _create_schemas_objects_list(self, schema: Type[BaseModel]) -> Type[JSONAPIResultListSchema]: object_jsonapi_list_schema, list_jsonapi_schema = self.build_list_schemas(schema) @@ -187,7 +196,7 @@ def build_schema_in( ) -> Tuple[Type[BaseJSONAPIDataInSchema], Type[BaseJSONAPIItemInSchema]]: base_schema_name = schema_in.__name__.removesuffix("Schema") + schema_name_suffix - dto = self._get_info_from_schema_for_building( + dto = self._get_info_from_schema_for_building_wrapper( base_name=base_schema_name, schema=schema_in, non_optional_relationships=non_optional_relationships, @@ -258,6 +267,40 @@ def build_list_schemas( includes=includes, ) + def _get_info_from_schema_for_building_cached( + self, + base_name: str, + schema: Type[BaseModel], + includes: Iterable[str], + non_optional_relationships: bool, + ): + return self._get_info_from_schema_for_building( + base_name=base_name, + schema=schema, + includes=includes, + non_optional_relationships=non_optional_relationships, + ) + + def _get_info_from_schema_for_building_wrapper( + self, + base_name: str, + schema: Type[BaseModel], + includes: Iterable[str] = not_passed, + non_optional_relationships: bool = False, + ): + """ + Wrapper function for return cached schema result + """ + if includes is not not_passed: + includes = tuple(includes) + + return self._get_info_from_schema_for_building_cached( + base_name=base_name, + schema=schema, + includes=includes, + non_optional_relationships=non_optional_relationships, + ) + def _get_info_from_schema_for_building( self, base_name: str, @@ -494,7 +537,7 @@ def create_jsonapi_object_schemas( if includes is not not_passed: includes = set(includes) - dto = self._get_info_from_schema_for_building( + dto = self._get_info_from_schema_for_building_wrapper( base_name=base_name, schema=schema, includes=includes, diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index b0a68075..70998e58 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -232,10 +232,11 @@ def build_app_custom( resource_type: str = "misc", class_list: Type[ListViewBase] = ListViewBaseGeneric, class_detail: Type[DetailViewBase] = DetailViewBaseGeneric, + max_cache_size: int = 0, ) -> FastAPI: router: APIRouter = APIRouter() - RoutersJSONAPI( + jsonapi_routers = RoutersJSONAPI( router=router, path=path, tags=["Misc"], @@ -246,6 +247,7 @@ def build_app_custom( schema_in_patch=schema_in_patch, schema_in_post=schema_in_post, model=model, + max_cache_size=max_cache_size, ) app = build_app_plain() @@ -254,6 +256,9 @@ def build_app_custom( atomic = AtomicOperations() app.include_router(atomic.router, prefix="") init(app) + + app.jsonapi_routers = jsonapi_routers + return app diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 82f3941d..0ea5ab9d 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -6,6 +6,7 @@ from itertools import chain, zip_longest from json import dumps, loads from typing import Dict, List, Literal, Set, Tuple +from unittest.mock import call, patch from uuid import UUID, uuid4 import pytest @@ -20,6 +21,7 @@ from starlette.datastructures import QueryParams from fastapi_jsonapi.api import RoutersJSONAPI +from fastapi_jsonapi.schema_builder import SchemaBuilder from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests from tests.fixtures.app import build_alphabet_app, build_app_custom @@ -52,6 +54,8 @@ CustomUUIDItemAttributesSchema, PostAttributesBaseSchema, PostCommentAttributesBaseSchema, + PostCommentSchema, + PostSchema, SelfRelationshipAttributesSchema, SelfRelationshipSchema, UserAttributesBaseSchema, @@ -360,6 +364,215 @@ async def test_select_custom_fields_for_includes_without_requesting_includes( "meta": {"count": 1, "totalPages": 1}, } + def _get_clear_mock_calls(self, mock_obj) -> list[call]: + mock_calls = mock_obj.mock_calls + return [call_ for call_ in mock_calls if call_ not in [call.__len__(), call.__str__()]] + + def _prepare_info_schema_calls_to_assert(self, mock_calls) -> list[call]: + calls_to_check = [] + for wrapper_call in mock_calls: + kwargs = wrapper_call.kwargs + kwargs["includes"] = sorted(kwargs["includes"], key=lambda x: x) + + calls_to_check.append( + call( + *wrapper_call.args, + **kwargs, + ), + ) + + return sorted( + calls_to_check, + key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]), + ) + + async def test_check_get_info_schema_cache( + self, + user_1: User, + ): + resource_type = "user_with_cache" + with suppress(KeyError): + RoutersJSONAPI.all_jsonapi_routers.pop(resource_type) + + app_with_cache = build_app_custom( + model=User, + schema=UserSchema, + schema_in_post=UserInSchemaAllowIdOnPost, + schema_in_patch=UserPatchSchema, + resource_type=resource_type, + # set cache size to enable caching + max_cache_size=128, + ) + + target_func_name = "_get_info_from_schema_for_building" + url = app_with_cache.url_path_for(f"get_{resource_type}_list") + params = { + "include": "posts,posts.comments", + } + + expected_len_with_cache = 6 + expected_len_without_cache = 10 + + with patch.object( + SchemaBuilder, + target_func_name, + wraps=app_with_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building, + ) as wrapped_func: + async with AsyncClient(app=app_with_cache, base_url="http://test") as client: + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + + calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func)) + + # there are no duplicated calls + assert calls_to_check == sorted( + [ + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts"], + non_optional_relationships=False, + ), + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts", "posts.comments"], + non_optional_relationships=False, + ), + call( + base_name="PostSchema", + schema=PostSchema, + includes=[], + non_optional_relationships=False, + ), + call( + base_name="PostSchema", + schema=PostSchema, + includes=["comments"], + non_optional_relationships=False, + ), + call( + base_name="PostCommentSchema", + schema=PostCommentSchema, + includes=[], + non_optional_relationships=False, + ), + call( + base_name="PostCommentSchema", + schema=PostCommentSchema, + includes=["posts"], + non_optional_relationships=False, + ), + ], + key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]), + ) + assert wrapped_func.call_count == expected_len_with_cache + + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + + # there are no new calls + assert wrapped_func.call_count == expected_len_with_cache + + resource_type = "user_without_cache" + with suppress(KeyError): + RoutersJSONAPI.all_jsonapi_routers.pop(resource_type) + + app_without_cache = build_app_custom( + model=User, + schema=UserSchema, + schema_in_post=UserInSchemaAllowIdOnPost, + schema_in_patch=UserPatchSchema, + resource_type=resource_type, + max_cache_size=0, + ) + + with patch.object( + SchemaBuilder, + target_func_name, + wraps=app_without_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building, + ) as wrapped_func: + async with AsyncClient(app=app_without_cache, base_url="http://test") as client: + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + + calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func)) + + # there are duplicated calls + assert calls_to_check == sorted( + [ + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts"], + non_optional_relationships=False, + ), + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts"], + non_optional_relationships=False, + ), # duplicate + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts", "posts.comments"], + non_optional_relationships=False, + ), + call( + base_name="PostSchema", + schema=PostSchema, + includes=[], + non_optional_relationships=False, + ), + call( + base_name="PostSchema", + schema=PostSchema, + includes=[], + non_optional_relationships=False, + ), # duplicate + call( + base_name="PostSchema", + schema=PostSchema, + includes=[], + non_optional_relationships=False, + ), # duplicate + call( + base_name="PostSchema", + schema=PostSchema, + includes=["comments"], + non_optional_relationships=False, + ), + call( + base_name="PostSchema", + schema=PostSchema, + includes=["comments"], + non_optional_relationships=False, + ), # duplicate + call( + base_name="PostCommentSchema", + schema=PostCommentSchema, + includes=[], + non_optional_relationships=False, + ), + call( + base_name="PostCommentSchema", + schema=PostCommentSchema, + includes=["posts"], + non_optional_relationships=False, + ), # duplicate + ], + key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]), + ) + + assert wrapped_func.call_count == expected_len_without_cache + + response = await client.get(url, params=params) + assert response.status_code == status.HTTP_200_OK, response.text + + # there are new calls + assert wrapped_func.call_count == expected_len_without_cache * 2 + class TestCreatePostAndComments: async def test_get_posts_with_users( @@ -371,6 +584,13 @@ async def test_get_posts_with_users( user_1_posts: List[Post], user_2_posts: List[Post], ): + call( + base_name="UserSchema", + schema=UserSchema, + includes=["posts"], + non_optional_relationships=False, + on_optional_relationships=False, + ) url = app.url_path_for("get_post_list") url = f"{url}?include=user" response = await client.get(url)