Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #2196 - incorrect handling of mutable headers in ASGIResponse #2308

Merged
merged 4 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions litestar/handlers/http_handlers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from litestar.enums import HttpMethod
from litestar.exceptions import ValidationException
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT
from litestar.utils import encode_headers

if TYPE_CHECKING:
from litestar.app import Litestar
Expand Down Expand Up @@ -60,7 +59,6 @@ def create_data_handler(
A handler function.

"""
raw_headers = encode_headers(normalize_headers(headers).items(), cookies, [])

async def handler(
data: Any,
Expand All @@ -82,7 +80,7 @@ async def handler(
if after_request:
response = await after_request(response) # type: ignore[arg-type,misc]

return response.to_asgi_response(app=None, request=request, encoded_headers=raw_headers)
return response.to_asgi_response(app=None, request=request, headers=normalize_headers(headers), cookies=cookies)

return handler

Expand Down
65 changes: 40 additions & 25 deletions litestar/response/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, Mapping, TypeVar, overload
import itertools
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterable, Literal, Mapping, TypeVar, overload

from litestar.datastructures.cookie import Cookie
from litestar.datastructures.headers import ETag
from litestar.datastructures.headers import ETag, MutableScopeHeaders
from litestar.enums import MediaType, OpenAPIMediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.serialization import default_serializer, encode_json, encode_msgpack, get_serializer
from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED
from litestar.utils.deprecation import warn_deprecation
from litestar.utils.helpers import encode_headers, filter_cookies, get_enum_string_value
from litestar.utils.deprecation import deprecated, warn_deprecation
from litestar.utils.helpers import get_enum_string_value

if TYPE_CHECKING:
from typing import Optional
Expand Down Expand Up @@ -41,10 +42,11 @@ class ASGIResponse:
"background",
"body",
"content_length",
"encoded_headers",
"encoding",
"is_head_response",
"status_code",
"_encoded_cookies",
"headers",
)

_should_set_content_length: ClassVar[bool] = True
Expand All @@ -56,10 +58,10 @@ def __init__(
background: BackgroundTask | BackgroundTasks | None = None,
body: bytes | str = b"",
content_length: int | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
encoding: str = "utf-8",
headers: dict[str, Any] | None = None,
headers: dict[str, Any] | Iterable[tuple[str, str]] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
Expand All @@ -80,8 +82,17 @@ def __init__(
"""
body = body.encode() if isinstance(body, str) else body
status_code = status_code or HTTP_200_OK
encoded_headers = encoded_headers or []
headers = headers or {}
self.headers = MutableScopeHeaders()

if encoded_headers is not None:
warn_deprecation("3.0", kind="parameter", deprecated_name="encoded_headers", alternative="headers")
for header_name, header_value in encoded_headers:
self.headers.add(header_name.decode("latin-1"), header_value.decode("latin-1"))

if headers is not None:
for k, v in headers.items() if isinstance(headers, dict) else headers:
self.headers.add(k, v) # pyright: ignore

media_type = get_enum_string_value(media_type or MediaType.JSON)

status_allows_body = (
Expand All @@ -99,27 +110,31 @@ def __init__(
)
body = b""
else:
encoded_headers.append(
(
b"content-type",
(f"{media_type}; charset={encoding}" if media_type.startswith("text/") else media_type).encode(
"latin-1"
),
),
self.headers.setdefault(
"content-type", (f"{media_type}; charset={encoding}" if media_type.startswith("text/") else media_type)
)

if self._should_set_content_length and "content-length" not in headers:
encoded_headers.append((b"content-length", str(content_length).encode("latin-1")))
if self._should_set_content_length:
self.headers.setdefault("content-length", str(content_length))

self.background = background
self.body = body
self.content_length = content_length
cookies = cookies or []
self.encoded_headers = encode_headers(headers.items(), cookies, encoded_headers)
self._encoded_cookies = tuple(
cookie.to_encoded_header() for cookie in (cookies or ()) if not cookie.documentation_only
)
self.encoding = encoding
self.is_head_response = is_head_response
self.status_code = status_code

@property
@deprecated("3.0", kind="property", alternative="encode_headers()")
def encoded_headers(self) -> list[tuple[bytes, bytes]]:
return self.encode_headers()

def encode_headers(self) -> list[tuple[bytes, bytes]]:
return [*self.headers.headers, *self._encoded_cookies]

async def after_response(self) -> None:
"""Execute after the response is sent.

Expand All @@ -141,7 +156,7 @@ async def start_response(self, send: Send) -> None:
event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": self.status_code,
"headers": self.encoded_headers,
"headers": self.encode_headers(),
}
await send(event)

Expand Down Expand Up @@ -379,8 +394,8 @@ def to_asgi_response(
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
Expand Down Expand Up @@ -415,7 +430,7 @@ def to_asgi_response(
)

headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else filter_cookies(self.cookies, cookies)
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)

if type_encoders:
type_encoders = {**type_encoders, **(self.response_type_encoders or {})}
Expand Down
76 changes: 47 additions & 29 deletions litestar/response/file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import itertools
from email.utils import formatdate
from inspect import iscoroutine
from mimetypes import encodings_map, guess_type
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Literal, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Iterable, Literal, cast
from urllib.parse import quote
from zlib import adler32

Expand All @@ -13,7 +14,7 @@
from litestar.response.base import Response
from litestar.response.streaming import ASGIStreamingResponse
from litestar.utils.deprecation import warn_deprecation
from litestar.utils.helpers import filter_cookies, get_enum_string_value
from litestar.utils.helpers import get_enum_string_value

if TYPE_CHECKING:
from os import PathLike
Expand Down Expand Up @@ -87,65 +88,80 @@ class ASGIFileResponse(ASGIStreamingResponse):
def __init__(
self,
*,
background: BackgroundTask | BackgroundTasks | None = None,
body: bytes | str = b"",
chunk_size: int = ONE_MEGABYTE,
content_disposition_type: Literal["attachment", "inline"] = "attachment",
encoded_headers: list[tuple[bytes, bytes]] | None = None,
content_length: int | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
encoding: str = "utf-8",
etag: ETag | None = None,
file_info: FileInfo | Coroutine[None, None, FileInfo] | None = None,
file_path: str | PathLike | Path,
file_system: FileSystemProtocol | None = None,
filename: str = "",
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
stat_result: stat_result_type | None = None,
**kwargs: Any,
status_code: int | None = None,
) -> None:
"""A low-level ASGI response, streaming a file as response body.

Args:
background: A background task or a list of background tasks to be executed after the response is sent.
body: encoded content to send in the response body.
chunk_size: The chunk size to use.
content_disposition_type: The type of the ``Content-Disposition``. Either ``inline`` or ``attachment``.
content_length: The response content length.
cookies: The response cookies.
encoded_headers: A list of encoded headers.
encoding: The response encoding.
etag: An etag.
file_info: A file info.
file_path: A path to a file.
file_system: A file system adapter.
filename: The name of the file.
headers: A dictionary of headers.
headers: The response headers.
is_head_response: A boolean indicating if the response is a HEAD response.
media_type: The media type of the file.
stat_result: A stat result.
**kwargs: Additional keyword arguments, propagated to :class:`ASGIResponse <.response.base.ASGIResponse>`.
status_code: The response status code.
"""
encoded_headers = encoded_headers or []
headers = headers or {}
headers.pop("content-length", None)
headers.pop("etag", None)
headers.pop("last-modified", None)

if not media_type:
mimetype, content_encoding = guess_type(filename) if filename else (None, None)
media_type = mimetype or "application/octet-stream"
if content_encoding is not None:
headers.update({"content-encoding": content_encoding})

self.adapter = FileSystemAdapter(file_system or BaseLocalFileSystem())

super().__init__(
iterator=async_file_iterator(file_path=file_path, chunk_size=chunk_size, adapter=self.adapter),
headers=headers,
media_type=media_type,
cookies=cookies,
background=background,
status_code=status_code,
body=body,
content_length=content_length,
encoding=encoding,
is_head_response=is_head_response,
encoded_headers=encoded_headers,
)

quoted_filename = quote(filename)
is_utf8 = quoted_filename == filename
if is_utf8:
content_disposition = f'{content_disposition_type}; filename="{filename}"'
else:
content_disposition = f"{content_disposition_type}; filename*=utf-8''{quoted_filename}"

encoded_headers.append((b"content-disposition", content_disposition.encode("ascii")))
self.headers.setdefault("content-disposition", content_disposition)

self.adapter = FileSystemAdapter(file_system or BaseLocalFileSystem())

super().__init__(
iterator=async_file_iterator(file_path=file_path, chunk_size=chunk_size, adapter=self.adapter),
encoded_headers=encoded_headers,
headers=headers,
media_type=media_type,
**kwargs,
)
self.chunk_size = chunk_size
self.etag = etag
self.file_path = file_path
Expand Down Expand Up @@ -199,15 +215,17 @@ async def start_response(self, send: Send) -> None:
raise ImproperlyConfiguredException(f"{self.file_path} is not a file")

self.content_length = fs_info["size"]
self.encoded_headers.append((b"content-length", str(self.content_length).encode("ascii")))

self.encoded_headers.append((b"last-modified", formatdate(fs_info["mtime"], usegmt=True).encode("ascii")))
self.headers.setdefault("content-length", str(self.content_length))
self.headers.setdefault("last-modified", formatdate(fs_info["mtime"], usegmt=True))

if self.etag:
self.encoded_headers.append((b"etag", self.etag.to_header().encode("ascii")))
self.headers.setdefault("etag", self.etag.to_header())
else:
etag = create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"])
self.encoded_headers.append((b"etag", etag.encode("ascii")))
self.headers.setdefault(
"etag",
create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"]),
)

await super().start_response(send=send)

Expand Down Expand Up @@ -305,8 +323,8 @@ def to_asgi_response(
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: list[Cookie] | None = None,
encoded_headers: list[tuple[bytes, bytes]] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
cookies: Iterable[Cookie] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
Expand Down Expand Up @@ -340,7 +358,7 @@ def to_asgi_response(
)

headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else filter_cookies(self.cookies, cookies)
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)

media_type = self.media_type or media_type
if media_type is not None:
Expand All @@ -353,7 +371,7 @@ def to_asgi_response(
content_disposition_type=self.content_disposition_type, # pyright: ignore
content_length=0,
cookies=cookies,
encoded_headers=encoded_headers or [],
encoded_headers=encoded_headers,
encoding=self.encoding,
etag=self.etag,
file_info=self.file_info,
Expand Down
Loading