Skip to content

Commit

Permalink
Merge pull request #136 from anwai98/save-plt
Browse files Browse the repository at this point in the history
Update check_loader to save plots
  • Loading branch information
constantinpape committed Jul 14, 2023
2 parents 620ab64 + 3c4f56a commit 6fcb386
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_em/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .util import ensure_array


def _check_plt(loader, n_samples, instance_labels, model=None, device=None):
def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None):
import matplotlib.pyplot as plt
img_size = 5

Expand Down Expand Up @@ -65,7 +65,11 @@ def to_index(ns, rid, sid):
ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii))
ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")

plt.show()
if save_path is None:
plt.show()
else:
plt.savefig(save_path)
plt.close()


def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False):
Expand Down Expand Up @@ -118,8 +122,8 @@ def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader
_check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)


def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False):
def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None):
if plt:
_check_plt(loader, n_samples, instance_labels)
_check_plt(loader, n_samples, instance_labels, save_path=save_path)
else:
_check_napari(loader, n_samples, instance_labels, rgb=rgb)

0 comments on commit 6fcb386

Please sign in to comment.