Skip to content

Commit

Permalink
Merge pull request #86 from constantinpape/update-loss-masking
Browse files Browse the repository at this point in the history
Use ApplyMask in other masking loss wrappers
  • Loading branch information
constantinpape committed Sep 7, 2022
2 parents 5ce9eff + 9fdc23f commit c4d8ce3
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions torch_em/loss/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,24 @@ def forward(self, prediction, target, **kwargs):
# Loss transformations
#


class ApplyMask:
def __call__(self, prediction, target, mask):
mask.requires_grad = False
prediction = prediction * mask
target = target * mask
return prediction, target


class ApplyAndRemoveMask:
def __call__(self, prediction, target):
assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}"
assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}"
assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}"

seperating_channel = target.size(1) // 2
mask = target[:, seperating_channel:]
target = target[:, :seperating_channel]
mask.requires_grad = False

# mask the prediction
prediction = prediction * mask
return prediction, target


class ApplyMask:
def __call__(self, prediction, target, mask):
mask.requires_grad = False
prediction = prediction * mask
target = target * mask
prediction, target = ApplyMask()(prediction, target, mask)
return prediction, target


Expand All @@ -71,7 +68,5 @@ def __init__(self, ignore_label=-1):

def __call__(self, prediction, target):
mask = (target != self.ignore_label)
mask.requires_grad = False
prediction = prediction * mask
target = target * mask
prediction, target = ApplyMask()(prediction, target, mask)
return prediction, target

0 comments on commit c4d8ce3

Please sign in to comment.