Skip to content

Commit

Permalink
CORS refactor (WIP) (#105)
Browse files Browse the repository at this point in the history
* WIP cors refactor

* fix unit tests

* add cors handler tests
  • Loading branch information
keotl committed Jul 4, 2020
1 parent 86ff0c5 commit 98e8992
Show file tree
Hide file tree
Showing 22 changed files with 264 additions and 114 deletions.
9 changes: 7 additions & 2 deletions e2e_test/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import e2e_test.app.static
from e2e_test.app import components
from jivago.config.debug_jivago_context import DebugJivagoContext
from jivago.config.router.cors_rule import CorsRule
from jivago.config.router.filtering.filtering_rule import FilteringRule
from jivago.config.router.router_builder import RouterBuilder
from jivago.jivago_application import JivagoApplication
from jivago.wsgi.filter.system_filters.default_filters import JIVAGO_DEFAULT_FILTERS
from jivago.wsgi.routing.routing_rule import RoutingRule
from jivago.wsgi.routing.serving.static_file_routing_table import StaticFileRoutingTable
from jivago.wsgi.routing.table.auto_discovering_routing_table import AutoDiscoveringRoutingTable
Expand All @@ -25,10 +27,13 @@ def configure_service_locator(self):

def create_router_config(self) -> RouterBuilder:
return RouterBuilder() \
.add_rule(FilteringRule("*", self.get_default_filters())) \
.add_rule(FilteringRule("*", JIVAGO_DEFAULT_FILTERS)) \
.add_rule(RoutingRule("/api", AutoDiscoveringRoutingTable(self.registry, self.root_package_name))) \
.add_rule(RoutingRule("/static", StaticFileRoutingTable(os.path.dirname(e2e_test.app.static.__file__),
allowed_extensions=['.txt'])))
allowed_extensions=['.txt']))) \
.add_rule(CorsRule("/", {'Access-Control-Allow-Origin': 'http://jivago.io',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*'}))


application = JivagoApplication(components, context=DemoContext)
Expand Down
9 changes: 7 additions & 2 deletions e2e_test/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from e2e_test import tests
from e2e_test.app import components
from jivago.config.debug_jivago_context import DebugJivagoContext
from jivago.config.router.cors_rule import CorsRule
from jivago.config.router.filtering.auto_discovering_filtering_rule import AutoDiscoveringFilteringRule
from jivago.config.router.filtering.filtering_rule import FilteringRule
from jivago.config.router.router_builder import RouterBuilder
from jivago.jivago_application import JivagoApplication
from jivago.lang.annotations import Override
from jivago.wsgi.filter.system_filters.default_filters import JIVAGO_DEFAULT_FILTERS
from jivago.wsgi.routing.routing_rule import RoutingRule
from jivago.wsgi.routing.serving.static_file_routing_table import StaticFileRoutingTable
from jivago.wsgi.routing.table.auto_discovering_routing_table import AutoDiscoveringRoutingTable
Expand All @@ -30,11 +32,14 @@ def configure_service_locator(self):

def create_router_config(self) -> RouterBuilder:
return RouterBuilder() \
.add_rule(FilteringRule("*", self.get_default_filters())) \
.add_rule(FilteringRule("*", JIVAGO_DEFAULT_FILTERS)) \
.add_rule(AutoDiscoveringFilteringRule("*", self.registry, self.root_package_name)) \
.add_rule(RoutingRule("/api", AutoDiscoveringRoutingTable(self.registry, self.root_package_name))) \
.add_rule(RoutingRule("/static", StaticFileRoutingTable(os.path.dirname(e2e_test.app.static.__file__),
allowed_extensions=['.txt'])))
allowed_extensions=['.txt']))) \
.add_rule(CorsRule("/", {'Access-Control-Allow-Origin': 'http://jivago.io',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*'}))

def get_banner(self) -> List[str]:
return []
Expand Down
53 changes: 53 additions & 0 deletions e2e_test/tests/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import anachronos

from e2e_test.runner import http


class CorsTests(anachronos.TestCase):

def test_cors_headers_injected_on_success(self):
res = http.get("/")

self.assertEqual("http://jivago.io", res.headers["Access-Control-Allow-Origin"])
self.assertEqual("*", res.headers["Access-Control-Allow-Headers"])
self.assertEqual("*", res.headers["Access-Control-Allow-Methods"])

def test_cors_headers_injected_on_error(self):
res = http.get("/error")

self.assertEqual("http://jivago.io", res.headers["Access-Control-Allow-Origin"])
self.assertEqual("*", res.headers["Access-Control-Allow-Headers"])
self.assertEqual("*", res.headers["Access-Control-Allow-Methods"])

def test_cors_preflight_succeeds_for_allowed_origin(self):
res = http.options("/", headers={"Origin": "http://jivago.io"})

self.assertEqual(200, res.status_code)
self.assertEqual("http://jivago.io", res.headers["Access-Control-Allow-Origin"])
self.assertEqual("*", res.headers["Access-Control-Allow-Headers"])
self.assertEqual("*", res.headers["Access-Control-Allow-Methods"])

def test_cors_preflight_fails_on_missing_origin(self):
res = http.options("/")

self.assertEqual(400, res.status_code)
self.assertEqual("http://jivago.io", res.headers["Access-Control-Allow-Origin"])
self.assertEqual("*", res.headers["Access-Control-Allow-Headers"])
self.assertEqual("*", res.headers["Access-Control-Allow-Methods"])

def test_cors_preflight_fails_on_disallowed_origin(self):
res = http.options("/", headers={"Origin": "http://hello.example.com"})

self.assertEqual(400, res.status_code)
self.assertEqual("http://jivago.io", res.headers["Access-Control-Allow-Origin"])
self.assertEqual("*", res.headers["Access-Control-Allow-Headers"])
self.assertEqual("*", res.headers["Access-Control-Allow-Methods"])

def test_cors_preflight_fails_on_unknown_path(self):
res = http.options("/foo/bar/unknown/path")

self.assertEqual(404, res.status_code)


if __name__ == '__main__':
anachronos.run_tests()
16 changes: 7 additions & 9 deletions jivago/config/debug_jivago_context.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from typing import List, Type

from jivago.config.production_jivago_context import ProductionJivagoContext
from jivago.config.router.cors_rule import CorsRule
from jivago.config.router.router_builder import RouterBuilder
from jivago.lang.annotations import Override
from jivago.lang.registry import Registry
from jivago.wsgi.filter.filter import Filter
from jivago.wsgi.filter.system_filters.error_handling.debug_exception_filter import DebugExceptionFilter
from jivago.wsgi.filter.system_filters.error_handling.unknown_exception_filter import UnknownExceptionFilter


class DebugJivagoContext(ProductionJivagoContext):
"""Jivago context for easy development. Automatically configures CORS requests and provides stacktrace logging."""

def __init__(self, root_package: "Module", registry: Registry, banner: bool = True):
super().__init__(root_package, registry, banner=banner)

@Override
def configure_service_locator(self):
super().configure_service_locator()
self.serviceLocator.bind(UnknownExceptionFilter, DebugExceptionFilter)

@Override
def create_router_config(self) -> RouterBuilder:
return super().create_router_config().add_rule(CorsRule("/", {"Access-Control-Allow-Origin": '*',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*'}))

@Override
def get_default_filters(self) -> List[Type[Filter]]:
production_filters = super().get_default_filters()
production_filters.remove(UnknownExceptionFilter)
return [DebugExceptionFilter] + production_filters
29 changes: 10 additions & 19 deletions jivago/config/production_jivago_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Type, Union
from typing import List

from jivago.config.abstract_context import AbstractContext
from jivago.config.exception_mapper_binder import ExceptionMapperBinder
Expand All @@ -26,18 +26,14 @@
from jivago.serialization.deserializer import Deserializer
from jivago.serialization.object_mapper import ObjectMapper
from jivago.serialization.serializer import Serializer
from jivago.templating.template_filter import TemplateFilter
from jivago.templating.view_template_repository import ViewTemplateRepository
from jivago.wsgi.annotations import Resource
from jivago.wsgi.filter.filter import Filter
from jivago.wsgi.filter.system_filters.banner_filter import BannerFilter, DummyBannerFilter
from jivago.wsgi.filter.system_filters.body_serialization_filter import BodySerializationFilter
from jivago.wsgi.filter.system_filters.error_handling.application_exception_filter import ApplicationExceptionFilter
from jivago.wsgi.filter.system_filters.error_handling.unknown_exception_filter import UnknownExceptionFilter
from jivago.wsgi.filter.system_filters.jivago_banner_filter import JivagoBannerFilter
from jivago.wsgi.request.http_form_deserialization_filter import HttpFormDeserializationFilter
from jivago.wsgi.filter.system_filters.default_filters import JIVAGO_DEFAULT_FILTERS
from jivago.wsgi.request.http_status_code_resolver import HttpStatusCodeResolver
from jivago.wsgi.request.json_serialization_filter import JsonSerializationFilter
from jivago.wsgi.request.partial_content_handler import PartialContentHandler
from jivago.wsgi.routing.cors.cors_headers_injection_filter import CorsHeadersInjectionFilter
from jivago.wsgi.routing.routing_rule import RoutingRule
from jivago.wsgi.routing.table.auto_discovering_routing_table import AutoDiscoveringRoutingTable

Expand Down Expand Up @@ -70,14 +66,15 @@ def configure_service_locator(self):
cache = ScopeCache(scope, scoped_classes)
self.serviceLocator.register_scope(cache)

Stream(self.get_default_filters()).forEach(lambda f: self.serviceLocator.bind(f, f))
Stream(JIVAGO_DEFAULT_FILTERS).forEach(lambda f: self.serviceLocator.bind(f, f))

# TODO better way to handle Jivago Dependencies
self.serviceLocator.bind(Registry, Registry.INSTANCE)
self.serviceLocator.bind(TaskScheduler, TaskScheduler(self.serviceLocator))
self.serviceLocator.bind(Deserializer, Deserializer(Registry.INSTANCE))
self.serviceLocator.bind(Serializer, Serializer())
self.serviceLocator.bind(ViewTemplateRepository, ViewTemplateRepository(self.get_views_folder_path()))
self.serviceLocator.bind(CorsHeadersInjectionFilter, CorsHeadersInjectionFilter)
self.serviceLocator.bind(BodySerializationFilter, BodySerializationFilter)
self.serviceLocator.bind(PartialContentHandler, PartialContentHandler)
self.serviceLocator.bind(HttpStatusCodeResolver, HttpStatusCodeResolver)
Expand All @@ -90,6 +87,9 @@ def configure_service_locator(self):

ExceptionMapperBinder().bind(self.serviceLocator)

if not self.banner:
self.serviceLocator.bind(BannerFilter, DummyBannerFilter)

def scopes(self) -> List[type]:
return [Singleton, BackgroundWorker]

Expand All @@ -104,19 +104,10 @@ def get_config_file_locations(self) -> List[str]:
@Override
def create_router_config(self) -> RouterBuilder:
return RouterBuilder() \
.add_rule(FilteringRule("*", self.get_default_filters())) \
.add_rule(FilteringRule("*", JIVAGO_DEFAULT_FILTERS)) \
.add_rule(AutoDiscoveringFilteringRule("*", self.registry, self.root_package_name)) \
.add_rule(RoutingRule("/", AutoDiscoveringRoutingTable(self.registry, self.root_package_name)))

def get_default_filters(self) -> List[Union[Filter, Type[Filter]]]:
default_filters = [UnknownExceptionFilter, TemplateFilter, JsonSerializationFilter,
HttpFormDeserializationFilter,
BodySerializationFilter, ApplicationExceptionFilter]
if self.banner:
default_filters.append(JivagoBannerFilter)

return default_filters

def create_event_bus(self) -> EventBus:
return ReflectiveEventBusInitializer(self.service_locator(), self.registry,
self.root_package_name).create_message_bus()
12 changes: 7 additions & 5 deletions jivago/config/router/router_builder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from jivago.config.router.cors_rule import CorsRule
from jivago.config.router.filtering.filtering_rule import FilteringRule
from jivago.config.router.router_config_rule import RouterConfigRule
from jivago.inject.service_locator import ServiceLocator
from jivago.lang.registry import Registry
from jivago.serialization.deserializer import Deserializer
from jivago.wsgi.filter.filter_chain_factory import FilterChainFactory
from jivago.config.router.filtering.filtering_rule import FilteringRule
from jivago.wsgi.invocation.route_handler_factory import RouteHandlerFactory
from jivago.wsgi.request.request_factory import RequestFactory
from jivago.wsgi.routing.cors.cors_request_handler_factory import CorsRequestHandlerFactory
from jivago.config.router.cors_rule import CorsRule
from jivago.wsgi.routing.cors.cors_handler import CorsHandler
from jivago.wsgi.routing.router import Router
from jivago.wsgi.routing.routing_rule import RoutingRule

Expand All @@ -29,10 +29,12 @@ def add_rule(self, rule: RouterConfigRule) -> "RouterBuilder":
return self

def build(self, registry: Registry, service_locator: ServiceLocator) -> Router:
cors_handler = CorsHandler(self.cors_rules)
service_locator.bind(CorsHandler, cors_handler)

filter_chain_factory = FilterChainFactory(self.filtering_rules, service_locator,
RouteHandlerFactory(service_locator,
Deserializer(registry),
self.routing_rules,
CorsRequestHandlerFactory(self.cors_rules))
)
cors_handler))
return Router(service_locator, RequestFactory(), filter_chain_factory)
5 changes: 4 additions & 1 deletion jivago/wsgi/filter/filter_chain_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jivago.inject.service_locator import ServiceLocator
from jivago.lang.stream import Stream
from jivago.wsgi.filter.filter_chain import FilterChain
from jivago.wsgi.filter.system_filters.default_filters import JIVAGO_DEFAULT_OPTIONS_FILTERS
from jivago.wsgi.invocation.route_handler_factory import RouteHandlerFactory
from jivago.wsgi.request.request import Request

Expand All @@ -19,7 +20,9 @@ def __init__(self, filtering_rules: List[FilteringRule],

def create_filter_chain(self, request: Request) -> FilterChain:
if request.method == 'OPTIONS':
filters = []
filters = Stream(JIVAGO_DEFAULT_OPTIONS_FILTERS) \
.map(self.service_locator.get) \
.toList()
else:
filters = Stream(self.filtering_rules) \
.filter(lambda rule: rule.matches(request.path)) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
from jivago.wsgi.request.response import Response


class JivagoBannerFilter(Filter):
class BannerFilter(Filter):

@Override
def doFilter(self, request: Request, response: Response, chain: FilterChain):
chain.doFilter(request, response)

response.headers['X-Powered-By'] = f"Jivago {jivago.__version__}"


class DummyBannerFilter(BannerFilter):

@Override
def doFilter(self, request: Request, response: Response, chain: FilterChain):
chain.doFilter(request, response)
24 changes: 24 additions & 0 deletions jivago/wsgi/filter/system_filters/default_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from jivago.templating.template_filter import TemplateFilter
from jivago.wsgi.filter.system_filters.banner_filter import BannerFilter
from jivago.wsgi.filter.system_filters.body_serialization_filter import BodySerializationFilter
from jivago.wsgi.filter.system_filters.error_handling.application_exception_filter import ApplicationExceptionFilter
from jivago.wsgi.filter.system_filters.error_handling.unknown_exception_filter import UnknownExceptionFilter
from jivago.wsgi.request.http_form_deserialization_filter import HttpFormDeserializationFilter
from jivago.wsgi.request.json_serialization_filter import JsonSerializationFilter
from jivago.wsgi.routing.cors.cors_headers_injection_filter import CorsHeadersInjectionFilter

JIVAGO_DEFAULT_FILTERS = [BannerFilter,
CorsHeadersInjectionFilter,
UnknownExceptionFilter,
TemplateFilter,
JsonSerializationFilter,
HttpFormDeserializationFilter,
BodySerializationFilter,
ApplicationExceptionFilter]

JIVAGO_DEFAULT_OPTIONS_FILTERS = [BannerFilter,
CorsHeadersInjectionFilter,
UnknownExceptionFilter,
JsonSerializationFilter,
BodySerializationFilter,
ApplicationExceptionFilter]
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import traceback

from jivago.lang.annotations import Override
from jivago.wsgi.filter.filter import Filter
from jivago.wsgi.filter.filter_chain import FilterChain
from jivago.wsgi.filter.system_filters.error_handling.unknown_exception_filter import UnknownExceptionFilter
from jivago.wsgi.request.request import Request
from jivago.wsgi.request.response import Response


class DebugExceptionFilter(Filter):
class DebugExceptionFilter(UnknownExceptionFilter):
"""Sends a stacktrace in the response body to help debugging. Enabled by default when using 'DebugJivagoContext'."""
LOGGER = logging.getLogger("DebugExceptionFilter")

@Override
Expand Down
11 changes: 5 additions & 6 deletions jivago/wsgi/invocation/route_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jivago.wsgi.invocation.route_handler import RouteHandler
from jivago.wsgi.methods import OPTIONS
from jivago.wsgi.request.request import Request
from jivago.wsgi.routing.cors.cors_request_handler_factory import CorsRequestHandlerFactory
from jivago.wsgi.routing.cors.cors_handler import CorsHandler
from jivago.wsgi.routing.exception.method_not_allowed_exception import MethodNotAllowedException
from jivago.wsgi.routing.exception.unknown_path_exception import UnknownPathException
from jivago.wsgi.routing.routing_rule import RoutingRule
Expand All @@ -17,9 +17,9 @@ class RouteHandlerFactory(object):
def __init__(self, service_locator: ServiceLocator,
deserializer: Deserializer,
routing_rules: List[RoutingRule],
cors_handler_factory: CorsRequestHandlerFactory):
cors_handler: CorsHandler):

self.cors_handler_factory = cors_handler_factory
self.cors_handler = cors_handler
self.routing_rules = routing_rules
self.deserializer = deserializer
self.service_locator = service_locator
Expand All @@ -35,15 +35,14 @@ def create_route_handlers(self, request: Request) -> Iterable[RouteHandler]:
raise UnknownPathException(request.path)

if self.is_cors_request(request) and OPTIONS not in routable_http_methods:
return Stream.of(self.cors_handler_factory.create_cors_preflight_handler(request.path))
return Stream.of(self.cors_handler.create_cors_preflight_handler(request.path))

if request.method_annotation not in routable_http_methods:
raise MethodNotAllowedException()

return Stream(self.routing_rules) \
.map(lambda rule: rule.create_route_handlers(request, self.service_locator, self.deserializer)) \
.flat() \
.map(lambda route_handler: self.cors_handler_factory.apply_cors_rules(request.path, route_handler))
.flat()

def is_cors_request(self, request: Request) -> bool:
return request.method_annotation == OPTIONS

0 comments on commit 98e8992

Please sign in to comment.