Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom headers in multipart/form-data requests #1936

Merged
merged 15 commits into from Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 26 additions & 8 deletions httpx/_multipart.py
Expand Up @@ -78,23 +78,41 @@ def __init__(self, name: str, value: FileTypes) -> None:

fileobj: FileContent

headers: typing.Dict[str, str] = {}
content_type: typing.Optional[str] = None

# This large tuple based API largely mirror's requests' API
# It would be good to think of better APIs for this that we could include in httpx 2.0
# since variable length tuples (especially of 4 elements) are quite unwieldly
if isinstance(value, tuple):
try:
filename, fileobj, content_type = value # type: ignore
except ValueError:
if len(value) == 2:
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
filename, fileobj = value # type: ignore
content_type = guess_content_type(filename)
elif len(value) == 3:
filename, fileobj, content_type = value # type: ignore
else:
# all 4 parameters included
filename, fileobj, content_type, headers = value # type: ignore
headers = {k.title(): v for k, v in headers.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should .title() case here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah... I see the comparison case. Huh. Fiddly.

else:
filename = Path(str(getattr(value, "name", "upload"))).name
fileobj = value

if content_type is None:
content_type = guess_content_type(filename)

if content_type is not None and "Content-Type" not in headers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps...

has_content_type_header = any(["content-type" in key.lower() for key in headers])
if content_type is not None and not has_content_type_header:
    ...

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adapted it to any("content-type" in key.lower() for key in headers) (so it'll stop early).
Also removed the {header.title() ...} line.

# note that unlike requests, we ignore the content_type
# provided in the 3rd tuple element if it is also included in the headers
# requests does the opposite (it overwrites the header with the 3rd tuple element)
headers["Content-Type"] = content_type

if isinstance(fileobj, (str, io.StringIO)):
raise TypeError(f"Expected bytes or bytes-like object got: {type(fileobj)}")

self.filename = filename
self.file = fileobj
self.content_type = content_type
self.headers = headers
self._consumed = False

def get_length(self) -> int:
Expand Down Expand Up @@ -122,9 +140,9 @@ def render_headers(self) -> bytes:
if self.filename:
filename = format_form_param("filename", self.filename)
parts.extend([b"; ", filename])
if self.content_type is not None:
content_type = self.content_type.encode()
parts.extend([b"\r\nContent-Type: ", content_type])
for header_name, header_value in self.headers.items():
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
parts.extend([key, val])
parts.append(b"\r\n\r\n")
self._headers = b"".join(parts)

Expand Down
2 changes: 2 additions & 0 deletions httpx/_types.py
Expand Up @@ -89,6 +89,8 @@
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]

Expand Down
52 changes: 52 additions & 0 deletions tests/test_multipart.py
Expand Up @@ -94,6 +94,58 @@ def test_multipart_file_tuple():
assert multipart["file"] == [b"<file content>"]


@pytest.mark.parametrize("content_type", [None, "text/plain"])
def test_multipart_file_tuple_headers(content_type: typing.Optional[str]):
file_name = "test.txt"
expected_content_type = "text/plain"
headers = {"Expires": "0"}

files = {"file": (file_name, io.BytesIO(b"<file content>"), content_type, headers)}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()

headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, typing.Iterable)

content = (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nExpires: 0\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
"".encode("ascii")
)
assert headers == {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Content-Length": str(len(content)),
}
assert content == b"".join(stream)


def test_multipart_headers_include_content_type() -> None:
"""Content-Type from 4th tuple parameter (headers) should override the 3rd parameter (content_type)"""
file_name = "test.txt"
expected_content_type = "image/png"
headers = {"Content-Type": "image/png"}

files = {"file": (file_name, io.BytesIO(b"<file content>"), "text_plain", headers)}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()

headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, typing.Iterable)

content = (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
"".encode("ascii")
)
assert headers == {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Content-Length": str(len(content)),
}
assert content == b"".join(stream)


def test_multipart_encode(tmp_path: typing.Any) -> None:
path = str(tmp_path / "name.txt")
with open(path, "wb") as f:
Expand Down