Skip to content

Commit

Permalink
added simple cache (#87)
Browse files Browse the repository at this point in the history
* added simple cache

* micro fix style

---------

Co-authored-by: Suren Khorenyan <s.khorenyan@mts.ai>
  • Loading branch information
CosmoV and mahenzon committed Apr 27, 2024
1 parent b9cfe93 commit 907d0a8
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 4 deletions.
3 changes: 2 additions & 1 deletion fastapi_jsonapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
pagination_default_offset: Optional[int] = None,
pagination_default_limit: Optional[int] = None,
methods: Iterable[str] = (),
max_cache_size: int = 0,
) -> None:
"""
Initialize router items.
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self.pagination_default_number: Optional[int] = pagination_default_number
self.pagination_default_offset: Optional[int] = pagination_default_offset
self.pagination_default_limit: Optional[int] = pagination_default_limit
self.schema_builder = SchemaBuilder(resource_type=resource_type)
self.schema_builder = SchemaBuilder(resource_type=resource_type, max_cache_size=max_cache_size)

dto = self.schema_builder.create_schemas(
schema=schema,
Expand Down
47 changes: 45 additions & 2 deletions fastapi_jsonapi/schema_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""JSON API schemas builder class."""
from dataclasses import dataclass
from functools import lru_cache
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -122,8 +123,16 @@ class SchemaBuilder:
def __init__(
self,
resource_type: str,
max_cache_size: int = 0,
):
self._resource_type = resource_type
self._init_cache(max_cache_size)

def _init_cache(self, max_cache_size: int):
# TODO: remove crutch
self._get_info_from_schema_for_building_cached = lru_cache(maxsize=max_cache_size)(
self._get_info_from_schema_for_building_cached,
)

def _create_schemas_objects_list(self, schema: Type[BaseModel]) -> Type[JSONAPIResultListSchema]:
object_jsonapi_list_schema, list_jsonapi_schema = self.build_list_schemas(schema)
Expand Down Expand Up @@ -187,7 +196,7 @@ def build_schema_in(
) -> Tuple[Type[BaseJSONAPIDataInSchema], Type[BaseJSONAPIItemInSchema]]:
base_schema_name = schema_in.__name__.removesuffix("Schema") + schema_name_suffix

