Skip to content

Commit

Permalink
Fix py http parser not treating 204/304/1xx as an empty body (aio-lib…
Browse files Browse the repository at this point in the history
…s#7755)

<!-- Thank you for your contribution! -->

There was a disagreement on how to handle 204/304/1xx responses between
the c parser and py parser.

204/304/1xx are now always treated as an empty body in the py parser to
match the c parser and comply with
https://datatracker.ietf.org/doc/html/rfc9112#section-6.3

A empty chunked response body `0\r\n\r\n` will no longer be read for
204, 203, 1xx responses when using the py parser. This matches the
behavior of the c parser.

(cherry picked from commit 6f1315b)
  • Loading branch information
bdraco committed Oct 30, 2023
1 parent 2ee44e0 commit 327e28e
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGES/7755.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix py http parser not treating 204/304/1xx as an empty body
3 changes: 2 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
ceil_timeout,
get_env_proxy_for_url,
get_running_loop,
method_must_be_empty_body,
sentinel,
strip_auth_from_url,
)
Expand Down Expand Up @@ -583,7 +584,7 @@ async def _request(
assert conn.protocol is not None
conn.protocol.set_response_params(
timer=timer,
skip_payload=method.upper() == "HEAD",
skip_payload=method_must_be_empty_body(method),
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
read_timeout=real_timeout.sock_read,
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext
from .helpers import BaseTimerContext, status_code_must_be_empty_body
from .http import HttpResponseParser, RawResponseMessage
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader

Expand Down Expand Up @@ -241,7 +241,9 @@ def data_received(self, data: bytes) -> None:

self._payload = payload

if self._skip_payload or message.code in (204, 304):
if self._skip_payload or status_code_must_be_empty_body(
message.code
):
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
Expand Down
13 changes: 13 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,3 +968,16 @@ def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None


def method_must_be_empty_body(method: str) -> bool:
"""Check if a method must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
return method.upper() in (hdrs.METH_CONNECT, hdrs.METH_HEAD)


def status_code_must_be_empty_body(code: int) -> bool:
"""Check if a status code must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
return code in {204, 304} or 100 <= code < 200
64 changes: 35 additions & 29 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext
from .helpers import (
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
status_code_must_be_empty_body,
)
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
Expand Down Expand Up @@ -345,10 +351,15 @@ def get_content_length() -> Optional[int]:
self._upgraded = msg.upgrade

method = getattr(msg, "method", self.method)
# code is only present on responses
code = getattr(msg, "code", 0)

assert self.protocol is not None
# calculate payload
if (
empty_body = status_code_must_be_empty_body(code) or bool(
method and method_must_be_empty_body(method)
)
if not empty_body and (
(length is not None and length > 0)
or msg.chunked
and not msg.upgrade
Expand Down Expand Up @@ -390,34 +401,29 @@ def get_content_length() -> Optional[int]:
auto_decompress=self._auto_decompress,
lax=self.lax,
)
elif not empty_body and length is None and self.read_until_eof:
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
if (
getattr(msg, "code", 100) >= 199
and length is None
and self.read_until_eof
):
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD
payload = EMPTY_PAYLOAD

messages.append((msg, payload))
else:
Expand Down
128 changes: 127 additions & 1 deletion tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
HttpPayloadParser,
HttpRequestParserPy,
HttpResponseParserPy,
HttpVersion,
)

try:
Expand Down Expand Up @@ -1053,7 +1054,132 @@ def test_parse_no_length_payload(parser) -> None:
assert payload.is_eof()


def test_partial_url(parser) -> None:
def test_parse_content_length_payload_multiple(response: Any) -> None:
text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "5"),
]
)
assert msg.raw_headers == ((b"content-length", b"5"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"first" == b"".join(d for d in payload._buffer)

text = b"HTTP/1.1 200 OK\r\ncontent-length: 6\r\n\r\nsecond"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "6"),
]
)
assert msg.raw_headers == ((b"content-length", b"6"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


def test_parse_content_length_than_chunked_payload(response: Any) -> None:
text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "5"),
]
)
assert msg.raw_headers == ((b"content-length", b"5"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"first" == b"".join(d for d in payload._buffer)

text = (
b"HTTP/1.1 200 OK\r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"6\r\nsecond\r\n0\r\n\r\n"
)
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


@pytest.mark.parametrize("code", (204, 304, 101, 102))
def test_parse_chunked_payload_empty_body_than_another_chunked(
response: Any, code: int
) -> None:
head = f"HTTP/1.1 {code} OK\r\n".encode()
text = head + b"transfer-encoding: chunked\r\n\r\n"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == code
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()

text = (
b"HTTP/1.1 200 OK\r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"6\r\nsecond\r\n0\r\n\r\n"
)
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


def test_partial_url(parser: Any) -> None:
messages, upgrade, tail = parser.feed_data(b"GET /te")
assert len(messages) == 0
messages, upgrade, tail = parser.feed_data(b"st HTTP/1.1\r\n\r\n")
Expand Down

0 comments on commit 327e28e

Please sign in to comment.