Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure body is consumed only once (alternative to #847) #851

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion tests/unit/test_stubs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import http.client as httplib
from io import BytesIO
from tempfile import NamedTemporaryFile
from unittest import mock

from pytest import mark

from vcr import mode
from vcr import mode, use_cassette
from vcr.cassette import Cassette
from vcr.stubs import VCRHTTPSConnection

Expand All @@ -21,3 +24,52 @@ def testing_connect(*args):
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
vcr_connection.real_connection.connect()
assert vcr_connection.real_connection.sock is not None

def test_body_consumed_once_stream(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
BytesIO(b"1234567890"),
BytesIO(b"9876543210"),
BytesIO(b"9876543210"),
)

def test_body_consumed_once_iterator(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
iter([b"1234567890"]),
iter([b"9876543210"]),
iter([b"9876543210"]),
)

# data2 and data3 should serve the same data, potentially as iterators
def _test_body_consumed_once(
self,
tmpdir,
httpbin,
data1,
data2,
data3,
):
with NamedTemporaryFile(dir=tmpdir, suffix=".yml") as f:
testpath = f.name
# NOTE: ``use_cassette`` is not okay with the file existing
# already. So we using ``.close()`` to not only
# close but also delete the empty file, before we start.
f.close()
host, port = httpbin.host, httpbin.port
match_on = ["method", "uri", "body"]
with use_cassette(testpath, match_on=match_on):
conn1 = httplib.HTTPConnection(host, port)
conn1.request("POST", "/anything", body=data1)
conn1.getresponse()
conn2 = httplib.HTTPConnection(host, port)
conn2.request("POST", "/anything", body=data2)
conn2.getresponse()
with use_cassette(testpath, match_on=match_on) as cass:
conn3 = httplib.HTTPConnection(host, port)
conn3.request("POST", "/anything", body=data3)
conn3.getresponse()
assert cass.play_counts[0] == 0
assert cass.play_counts[1] == 1
hartwork marked this conversation as resolved.
Show resolved Hide resolved
hartwork marked this conversation as resolved.
Show resolved Hide resolved
33 changes: 33 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from io import BytesIO, StringIO

import pytest

from vcr import request
from vcr.util import read_body


@pytest.mark.parametrize(
"input_, expected_output",
[
(BytesIO(b"Stream"), b"Stream"),
(StringIO("Stream"), b"Stream"),
(iter(["StringIter"]), b"StringIter"),
(iter(["String", "Iter"]), b"StringIter"),
(iter([b"BytesIter"]), b"BytesIter"),
(iter([b"Bytes", b"Iter"]), b"BytesIter"),
(iter([70, 111, 111]), b"Foo"),
(iter([]), b""),
("String", b"String"),
(b"Bytes", b"Bytes"),
],
)
def test_read_body(input_, expected_output):
r = request.Request("POST", "http://host.com/", input_, {})
assert read_body(r) == expected_output


def test_unsupported_read_body():
r = request.Request("POST", "http://host.com/", iter([[]]), {})
with pytest.raises(ValueError) as excinfo:
assert read_body(r)
assert excinfo.value.args == ("Body type <class 'list'> not supported",)
11 changes: 9 additions & 2 deletions vcr/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import BytesIO
from urllib.parse import parse_qsl, urlparse

from .util import CaseInsensitiveDict
from .util import CaseInsensitiveDict, _is_nonsequence_iterator

log = logging.getLogger(__name__)

Expand All @@ -17,8 +17,11 @@ def __init__(self, method, uri, body, headers):
self.method = method
self.uri = uri
self._was_file = hasattr(body, "read")
self._was_iter = _is_nonsequence_iterator(body)
if self._was_file:
self.body = body.read()
elif self._was_iter:
self.body = list(body)
else:
self.body = body
self.headers = headers
Expand All @@ -36,7 +39,11 @@ def headers(self, value):

@property
def body(self):
return BytesIO(self._body) if self._was_file else self._body
if self._was_file:
return BytesIO(self._body)
if self._was_iter:
return iter(self._body)
return self._body

@body.setter
def body(self, value):
Expand Down
19 changes: 19 additions & 0 deletions vcr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,28 @@ def composed(incoming):
return composed


def _is_nonsequence_iterator(obj):
return hasattr(obj, "__iter__") and not isinstance(
obj,
(bytearray, bytes, dict, list, str),
)


def read_body(request):
if hasattr(request.body, "read"):
return request.body.read()
if _is_nonsequence_iterator(request.body):
body = list(request.body)
if body:
if isinstance(body[0], str):
return "".join(body).encode("utf-8")
elif isinstance(body[0], (bytes, bytearray)):
return b"".join(body)
elif isinstance(body[0], int):
return bytes(body)
else:
hartwork marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Body type {type(body[0])} not supported")
return b""
return request.body


Expand Down