Skip to content

Commit

Permalink
Merge pull request #49 from lsst/tickets/DM-38589
Browse files Browse the repository at this point in the history
DM-38589: Fix repeated reads with stream handle
  • Loading branch information
timj committed Apr 11, 2023
2 parents c73dc92 + 45af22f commit 267b0d9
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 15 deletions.
2 changes: 2 additions & 0 deletions doc/changes/DM-38589.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Fix EOF detection with S3 and HTTP resource handles when using repeated ``read()``.
* Ensure that HTTP reads with resource handles using byte ranges correctly disable remote compression.
66 changes: 57 additions & 9 deletions python/lsst/resources/_resourceHandles/_httpResourceHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def __init__(

self._closed = CloseStatus.OPEN
self._current_position = 0
self._eof = False

def close(self) -> None:
self._closed = CloseStatus.CLOSED
self._completeBuffer = None
self._eof = True

@property
def closed(self) -> bool:
Expand All @@ -63,7 +65,9 @@ def fileno(self) -> int:
raise io.UnsupportedOperation("HttpReadResourceHandle does not have a file number")

def flush(self) -> None:
raise io.UnsupportedOperation("HttpReadResourceHandles are read only")
modes = set(self._mode)
if {"w", "x", "a", "+"} & modes:
raise io.UnsupportedOperation("HttpReadResourceHandles are read only")

@property
def isatty(self) -> Union[bool, Callable[[], bool]]:
Expand All @@ -79,6 +83,7 @@ def readlines(self, size: int = -1) -> Iterable[bytes]:
raise io.UnsupportedOperation("HttpReadResourceHandles Do not support line by line reading")

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
self._eof = False
if whence == io.SEEK_CUR and (self._current_position + offset) >= 0:
self._current_position += offset
elif whence == io.SEEK_SET and offset >= 0:
Expand Down Expand Up @@ -110,6 +115,10 @@ def writelines(self, b: Iterable[bytes], /) -> None:
raise io.UnsupportedOperation("HttpReadResourceHandles are read only")

def read(self, size: int = -1) -> bytes:
if self._eof:
# At EOF so always return an empty byte string.
return b""

# branch for if the complete file has been read before
if self._completeBuffer is not None:
result = self._completeBuffer.read(size)
Expand All @@ -122,34 +131,73 @@ def read(self, size: int = -1) -> bytes:
self._completeBuffer = io.BytesIO()
with time_this(self._log, msg="Read from remote resource %s", args=(self._url,)):
resp = self._session.get(self._url, stream=False, timeout=self._timeout)
if (code := resp.status_code) not in (200, 206):
if (code := resp.status_code) not in (requests.codes.ok, requests.codes.partial):
raise FileNotFoundError(f"Unable to read resource {self._url}; status code: {code}")
self._completeBuffer.write(resp.content)
self._current_position = self._completeBuffer.tell()

return self._completeBuffer.getbuffer().tobytes()

# a partial read is required, either because a size has been specified,
# or a read has previously been done.
# A partial read is required, either because a size has been specified,
# or a read has previously been done. Any time we specify a byte range
# we must disable the gzip compression on the server since we want
# to address ranges in the uncompressed file. If we send ranges that
# are interpreted by the server as offsets into the compressed file
# then that is at least confusing and also there is no guarantee that
# the bytes can be uncompressed.

end_pos = self._current_position + (size - 1) if size >= 0 else ""
headers = {"Range": f"bytes={self._current_position}-{end_pos}"}
headers = {"Range": f"bytes={self._current_position}-{end_pos}", "Accept-Encoding": "identity"}

with time_this(self._log, msg="Read from remote resource %s", args=(self._url,)):
with time_this(
self._log, msg="Read from remote resource %s using headers %s", args=(self._url, headers)
):
resp = self._session.get(self._url, stream=False, timeout=self._timeout, headers=headers)

if (code := resp.status_code) not in (200, 206):
if resp.status_code == requests.codes.range_not_satisfiable:
# Must have run off the end of the file. A standard file handle
# will treat this as EOF so be consistent with that. Do not change
# the current position.
self._eof = True
return b""

if (code := resp.status_code) not in (requests.codes.ok, requests.codes.partial):
raise FileNotFoundError(
f"Unable to read resource {self._url}, or bytes are out of range; status code: {code}"
)

len_content = len(resp.content)

# verify this is not actually the whole file and the server did not lie
# about supporting ranges
if len(resp.content) > size or code != 206:
if len_content > size or code != requests.codes.partial:
self._completeBuffer = io.BytesIO()
self._completeBuffer.write(resp.content)
self._completeBuffer.seek(0)
return self.read(size=size)

self._current_position += size
# The response header should tell us the total number of bytes
# in the file and also the current position we have got to in the
# server.
if "Content-Range" in resp.headers:
content_range = resp.headers["Content-Range"]
units, range_string = content_range.split(" ")
if units == "bytes":
range, total = range_string.split("/")
if "-" in range:
_, end = range.split("-")
end_pos = int(end)
if total != "*":
if end_pos >= int(total) - 1:
self._eof = True
else:
self._log.warning("Requested byte range from server but instead got: %s", content_range)

# Try to guess that we overran the end. This will not help if we
# read exactly the number of bytes to get us to the end and so we
# will need to do one more read and get a 416.
if len_content < size:
self._eof = True

self._current_position += len_content
return resp.content
30 changes: 25 additions & 5 deletions python/lsst/resources/_resourceHandles/_s3ResourceHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from logging import Logger
from typing import TYPE_CHECKING, Iterable, Mapping, Optional

from botocore.exceptions import ClientError
from lsst.utils.timer import time_this

from ..s3utils import all_retryable_errors, backoff, max_retry_time
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(
self._last_flush_position: Optional[int] = None
self._warned = False
self._readable = bool({"r", "+"} & set(self._mode))
self._max_size: int | None = None
self._recursing = False
if {"w", "a", "x", "+"} & set(self._mode):
self._writable = True
self._multiPartUpload = client.create_multipart_upload(Bucket=bucket, Key=key)
Expand Down Expand Up @@ -260,16 +263,33 @@ def read(self, size: int = -1) -> bytes:
self._buffer.seek(self._position)
return self._buffer.read(size)
# otherwise fetch the appropriate bytes from the remote resource
if self._max_size is not None and self._position >= self._max_size:
return b""
if size > 0:
stop = f"{self._position + size - 1}"
else:
stop = ""
args = {"Range": f"bytes={self._position}-{stop}"}
response = self._client.get_object(Bucket=self._bucket, Key=self._key, **args)
contents = response["Body"].read()
response["Body"].close()
self._position = len(contents)
return contents
try:
response = self._client.get_object(Bucket=self._bucket, Key=self._key, **args)
contents = response["Body"].read()
response["Body"].close()
self._position += len(contents)
return contents
except ClientError as exc:
if exc.response["ResponseMetadata"]["HTTPStatusCode"] == 416:
if self._recursing:
# This means the function has attempted to read the whole
# byte range and failed again, meaning the previous byte
# was the last byte
return b""
self._recursing = True
result = self.read()
self._max_size = self._position
self._recursing = False
return result
else:
raise

def write(self, b: bytes) -> int:
if self.writable():
Expand Down
51 changes: 50 additions & 1 deletion python/lsst/resources/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _check_open(
**kwargs
Additional keyword arguments to forward to all calls to `open`.
"""
text_content = "wxyz🙂"
text_content = "abcdefghijklmnopqrstuvwxyz🙂"
bytes_content = uuid.uuid4().bytes
content_by_mode_suffix = {
"": text_content,
Expand Down Expand Up @@ -84,6 +84,55 @@ def _check_open(
# Read the file we created and check the contents.
with uri.open("r" + mode_suffix, **kwargs) as read_buffer:
test_case.assertEqual(read_buffer.read(), content)
# Check that we can read bytes in a loop and get EOF
with uri.open("r" + mode_suffix, **kwargs) as read_buffer:
# Seek off the end of the file and should read empty back.
read_buffer.seek(1024)
test_case.assertEqual(read_buffer.tell(), 1024)
content_read = read_buffer.read() # Read as much as we can.
test_case.assertEqual(len(content_read), 0, f"Read: {content_read!r}, expected empty.")

# First read more than the content.
read_buffer.seek(0)
size = len(content) * 3
chunk_read = read_buffer.read(size)
test_case.assertEqual(chunk_read, content)

# Repeated reads should always return empty string.
chunk_read = read_buffer.read(size)
test_case.assertEqual(len(chunk_read), 0)
chunk_read = read_buffer.read(size)
test_case.assertEqual(len(chunk_read), 0)

# Go back to start of file and read in smaller chunks.
read_buffer.seek(0)
size = len(content) // 3

content_read = empty_content_by_mode_suffix[mode_suffix]
n_reads = 0
while chunk_read := read_buffer.read(size):
content_read += chunk_read
n_reads += 1
if n_reads > 10: # In case EOF never hits because of bug.
raise AssertionError(
f"Failed to stop reading from file after {n_reads} loops. "
f"Read {len(content_read)} bytes/characters. Expected {len(content)}."
)
test_case.assertEqual(content_read, content)

# Go back to start of file and read the entire thing.
read_buffer.seek(0)
content_read = read_buffer.read()
test_case.assertEqual(content_read, content)

# Seek off the end of the file and should read empty back.
# We run this check twice since in some cases the handle will
# cache knowledge of the file size.
read_buffer.seek(1024)
test_case.assertEqual(read_buffer.tell(), 1024)
content_read = read_buffer.read()
test_case.assertEqual(len(content_read), 0, f"Read: {content_read!r}, expected empty.")

# Write two copies of the content, overwriting the single copy there.
with uri.open("w" + mode_suffix, **kwargs) as write_buffer:
write_buffer.write(double_content)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def test_dav_file_handle(self):
self.assertIsNotNone(handle._completeBuffer)
self.assertEqual(result, contents)

# Check that flush works on read-only handle.
handle.flush()

# Verify reading as a string handle works as expected.
with remote_file.open("r") as handle:
self.assertIsInstance(handle, io.TextIOWrapper)
Expand All @@ -182,6 +185,9 @@ def test_dav_file_handle(self):
result = handle.read()
self.assertEqual(result, contents)

# Check that flush works on read-only handle.
handle.flush()

# Verify that write modes invoke the default base method
with remote_file.open("w") as handle:
self.assertIsInstance(handle, io.StringIO)
Expand Down

0 comments on commit 267b0d9

Please sign in to comment.