Skip to content

Commit

Permalink
Register a authorization error response for endpoints which use @auth…
Browse files Browse the repository at this point in the history
…_required

Add new config:
- AUTO_AUTH_ERROR_RESPONSE
- AUTH_ERROR_STATUS_CODE
- AUTH_ERROR_DESCRIPTION
- AUTH_ERROR_SCHEMA
  • Loading branch information
greyli committed Mar 24, 2021
1 parent 3586945 commit c8992a8
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 15 deletions.
43 changes: 30 additions & 13 deletions apiflask/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,25 +470,42 @@ def add_response(status_code, schema, description):
if not view_func._spec.get('responses') and self.config['AUTO_200_RESPONSE']:
add_response('200', {}, descriptions['200'])

def add_response_and_schema(status_code, schema, schema_name, description):
if isinstance(schema, type):
schema = schema()
add_response(status_code, schema, description)
elif isinstance(schema, dict):
if schema_name not in spec.components._schemas:
spec.components.schema(schema_name, schema)
schema_ref = {'$ref': f'#/components/schemas/{schema_name}'}
add_response(status_code, schema_ref, description)
else:
raise RuntimeError(
'The schema must be a Marshamallow schema \
class or an OpenAPI schema dict.'
)

# add validation error response
if self.config['AUTO_VALIDATION_ERROR_RESPONSE']:
if view_func._spec.get('body') or view_func._spec.get('args'):
status_code = str(self.config['VALIDATION_ERROR_STATUS_CODE'])
description = self.config['VALIDATION_ERROR_DESCRIPTION']
schema = self.config['VALIDATION_ERROR_SCHEMA']
if isinstance(schema, type):
schema = schema()
add_response(status_code, schema, description)
elif isinstance(schema, dict):
if 'ValidationError' not in spec.components._schemas:
spec.components.schema('ValidationError', schema)
schema_ref = {'$ref': '#/components/schemas/ValidationError'}
add_response(status_code, schema_ref, description)
else:
raise RuntimeError(
'The schema must be a Marshamallow schema \
class or an OpenAPI schema dict.'
)
add_response_and_schema(
status_code, schema, 'ValidationError', description
)

# add authorization error response
if self.config['AUTO_AUTH_ERROR_RESPONSE']:
if view_func._spec.get('auth') or (
blueprint_name is not None and blueprint_name in auth_blueprints
):
status_code = str(self.config['AUTH_ERROR_STATUS_CODE'])
description = self.config['AUTH_ERROR_DESCRIPTION']
schema = self.config['AUTH_ERROR_SCHEMA']
add_response_and_schema(
status_code, schema, 'AuthorizationError', description
)

if view_func._spec.get('responses'):
for status_code, description in view_func._spec.get('responses').items():
Expand Down
4 changes: 4 additions & 0 deletions apiflask/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
VALIDATION_ERROR_STATUS_CODE = 400
VALIDATION_ERROR_DESCRIPTION = 'Validation error'
VALIDATION_ERROR_SCHEMA = http_error_schema
AUTO_AUTH_ERROR_RESPONSE = True
AUTH_ERROR_STATUS_CODE = 401
AUTH_ERROR_DESCRIPTION = 'Authorization error'
AUTH_ERROR_SCHEMA = http_error_schema
# Swagger UI and Redoc
DOCS_HIDE_BLUEPRINTS = []
DOCS_FAVICON = None
Expand Down
5 changes: 5 additions & 0 deletions tests/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ class Meta:
class ValidationErrorSchema(Schema):
status_code = String(required=True)
message = String(required=True)


class AuthorizationErrorSchema(Schema):
status_code = String(required=True)
message = String(required=True)
80 changes: 78 additions & 2 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from apiflask.decorators import auth_required
import pytest
from openapi_spec_validator import validate_spec

from apiflask import APIBlueprint, input, output
from apiflask.schemas import EmptySchema, http_error_schema
from apiflask.security import HTTPBasicAuth

from .schemas import QuerySchema, FooSchema, ValidationErrorSchema
from .schemas import QuerySchema, FooSchema, ValidationErrorSchema, AuthorizationErrorSchema


def test_openapi_fields(app, client):
Expand Down Expand Up @@ -208,8 +210,8 @@ def foo():
assert rv.status_code == 200
validate_spec(rv.json)
assert bool('400' in rv.json['paths']['/foo']['post']['responses']) is config_value
assert bool('ValidationError' in rv.json['components']['schemas']) is config_value
if config_value:
assert 'ValidationError' in rv.json['components']['schemas']
assert '#/components/schemas/ValidationError' in \
rv.json['paths']['/foo']['post']['responses']['400'][
'content']['application/json']['schema']['$ref']
Expand Down Expand Up @@ -265,6 +267,80 @@ def foo():
app.spec


@pytest.mark.parametrize('config_value', [True, False])
def test_auto_auth_error_response(app, client, config_value):
app.config['AUTO_AUTH_ERROR_RESPONSE'] = config_value
auth = HTTPBasicAuth()

@app.post('/foo')
@auth_required(auth)
def foo():
pass

rv = client.get('/openapi.json')
assert rv.status_code == 200
validate_spec(rv.json)
assert bool('401' in rv.json['paths']['/foo']['post']['responses']) is config_value
if config_value:
assert 'AuthorizationError' in rv.json['components']['schemas']
assert '#/components/schemas/AuthorizationError' in \
rv.json['paths']['/foo']['post']['responses']['401'][
'content']['application/json']['schema']['$ref']


def test_auth_error_status_code_and_description(app, client):
app.config['AUTH_ERROR_STATUS_CODE'] = 403
app.config['AUTH_ERROR_DESCRIPTION'] = 'Bad'
auth = HTTPBasicAuth()

@app.post('/foo')
@auth_required(auth)
def foo():
pass

rv = client.get('/openapi.json')
assert rv.status_code == 200
validate_spec(rv.json)
assert rv.json['paths']['/foo']['post']['responses']['403'] is not None
assert rv.json['paths']['/foo']['post']['responses'][
'403']['description'] == 'Bad'


@pytest.mark.parametrize('schema', [
http_error_schema,
AuthorizationErrorSchema
])
def test_auth_error_schema(app, client, schema):
app.config['AUTH_ERROR_SCHEMA'] = schema
auth = HTTPBasicAuth()

@app.post('/foo')
@auth_required(auth)
def foo():
pass

rv = client.get('/openapi.json')
assert rv.status_code == 200
validate_spec(rv.json)
assert rv.json['paths']['/foo']['post']['responses']['401']
assert rv.json['paths']['/foo']['post']['responses']['401'][
'description'] == 'Authorization error'
assert 'AuthorizationError' in rv.json['components']['schemas']


def test_auth_error_schema_bad_type(app):
app.config['AUTH_ERROR_SCHEMA'] = 'schema'
auth = HTTPBasicAuth()

@app.post('/foo')
@auth_required(auth)
def foo():
pass

with pytest.raises(RuntimeError):
app.spec


def test_docs_hide_blueprints(app, client):
bp = APIBlueprint('foo', __name__, tag='test')

Expand Down

0 comments on commit c8992a8

Please sign in to comment.