Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #122 from justusschock/fix-losstype-2dunet
Browse files Browse the repository at this point in the history
Make returned losses and metrics scalar
  • Loading branch information
justusschock committed Jun 5, 2019
2 parents f7ae91b + 6954d17 commit 62dccbe
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions delira/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},

for key, crit_fn in losses.items():
_loss_val = crit_fn(preds["pred"], *data_dict.values())
loss_vals[key] = _loss_val.detach()
loss_vals[key] = _loss_val.item()
total_loss += _loss_val

with torch.no_grad():
for key, metric_fn in metrics.items():
metric_vals[key] = metric_fn(
preds["pred"], *data_dict.values())
preds["pred"], *data_dict.values()).item()

if optimizers:
optimizers['default'].zero_grad()
Expand Down

0 comments on commit 62dccbe

Please sign in to comment.