Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql alchemy integration #27

Merged
merged 14 commits into from Jan 15, 2022
2 changes: 1 addition & 1 deletion .flake8
@@ -1,4 +1,4 @@
[flake8]
max-line-length = 120
max-complexity = 12
ignore = E501, C408, B008, W503
ignore = E501, C408, B008, B009, W503
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Expand Up @@ -59,9 +59,10 @@ repos:
pydantic_factories,
pyyaml,
starlette,
sqlalchemy
]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.930"
rev: "v0.931"
hooks:
- id: mypy
exclude: "test_*"
Expand All @@ -74,4 +75,5 @@ repos:
pydantic,
pydantic_factories,
starlette,
sqlalchemy
]
2 changes: 1 addition & 1 deletion mypy.ini
@@ -1,5 +1,5 @@
[mypy]
# plugins = pydantic.mypy
plugins = pydantic.mypy, sqlalchemy.ext.mypy.plugin

warn_unused_ignores = True
warn_redundant_casts = True
Expand Down
244 changes: 220 additions & 24 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -50,7 +50,7 @@ pytest = "*"
pytest-asyncio = "*"
pytest-cov = "*"
uvicorn = "*"
anyio = "^3.4.0"
sqlalchemy = {extras = ["mypy"], version = "*"}

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
34 changes: 29 additions & 5 deletions starlite/app.py
@@ -1,9 +1,9 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

from openapi_schema_pydantic import OpenAPI, Schema
from openapi_schema_pydantic.util import construct_open_api_with_schema_class
from pydantic import Extra, validate_arguments
from pydantic.typing import NoArgAnyCallable
from pydantic.typing import AnyCallable, NoArgAnyCallable
from starlette.datastructures import State
from starlette.exceptions import ExceptionMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
Expand All @@ -20,7 +20,9 @@
from starlite.config import CORSConfig, OpenAPIConfig
from starlite.enums import MediaType
from starlite.exceptions import HTTPException
from starlite.handlers import BaseRouteHandler
from starlite.openapi.path_item import create_path_item
from starlite.plugins.base import PluginProtocol
from starlite.provide import Provide
from starlite.request import Request
from starlite.response import Response
Expand All @@ -32,6 +34,7 @@
MiddlewareProtocol,
ResponseHeader,
)
from starlite.utils import create_function_signature_model

DEFAULT_OPENAPI_CONFIG = OpenAPIConfig(title="Starlite API", version="1.0.0")

