From bf0964ae2376ec7ca4e9ccbd0b62f22c5fe12719 Mon Sep 17 00:00:00 2001 From: Andrew McCluskey Date: Fri, 6 Aug 2021 16:53:13 +0200 Subject: [PATCH] Added functionality to give a description in the progress bar (#401) * Update pbar.py * Update ensemble.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update ensemble.py * Update pbar.py * Update ensemble.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/emcee/ensemble.py | 7 ++++++- src/emcee/pbar.py | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 0bc83d73..2ee10790 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -261,6 +261,7 @@ def sample( thin=None, store=True, progress=False, + progress_kwargs=None, ): """Advance the chain as a generator @@ -287,6 +288,8 @@ def sample( ``'notebook'``, which shows a progress bar suitable for Jupyter notebooks. If ``False``, no progress bar will be shown. + progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments + to be passed to the tqdm call. skip_initial_state_check (Optional[bool]): If ``True``, a check that the initial_state can fully explore the space will be skipped. (default: ``False``) @@ -383,10 +386,12 @@ def sample( model = Model( self.log_prob_fn, self.compute_log_prob, map_fn, self._random ) + if progress_kwargs is None: + progress_kwargs = {} # Inject the progress bar total = None if iterations is None else iterations * yield_step - with get_progress_bar(progress, total) as pbar: + with get_progress_bar(progress, total, **progress_kwargs) as pbar: i = 0 for _ in count() if iterations is None else range(iterations): for _ in range(yield_step): diff --git a/src/emcee/pbar.py b/src/emcee/pbar.py index 46fec550..fbfa8ba1 100644 --- a/src/emcee/pbar.py +++ b/src/emcee/pbar.py @@ -28,7 +28,7 @@ def update(self, count): pass -def get_progress_bar(display, total): +def get_progress_bar(display, total, **kwargs): """Get a progress bar interface with given properties If the tqdm library is not installed, this will always return a "progress @@ -38,6 +38,7 @@ def get_progress_bar(display, total): display (bool or str): Should the bar actually show the progress? Or a string to indicate which tqdm bar to use. total (int): The total size of the progress bar. + kwargs (dict): Optional keyword arguments to be passed to the tqdm call. """ if display: @@ -49,8 +50,8 @@ def get_progress_bar(display, total): return _NoOpPBar() else: if display is True: - return tqdm.tqdm(total=total) + return tqdm.tqdm(total=total, **kwargs) else: - return getattr(tqdm, "tqdm_" + display)(total=total) + return getattr(tqdm, "tqdm_" + display)(total=total, **kwargs) return _NoOpPBar()