Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Example of adding hooks to mutation #78

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions platformics/api/core/strawberry_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from strawberry.extensions import FieldExtension
from strawberry.field import StrawberryField
from strawberry.types import Info
from typing import Any, Awaitable, Callable


def get_func_with_only_deps(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]:
Expand All @@ -35,6 +36,52 @@ def get_func_with_only_deps(func: typing.Callable[..., typing.Any]) -> typing.Ca
return newfunc


class RegisteredPlatformicsPlugins:
plugins: dict[str, typing.Callable[..., typing.Any]] = {}

@classmethod
def register(cls, callback_order: str, type: str, action: str, callback: typing.Callable[..., typing.Any]) -> None:
cls.plugins[f"{callback_order}:{type}:{action}"] = callback

@classmethod
def getCallback(cls, callback_order: str, type: str, action: str) -> typing.Callable[..., typing.Any] | None:
return cls.plugins.get(f"{callback_order}:{type}:{action}")


def register_plugin(callback_order: str, type: str, action: str) -> Callable[..., Callable[..., Any]]:
def decorator_register(func: Callable[..., Any]) -> Callable[..., Any]:
RegisteredPlatformicsPlugins.register(callback_order, type, action, func)
return func

return decorator_register


class PlatformicsPluginExtension(FieldExtension):
def __init__(self, type: str, action: str) -> None:
self.type = type
self.action = action
self.strawberry_field_names = ["self"]

async def resolve_async(
self,
next_: typing.Callable[..., typing.Any],
source: typing.Any,
info: Info,
**kwargs: dict[str, typing.Any],
) -> typing.Any:
before_callback = RegisteredPlatformicsPlugins.getCallback("before", self.type, self.action)
if before_callback:
before_callback(source, info, **kwargs)

result = await next_(source, info, **kwargs)

after_callback = RegisteredPlatformicsPlugins.getCallback("after", self.type, self.action)
if after_callback:
result = after_callback(result, source, info, **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should before and after callbacks accept the same number of args?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd also send the result to the after callbacks so that they can manipulate the result


return result


class DependencyExtension(FieldExtension):
def __init__(self) -> None:
self.dependency_args: list[typing.Any] = []
Expand Down
8 changes: 4 additions & 4 deletions platformics/codegen/templates/api/types/class_name.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ from fastapi import Depends
from platformics.api.core.errors import PlatformicsError
from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal, is_system_user
from platformics.api.core.query_input_types import aggregator_map, orderBy, EnumComparators, DatetimeComparators, IntComparators, FloatComparators, StrComparators, UUIDComparators, BoolComparators
from platformics.api.core.strawberry_extensions import DependencyExtension
from platformics.api.core.strawberry_extensions import DependencyExtension, PlatformicsPluginExtension
from platformics.security.authorization import CerbosAction, get_resource_query
from sqlalchemy import inspect
from sqlalchemy.engine.row import RowMapping
Expand Down Expand Up @@ -500,7 +500,7 @@ async def resolve_{{ cls.plural_snake_name }}_aggregate(
return aggregate_output

{%- if cls.create_fields %}
@strawberry.mutation(extensions=[DependencyExtension()])
@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "create")])
async def create_{{ cls.snake_name }}(
input: {{ cls.name }}CreateInput,
session: AsyncSession = Depends(get_db_session, use_cache=False),
Expand Down Expand Up @@ -559,7 +559,7 @@ async def create_{{ cls.snake_name }}(


{%- if cls.mutable_fields %}
@strawberry.mutation(extensions=[DependencyExtension()])
@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "update")])
async def update_{{ cls.snake_name }}(
input: {{ cls.name }}UpdateInput,
where: {{ cls.name }}WhereClauseMutations,
Expand Down Expand Up @@ -634,7 +634,7 @@ async def update_{{ cls.snake_name }}(
{%- endif %}


@strawberry.mutation(extensions=[DependencyExtension()])
@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "delete")])
async def delete_{{ cls.snake_name }}(
where: {{ cls.name }}WhereClauseMutations,
session: AsyncSession = Depends(get_db_session, use_cache=False),
Expand Down
2 changes: 2 additions & 0 deletions test_app/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ async def patched_session() -> typing.AsyncGenerator[AsyncSession, None]:
def raise_exception() -> str:
raise Exception("Unexpected error")


# Subclass Query with an additional field to test Exception handling.
@strawberry.type
class MyQuery(Query):
Expand All @@ -239,6 +240,7 @@ def uncaught_exception(self) -> str:
# Trigger an AttributeException
return self.kaboom # type: ignore


@pytest_asyncio.fixture()
async def api_test_schema(async_db: AsyncDB) -> FastAPI:
"""
Expand Down
11 changes: 11 additions & 0 deletions test_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@

import strawberry
import uvicorn
from api.types.sample import SampleCreateInput
from platformics.api.setup import get_app, get_strawberry_config
from platformics.api.core.error_handler import HandleErrors
from platformics.settings import APISettings
from database import models
from platformics.api.core.strawberry_extensions import register_plugin

from api.mutations import Mutation
from api.queries import Query
from typing import Any
from strawberry.types import Info


@register_plugin("before", "sample", "create")
def validate_sample_name(source: Any, info: Info, **kwargs: SampleCreateInput) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think **kwargs might include everything that gets sent to the resolver, so in addition to SampleCreateInput, session, cerbos_client, principal, is_system_user, etc.

if kwargs["input"].name == "foo":
raise ValueError("Sample name cannot be 'foo'")


settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753
schema = strawberry.Schema(query=Query, mutation=Mutation, config=get_strawberry_config(), extensions=[HandleErrors()])
Expand Down