From 9fdc23fa66cd265ce72fc1edfe893a49c6952613 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 7 Sep 2022 16:49:08 +0200 Subject: [PATCH] Use ApplyMask in other masking loss wrappers --- torch_em/loss/wrapper.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/torch_em/loss/wrapper.py b/torch_em/loss/wrapper.py index acb96d52..895e2ca0 100644 --- a/torch_em/loss/wrapper.py +++ b/torch_em/loss/wrapper.py @@ -41,27 +41,24 @@ def forward(self, prediction, target, **kwargs): # Loss transformations # + +class ApplyMask: + def __call__(self, prediction, target, mask): + mask.requires_grad = False + prediction = prediction * mask + target = target * mask + return prediction, target + + class ApplyAndRemoveMask: def __call__(self, prediction, target): assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" - seperating_channel = target.size(1) // 2 mask = target[:, seperating_channel:] target = target[:, :seperating_channel] - mask.requires_grad = False - - # mask the prediction - prediction = prediction * mask - return prediction, target - - -class ApplyMask: - def __call__(self, prediction, target, mask): - mask.requires_grad = False - prediction = prediction * mask - target = target * mask + prediction, target = ApplyMask()(prediction, target, mask) return prediction, target @@ -71,7 +68,5 @@ def __init__(self, ignore_label=-1): def __call__(self, prediction, target): mask = (target != self.ignore_label) - mask.requires_grad = False - prediction = prediction * mask - target = target * mask + prediction, target = ApplyMask()(prediction, target, mask) return prediction, target