Skip to content

Commit

Permalink
Revert "Merge branch 'augmentations'"
Browse files Browse the repository at this point in the history
This reverts commit eb93b42, reversing
changes made to f7f23cf.
  • Loading branch information
JonasHell committed Dec 9, 2022
1 parent fa18729 commit e242f70
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 100 deletions.
55 changes: 0 additions & 55 deletions test/loss/test_loss_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,61 +32,6 @@ def test_masking(self):
# print((grad[~mask] == 0).sum())
self.assertTrue((grad[~mask] == 0).all())

def test_ApplyMask_output_shape_crop(self):
from torch_em.loss import ApplyMask

# _crop batch_size=1
shape = (1, 1, 10, 128, 128)
p = torch.rand(*shape)
t = torch.rand(*shape)
m = torch.rand(*shape) > .5
p_masked, t_masked = ApplyMask()(p, t, m)
out_shape = (m.sum(), shape[1])
self.assertTrue(p_masked.shape == out_shape)
self.assertTrue(t_masked.shape == out_shape)

# _crop batch_size>1
shape = (5, 1, 10, 128, 128)
p = torch.rand(*shape)
t = torch.rand(*shape)
m = torch.rand(*shape) > .5
p_masked, t_masked = ApplyMask()(p, t, m)
out_shape = (m.sum(), shape[1])
self.assertTrue(p_masked.shape == out_shape)
self.assertTrue(t_masked.shape == out_shape)

# _crop n_channels>1
shape = (1, 2, 10, 128, 128)
p = torch.rand(*shape)
t = torch.rand(*shape)
m = torch.rand(*shape) > .5
with self.assertRaises(ValueError):
p_masked, t_masked = ApplyMask()(p, t, m)

# _crop different shapes
shape_pt = (5, 2, 10, 128, 128)
p = torch.rand(*shape_pt)
t = torch.rand(*shape_pt)
shape_m = (5, 1, 10, 128, 128)
m = torch.rand(*shape_m) > .5
p_masked, t_masked = ApplyMask()(p, t, m)
out_shape = (m.sum(), shape_pt[1])
self.assertTrue(p_masked.shape == out_shape)
self.assertTrue(t_masked.shape == out_shape)

def test_ApplyMask_output_shape_multiply(self):
from torch_em.loss import ApplyMask

# _multiply
shape = (2, 5, 10, 128, 128)
p = torch.rand(*shape)
t = torch.rand(*shape)
m = torch.rand(*shape) > .5

p_masked, t_masked = ApplyMask(masking_method="multiply")(p, t, m)
self.assertTrue(p_masked.shape == shape)
self.assertTrue(t_masked.shape == shape)


if __name__ == '__main__':
unittest.main()
52 changes: 7 additions & 45 deletions torch_em/loss/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.nn as nn


Expand Down Expand Up @@ -44,67 +43,30 @@ def forward(self, prediction, target, **kwargs):


class ApplyMask:
def _crop(prediction, target, mask, channel_dim):
if mask.shape[channel_dim] != 1:
raise ValueError(
"_crop only supports a mask with a singleton channel axis. \
Please consider using masking_method=multiply."
)
mask = mask.type(torch.bool)
# remove singleton axis
mask = mask.squeeze(channel_dim)
# move channel axis to end
prediction = prediction.moveaxis(channel_dim, -1)
target = target.moveaxis(channel_dim, -1)
# output has shape N x C
# correct for torch_em.loss.dice.flatten_samples
return prediction[mask], target[mask]

def _multiply(prediction, target, mask, channel_dim):
def __call__(self, prediction, target, mask):
mask.requires_grad = False
prediction = prediction * mask
target = target * mask
return prediction, target

MASKING_FUNCS = {
"crop": _crop,
"multiply": _multiply,
}

def __init__(self, masking_method="crop", channel_dim=1):
if masking_method not in self.MASKING_FUNCS.keys():
raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.")
self.masking_func = self.MASKING_FUNCS[masking_method]
self.channel_dim = channel_dim

self.init_kwargs = {
"masking_method": masking_method,
"channel_dim": channel_dim,
}

def __call__(self, prediction, target, mask):
mask.requires_grad = False
return self.masking_func(prediction, target, mask, self.channel_dim)


class ApplyAndRemoveMask(ApplyMask):
class ApplyAndRemoveMask:
def __call__(self, prediction, target):
assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"
seperating_channel = target.size(1) // 2
mask = target[:, seperating_channel:]
target = target[:, :seperating_channel]
prediction, target = super().__call__(prediction, target, mask)
prediction, target = ApplyMask()(prediction, target, mask)
return prediction, target


class MaskIgnoreLabel(ApplyMask):
def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1):
super().__init__(masking_method, channel_dim)
class MaskIgnoreLabel:
def __init__(self, ignore_label=-1):
self.ignore_label = ignore_label
self.init_kwargs["ignore_label"] = ignore_label

def __call__(self, prediction, target):
mask = (target != self.ignore_label)
prediction, target = super().__call__(prediction, target, mask)
prediction, target = ApplyMask()(prediction, target, mask)
return prediction, target

0 comments on commit e242f70

Please sign in to comment.