Skip to content

Commit

Permalink
Handle error handlers in Swagger specifications
Browse files Browse the repository at this point in the history
  • Loading branch information
noirbizarre committed Dec 26, 2015
1 parent 01b1525 commit a2b73da
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Current
-------

- Handle callable on API infos
- Handle documentation on error handlers
- Drop/merge flask_restful `flask_restful.RequestParser`
- Handle :class:`~flask_restplus.reqparse.RequestParser` into :meth:`~flask_restplus.Api.expect` decorator
- Handle schema for :mod:`~flask_restplus.inputs` parsers
Expand Down
29 changes: 27 additions & 2 deletions doc/syntaxic_sugar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,34 @@ that you can do with Flask/Blueprint ``@errorhandler`` decorator.
@api.errorhandler(FakeException)
def handle_fake_exception_with_header(error):
'''Return a custom message and 500 status code'''
'''Return a custom message and 400 status code'''
return {'message': error.message}, 400, {'My-Header': 'Value'}
You can also document the error:

.. code-block:: python
@api.errorhandler(FakeException)
@api.marshal_with(error_fields, code=400)
@api.header('My-Header', 'Some description')
def handle_fake_exception_with_header(error):
'''This is a custom error'''
return {'message': error.message}, 400, {'My-Header': 'Value'}
@api.route('/test/')
class TestResource(Resource):
def get(self):
'''
Do something
:raises CustomException: In case of something
'''
pass
In this example, the ``:raise:`` docstring will be automatically extracted
and the response 400 will be documented properly.


It also allows for overriding the default error handler when used wihtout parameter:

Expand Down
2 changes: 2 additions & 0 deletions flask_restplus/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,12 @@ def mediatypes(self):


def mask_parse_error_handler(error):
'''When a mask can't be parsed'''
return {'message': 'Mask parse error: {0}'.format(error)}, 400


def mask_error_handler(error):
'''When any error occurs on mask'''
return {'message': 'Mask error: {0}'.format(error)}, 400


Expand Down
88 changes: 74 additions & 14 deletions flask_restplus/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re

from inspect import isclass
from inspect import isclass, getdoc
from collections import Hashable
from six import string_types, itervalues, iteritems, iterkeys

Expand Down Expand Up @@ -51,6 +51,8 @@
DEFAULT_RESPONSE_DESCRIPTION = 'Success'
DEFAULT_RESPONSE = {'description': DEFAULT_RESPONSE_DESCRIPTION}

RE_RAISES = re.compile(r'^:raises\s+(?P<name>[\w\d_]+)\s*:\s*(?P<description>.*)$', re.MULTILINE)


def ref(model):
'''Return a reference to model in definitions'''
Expand Down Expand Up @@ -139,6 +141,40 @@ def _handle_arg_type(arg, param):
param['type'] = 'string'


def _param_to_header(param):
if 'in' in param:
del param['in']

typedef = param.get('type', 'string')
if isinstance(typedef, Hashable) and typedef in PY_TYPES:
param['type'] = PY_TYPES[typedef]
elif hasattr(typedef, '__schema__'):
param.update(typedef.__schema__)
else:
param['type'] = typedef
return param


def parse_docstring(obj):
raw = getdoc(obj)
summary = raw.strip(' \n').split('\n')[0].split('.')[0] if raw else None
raises = {}
details = raw.replace(summary, '').lstrip('. \n').strip(' \n') if raw else None
for match in RE_RAISES.finditer(raw or ''):
raises[match.group('name')] = match.group('description')
if details:
details = details.replace(match.group(0), '')
parsed = {
'raw': raw,
'summary': summary or None,
'details': details or None,
'returns': None,
'params': [],
'raises': raises,
}
return parsed


