diff --git a/test/loss/test_loss_wrapper.py b/test/loss/test_loss_wrapper.py index 44cfc332..9d0864c1 100644 --- a/test/loss/test_loss_wrapper.py +++ b/test/loss/test_loss_wrapper.py @@ -3,34 +3,151 @@ class TestLossWrapper(unittest.TestCase): - def test_masking(self): - from torch_em.loss import (ApplyAndRemoveMask, - DiceLoss, - LossWrapper) - loss = LossWrapper(DiceLoss(), - transform=ApplyAndRemoveMask()) + def test_ApplyAndRemove_grad_masking(self): + from torch_em.loss import ( ApplyAndRemoveMask, + ApplyMask, + DiceLoss, + LossWrapper) + shape = (1, 1, 128, 128) + for masking_func in ApplyMask.MASKING_FUNCS: + transform = ApplyAndRemoveMask( + masking_method=masking_func + ) + loss = LossWrapper(DiceLoss(), transform=transform) + + x = torch.rand(*shape) + x.requires_grad = True + x.retain_grad = True + + y = torch.rand(*shape) + mask = torch.rand(*shape) > .5 + y = torch.cat([ + y, mask.to(dtype=y.dtype) + ], dim=1) + lval = loss(x, y) + self.assertTrue(0. < lval.item() < 1.) + lval.backward() + + grad = x.grad.numpy() + mask = mask.numpy() + # print((grad[mask] == 0).sum()) + self.assertFalse((grad[mask] == 0).all()) + # print((grad[~mask] == 0).sum()) + self.assertTrue((grad[~mask] == 0).all()) + + def test_MaskIgnoreLabel_grad_masking(self): + from torch_em.loss import ( MaskIgnoreLabel, + ApplyMask, + DiceLoss, + LossWrapper) shape = (1, 1, 128, 128) - x = torch.rand(*shape) - x.requires_grad = True - x.retain_grad = True - - y = torch.rand(*shape) - mask = torch.rand(*shape) > .5 - y = torch.cat([ - y, mask.to(dtype=y.dtype) - ], dim=1) - - lval = loss(x, y) - self.assertTrue(0. < lval.item() < 1.) - lval.backward() - - grad = x.grad.numpy() - mask = mask.numpy() - # print((grad[mask] == 0).sum()) - self.assertFalse((grad[mask] == 0).all()) - # print((grad[~mask] == 0).sum()) - self.assertTrue((grad[~mask] == 0).all()) + ignore_label = -1 + for masking_func in ApplyMask.MASKING_FUNCS: + transform = MaskIgnoreLabel( + masking_method=masking_func, + ignore_label=ignore_label + ) + loss = LossWrapper(DiceLoss(), transform=transform) + + x = torch.rand(*shape) + x.requires_grad = True + x.retain_grad = True + + y = torch.rand(*shape) + mask = torch.rand(*shape) > .5 + y[mask] = ignore_label + + lval = loss(x, y) + self.assertTrue(0. < lval.item() < 1.) + lval.backward() + + grad = x.grad.numpy() + mask = mask.numpy() + self.assertFalse((grad[~mask] == 0).all()) + self.assertTrue((grad[mask] == 0).all()) + + def test_ApplyMask_grad_masking(self): + from torch_em.loss import ( ApplyMask, + DiceLoss, + LossWrapper) + shape = (1, 1, 128, 128) + for masking_func in ApplyMask.MASKING_FUNCS: + transform = ApplyMask( + masking_method=masking_func + ) + loss = LossWrapper(DiceLoss(), transform=transform) + + x = torch.rand(*shape) + x.requires_grad = True + x.retain_grad = True + + y = torch.rand(*shape) + mask = torch.rand(*shape) > .5 + + lval = loss(x, y, mask=mask) + self.assertTrue(0. < lval.item() < 1.) + lval.backward() + + grad = x.grad.numpy() + mask = mask.numpy() + self.assertFalse((grad[mask] == 0).all()) + self.assertTrue((grad[~mask] == 0).all()) + + def test_ApplyMask_output_shape_crop(self): + from torch_em.loss import ApplyMask + + # _crop batch_size=1 + shape = (1, 1, 10, 128, 128) + p = torch.rand(*shape) + t = torch.rand(*shape) + m = torch.rand(*shape) > .5 + p_masked, t_masked = ApplyMask()(p, t, m) + out_shape = (m.sum(), shape[1]) + self.assertTrue(p_masked.shape == out_shape) + self.assertTrue(t_masked.shape == out_shape) + + # _crop batch_size>1 + shape = (5, 1, 10, 128, 128) + p = torch.rand(*shape) + t = torch.rand(*shape) + m = torch.rand(*shape) > .5 + p_masked, t_masked = ApplyMask()(p, t, m) + out_shape = (m.sum(), shape[1]) + self.assertTrue(p_masked.shape == out_shape) + self.assertTrue(t_masked.shape == out_shape) + + # _crop n_channels>1 + shape = (1, 2, 10, 128, 128) + p = torch.rand(*shape) + t = torch.rand(*shape) + m = torch.rand(*shape) > .5 + with self.assertRaises(ValueError): + p_masked, t_masked = ApplyMask()(p, t, m) + + # _crop different shapes + shape_pt = (5, 2, 10, 128, 128) + p = torch.rand(*shape_pt) + t = torch.rand(*shape_pt) + shape_m = (5, 1, 10, 128, 128) + m = torch.rand(*shape_m) > .5 + p_masked, t_masked = ApplyMask()(p, t, m) + out_shape = (m.sum(), shape_pt[1]) + self.assertTrue(p_masked.shape == out_shape) + self.assertTrue(t_masked.shape == out_shape) + + def test_ApplyMask_output_shape_multiply(self): + from torch_em.loss import ApplyMask + + # _multiply + shape = (2, 5, 10, 128, 128) + p = torch.rand(*shape) + t = torch.rand(*shape) + m = torch.rand(*shape) > .5 + + p_masked, t_masked = ApplyMask(masking_method="multiply")(p, t, m) + self.assertTrue(p_masked.shape == shape) + self.assertTrue(t_masked.shape == shape) if __name__ == '__main__': diff --git a/torch_em/loss/wrapper.py b/torch_em/loss/wrapper.py index 895e2ca0..52312a84 100644 --- a/torch_em/loss/wrapper.py +++ b/torch_em/loss/wrapper.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -43,14 +44,49 @@ def forward(self, prediction, target, **kwargs): class ApplyMask: - def __call__(self, prediction, target, mask): - mask.requires_grad = False + def _crop(prediction, target, mask, channel_dim): + if mask.shape[channel_dim] != 1: + raise ValueError( + "_crop only supports a mask with a singleton channel axis. \ + Please consider using masking_method=multiply." + ) + mask = mask.type(torch.bool) + # remove singleton axis + mask = mask.squeeze(channel_dim) + # move channel axis to end + prediction = prediction.moveaxis(channel_dim, -1) + target = target.moveaxis(channel_dim, -1) + # output has shape N x C + # correct for torch_em.loss.dice.flatten_samples + return prediction[mask], target[mask] + + def _multiply(prediction, target, mask, channel_dim): prediction = prediction * mask target = target * mask return prediction, target + MASKING_FUNCS = { + "crop": _crop, + "multiply": _multiply, + } + + def __init__(self, masking_method="crop", channel_dim=1): + if masking_method not in self.MASKING_FUNCS.keys(): + raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") + self.masking_func = self.MASKING_FUNCS[masking_method] + self.channel_dim = channel_dim + + self.init_kwargs = { + "masking_method": masking_method, + "channel_dim": channel_dim, + } + + def __call__(self, prediction, target, mask): + mask.requires_grad = False + return self.masking_func(prediction, target, mask, self.channel_dim) + -class ApplyAndRemoveMask: +class ApplyAndRemoveMask(ApplyMask): def __call__(self, prediction, target): assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" @@ -58,15 +94,17 @@ def __call__(self, prediction, target): seperating_channel = target.size(1) // 2 mask = target[:, seperating_channel:] target = target[:, :seperating_channel] - prediction, target = ApplyMask()(prediction, target, mask) + prediction, target = super().__call__(prediction, target, mask) return prediction, target -class MaskIgnoreLabel: - def __init__(self, ignore_label=-1): +class MaskIgnoreLabel(ApplyMask): + def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1): + super().__init__(masking_method, channel_dim) self.ignore_label = ignore_label + self.init_kwargs["ignore_label"] = ignore_label def __call__(self, prediction, target): mask = (target != self.ignore_label) - prediction, target = ApplyMask()(prediction, target, mask) + prediction, target = super().__call__(prediction, target, mask) return prediction, target