diff --git a/flask_stupe/validation.py b/flask_stupe/validation.py index a4ead5e..8b453f7 100644 --- a/flask_stupe/validation.py +++ b/flask_stupe/validation.py @@ -8,20 +8,37 @@ if marshmallow: - def schema_required(schema): - """Validate body of the request against the schema. - - Abort with a status code 400 if the schema yields errors.""" - def __inner(f): - @functools.wraps(f) - def __inner(*args, **kwargs): - json = request.get_json(force=True) - results = schema.load(json) - if results.errors: - abort(400, results.errors) - request.schema = results.data - return f(*args, **kwargs) + if marshmallow.__version__.startswith('3'): # pragma: no cover + def schema_required(schema): + """Validate body of the request against the schema. + + Abort with a status code 400 if the schema yields errors.""" + def __inner(f): + @functools.wraps(f) + def __inner(*args, **kwargs): + json = request.get_json(force=True) + try: + request.schema = schema.load(json) + except marshmallow.exceptions.ValidationError as e: + abort(400, e.messages) + return f(*args, **kwargs) + return __inner + return __inner + else: + def schema_required(schema): + """Validate body of the request against the schema. + + Abort with a status code 400 if the schema yields errors.""" + def __inner(f): + @functools.wraps(f) + def __inner(*args, **kwargs): + json = request.get_json(force=True) + results = schema.load(json) + if results.errors: + abort(400, results.errors) + request.schema = results.data + return f(*args, **kwargs) + return __inner return __inner - return __inner __all__.extend(["schema_required"])