diff --git a/torch_em/trainer/wandb_logger.py b/torch_em/trainer/wandb_logger.py index 06e5fba3..b312d710 100644 --- a/torch_em/trainer/wandb_logger.py +++ b/torch_em/trainer/wandb_logger.py @@ -26,7 +26,7 @@ def __init__(self, trainer): wandb.watch(trainer.model) - def _log_images(self, x, y, prediction, name, gradients=None): + def _log_images(self, step, x, y, prediction, name, gradients=None): selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] @@ -35,7 +35,7 @@ def _log_images(self, x, y, prediction, name, gradients=None): # to numpy and channel last image = image.numpy().transpose((1, 2, 0)) - wandb.log({f"images_{name}/input": [wandb.Image(image, caption='Input Data')]}) + wandb.log({f"images_{name}/input": [wandb.Image(image, caption='Input Data')]}, step=step) grid_image = grid_image.numpy().transpose((1, 2, 0)) @@ -43,18 +43,18 @@ def _log_images(self, x, y, prediction, name, gradients=None): if gradients is not None: grid_name += '_gradients' - wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}) + wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step) def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): - wandb.log({"train/loss": loss}) + wandb.log({"train/loss": loss}, step=step) if step % self.log_image_interval == 0: gradients = prediction.grad if log_gradients else None - self._log_images(x, y, prediction, 'train', + self._log_images(step, x, y, prediction, 'train', gradients=gradients) def log_validation(self, step, metric, loss, x, y, prediction): - wandb.log({"validation/loss": loss, "validation/metric": metric}) - self._log_images(x, y, prediction, 'validation') + wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step) + self._log_images(step, x, y, prediction, "validation") def get_wandb(self): return wandb