diff --git a/torch_em/loss/wrapper.py b/torch_em/loss/wrapper.py index acb96d52..895e2ca0 100644 --- a/torch_em/loss/wrapper.py +++ b/torch_em/loss/wrapper.py @@ -41,27 +41,24 @@ def forward(self, prediction, target, **kwargs): # Loss transformations # + +class ApplyMask: + def __call__(self, prediction, target, mask): + mask.requires_grad = False + prediction = prediction * mask + target = target * mask + return prediction, target + + class ApplyAndRemoveMask: def __call__(self, prediction, target): assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" - seperating_channel = target.size(1) // 2 mask = target[:, seperating_channel:] target = target[:, :seperating_channel] - mask.requires_grad = False - - # mask the prediction - prediction = prediction * mask - return prediction, target - - -class ApplyMask: - def __call__(self, prediction, target, mask): - mask.requires_grad = False - prediction = prediction * mask - target = target * mask + prediction, target = ApplyMask()(prediction, target, mask) return prediction, target @@ -71,7 +68,5 @@ def __init__(self, ignore_label=-1): def __call__(self, prediction, target): mask = (target != self.ignore_label) - mask.requires_grad = False - prediction = prediction * mask - target = target * mask + prediction, target = ApplyMask()(prediction, target, mask) return prediction, target