Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TqdmCallback: make it work on absolute_update too #1480

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 10 additions & 9 deletions fsspec/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,28 @@ class TqdmCallback(Callback):

def __init__(self, tqdm_kwargs=None, *args, **kwargs):
try:
import tqdm
from tqdm import tqdm

self._tqdm = tqdm
except ImportError as exce:
raise ImportError(
"Using TqdmCallback requires tqdm to be installed"
) from exce

self._tqdm_cls = tqdm
self._tqdm_kwargs = tqdm_kwargs or {}
self.tqdm = None
super().__init__(*args, **kwargs)

def set_size(self, size):
self.tqdm = self._tqdm.tqdm(total=size, **self._tqdm_kwargs)

def relative_update(self, inc=1):
self.tqdm.update(inc)
def call(self, *args, **kwargs):
if self.tqdm is None:
self.tqdm = self._tqdm_cls(total=self.size, **self._tqdm_kwargs)
self.tqdm.total = self.size
self.tqdm.update(self.value - self.tqdm.n)

def __del__(self):
if hasattr(self.tqdm, "close"):
if self.tqdm is not None:
self.tqdm.close()
self.tqdm = None
self.tqdm = None


_DEFAULT_CALLBACK = NoOpCallback()
8 changes: 4 additions & 4 deletions fsspec/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def relative_update(self, inc=1):
def test_tqdm_callback(tqdm_kwargs, mocker):

callback = TqdmCallback(tqdm_kwargs=tqdm_kwargs)
mocker.patch.object(callback, "_tqdm")
mocker.patch.object(callback, "_tqdm_cls")
callback.set_size(10)
for _ in callback.wrap(range(10)):
...

assert callback.tqdm.update.call_count == 10
assert callback.tqdm.update.call_count == 11
martindurant marked this conversation as resolved.
Show resolved Hide resolved
if not tqdm_kwargs:
callback._tqdm.tqdm.assert_called_with(total=10)
callback._tqdm_cls.assert_called_with(total=10)
else:
callback._tqdm.tqdm.assert_called_with(total=10, **tqdm_kwargs)
callback._tqdm_cls.assert_called_with(total=10, **tqdm_kwargs)