dto = self._get_info_from_schema_for_building(
dto = self._get_info_from_schema_for_building_wrapper(
base_name=base_schema_name,
schema=schema_in,
non_optional_relationships=non_optional_relationships,
Expand Down Expand Up @@ -258,6 +267,40 @@ def build_list_schemas(
includes=includes,
)

def _get_info_from_schema_for_building_cached(
self,
base_name: str,
schema: Type[BaseModel],
includes: Iterable[str],
non_optional_relationships: bool,
):
return self._get_info_from_schema_for_building(
base_name=base_name,
schema=schema,
includes=includes,
non_optional_relationships=non_optional_relationships,
)

def _get_info_from_schema_for_building_wrapper(
self,
base_name: str,
schema: Type[BaseModel],
includes: Iterable[str] = not_passed,
non_optional_relationships: bool = False,
):
"""
Wrapper function for return cached schema result
"""
if includes is not not_passed:
includes = tuple(includes)

return self._get_info_from_schema_for_building_cached(
base_name=base_name,
schema=schema,
includes=includes,
non_optional_relationships=non_optional_relationships,
)

def _get_info_from_schema_for_building(
self,
base_name: str,
Expand Down Expand Up @@ -494,7 +537,7 @@ def create_jsonapi_object_schemas(
if includes is not not_passed:
includes = set(includes)

dto = self._get_info_from_schema_for_building(
dto = self._get_info_from_schema_for_building_wrapper(
base_name=base_name,
schema=schema,
includes=includes,
Expand Down
7 changes: 6 additions & 1 deletion tests/fixtures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,11 @@ def build_app_custom(
resource_type: str = "misc",
class_list: Type[ListViewBase] = ListViewBaseGeneric,
class_detail: Type[DetailViewBase] = DetailViewBaseGeneric,
max_cache_size: int = 0,
) -> FastAPI:
router: APIRouter = APIRouter()

RoutersJSONAPI(
jsonapi_routers = RoutersJSONAPI(
router=router,
path=path,
tags=["Misc"],
Expand All @@ -246,6 +247,7 @@ def build_app_custom(
schema_in_patch=schema_in_patch,
schema_in_post=schema_in_post,
model=model,
max_cache_size=max_cache_size,
)

app = build_app_plain()
Expand All @@ -254,6 +256,9 @@ def build_app_custom(
atomic = AtomicOperations()
app.include_router(atomic.router, prefix="")
init(app)

app.jsonapi_routers = jsonapi_routers

return app


Expand Down
220 changes: 220 additions & 0 deletions tests/test_api/test_api_sqla_with_includes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import chain, zip_longest
from json import dumps, loads
from typing import Dict, List, Literal, Set, Tuple
from unittest.mock import call, patch
from uuid import UUID, uuid4

import pytest
Expand All @@ -20,6 +21,7 @@
from starlette.datastructures import QueryParams

from fastapi_jsonapi.api import RoutersJSONAPI
from fastapi_jsonapi.schema_builder import SchemaBuilder
from fastapi_jsonapi.views.view_base import ViewBase
from tests.common import is_postgres_tests
from tests.fixtures.app import build_alphabet_app, build_app_custom
Expand Down Expand Up @@ -52,6 +54,8 @@
CustomUUIDItemAttributesSchema,
PostAttributesBaseSchema,
PostCommentAttributesBaseSchema,
PostCommentSchema,
PostSchema,
SelfRelationshipAttributesSchema,
SelfRelationshipSchema,
UserAttributesBaseSchema,
Expand Down Expand Up @@ -360,6 +364,215 @@ async def test_select_custom_fields_for_includes_without_requesting_includes(
"meta": {"count": 1, "totalPages": 1},
}

def _get_clear_mock_calls(self, mock_obj) -> list[call]:
mock_calls = mock_obj.mock_calls
return [call_ for call_ in mock_calls if call_ not in [call.__len__(), call.__str__()]]

def _prepare_info_schema_calls_to_assert(self, mock_calls) -> list[call]:
calls_to_check = []
for wrapper_call in mock_calls:
kwargs = wrapper_call.kwargs
kwargs["includes"] = sorted(kwargs["includes"], key=lambda x: x)

calls_to_check.append(
call(
*wrapper_call.args,
**kwargs,
),
)

return sorted(
calls_to_check,
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)

async def test_check_get_info_schema_cache(
self,
user_1: User,
):
resource_type = "user_with_cache"
with suppress(KeyError):
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)

app_with_cache = build_app_custom(
model=User,
schema=UserSchema,
schema_in_post=UserInSchemaAllowIdOnPost,
schema_in_patch=UserPatchSchema,
resource_type=resource_type,
# set cache size to enable caching
max_cache_size=128,
)

target_func_name = "_get_info_from_schema_for_building"
url = app_with_cache.url_path_for(f"get_{resource_type}_list")
params = {
"include": "posts,posts.comments",
}

expected_len_with_cache = 6
expected_len_without_cache = 10

with patch.object(
SchemaBuilder,
target_func_name,
wraps=app_with_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
) as wrapped_func:
async with AsyncClient(app=app_with_cache, base_url="http://test") as client:
response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))

# there are no duplicated calls
assert calls_to_check == sorted(
[
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
),
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts", "posts.comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=["posts"],
non_optional_relationships=False,
),
],
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)
assert wrapped_func.call_count == expected_len_with_cache

response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

# there are no new calls
assert wrapped_func.call_count == expected_len_with_cache

resource_type = "user_without_cache"
with suppress(KeyError):
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)

app_without_cache = build_app_custom(
model=User,
schema=UserSchema,
schema_in_post=UserInSchemaAllowIdOnPost,
schema_in_patch=UserPatchSchema,
resource_type=resource_type,
max_cache_size=0,
)

with patch.object(
SchemaBuilder,
target_func_name,
wraps=app_without_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
) as wrapped_func:
async with AsyncClient(app=app_without_cache, base_url="http://test") as client:
response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))

# there are duplicated calls
assert calls_to_check == sorted(
[
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
),
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
), # duplicate
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts", "posts.comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=["posts"],
non_optional_relationships=False,
), # duplicate
],
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)

assert wrapped_func.call_count == expected_len_without_cache

response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

# there are new calls
assert wrapped_func.call_count == expected_len_without_cache * 2


class TestCreatePostAndComments:
async def test_get_posts_with_users(
Expand All @@ -371,6 +584,13 @@ async def test_get_posts_with_users(
user_1_posts: List[Post],
user_2_posts: List[Post],
):
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
on_optional_relationships=False,
)
url = app.url_path_for("get_post_list")
url = f"{url}?include=user"
response = await client.get(url)
Expand Down

0 comments on commit 907d0a8

Please sign in to comment.