Skip to content

Commit

Permalink
feat(validator): add request and response validation
Browse files Browse the repository at this point in the history
  • Loading branch information
sleistner committed Jun 22, 2019
1 parent 5778777 commit 0029a97
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 32 deletions.
8 changes: 8 additions & 0 deletions lambda_handlers/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ class NotFoundError(LambdaError):

class ValidationError(LambdaError):
pass


class RequestValidationError(ValidationError):
pass


class ResponseValidationError(ValidationError):
pass
26 changes: 18 additions & 8 deletions lambda_handlers/handlers/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from lambda_handlers.errors import (
NotFoundError,
BadRequestError,
ValidationError
ValidationError,
ResponseValidationError
)
from lambda_handlers.response import CorsHeaders
from lambda_handlers.response.builder import (
ok,
not_found,
bad_request,
bad_implementation,
internal_server_error
)
from lambda_handlers.handlers.lambda_handler import LambdaHandler
Expand Down Expand Up @@ -50,21 +52,29 @@ def __init__(self, cors=None, body_format=None, output_format=None, validation=N
self._cors = cors or CorsHeaders(origin='*', credentials=True)

def before(self, event, context):
self._validate(event, context)
self._validate_request(event, context)
self._parse_body(event)
return event, context

def after(self, result):
if not isinstance(result, APIGatewayProxyResult) and 'statusCode' not in result:
result = ok(result)
return self._create_response(result)
response = self._create_response(result)
self._validate_response(response)
return response

def on_exception(self, exception):
return self._create_response(self._handle_error(exception))

def _validate(self, event, context):
def _validate_request(self, event, context):
if self._validator:
self._validator(event, context)
transformed_event = transformed_context = self._validator.validate_request(event, context)
event.update(transformed_event)
context.update(transformed_context)

def _validate_response(self, response):
if self._validator:
self._validator.validate_response(response)

def _parse_body(self, event):
if 'body' in event:
Expand All @@ -87,9 +97,9 @@ def _create_headers(self, headers: Headers) -> Headers:
def _handle_error(self, error) -> APIGatewayProxyResult:
if isinstance(error, NotFoundError):
return not_found(str(error))
if isinstance(error, ValidationError):
return bad_request(str(error))
if isinstance(error, BadRequestError):
if isinstance(error, ResponseValidationError):
return bad_implementation(str(error))
if isinstance(error, (BadRequestError, ValidationError)):
return bad_request(str(error))

logger.error(error)
Expand Down
9 changes: 5 additions & 4 deletions lambda_handlers/response/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def forbidden(description: str) -> APIGatewayProxyResult:
return _build_request(error, HTTPStatus.FORBIDDEN)


def bad_implementation(description: str = None) -> APIGatewayProxyResult:
return internal_server_error(description)


def internal_server_error(description: str = None) -> APIGatewayProxyResult:
error = InternalServerError(description or 'InternalServerError')
return _build_request(error, HTTPStatus.INTERNAL_SERVER_ERROR)
Expand All @@ -40,10 +44,7 @@ def created(result: str) -> APIGatewayProxyResult:
return _build_request(result, HTTPStatus.CREATED)


def _build_request(
result: Union[LambdaError, Any],
status_code: HTTPStatus
) -> APIGatewayProxyResult:
def _build_request(result: Union[LambdaError, Any], status_code: HTTPStatus) -> APIGatewayProxyResult:
if isinstance(result, LambdaError):
body = {'errors': result.description}
else:
Expand Down
67 changes: 47 additions & 20 deletions lambda_handlers/validators/validator.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,65 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Callable

from lambda_handlers.errors import ValidationError
from lambda_handlers.errors import (
RequestValidationError,
ResponseValidationError
)


class Validator(ABC):

def __init__(self, path=None, query=None, body=None):
def __init__(self, path=None, query=None, body=None, request=None, response=None):
self._path_parameters_schema = path
self._query_string_parameters_schema = query
self._body_schema = body
self._request_schema = request
self._response_schema = response

def __call__(self, event, context):
cumulative_errors = []
def validate_request(self, event, context) -> Tuple[Any, Any]:
if self._request_schema:
data, errors = self.validate(event, self._request_schema())
else:
data, errors = self._validate_request_contexts(event, context)

def _validate(key, schema):
data, errors = self.on_validate(event.get(key, {}), schema)
if errors:
cumulative_errors.append(errors)
elif key in event:
event[key].update(data)
if errors:
description = self.format_errors(errors)
raise RequestValidationError(description)

return data, context

def validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
if not self._response_schema:
return response

data, errors = self.validate(response, self._response_schema())

if self._path_parameters_schema:
_validate('pathParameters', self._path_parameters_schema())
if errors:
description = self.format_errors(errors)
raise ResponseValidationError(description)

if self._query_string_parameters_schema:
_validate('queryStringParameters', self._query_string_parameters_schema())
return data

if self._body_schema:
_validate('body', self._body_schema())
def _validate_request_contexts(self, event, context) -> Tuple[Dict[str, Any], List[Any]]:
contexts = {
'pathParameters': self._path_parameters_schema,
'queryStringParameters': self._query_string_parameters_schema,
'body': self._body_schema
}
return self._validate_many(event, {key: schema() for key, schema in contexts.items() if schema})

def _validate_many(self, target: Dict[str, Any], definitions: Dict[str, Callable]) -> Tuple[Dict[str, Any], List[Any]]:
cumulative_errors = []
transformed_data = {}

for key, schema in definitions.items():
data, errors = self.validate(target.get(key, {}), schema)
if errors:
cumulative_errors.append(errors)
elif key in target:
transformed_data[key] = data

if cumulative_errors:
description = self.format_errors(cumulative_errors)
raise ValidationError(description)
return transformed_data, cumulative_errors

@abstractmethod
def validate(self, instance: Any, schema: Any) -> Tuple[Any, List[Any]]:
Expand Down

0 comments on commit 0029a97

Please sign in to comment.