diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 0bc83d73..2ee10790 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -261,6 +261,7 @@ def sample( thin=None, store=True, progress=False, + progress_kwargs=None, ): """Advance the chain as a generator @@ -287,6 +288,8 @@ def sample( ``'notebook'``, which shows a progress bar suitable for Jupyter notebooks. If ``False``, no progress bar will be shown. + progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments + to be passed to the tqdm call. skip_initial_state_check (Optional[bool]): If ``True``, a check that the initial_state can fully explore the space will be skipped. (default: ``False``) @@ -383,10 +386,12 @@ def sample( model = Model( self.log_prob_fn, self.compute_log_prob, map_fn, self._random ) + if progress_kwargs is None: + progress_kwargs = {} # Inject the progress bar total = None if iterations is None else iterations * yield_step - with get_progress_bar(progress, total) as pbar: + with get_progress_bar(progress, total, **progress_kwargs) as pbar: i = 0 for _ in count() if iterations is None else range(iterations): for _ in range(yield_step): diff --git a/src/emcee/pbar.py b/src/emcee/pbar.py index 46fec550..fbfa8ba1 100644 --- a/src/emcee/pbar.py +++ b/src/emcee/pbar.py @@ -28,7 +28,7 @@ def update(self, count): pass -def get_progress_bar(display, total): +def get_progress_bar(display, total, **kwargs): """Get a progress bar interface with given properties If the tqdm library is not installed, this will always return a "progress @@ -38,6 +38,7 @@ def get_progress_bar(display, total): display (bool or str): Should the bar actually show the progress? Or a string to indicate which tqdm bar to use. total (int): The total size of the progress bar. + kwargs (dict): Optional keyword arguments to be passed to the tqdm call. """ if display: @@ -49,8 +50,8 @@ def get_progress_bar(display, total): return _NoOpPBar() else: if display is True: - return tqdm.tqdm(total=total) + return tqdm.tqdm(total=total, **kwargs) else: - return getattr(tqdm, "tqdm_" + display)(total=total) + return getattr(tqdm, "tqdm_" + display)(total=total, **kwargs) return _NoOpPBar()