From b0f95d380dc7b03e073aa74d8a1cc1797cd49337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 26 Dec 2023 21:19:49 +0545 Subject: [PATCH] TqdmCallback: make it work on absolute_update too --- fsspec/callbacks.py | 19 ++++++++++--------- fsspec/tests/test_callbacks.py | 8 ++++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fsspec/callbacks.py b/fsspec/callbacks.py index 8f1095f1d..fdb938ced 100644 --- a/fsspec/callbacks.py +++ b/fsspec/callbacks.py @@ -213,27 +213,28 @@ class TqdmCallback(Callback): def __init__(self, tqdm_kwargs=None, *args, **kwargs): try: - import tqdm + from tqdm import tqdm - self._tqdm = tqdm except ImportError as exce: raise ImportError( "Using TqdmCallback requires tqdm to be installed" ) from exce + self._tqdm_cls = tqdm self._tqdm_kwargs = tqdm_kwargs or {} + self.tqdm = None super().__init__(*args, **kwargs) - def set_size(self, size): - self.tqdm = self._tqdm.tqdm(total=size, **self._tqdm_kwargs) - - def relative_update(self, inc=1): - self.tqdm.update(inc) + def call(self, *args, **kwargs): + if self.tqdm is None: + self.tqdm = self._tqdm_cls(total=self.size, **self._tqdm_kwargs) + self.tqdm.total = self.size + self.tqdm.update(self.value - self.tqdm.n) def __del__(self): - if hasattr(self.tqdm, "close"): + if self.tqdm is not None: self.tqdm.close() - self.tqdm = None + self.tqdm = None _DEFAULT_CALLBACK = NoOpCallback() diff --git a/fsspec/tests/test_callbacks.py b/fsspec/tests/test_callbacks.py index 74d31e85f..0b3b448c5 100644 --- a/fsspec/tests/test_callbacks.py +++ b/fsspec/tests/test_callbacks.py @@ -44,13 +44,13 @@ def relative_update(self, inc=1): def test_tqdm_callback(tqdm_kwargs, mocker): callback = TqdmCallback(tqdm_kwargs=tqdm_kwargs) - mocker.patch.object(callback, "_tqdm") + mocker.patch.object(callback, "_tqdm_cls") callback.set_size(10) for _ in callback.wrap(range(10)): ... - assert callback.tqdm.update.call_count == 10 + assert callback.tqdm.update.call_count == 11 if not tqdm_kwargs: - callback._tqdm.tqdm.assert_called_with(total=10) + callback._tqdm_cls.assert_called_with(total=10) else: - callback._tqdm.tqdm.assert_called_with(total=10, **tqdm_kwargs) + callback._tqdm_cls.assert_called_with(total=10, **tqdm_kwargs)