Skip to content

Commit

Permalink
Added functionality to give a description in the progress bar (#401)
Browse files Browse the repository at this point in the history
* Update pbar.py

* Update ensemble.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update ensemble.py

* Update pbar.py

* Update ensemble.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arm61 and pre-commit-ci[bot] committed Aug 6, 2021
1 parent 91c19d6 commit bf0964a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def sample(
thin=None,
store=True,
progress=False,
progress_kwargs=None,
):
"""Advance the chain as a generator
Expand All @@ -287,6 +288,8 @@ def sample(
``'notebook'``, which shows a progress bar suitable for
Jupyter notebooks. If ``False``, no progress bar will be
shown.
progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments
to be passed to the tqdm call.
skip_initial_state_check (Optional[bool]): If ``True``, a check
that the initial_state can fully explore the space will be
skipped. (default: ``False``)
Expand Down Expand Up @@ -383,10 +386,12 @@ def sample(
model = Model(
self.log_prob_fn, self.compute_log_prob, map_fn, self._random
)
if progress_kwargs is None:
progress_kwargs = {}

# Inject the progress bar
total = None if iterations is None else iterations * yield_step
with get_progress_bar(progress, total) as pbar:
with get_progress_bar(progress, total, **progress_kwargs) as pbar:
i = 0
for _ in count() if iterations is None else range(iterations):
for _ in range(yield_step):
Expand Down
7 changes: 4 additions & 3 deletions src/emcee/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def update(self, count):
pass


def get_progress_bar(display, total):
def get_progress_bar(display, total, **kwargs):
"""Get a progress bar interface with given properties
If the tqdm library is not installed, this will always return a "progress
Expand All @@ -38,6 +38,7 @@ def get_progress_bar(display, total):
display (bool or str): Should the bar actually show the progress? Or a
string to indicate which tqdm bar to use.
total (int): The total size of the progress bar.
kwargs (dict): Optional keyword arguments to be passed to the tqdm call.
"""
if display:
Expand All @@ -49,8 +50,8 @@ def get_progress_bar(display, total):
return _NoOpPBar()
else:
if display is True:
return tqdm.tqdm(total=total)
return tqdm.tqdm(total=total, **kwargs)
else:
return getattr(tqdm, "tqdm_" + display)(total=total)
return getattr(tqdm, "tqdm_" + display)(total=total, **kwargs)

return _NoOpPBar()

0 comments on commit bf0964a

Please sign in to comment.