Expand All @@ -41,6 +44,7 @@ class Starlite(Router):
def __init__(
self,
*,
route_handlers: List[ControllerRouterHandler],
allowed_hosts: Optional[List[str]] = None,
cors_config: Optional[CORSConfig] = None,
debug: bool = False,
Expand All @@ -54,10 +58,11 @@ def __init__(
redirect_slashes: bool = True,
response_class: Optional[Type[Response]] = None,
response_headers: Optional[Dict[str, ResponseHeader]] = None,
route_handlers: List[ControllerRouterHandler],
plugins: Optional[List[PluginProtocol]] = None
):
self.debug = debug
self.state = State()
self.plugins = plugins or []
super().__init__(
dependencies=dependencies,
guards=guards,
Expand Down Expand Up @@ -89,11 +94,30 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
await self.middleware_stack(scope, receive, send)

def register(self, value: ControllerRouterHandler) -> None:
super().register(value=value)
def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override]
"""
Register a Controller, Route instance or RouteHandler on the app.

Calls Router.register() and then creates a signature model for all handlers.
"""
handlers = super().register(value=value)
for route_handler in handlers:
self.create_handler_signature_model(route_handler=route_handler)
if hasattr(self, "asgi_router"):
self.asgi_router.routes = self.routes # type: ignore

def create_handler_signature_model(self, route_handler: BaseRouteHandler) -> None:
"""
Creates function signature models for all route handler functions and provider dependencies
"""
if not route_handler.signature_model:
route_handler.signature_model = create_function_signature_model(
fn=cast(AnyCallable, route_handler.fn), plugins=self.plugins
)
for provider in list(route_handler.resolve_dependencies().values()):
if not provider.signature_model:
provider.signature_model = create_function_signature_model(fn=provider.dependency, plugins=self.plugins)

def build_middleware_stack(
self,
user_middleware: List[Union[Middleware, Type[BaseHTTPMiddleware], Type[MiddlewareProtocol]]],
Expand Down
4 changes: 4 additions & 0 deletions starlite/exceptions.py
Expand Up @@ -24,6 +24,10 @@ def __repr__(self) -> str:
return self.__class__.__name__


class MissingDependencyException(ImportError, StarLiteException):
pass


class HTTPException(StarLiteException, StarletteHTTPException):
status_code = HTTP_500_INTERNAL_SERVER_ERROR
extra: Optional[Dict[str, Any]] = None
Expand Down
69 changes: 49 additions & 20 deletions starlite/handlers.py
Expand Up @@ -21,11 +21,12 @@
ImproperlyConfiguredException,
ValidationException,
)
from starlite.plugins.base import PluginMapping, get_plugin_for_value
from starlite.provide import Provide
from starlite.request import Request, WebSocket, get_model_kwargs_from_connection
from starlite.response import Response
from starlite.types import File, Guard, Redirect, ResponseHeader, Stream
from starlite.utils.model import create_function_signature_model
from starlite.utils import SignatureModel

if TYPE_CHECKING: # pragma: no cover
from starlite.routing import Router
Expand All @@ -35,6 +36,16 @@ class _empty:
"""Placeholder"""


def get_signature_model(value: Any) -> Type[SignatureModel]:
"""
Helper function to retrieve and validate the signature model from a provider or handler
"""
try:
return cast(Type[SignatureModel], getattr(value, "signature_model"))
except AttributeError as e: # pragma: no cover
raise ImproperlyConfiguredException(f"The 'signature_model' attribute for {value} is not set") from e


class BaseRouteHandler(BaseModel):
class Config:
arbitrary_types_allowed = True
Expand All @@ -49,7 +60,7 @@ class Config:
owner: Optional[Union[Controller, "Router"]] = None
resolved_dependencies: Union[Dict[str, Provide], Type[_empty]] = _empty
resolved_guards: Union[List[Guard], Type[_empty]] = _empty
signature_model: Optional[Type[BaseModel]] = None
signature_model: Optional[Type[SignatureModel]] = None

def resolve_guards(self) -> List[Guard]:
"""Returns all guards in the handlers scope, starting from highest to current layer"""
Expand Down Expand Up @@ -113,30 +124,46 @@ async def get_parameters_from_connection(self, connection: HTTPConnection) -> Di
"""
Parse the signature_model of the route handler return values matching function parameter keys as well as dependencies
"""
assert self.signature_model, "route handler has no signature model"
signature_model = get_signature_model(self)
try:
# dependency injection
dependencies: Dict[str, Any] = {}
for key, provider in self.resolve_dependencies().items():
provider_signature_model = get_signature_model(provider)
provider_kwargs = await get_model_kwargs_from_connection(
connection=connection, fields=provider.signature_model.__fields__
connection=connection, fields=provider_signature_model.__fields__
)
value = provider(**provider.signature_model(**provider_kwargs).dict())
value = provider(**provider_signature_model(**provider_kwargs).dict())
if isawaitable(value):
value = await value
dependencies[key] = value
model_kwargs = await get_model_kwargs_from_connection(
connection=connection,
fields={k: v for k, v in self.signature_model.__fields__.items() if k not in dependencies},
fields={k: v for k, v in signature_model.__fields__.items() if k not in dependencies},
)
# we return the model's attributes as a dict in order to preserve any nested models
fields = list(self.signature_model.__fields__.keys())
return {
key: self.signature_model( # pylint: disable=not-callable
**model_kwargs, **dependencies
).__getattribute__(key)
for key in fields
}
fields = list(signature_model.__fields__.keys())

output: Dict[str, Any] = {}
modelled_signature = signature_model(**model_kwargs, **dependencies)
for key in fields:
value = modelled_signature.__getattribute__(key)
plugin_mapping: Optional[PluginMapping] = signature_model.field_plugin_mappings.get(key)
if plugin_mapping:
if isinstance(value, (list, tuple)):
output[key] = [
plugin_mapping.plugin.from_pydantic_model_instance(
plugin_mapping.model_class, pydantic_model_instance=v
)
for v in value
]
else:
output[key] = plugin_mapping.plugin.from_pydantic_model_instance(
plugin_mapping.model_class, pydantic_model_instance=value
)
else:
output[key] = value
return output
except ValidationError as e:
raise ValidationException(
detail=f"Validation failed for {connection.method if isinstance(connection, Request) else 'websocket'} {connection.url}:\n\n{display_errors(e.errors())}"
Expand Down Expand Up @@ -170,7 +197,6 @@ def __call__(self, fn: AnyCallable) -> "HTTPRouteHandler":
Replaces a function with itself
"""
self.fn = fn
self.signature_model = create_function_signature_model(fn)
self.validate_handler_function()
return self

Expand Down Expand Up @@ -295,7 +321,12 @@ async def handle_request(self, request: Request) -> StarletteResponse:
return StreamingResponse(
content=data.iterator, status_code=status_code, media_type=media_type, headers=headers
)

plugin = get_plugin_for_value(data, request.app.plugins)
if plugin:
if isinstance(data, (list, tuple)):
data = [plugin.to_dict(datum) for datum in data]
else:
data = plugin.to_dict(data)
response_class = self.resolve_response_class()
return response_class(
headers=headers,
Expand Down Expand Up @@ -334,7 +365,6 @@ def __call__(self, fn: AnyCallable) -> "WebsocketRouteHandler":
Replaces a function with itself
"""
self.fn = fn
self.signature_model = create_function_signature_model(fn)
self.validate_handler_function()
return self

Expand All @@ -343,11 +373,10 @@ def validate_handler_function(self) -> None:
Validates the route handler function once it's set by inspecting its return annotations
"""
super().validate_handler_function()
signature_model = cast(BaseModel, self.signature_model)
return_annotation = Signature.from_callable(cast(AnyCallable, self.fn)).return_annotation
signature = Signature.from_callable(cast(AnyCallable, self.fn))

assert return_annotation is None, "websocket handler functions should return 'None' values"
assert "socket" in signature_model.__fields__, "websocket handlers must set a 'socket' kwarg"
assert signature.return_annotation is None, "websocket handler functions should return 'None' values"
assert "socket" in signature.parameters, "websocket handlers must set a 'socket' kwarg"

async def handle_websocket(self, web_socket: WebSocket) -> None:
"""
Expand Down
7 changes: 3 additions & 4 deletions starlite/openapi/path_item.py
@@ -1,13 +1,13 @@
from typing import TYPE_CHECKING, cast

from openapi_schema_pydantic import Operation, PathItem
from pydantic import BaseModel
from pydantic.typing import AnyCallable
from starlette.routing import get_name

from starlite.openapi.parameters import create_parameters
from starlite.openapi.request_body import create_request_body
from starlite.openapi.responses import create_responses
from starlite.utils.model import create_function_signature_model

if TYPE_CHECKING: # pragma: no cover
from starlite.routing import HTTPRoute
Expand All @@ -20,8 +20,7 @@ def create_path_item(route: "HTTPRoute", create_examples: bool) -> PathItem:
path_item = PathItem()
for http_method, route_handler in route.route_handler_map.items():
if route_handler.include_in_schema:
route_handler_fn = cast(AnyCallable, route_handler.fn)
handler_fields = create_function_signature_model(fn=route_handler_fn).__fields__
handler_fields = cast(BaseModel, route_handler.signature_model).__fields__
parameters = (
create_parameters(
route_handler=route_handler,
Expand All @@ -32,7 +31,7 @@ def create_path_item(route: "HTTPRoute", create_examples: bool) -> PathItem:
or None
)
raises_validation_error = bool("data" in handler_fields or path_item.parameters or parameters)
handler_name = get_name(route_handler_fn)
handler_name = get_name(cast(AnyCallable, route_handler.fn))
request_body = None
if "data" in handler_fields:
request_body = create_request_body(field=handler_fields["data"], generate_examples=create_examples)
Expand Down
Empty file added starlite/plugins/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions starlite/plugins/base.py
@@ -0,0 +1,51 @@
from typing import Any, Dict, List, NamedTuple, Optional, TypeVar

from pydantic import BaseModel
from typing_extensions import Protocol, Type, get_args, runtime_checkable

T = TypeVar("T")


@runtime_checkable
class PluginProtocol(Protocol[T]): # pragma: no cover
def to_pydantic_model_class(self, model_class: Type[T], **kwargs: Any) -> Type[BaseModel]: # pragma: no cover
"""
Given a model_class T, convert it to a pydantic model class
"""
...

@staticmethod
def is_plugin_supported_type(value: Any) -> bool:
"""
Given a value of indeterminate type, determine if this value is supported by the plugin.
"""
...

def from_pydantic_model_instance(self, model_class: Type[T], pydantic_model_instance: BaseModel) -> T:
"""
Given a dict of parsed values, create an instance of the plugin's model class
"""
...

def to_dict(self, model_instance: T) -> Dict[str, Any]:
"""
Given an instance of the model, return a dictionary of values that can be serialized
"""
...


def get_plugin_for_value(value: Any, plugins: List[PluginProtocol]) -> Optional[PluginProtocol]:
"""Helper function to returns a plugin to handle a given value, if any plugin supports it"""
if value and isinstance(value, (list, tuple)):
value = value[0]
if get_args(value):
value = get_args(value)[0]
for plugin in plugins:
if plugin.is_plugin_supported_type(value):
return plugin
return None


class PluginMapping(NamedTuple):
plugin: PluginProtocol
model_class: Any