diff --git a/torch_em/util/debug.py b/torch_em/util/debug.py index b3fe1f81..ca7bfa90 100644 --- a/torch_em/util/debug.py +++ b/torch_em/util/debug.py @@ -3,7 +3,7 @@ from .util import ensure_array -def _check_plt(loader, n_samples, instance_labels, model=None, device=None): +def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None): import matplotlib.pyplot as plt img_size = 5 @@ -65,7 +65,11 @@ def to_index(ns, rid, sid): ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii)) ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") - plt.show() + if save_path is None: + plt.show() + else: + plt.savefig(save_path) + plt.close() def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False): @@ -118,8 +122,8 @@ def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device) -def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False): +def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None): if plt: - _check_plt(loader, n_samples, instance_labels) + _check_plt(loader, n_samples, instance_labels, save_path=save_path) else: _check_napari(loader, n_samples, instance_labels, rgb=rgb)