Skip to content

Commit

Permalink
fix: #2196 - incorrect handling of mutable headers in ASGIResponse (#…
Browse files Browse the repository at this point in the history
…2308)

* fix headers handling in ASGIResponse

---------

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
  • Loading branch information
provinzkraut committed Sep 17, 2023
1 parent ec3b19b commit 359fa32
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 131 deletions.
4 changes: 1 addition & 3 deletions litestar/handlers/http_handlers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from litestar.enums import HttpMethod
from litestar.exceptions import ValidationException
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT
from litestar.utils import encode_headers

if TYPE_CHECKING:
from litestar.app import Litestar
Expand Down Expand Up @@ -60,7 +59,6 @@ def create_data_handler(
A handler function.
"""
raw_headers = encode_headers(normalize_headers(headers).items(), cookies, [])

async def handler(
data: Any,
Expand All @@ -82,7 +80,7 @@ async def handler(
if after_request:
response = await after_request(response) # type: ignore[arg-type,misc]

return response.to_asgi_response(app=None, request=request, encoded_headers=raw_headers)
return response.to_asgi_response(app=None, request=request, headers=normalize_headers(headers), cookies=cookies)

return handler

Expand Down
65 changes: 40 additions & 25 deletions litestar/response/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, Mapping, TypeVar, overload
import itertools
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterable, Literal, Mapping, TypeVar, overload

from litestar.datastructures.cookie import Cookie
from litestar.datastructures.headers import ETag
from litestar.datastructures.headers import ETag, MutableScopeHeaders
from litestar.enums import MediaType, OpenAPIMediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.serialization import default_serializer, encode_json, encode_msgpack, get_serializer
from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED
from litestar.utils.deprecation import warn_deprecation
from litestar.utils.helpers import encode_headers, filter_cookies, get_enum_string_value
from litestar.utils.deprecation import deprecated, warn_deprecation
from litestar.utils.helpers import get_enum_string_value

if TYPE_CHECKING:
from typing import Optional
Expand Down Expand Up @@ -41,10 +42,11 @@ class ASGIResponse:
"background",
"body",
"content_length",
"encoded_headers",
"encoding",
"is_head_response",
"status_code",
"_encoded_cookies",
"headers",
)

_should_set_content_length: ClassVar[bool] = True
Expand All @@ -56,10 +58,10 @@ def __init__(
background: BackgroundTask | BackgroundTasks | None = None,
body: bytes | str = b"",
content_length: int | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
encoding: str = "utf-8",
headers: dict[str, Any] | None = None,
headers: dict[str, Any] | Iterable[tuple[str, str]] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
Expand All @@ -80,8 +82,17 @@ def __init__(
"""
body = body.encode() if isinstance(body, str) else body
status_code = status_code or HTTP_200_OK
encoded_headers = encoded_headers or []
headers = headers or {}
self.headers = MutableScopeHeaders()

if encoded_headers is not None:
warn_deprecation("3.0", kind="parameter", deprecated_name="encoded_headers", alternative="headers")
for header_name, header_value in encoded_headers:
self.headers.add(header_name.decode("latin-1"), header_value.decode("latin-1"))

if headers is not None:
for k, v in headers.items() if isinstance(headers, dict) else headers:
self.headers.add(k, v) # pyright: ignore

media_type = get_enum_string_value(media_type or MediaType.JSON)

status_allows_body = (
Expand All @@ -99,27 +110,31 @@ def __init__(
)
body = b""
else:
encoded_headers.append(
(
b"content-type",
(f"{media_type}; charset={encoding}" if media_type.startswith("text/") else media_type).encode(
"latin-1"
),
),
self.headers.setdefault(
"content-type", (f"{media_type}; charset={encoding}" if media_type.startswith("text/") else media_type)
)

if self._should_set_content_length and "content-length" not in headers:
encoded_headers.append((b"content-length", str(content_length).encode("latin-1")))
if self._should_set_content_length:
self.headers.setdefault("content-length", str(content_length))

self.background = background
self.body = body
self.content_length = content_length
cookies = cookies or []
self.encoded_headers = encode_headers(headers.items(), cookies, encoded_headers)
self._encoded_cookies = tuple(
cookie.to_encoded_header() for cookie in (cookies or ()) if not cookie.documentation_only
)
self.encoding = encoding
self.is_head_response = is_head_response
self.status_code = status_code

@property
@deprecated("3.0", kind="property", alternative="encode_headers()")
def encoded_headers(self) -> list[tuple[bytes, bytes]]:
return self.encode_headers()

def encode_headers(self) -> list[tuple[bytes, bytes]]:
return [*self.headers.headers, *self._encoded_cookies]

async def after_response(self) -> None:
"""Execute after the response is sent.
Expand All @@ -141,7 +156,7 @@ async def start_response(self, send: Send) -> None:
event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": self.status_code,
"headers": self.encoded_headers,
"headers": self.encode_headers(),
}
await send(event)

Expand Down Expand Up @@ -379,8 +394,8 @@ def to_asgi_response(
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
Expand Down Expand Up @@ -415,7 +430,7 @@ def to_asgi_response(
)

headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else filter_cookies(self.cookies, cookies)
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)

if type_encoders:
type_encoders = {**type_encoders, **(self.response_type_encoders or {})}
Expand Down
76 changes: 47 additions & 29 deletions litestar/response/file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import itertools
from email.utils import formatdate
from inspect import iscoroutine
from mimetypes import encodings_map, guess_type
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Literal, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Iterable, Literal, cast
from urllib.parse import quote
from zlib import adler32

Expand All @@ -13,7 +14,7 @@
from litestar.response.base import Response
from litestar.response.streaming import ASGIStreamingResponse
from litestar.utils.deprecation import warn_deprecation
from litestar.utils.helpers import filter_cookies, get_enum_string_value
from litestar.utils.helpers import get_enum_string_value

if TYPE_CHECKING:
from os import PathLike
Expand Down Expand Up @@ -87,65 +88,80 @@ class ASGIFileResponse(ASGIStreamingResponse):
def __init__(
self,
*,
background: BackgroundTask | BackgroundTasks | None = None,
body: bytes | str = b"",
chunk_size: int = ONE_MEGABYTE,
content_disposition_type: Literal["attachment", "inline"] = "attachment",
encoded_headers: list[tuple[bytes, bytes]] | None = None,
content_length: int | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
encoding: str = "utf-8",
etag: ETag | None = None,
file_info: FileInfo | Coroutine[None, None, FileInfo] | None = None,
file_path: str | PathLike | Path,
file_system: FileSystemProtocol | None = None,
filename: str = "",
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
stat_result: stat_result_type | None = None,
**kwargs: Any,
status_code: int | None = None,
) -> None:
"""A low-level ASGI response, streaming a file as response body.
Args:
background: A background task or a list of background tasks to be executed after the response is sent.
body: encoded content to send in the response body.
chunk_size: The chunk size to use.
content_disposition_type: The type of the ``Content-Disposition``. Either ``inline`` or ``attachment``.
content_length: The response content length.
cookies: The response cookies.
encoded_headers: A list of encoded headers.
encoding: The response encoding.
etag: An etag.
file_info: A file info.
file_path: A path to a file.
file_system: A file system adapter.
filename: The name of the file.
headers: A dictionary of headers.
headers: The response headers.
is_head_response: A boolean indicating if the response is a HEAD response.
media_type: The media type of the file.
stat_result: A stat result.
**kwargs: Additional keyword arguments, propagated to :class:`ASGIResponse <.response.base.ASGIResponse>`.
status_code: The response status code.
"""
encoded_headers = encoded_headers or []
headers = headers or {}
headers.pop("content-length", None)
headers.pop("etag", None)
headers.pop("last-modified", None)

if not media_type:
mimetype, content_encoding = guess_type(filename) if filename else (None, None)
media_type = mimetype or "application/octet-stream"
if content_encoding is not None:
headers.update({"content-encoding": content_encoding})

self.adapter = FileSystemAdapter(file_system or BaseLocalFileSystem())

super().__init__(
iterator=async_file_iterator(file_path=file_path, chunk_size=chunk_size, adapter=self.adapter),
headers=headers,
media_type=media_type,
cookies=cookies,
background=background,
status_code=status_code,
body=body,
content_length=content_length,
encoding=encoding,
is_head_response=is_head_response,
encoded_headers=encoded_headers,
)

quoted_filename = quote(filename)
is_utf8 = quoted_filename == filename
if is_utf8:
content_disposition = f'{content_disposition_type}; filename="{filename}"'
else:
content_disposition = f"{content_disposition_type}; filename*=utf-8''{quoted_filename}"

encoded_headers.append((b"content-disposition", content_disposition.encode("ascii")))
self.headers.setdefault("content-disposition", content_disposition)

self.adapter = FileSystemAdapter(file_system or BaseLocalFileSystem())

super().__init__(
iterator=async_file_iterator(file_path=file_path, chunk_size=chunk_size, adapter=self.adapter),
encoded_headers=encoded_headers,
headers=headers,
media_type=media_type,
**kwargs,
)
self.chunk_size = chunk_size
self.etag = etag
self.file_path = file_path
Expand Down Expand Up @@ -199,15 +215,17 @@ async def start_response(self, send: Send) -> None:
raise ImproperlyConfiguredException(f"{self.file_path} is not a file")

self.content_length = fs_info["size"]
self.encoded_headers.append((b"content-length", str(self.content_length).encode("ascii")))

self.encoded_headers.append((b"last-modified", formatdate(fs_info["mtime"], usegmt=True).encode("ascii")))
self.headers.setdefault("content-length", str(self.content_length))
self.headers.setdefault("last-modified", formatdate(fs_info["mtime"], usegmt=True))

if self.etag:
self.encoded_headers.append((b"etag", self.etag.to_header().encode("ascii")))
self.headers.setdefault("etag", self.etag.to_header())
else:
etag = create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"])
self.encoded_headers.append((b"etag", etag.encode("ascii")))
self.headers.setdefault(
"etag",
create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"]),
)

await super().start_response(send=send)

Expand Down Expand Up @@ -305,8 +323,8 @@ def to_asgi_response(
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
Expand Down Expand Up @@ -340,7 +358,7 @@ def to_asgi_response(
)

headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else filter_cookies(self.cookies, cookies)
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)

media_type = self.media_type or media_type
if media_type is not None:
Expand All @@ -353,7 +371,7 @@ def to_asgi_response(
content_disposition_type=self.content_disposition_type, # pyright: ignore
content_length=0,
cookies=cookies,
encoded_headers=encoded_headers or [],
encoded_headers=encoded_headers,
encoding=self.encoding,
etag=self.etag,
file_info=self.file_info,
Expand Down
Loading

0 comments on commit 359fa32

Please sign in to comment.