Skip to content

Commit 07e7c4a

Browse files
committed
implement chunking to the async file system
1 parent f7daa82 commit 07e7c4a

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

fsspec/asyn.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from contextlib import contextmanager
1010
from glob import has_magic
1111

12+
from .callbacks import as_callback, branch
1213
from .exceptions import FSTimeoutError
1314
from .spec import AbstractFileSystem
1415
from .utils import PY36, is_exception, other_paths
@@ -154,26 +155,29 @@ def _get_batch_size():
154155
return soft_limit // 8
155156

156157

157-
async def _throttled_gather(coros, batch_size=None, **gather_kwargs):
158+
async def _run_coros_in_chunks(coros, batch_size=None, callback=None, timeout=None):
158159
"""Run the given coroutines in smaller chunks to
159160
not crossing the file descriptor limit.
160161
161162
If batch_size parameter is -1, then it will not be any throttling. If
162163
it is none, it will be inferred from the process resources (soft limit divided
163164
by 8) and fallback to 128 if the system doesn't support it."""
164165

166+
callback = as_callback(callback)
165167
if batch_size is None:
166168
batch_size = _get_batch_size()
167169

168170
if batch_size == -1:
169-
return await asyncio.gather(*coros, **gather_kwargs)
171+
batch_size = len(coros)
170172

171173
assert batch_size > 0
172174

173175
results = []
174176
for start in range(0, len(coros), batch_size):
175177
chunk = coros[start : start + batch_size]
176-
results.extend(await asyncio.gather(*chunk, **gather_kwargs))
178+
for coro in asyncio.as_completed(chunk, timeout=timeout):
179+
results.append(await coro)
180+
callback.call("relative_update", 1)
177181
return results
178182

179183

@@ -340,13 +344,16 @@ async def _put(self, lpath, rpath, recursive=False, **kwargs):
340344
fs = LocalFileSystem()
341345
lpaths = fs.expand_path(lpath, recursive=recursive)
342346
rpaths = other_paths(lpaths, rpath)
347+
callback = as_callback(kwargs.pop("callback", None))
343348
batch_size = kwargs.pop("batch_size", self.batch_size)
344-
return await _throttled_gather(
345-
[
346-
self._put_file(lpath, rpath, **kwargs)
347-
for lpath, rpath in zip(lpaths, rpaths)
348-
],
349-
batch_size=batch_size,
349+
350+
coros = []
351+
callback.lazy_call("set_size", len, lpaths)
352+
for lpath, rpath in zip(lpaths, rpaths):
353+
branch(callback, lpath, rpath, kwargs)
354+
coros.append(self._get_file(lpath, rpath, **kwargs))
355+
return await _run_coros_in_chunks(
356+
coros, batch_size=batch_size, callback=callback
350357
)
351358

352359
async def _get_file(self, rpath, lpath, **kwargs):
@@ -374,13 +381,16 @@ async def _get(self, rpath, lpath, recursive=False, **kwargs):
374381
rpaths = await self._expand_path(rpath, recursive=recursive)
375382
lpaths = other_paths(rpaths, lpath)
376383
[os.makedirs(os.path.dirname(lp), exist_ok=True) for lp in lpaths]
384+
callback = as_callback(kwargs.pop("callback", None))
377385
batch_size = kwargs.pop("batch_size", self.batch_size)
378-
return await _throttled_gather(
379-
[
380-
self._get_file(rpath, lpath, **kwargs)
381-
for lpath, rpath in zip(lpaths, rpaths)
382-
],
383-
batch_size=batch_size,
386+
387+
coros = []
388+
callback.lazy_call("set_size", len, lpaths)
389+
for lpath, rpath in zip(lpaths, rpaths):
390+
branch(callback, rpath, lpath, kwargs)
391+
coros.append(self._get_file(rpath, lpath, **kwargs))
392+
return await _run_coros_in_chunks(
393+
coros, batch_size=batch_size, callback=callback
384394
)
385395

386396
async def _isfile(self, path):

fsspec/tests/test_async.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import fsspec
1010
import fsspec.asyn
11-
from fsspec.asyn import _throttled_gather
11+
from fsspec.asyn import _run_coros_in_chunks
1212

1313

1414
def test_sync_methods():
@@ -72,7 +72,7 @@ def test_sync_wrapper_treat_timeout_0_as_none():
7272

7373

7474
@pytest.mark.skipif(sys.version_info < (3, 7), reason="no asyncio.run in <3.7")
75-
def test_throttled_gather(monkeypatch):
75+
def test_run_coros_in_chunks(monkeypatch):
7676
total_running = 0
7777

7878
async def runner():
@@ -90,7 +90,7 @@ async def main(**kwargs):
9090

9191
total_running = 0
9292
coros = [runner() for _ in range(32)]
93-
results = await _throttled_gather(coros, **kwargs)
93+
results = await _run_coros_in_chunks(coros, **kwargs)
9494
for result in results:
9595
if isinstance(result, Exception):
9696
raise result
@@ -99,16 +99,16 @@ async def main(**kwargs):
9999
assert sum(asyncio.run(main(batch_size=4))) == 32
100100

101101
with pytest.raises(ValueError):
102-
asyncio.run(main(batch_size=5, return_exceptions=True))
102+
asyncio.run(main(batch_size=5))
103103

104104
with pytest.raises(ValueError):
105-
asyncio.run(main(batch_size=-1, return_exceptions=True))
105+
asyncio.run(main(batch_size=-1))
106106

107107
assert sum(asyncio.run(main(batch_size=4))) == 32
108108

109109
monkeypatch.setitem(fsspec.config.conf, "gather_batch_size", 5)
110110
with pytest.raises(ValueError):
111-
asyncio.run(main(return_exceptions=True))
111+
asyncio.run(main())
112112
assert sum(asyncio.run(main(batch_size=4))) == 32 # override
113113

114114
monkeypatch.setitem(fsspec.config.conf, "gather_batch_size", 4)

0 commit comments

Comments
 (0)