diff --git a/.gitignore b/.gitignore index 707e25cc..376f2fd4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .coverage .git .mypy_cache +.pytest_cache __pycache__ /*.egg-info /htmlcov diff --git a/apistar/codecs/__init__.py b/apistar/codecs/__init__.py index 4c2e7e5a..26c6cab9 100644 --- a/apistar/codecs/__init__.py +++ b/apistar/codecs/__init__.py @@ -1,5 +1,6 @@ from apistar.codecs.base import BaseCodec from apistar.codecs.download import DownloadCodec +from apistar.codecs.formdata import MultiPartCodec, URLEncodedCodec from apistar.codecs.jsondata import JSONCodec from apistar.codecs.jsonschema import JSONSchemaCodec from apistar.codecs.openapi import OpenAPICodec @@ -7,5 +8,5 @@ __all__ = [ 'BaseCodec', 'JSONCodec', 'JSONSchemaCodec', 'OpenAPICodec', 'TextCodec', - 'DownloadCodec' + 'DownloadCodec', 'MultiPartCodec', 'URLEncodedCodec', ] diff --git a/apistar/codecs/formdata.py b/apistar/codecs/formdata.py new file mode 100644 index 00000000..7704190e --- /dev/null +++ b/apistar/codecs/formdata.py @@ -0,0 +1,36 @@ +from io import BytesIO +from itertools import chain + +from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.formparser import FormDataParser +from werkzeug.http import parse_options_header +from werkzeug.urls import url_decode + +from apistar.codecs.base import BaseCodec + + +class MultiPartCodec(BaseCodec): + media_type = 'multipart/form-data' + + def decode(self, bytestring, headers, **options): + try: + content_length = max(0, int(headers['content-length'])) + except (KeyError, ValueError, TypeError): + content_length = None + + try: + mime_type, mime_options = parse_options_header(headers['content-type']) + except KeyError: + mime_type, mime_options = '', {} + + body_file = BytesIO(bytestring) + parser = FormDataParser() + stream, form, files = parser.parse(body_file, mime_type, content_length, mime_options) + return ImmutableMultiDict(chain(form.items(), files.items())) + + +class URLEncodedCodec(BaseCodec): + media_type = 'application/x-www-form-urlencoded' + + def decode(self, bytestring, **options): + return url_decode(bytestring, cls=ImmutableMultiDict) diff --git a/apistar/server/validation.py b/apistar/server/validation.py index 6f695411..bf880a12 100644 --- a/apistar/server/validation.py +++ b/apistar/server/validation.py @@ -13,7 +13,11 @@ class RequestDataComponent(Component): def __init__(self): - self.codecs = [codecs.JSONCodec()] + self.codecs = [ + codecs.JSONCodec(), + codecs.URLEncodedCodec(), + codecs.MultiPartCodec(), + ] def can_handle_parameter(self, parameter: inspect.Parameter): return parameter.annotation is http.RequestData diff --git a/tests/test_http.py b/tests/test_http.py index 7efcfbc9..134eea51 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,4 +1,5 @@ import pytest +from pytest import param from apistar import Route, http, test from apistar.server.app import App, ASyncApp @@ -75,6 +76,16 @@ def get_request_data(data: http.RequestData): return {'data': data} +def get_multipart_request_data(data: http.RequestData): + files = { + name: f if isinstance(f, str) else { + 'filename': f.filename, + 'content': f.read().decode('utf-8'), + } for name, f in data.items() + } + return {'data': files} + + def return_string(data: http.RequestData) -> str: return 'example content' @@ -107,6 +118,7 @@ def return_response(data: http.RequestData) -> http.Response: Route('/path_params/{example}/', 'GET', get_path_params), Route('/full_path_params/{+example}', 'GET', get_path_params, name='full_path_params'), Route('/request_data/', 'POST', get_request_data), + Route('/multipart_request_data/', 'POST', get_multipart_request_data), Route('/return_string/', 'GET', return_string), Route('/return_data/', 'GET', return_data), Route('/return_response/', 'GET', return_response), @@ -288,15 +300,41 @@ def test_full_path_params(client): assert response.json() == {'params': {'example': 'abc/def/'}} -def test_request_data(client): - response = client.post('/request_data/', json={'abc': 123}) - assert response.json() == {'data': {'abc': 123}} - response = client.post('/request_data/') - assert response.json() == {'data': None} - response = client.post('/request_data/', data=b'...', headers={'content-type': 'unknown'}) - assert response.status_code == 415 - response = client.post('/request_data/', data=b'...', headers={'content-type': 'application/json'}) - assert response.status_code == 400 +@pytest.mark.parametrize('request_params,response_status,response_json', [ + # JSON + param({'json': {'abc': 123}}, 200, {'data': {'abc': 123}}, id='valid json body'), + param({}, 200, {'data': None}, id='empty json body'), + + # Urlencoding + param({'data': {'abc': 123}}, 200, {'data': {'abc': '123'}}, id='valid urlencoded body'), + param( + {'headers': {'content-type': 'application/x-www-form-urlencoded'}}, 200, {'data': None}, + id='empty urlencoded body', + ), + + # Misc + param({'data': b'...', 'headers': {'content-type': 'unknown'}}, 415, None, id='unknown body type'), + param({'data': b'...', 'headers': {'content-type': 'application/json'}}, 400, None, id='json parse failure'), +]) +def test_request_data(request_params, response_status, response_json, client): + response = client.post('/request_data/', **request_params) + assert response.status_code == response_status + if response_json is not None: + assert response.json() == response_json + + +def test_multipart_request_data(client): + response = client.post('/multipart_request_data/', files={'a': ('b', '123')}, data={'b': '42'}) + assert response.status_code == 200 + assert response.json() == { + 'data': { + 'a': { + 'filename': 'b', + 'content': '123', + }, + 'b': '42', + } + } def test_return_string(client):