Skip to content

Commit 9674fe8

Browse files
Merge pull request from GHSA-p24m-863f-fm6q
* implement limitation to form multipart parts * Adds multipart_form_part_limit config to Body function. A limit defined through the Body function takes precedence over the app limit. * Replace "DDoS" with "DoS" in docstrings. * Provide `maxsplit` arg to `body.split()`. `multipart_form_part_limit + 3` parts are required to determine if body exceeds limit. We discard the first item of the result via [1:] (+1) and we discard the last item of the result [:-1] (+1) and need to observe whether there are N+1 items (+1), which gets us to 3 extra items after multipart_form_part_limit. (thanks @das7pad). * Fix type errors. --------- Co-authored-by: Na'aman Hirschfeld <nhirschfeld@gmail.com>
1 parent cbd10bb commit 9674fe8

File tree

9 files changed

+149
-56
lines changed

9 files changed

+149
-56
lines changed

Diff for: starlite/app.py

+6
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class Starlite(Router):
135135
"get_logger",
136136
"logger",
137137
"logging_config",
138+
"multipart_form_part_limit",
138139
"on_shutdown",
139140
"on_startup",
140141
"openapi_config",
@@ -175,6 +176,7 @@ def __init__(
175176
initial_state: Optional["InitialStateType"] = None,
176177
logging_config: Union["BaseLoggingConfig", "EmptyType", None] = Empty,
177178
middleware: Optional[List["Middleware"]] = None,
179+
multipart_form_part_limit: int = 1000,
178180
on_app_init: Optional[List["OnAppInitHandler"]] = None,
179181
on_shutdown: Optional[List["LifeSpanHandler"]] = None,
180182
on_startup: Optional[List["LifeSpanHandler"]] = None,
@@ -238,6 +240,8 @@ def __init__(
238240
initial_state: An object from which to initialize the app state.
239241
logging_config: A subclass of :class:`BaseLoggingConfig <starlite.config.logging.BaseLoggingConfig>`.
240242
middleware: A list of :class:`Middleware <starlite.types.Middleware>`.
243+
multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request.
244+
This limit is intended to protect from DoS attacks.
241245
on_app_init: A sequence of :class:`OnAppInitHandler <starlite.types.OnAppInitHandler>` instances. Handlers receive
242246
an instance of :class:`AppConfig <starlite.config.app.AppConfig>` that will have been initially populated with
243247
the parameters passed to :class:`Starlite <starlite.app.Starlite>`, and must return an instance of same. If more
@@ -300,6 +304,7 @@ def __init__(
300304
initial_state=initial_state or {},
301305
logging_config=logging_config if logging_config is not Empty else LoggingConfig() if debug else None, # type: ignore[arg-type]
302306
middleware=middleware or [],
307+
multipart_form_part_limit=multipart_form_part_limit,
303308
on_shutdown=on_shutdown or [],
304309
on_startup=on_startup or [],
305310
openapi_config=openapi_config,
@@ -343,6 +348,7 @@ def __init__(
343348
self.static_files_config = config.static_files_config
344349
self.template_engine = config.template_config.engine_instance if config.template_config else None
345350
self.websocket_class = config.websocket_class or WebSocket
351+
self.multipart_form_part_limit = config.multipart_form_part_limit
346352

347353
super().__init__(
348354
after_request=config.after_request,

Diff for: starlite/config/app.py

+2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class Config(BaseConfig):
178178
"""A mapping of types to callables that transform them into types supported for serialization."""
179179
websocket_class: Optional[Type[WebSocket]]
180180
"""An optional subclass of :class:`WebSocket <starlite.connection.websocket.WebSocket>` to use for websocket connections."""
181+
multipart_form_part_limit: int
182+
"""The maximal number of allowed parts in a multipart/formdata request. This limit is intended to protect from DoS attacks."""
181183

182184
@validator("allowed_hosts", always=True)
183185
def validate_allowed_hosts( # pylint: disable=no-self-argument

Diff for: starlite/connection/request.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ async def form(self) -> FormMultiDict:
150150
content_type, options = self.content_type
151151
if content_type == RequestEncodingType.MULTI_PART:
152152
self._form = self.scope["_form"] = form_values = parse_multipart_form( # type: ignore[typeddict-item]
153-
body=await self.body(), boundary=options.get("boundary", "").encode()
153+
body=await self.body(),
154+
boundary=options.get("boundary", "").encode(),
155+
multipart_form_part_limit=self.app.multipart_form_part_limit,
154156
)
155157
return FormMultiDict(form_values)
156158
if content_type == RequestEncodingType.URL_ENCODED:

Diff for: starlite/kwargs/extractors.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from starlite.enums import ParamType, RequestEncodingType
2020
from starlite.exceptions import ValidationException
2121
from starlite.multipart import parse_multipart_form
22+
from starlite.params import BodyKwarg
2223
from starlite.parsers import (
2324
parse_headers,
2425
parse_query_string,
@@ -289,15 +290,25 @@ def create_multipart_extractor(
289290
Returns:
290291
An extractor function.
291292
"""
293+
body_kwarg_multipart_form_part_limit: Optional[int] = None
294+
if signature_field.kwarg_model and isinstance(signature_field.kwarg_model, BodyKwarg):
295+
body_kwarg_multipart_form_part_limit = signature_field.kwarg_model.multipart_form_part_limit
292296

293297
async def extract_multipart(
294298
connection: "Request[Any, Any]",
295299
) -> Any:
300+
multipart_form_part_limit = (
301+
body_kwarg_multipart_form_part_limit
302+
if body_kwarg_multipart_form_part_limit is not None
303+
else connection.app.multipart_form_part_limit
304+
)
296305
connection.scope["_form"] = form_values = ( # type: ignore[typeddict-item]
297306
connection.scope["_form"] # type: ignore[typeddict-item]
298307
if "_form" in connection.scope
299308
else parse_multipart_form(
300-
body=await connection.body(), boundary=connection.content_type[-1].get("boundary", "").encode()
309+
body=await connection.body(),
310+
boundary=connection.content_type[-1].get("boundary", "").encode(),
311+
multipart_form_part_limit=multipart_form_part_limit,
301312
)
302313
)
303314

Diff for: starlite/multipart.py

+76-52
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
from urllib.parse import unquote
3131

3232
from starlite.datastructures.upload_file import UploadFile
33-
from starlite.exceptions import SerializationException
33+
from starlite.exceptions import SerializationException, ValidationException
3434
from starlite.utils.serialization import decode_json
3535

36-
_token, _quoted = r"([\w!#$%&'*+\-.^_`|~]+)", r'"([^"]*)"'
36+
_token = r"([\w!#$%&'*+\-.^_`|~]+)"
37+
_quoted = r'"([^"]*)"'
3738
_param = re.compile(rf";\s*{_token}=(?:{_token}|{_quoted})", re.ASCII)
3839
_firefox_quote_escape = re.compile(r'\\"(?!; |\s*$)')
3940

@@ -59,67 +60,90 @@ def parse_content_header(value: str) -> Tuple[str, Dict[str, str]]:
5960
return value.strip().lower(), options
6061

6162

62-
def parse_multipart_form(body: bytes, boundary: bytes) -> Dict[str, Any]:
63+
def parse_body(body: bytes, boundary: bytes, multipart_form_part_limit: int) -> List[bytes]:
64+
"""Split the body using the boundary
65+
and validate the number of form parts is within the allowed limit.
66+
67+
:param body: The form body.
68+
:param boundary: The boundary used to separate form components.
69+
:param multipart_form_part_limit: The limit of allowed form components
70+
:return:
71+
A list of form components.
72+
"""
73+
if not (body and boundary):
74+
return []
75+
76+
form_parts = body.split(boundary, multipart_form_part_limit + 3)[1:-1]
77+
78+
if len(form_parts) > multipart_form_part_limit:
79+
raise ValidationException(
80+
f"number of multipart components exceeds the allowed limit of {multipart_form_part_limit}, "
81+
f"this potentially indicates a DoS attack"
82+
)
83+
84+
return form_parts
85+
86+
87+
def parse_multipart_form(body: bytes, boundary: bytes, multipart_form_part_limit: int = 1000) -> Dict[str, Any]:
6388
"""Parse multipart form data.
6489
6590
Args:
6691
body: Body of the request.
6792
boundary: Boundary of the multipart message.
93+
multipart_form_part_limit: Limit of the number of parts allowed.
6894
6995
Returns:
7096
A dictionary of parsed results.
7197
"""
7298

7399
fields: DefaultDict[str, List[Any]] = defaultdict(list)
74100

75-
if body and boundary:
76-
form_parts = body.split(boundary)
77-
for form_part in form_parts[1:-1]:
78-
file_name = None
79-
content_type = "text/plain"
80-
content_charset = "utf-8"
81-
field_name = None
82-
line_index = 2
83-
line_end_index = 0
84-
headers: List[Tuple[str, str]] = []
85-
86-
while line_end_index != -1:
87-
line_end_index = form_part.find(b"\r\n", line_index)
88-
form_line = form_part[line_index:line_end_index].decode("utf-8")
89-
90-
if not form_line:
91-
break
92-
93-
line_index = line_end_index + 2
94-
colon_index = form_line.index(":")
95-
current_idx = colon_index + 2
96-
form_header_field = form_line[0:colon_index].lower()
97-
form_header_value, form_parameters = parse_content_header(form_line[current_idx:])
98-
99-
if form_header_field == "content-disposition":
100-
field_name = form_parameters.get("name")
101-
file_name = form_parameters.get("filename")
102-
103-
if file_name is None and (filename_with_asterisk := form_parameters.get("filename*")):
104-
encoding, _, value = decode_rfc2231(filename_with_asterisk)
105-
file_name = unquote(value, encoding=encoding or content_charset)
106-
107-
elif form_header_field == "content-type":
108-
content_type = form_header_value
109-
content_charset = form_parameters.get("charset", "utf-8")
110-
headers.append((form_header_field, form_header_value))
111-
112-
if field_name:
113-
post_data = form_part[line_index:-4].lstrip(b"\r\n")
114-
if file_name:
115-
form_file = UploadFile(
116-
content_type=content_type, filename=file_name, file_data=post_data, headers=dict(headers)
117-
)
118-
fields[field_name].append(form_file)
119-
else:
120-
try:
121-
fields[field_name].append(decode_json(post_data))
122-
except SerializationException:
123-
fields[field_name].append(post_data.decode(content_charset))
101+
for form_part in parse_body(body=body, boundary=boundary, multipart_form_part_limit=multipart_form_part_limit):
102+
file_name = None
103+
content_type = "text/plain"
104+
content_charset = "utf-8"
105+
field_name = None
106+
line_index = 2
107+
line_end_index = 0
108+
headers: List[Tuple[str, str]] = []
109+
110+
while line_end_index != -1:
111+
line_end_index = form_part.find(b"\r\n", line_index)
112+
form_line = form_part[line_index:line_end_index].decode("utf-8")
113+
114+
if not form_line:
115+
break
116+
117+
line_index = line_end_index + 2
118+
colon_index = form_line.index(":")
119+
current_idx = colon_index + 2
120+
form_header_field = form_line[0:colon_index].lower()
121+
form_header_value, form_parameters = parse_content_header(form_line[current_idx:])
122+
123+
if form_header_field == "content-disposition":
124+
field_name = form_parameters.get("name")
125+
file_name = form_parameters.get("filename")
126+
127+
if file_name is None and (filename_with_asterisk := form_parameters.get("filename*")):
128+
encoding, _, value = decode_rfc2231(filename_with_asterisk)
129+
file_name = unquote(value, encoding=encoding or content_charset)
130+
131+
elif form_header_field == "content-type":
132+
content_type = form_header_value
133+
content_charset = form_parameters.get("charset", "utf-8")
134+
headers.append((form_header_field, form_header_value))
135+
136+
if field_name:
137+
post_data = form_part[line_index:-4].lstrip(b"\r\n")
138+
if file_name:
139+
form_file = UploadFile(
140+
content_type=content_type, filename=file_name, file_data=post_data, headers=dict(headers)
141+
)
142+
fields[field_name].append(form_file)
143+
else:
144+
try:
145+
fields[field_name].append(decode_json(post_data))
146+
except SerializationException:
147+
fields[field_name].append(post_data.decode(content_charset))
124148

125149
return {k: v if len(v) > 1 else v[0] for k, v in fields.items()}

Diff for: starlite/params.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ class BodyKwarg:
286286
287287
Equivalent to pattern in the OpenAPI specification.
288288
"""
289+
multipart_form_part_limit: Optional[int] = field(default=None)
290+
"""The maximal number of allowed parts in a multipart/formdata request. This limit is intended to protect from DoS attacks."""
289291

290292
def __hash__(self) -> int: # pragma: no cover
291293
"""Hash the dataclass in a safe way.
@@ -315,7 +317,8 @@ def Body(
315317
max_items: Optional[int] = None,
316318
min_length: Optional[int] = None,
317319
max_length: Optional[int] = None,
318-
regex: Optional[str] = None
320+
regex: Optional[str] = None,
321+
multipart_form_part_limit: Optional[int] = None
319322
) -> Any:
320323
"""Create an extended request body kwarg definition.
321324
@@ -354,6 +357,8 @@ def Body(
354357
maxLength in the OpenAPI specification.
355358
regex: A string representing a regex against which the given string will be matched.
356359
Equivalent to pattern in the OpenAPI specification.
360+
multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request.
361+
This limit is intended to protect from DoS attacks.
357362
"""
358363
return BodyKwarg(
359364
media_type=media_type,
@@ -374,6 +379,7 @@ def Body(
374379
min_length=min_length,
375380
max_length=max_length,
376381
regex=regex,
382+
multipart_form_part_limit=multipart_form_part_limit,
377383
)
378384

379385

Diff for: starlite/testing/create_test_client.py

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def create_test_client(
7878
initial_state: Optional[Union["ImmutableState", Dict[str, Any], Iterable[Tuple[str, Any]]]] = None,
7979
logging_config: Optional["BaseLoggingConfig"] = None,
8080
middleware: Optional[List["Middleware"]] = None,
81+
multipart_form_part_limit: int = 1000,
8182
on_app_init: Optional[List["OnAppInitHandler"]] = None,
8283
on_shutdown: Optional[List["LifeSpanHandler"]] = None,
8384
on_startup: Optional[List["LifeSpanHandler"]] = None,
@@ -160,6 +161,8 @@ def test_my_handler() -> None:
160161
initial_state: An object from which to initialize the app state.
161162
logging_config: A subclass of :class:`BaseLoggingConfig <starlite.config.logging.BaseLoggingConfig>`.
162163
middleware: A list of :class:`Middleware <starlite.types.Middleware>`.
164+
multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request.
165+
This limit is intended to protect from DoS attacks.
163166
on_app_init: A sequence of :class:`OnAppInitHandler <starlite.types.OnAppInitHandler>` instances. Handlers receive
164167
an instance of :class:`AppConfig <starlite.config.app.AppConfig>` that will have been initially populated with
165168
the parameters passed to :class:`Starlite <starlite.app.Starlite>`, and must return an instance of same. If more
@@ -210,6 +213,7 @@ def test_my_handler() -> None:
210213
initial_state=initial_state,
211214
logging_config=logging_config,
212215
middleware=middleware,
216+
multipart_form_part_limit=multipart_form_part_limit,
213217
on_app_init=on_app_init,
214218
on_shutdown=on_shutdown,
215219
on_startup=on_startup,

Diff for: tests/app/test_app_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def app_config_object() -> AppConfig:
3535
initial_state={},
3636
logging_config=None,
3737
middleware=[],
38+
multipart_form_part_limit=1000,
3839
on_shutdown=[],
3940
on_startup=[],
4041
openapi_config=None,

Diff for: tests/kwargs/test_multipart_data.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from starlite import Body, Request, RequestEncodingType, post
1212
from starlite.datastructures import UploadFile
13-
from starlite.status_codes import HTTP_201_CREATED
13+
from starlite.status_codes import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
1414
from starlite.testing import create_test_client
1515
from tests import Person, PersonFactory
1616
from tests.kwargs import Form
@@ -405,3 +405,40 @@ async def hello_world(data: Optional[UploadFile] = Body(media_type=RequestEncodi
405405
with create_test_client(route_handlers=[hello_world]) as client:
406406
response = client.post("/")
407407
assert response.status_code == HTTP_201_CREATED
408+
409+
410+
@pytest.mark.parametrize("limit", (1000, 100, 10))
411+
def test_multipart_form_part_limit(limit: int) -> None:
412+
@post("/")
413+
async def hello_world(data: List[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART)) -> None:
414+
assert len(data) == limit
415+
416+
with create_test_client(route_handlers=[hello_world], multipart_form_part_limit=limit) as client:
417+
data = {str(i): "a" for i in range(limit)}
418+
response = client.post("/", files=data)
419+
assert response.status_code == HTTP_201_CREATED
420+
421+
data = {str(i): "a" for i in range(limit)}
422+
data[str(limit + 1)] = "b"
423+
response = client.post("/", files=data)
424+
assert response.status_code == HTTP_400_BAD_REQUEST
425+
426+
427+
def test_multipart_form_part_limit_body_param_precedence() -> None:
428+
app_limit = 100
429+
route_limit = 10
430+
431+
@post("/")
432+
async def hello_world(
433+
data: List[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART, multipart_form_part_limit=route_limit)
434+
) -> None:
435+
assert len(data) == route_limit
436+
437+
with create_test_client(route_handlers=[hello_world], multipart_form_part_limit=app_limit) as client:
438+
data = {str(i): "a" for i in range(route_limit)}
439+
response = client.post("/", files=data)
440+
assert response.status_code == HTTP_201_CREATED
441+
442+
data = {str(i): "a" for i in range(route_limit + 1)}
443+
response = client.post("/", files=data)
444+
assert response.status_code == HTTP_400_BAD_REQUEST

0 commit comments

Comments
 (0)