Skip to content

Add semaphore to AsyncFileSystemWrapper #1908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions fsspec/implementations/asyn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fsspec.asyn import AsyncFileSystem, running_async


def async_wrapper(func, obj=None):
def async_wrapper(func, obj=None, semaphore=None):
"""
Wraps a synchronous function to make it awaitable.

Expand All @@ -16,6 +16,8 @@ def async_wrapper(func, obj=None):
The synchronous function to wrap.
obj : object, optional
The instance to bind the function to, if applicable.
semaphore : asyncio.Semaphore, optional
A semaphore to limit concurrent calls.

Returns
-------
Expand All @@ -25,6 +27,9 @@ def async_wrapper(func, obj=None):

@functools.wraps(func)
async def wrapper(*args, **kwargs):
if semaphore:
async with semaphore:
return await asyncio.to_thread(func, *args, **kwargs)
return await asyncio.to_thread(func, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -52,6 +57,8 @@ def __init__(
asynchronous=None,
target_protocol=None,
target_options=None,
semaphore=None,
max_concurrent_tasks=None,
**kwargs,
):
if asynchronous is None:
Expand All @@ -62,6 +69,7 @@ def __init__(
else:
self.sync_fs = fsspec.filesystem(target_protocol, **target_options)
self.protocol = self.sync_fs.protocol
self.semaphore = semaphore
self._wrap_all_sync_methods()

@property
Expand All @@ -83,7 +91,7 @@ def _wrap_all_sync_methods(self):

method = getattr(self.sync_fs, method_name)
if callable(method) and not inspect.iscoroutinefunction(method):
async_method = async_wrapper(method, obj=self)
async_method = async_wrapper(method, obj=self, semaphore=self.semaphore)
setattr(self, f"_{method_name}", async_method)

@classmethod
Expand Down
57 changes: 57 additions & 0 deletions fsspec/implementations/tests/test_asyn_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
import asyncio
import os
from itertools import cycle

import pytest

import fsspec
from fsspec.asyn import AsyncFileSystem
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
from fsspec.implementations.local import LocalFileSystem

from .test_local import csv_files, filetexts


class LockedFileSystem(AsyncFileSystem):
"""
A mock file system that simulates a synchronous locking file systems with delays.
"""

def __init__(
self,
asynchronous: bool = False,
delays=None,
) -> None:
self.lock = asyncio.Lock()
self.delays = cycle((0.03, 0.01) if delays is None else delays)

super().__init__(asynchronous=asynchronous)

async def _cat_file(self, path, start=None, end=None) -> bytes:
await self._simulate_io_operation(path)
return path.encode()

async def _await_io(self) -> None:
await asyncio.sleep(next(self.delays))

async def _simulate_io_operation(self, path) -> None:
await self._check_active()
async with self.lock:
await self._await_io()

async def _check_active(self) -> None:
if self.lock.locked():
raise RuntimeError("Concurrent requests!")


@pytest.mark.asyncio
async def test_is_async_default():
fs = fsspec.filesystem("file")
Expand Down Expand Up @@ -161,3 +195,26 @@ def test_open(tmpdir):
)
with of as f:
assert f.read() == b"hello"


@pytest.mark.asyncio
async def test_semaphore_synchronous():
fs = AsyncFileSystemWrapper(
LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(1)
)

paths = [f"path_{i}" for i in range(1, 3)]
results = await asyncio.gather(*(fs._cat_file(path) for path in paths))

assert set(results) == {path.encode() for path in paths}


@pytest.mark.asyncio
async def test_deadlock_when_asynchronous():
fs = AsyncFileSystemWrapper(
LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(3)
)
paths = [f"path_{i}" for i in range(1, 3)]

with pytest.raises(RuntimeError, match="Concurrent requests!"):
await asyncio.gather(*(fs._cat_file(path) for path in paths))
Loading