From 5099c4335af9241e72a16bea52504c2cef71cc17 Mon Sep 17 00:00:00 2001 From: JonasHell Date: Wed, 22 Feb 2023 12:07:39 +0100 Subject: [PATCH] add args to adjust tqdm output --- torch_em/util/prediction.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_em/util/prediction.py b/torch_em/util/prediction.py index acdfa55a..cd159f03 100644 --- a/torch_em/util/prediction.py +++ b/torch_em/util/prediction.py @@ -108,6 +108,8 @@ def predict_with_halo( postprocess=None, with_channels=False, skip_block=None, + disable_tqdm=False, + tqdm_desc='predict with halo' ): """ Run block-wise network prediction with halo. @@ -125,6 +127,8 @@ def predict_with_halo( postprocess [callable] - function to postprocess the network predictions (default: None) with_channels [bool] - whether the input has a channel axis (default: False) skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None) + disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times) + tqdm_desc [str] - description shown by the tqdm output """ devices = [torch.device(gpu) for gpu in gpu_ids] models = [ @@ -191,6 +195,6 @@ def predict_block(block_id): n_blocks = blocking.numberOfBlocks with futures.ThreadPoolExecutor(n_workers) as tp: - list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks)) + list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc)) return output