class Swagger(object):
'''
A Swagger documentation wrapper for an API instance.
Expand Down Expand Up @@ -179,6 +215,9 @@ def as_dict(self):
paths = {}
tags = self.extract_tags(self.api)

# register errors
responses = self.register_errors()

for ns in self.api.namespaces:
for resource, urls, kwargs in ns.resources.values():
for url in urls:
Expand All @@ -195,6 +234,7 @@ def as_dict(self):
'security': self.security_requirements(self.api.security) or None,
'tags': tags,
'definitions': self.serialize_definitions() or None,
'responses': responses or None,
'host': self.get_host(),
}
return not_none(specs)
Expand Down Expand Up @@ -244,7 +284,7 @@ def extract_resource_doc(self, resource, url):
method_impl = method_impl.__func__
method_doc = merge(method_doc, getattr(method_impl, '__apidoc__', OrderedDict()))
if method_doc is not False:
method_doc['docstring'] = getattr(method_impl, '__doc__')
method_doc['docstring'] = parse_docstring(method_impl)
method_doc['params'] = self.merge_params(OrderedDict(), method_doc)
doc[method] = method_doc
return doc
Expand Down Expand Up @@ -276,6 +316,25 @@ def merge_params(self, params, doc):

return params

def register_errors(self):
responses = {}
for exception, handler in self.api._error_handlers.items():
doc = parse_docstring(handler)
response = {
'description': doc['summary']
}
apidoc = getattr(handler, '__apidoc__', {})
if 'params' in apidoc:
response['headers'] = dict(
(n, _param_to_header(o))
for n, o in apidoc['params'].items() if o.get('in') == 'header'
)
if 'responses' in apidoc:
_, model = list(apidoc['responses'].values())[0]
response['schema'] = self.serialize_schema(model)
responses[exception.__name__] = not_none(response)
return responses

def serialize_resource(self, ns, resource, url):
doc = self.extract_resource_doc(resource, url)
if doc is False:
Expand All @@ -291,7 +350,7 @@ def serialize_resource(self, ns, resource, url):
def serialize_operation(self, doc, method):
operation = {
'responses': self.responses_for(doc, method) or None,
'summary': self.summary_for(doc, method) or None,
'summary': doc[method]['docstring']['summary'],
'description': self.description_for(doc, method) or None,
'operationId': self.operation_id_for(doc, method),
'parameters': self.parameters_for(doc, method) or None,
Expand All @@ -308,24 +367,15 @@ def serialize_operation(self, doc, method):
operation['consumes'] = ['application/x-www-form-urlencoded', 'multipart/form-data']
return not_none(operation)

def summary_for(self, doc, method):
'''Extract the first sentence from the first docstring line'''
if not doc[method].get('docstring'):
return
first_line = doc[method]['docstring'].strip().split('\n')[0]
return first_line.split('.')[0]

def description_for(self, doc, method):
'''Extract the description metadata and fallback on the whole docstring'''
parts = []
if 'description' in doc:
parts.append(doc['description'])
if method in doc and 'description' in doc[method]:
parts.append(doc[method]['description'])
if doc[method].get('docstring'):
splitted = doc[method]['docstring'].strip().split('\n', 1)
if len(splitted) == 2:
parts.append(splitted[1].strip())
if doc[method]['docstring']['details']:
parts.append(doc[method]['docstring']['details'])

return '\n'.join(parts).strip()

Expand Down Expand Up @@ -391,6 +441,15 @@ def responses_for(self, doc, method):
responses[code] = DEFAULT_RESPONSE.copy()
responses[code]['schema'] = self.serialize_schema(d['model'])

if 'docstring' in d:
for name, description in d['docstring']['raises'].items():
for exception, handler in self.api._error_handlers.items():
error_responses = getattr(handler, '__apidoc__', {}).get('responses', {})
code = list(error_responses.keys())[0] if error_responses else None
if code and exception.__name__ == name:
responses[code] = {'$ref': '#/responses/{0}'.format(name)}
break

if not responses:
responses['200'] = DEFAULT_RESPONSE.copy()
return responses
Expand Down Expand Up @@ -439,6 +498,7 @@ def register_model(self, model):
self.register_model(specs.__parent__)
for field in itervalues(specs):
self.register_field(field)
return ref(model)

def register_field(self, field):
if isinstance(field, fields.Polymorph):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,55 @@ def test_handle_error_with_code(self):
response = api.handle_error(exception)
self.assertEquals(response.status_code, 500)
self.assertEquals(json.loads(response.data.decode()), {"foo": "bar"})

def test_errorhandler_swagger_doc(self):
api = restplus.Api(self.app)

class CustomException(RuntimeError):
pass

error = api.model('Error', {
'message': restplus.fields.String()
})

@api.route('/test/', endpoint='test')
class TestResource(restplus.Resource):
def get(self):
'''
Do something
:raises CustomException: In case of something
'''
pass

@api.errorhandler(CustomException)
@api.header('Custom-Header', 'Some custom header')
@api.marshal_with(error, code=503)
def handle_custom_exception(error):
'''Some description'''
pass

specs = self.get_specs()

self.assertIn('Error', specs['definitions'])
self.assertIn('CustomException', specs['responses'])

response = specs['responses']['CustomException']
self.assertEqual(response['description'], 'Some description')
self.assertEqual(response['schema'], {
'$ref': '#/definitions/Error'
})
self.assertEqual(response['headers'], {
'Custom-Header': {
'description': 'Some custom header',
'type': 'string'
}
})

operation = specs['paths']['/test/']['get']
self.assertIn('responses', operation)
self.assertEqual(operation['responses'], {
'503': {
'$ref': '#/responses/CustomException'
}
})
84 changes: 83 additions & 1 deletion tests/test_swagger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from werkzeug.datastructures import FileStorage

from flask_restplus import fields, reqparse, Api, SpecsError
from flask_restplus.swagger import extract_path, extract_path_params, parser_to_params
from flask_restplus.swagger import extract_path, extract_path_params, parser_to_params, parse_docstring

from . import TestCase

Expand Down Expand Up @@ -293,3 +293,85 @@ def custom(value):
'format': 'custom-format',
}
})


class ParseDocstringTest(TestCase):
def test_empty(self):
def without_doc():
pass

parsed = parse_docstring(without_doc)

self.assertIsNone(parsed['raw'])
self.assertIsNone(parsed['summary'])
self.assertIsNone(parsed['details'])
self.assertIsNone(parsed['returns'])
self.assertEqual(parsed['raises'], {})
self.assertEqual(parsed['params'], [])

def test_single_line(self):
def func():
'''Some summary'''
pass

parsed = parse_docstring(func)

self.assertEqual(parsed['raw'], 'Some summary')
self.assertEqual(parsed['summary'], 'Some summary')
self.assertIsNone(parsed['details'])
self.assertIsNone(parsed['returns'])
self.assertEqual(parsed['raises'], {})
self.assertEqual(parsed['params'], [])

def test_multi_line(self):
def func():
'''
Some summary
Some details
'''
pass

parsed = parse_docstring(func)

self.assertEqual(parsed['raw'], 'Some summary\nSome details')
self.assertEqual(parsed['summary'], 'Some summary')
self.assertEqual(parsed['details'], 'Some details')
self.assertIsNone(parsed['returns'])
self.assertEqual(parsed['raises'], {})
self.assertEqual(parsed['params'], [])

def test_multi_line_and_dot(self):
def func():
'''
Some summary. bla bla
Some details
'''
pass

parsed = parse_docstring(func)

self.assertEqual(parsed['raw'], 'Some summary. bla bla\nSome details')
self.assertEqual(parsed['summary'], 'Some summary')
self.assertEqual(parsed['details'], 'bla bla\nSome details')
self.assertIsNone(parsed['returns'])
self.assertEqual(parsed['raises'], {})
self.assertEqual(parsed['params'], [])

def test_raises(self):
def func():
'''
Some summary.
:raises SomeException: in case of something
'''
pass

parsed = parse_docstring(func)

self.assertEqual(parsed['raw'], 'Some summary.\n:raises SomeException: in case of something')
self.assertEqual(parsed['summary'], 'Some summary')
self.assertIsNone(parsed['details'])
self.assertIsNone(parsed['returns'])
self.assertEqual(parsed['params'], [])
self.assertEqual(parsed['raises'], {
'SomeException': 'in case of something'
})

0 comments on commit a2b73da

Please sign in to comment.