Skip to content

Commit

Permalink
Add async stream reader (#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Apr 8, 2023
1 parent d957d70 commit d5d0d98
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
21 changes: 21 additions & 0 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fsspec.utils import infer_storage_options, tokenize, setup_logging as setup_logger
from fsspec.asyn import (
AsyncFileSystem,
AbstractAsyncStreamedFile,
sync,
sync_wrapper,
FSTimeoutError,
Expand Down Expand Up @@ -1938,6 +1939,11 @@ async def _invalidate_region_cache(self):

invalidate_region_cache = sync_wrapper(_invalidate_region_cache)

async def open_async(self, path, mode="rb", **kwargs):
if "b" not in mode or kwargs.get("compression"):
raise ValueError
return S3AsyncStreamedFile(self, path, mode)


class S3File(AbstractBufferedFile):
"""
Expand Down Expand Up @@ -2277,6 +2283,21 @@ def _abort_mpu(self):
self.mpu = None


class S3AsyncStreamedFile(AbstractAsyncStreamedFile):
def __init__(self, fs, path, mode):
self.fs = fs
self.path = path
self.mode = mode
self.r = None

async def read(self, length=-1):
if self.r is None:
bucket, key, gen = self.fs.split_path(self.path)
r = await self.fs._call_s3("get_object", Bucket=bucket, Key=key)
self.r = r["Body"]
return await self.r.read(length)


def _fetch_range(fs, bucket, key, version_id, start, end, req_kw=None):
if req_kw is None:
req_kw = {}
Expand Down
24 changes: 24 additions & 0 deletions s3fs/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2574,3 +2574,27 @@ def test_cp_two_files(s3):
target + "/file0",
target + "/file1",
]


def test_async_stream(s3_base):
fn = test_bucket_name + "/target"
data = b"hello world" * 1000
out = []

async def read_stream():
fs = S3FileSystem(
anon=False,
client_kwargs={"endpoint_url": endpoint_uri},
skip_instance_cache=True,
)
await fs._mkdir(test_bucket_name)
await fs._pipe(fn, data)
f = await fs.open_async(fn, mode="rb", block_seze=1000)
while True:
got = await f.read(1000)
if not got:
break
out.append(got)

asyncio.run(read_stream())
assert b"".join(out) == data

0 comments on commit d5d0d98

Please sign in to comment.