Skip to content

Commit

Permalink
Merge pull request #4 from FynnBe/wandb_log_step
Browse files Browse the repository at this point in the history
specify log step in wand.log calls
  • Loading branch information
constantinpape committed May 20, 2021
2 parents 0cacdc1 + ee74099 commit 6db02a2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torch_em/trainer/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, trainer):

wandb.watch(trainer.model)

def _log_images(self, x, y, prediction, name, gradients=None):
def _log_images(self, step, x, y, prediction, name, gradients=None):

selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]

Expand All @@ -35,26 +35,26 @@ def _log_images(self, x, y, prediction, name, gradients=None):

# to numpy and channel last
image = image.numpy().transpose((1, 2, 0))
wandb.log({f"images_{name}/input": [wandb.Image(image, caption='Input Data')]})
wandb.log({f"images_{name}/input": [wandb.Image(image, caption='Input Data')]}, step=step)

grid_image = grid_image.numpy().transpose((1, 2, 0))

grid_name = 'raw_targets_predictions'
if gradients is not None:
grid_name += '_gradients'

wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]})
wandb.log({f"images_{name}/{grid_name}": [wandb.Image(grid_image, caption=grid_name)]}, step=step)

def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
wandb.log({"train/loss": loss})
wandb.log({"train/loss": loss}, step=step)
if step % self.log_image_interval == 0:
gradients = prediction.grad if log_gradients else None
self._log_images(x, y, prediction, 'train',
self._log_images(step, x, y, prediction, 'train',
gradients=gradients)

def log_validation(self, step, metric, loss, x, y, prediction):
wandb.log({"validation/loss": loss, "validation/metric": metric})
self._log_images(x, y, prediction, 'validation')
wandb.log({"validation/loss": loss, "validation/metric": metric}, step=step)
self._log_images(step, x, y, prediction, "validation")

def get_wandb(self):
return wandb

0 comments on commit 6db02a2

Please sign in to comment.