-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement gradient masking with cropping #90
Conversation
torch_em/loss/wrapper.py
Outdated
n_channels = prediction.shape[channel_dim] | ||
# for now only support same mask for all channels | ||
# to be able to stack/concat afterwards | ||
# TODO: make this more efficient, and for different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the case where we have a mask of lower dimensionality this can be done much simpler (I am using numpy here for simplicity, but the same works with torch, I just checked):
x = np.random.rand(128, 128, 3)
y = np.random.rand(128, 128) > 0.5
z = x[y]
print(z.shape)
(8186, 3) # 8186 is the number of elems in mask, i.e. it's not always the same number
So you just need to remove the singleton channel dimension in the mask, swap the channel dimension to the back in the prediction / target, apply the mask and bring the channel dimension back to the original pos.
I still need to think about the case where we have a full mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I implement it. Thank you :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now, we should improve the error that is being thrown when the mask is not a singleton. Also, could you add a test for ApplyMask
here https://github.com/constantinpape/torch-em/blob/main/test/loss/test_loss_wrapper.py that checks that the shape after calling it is correct? (N x C x Flattened Spatial).
torch_em/loss/wrapper.py
Outdated
mask = mask.type(torch.bool) | ||
# TODO: make this work for different | ||
# masks per channel | ||
assert mask.shape[channel_dim] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case where we have a full mask is complicated and I can't see an option right now to make it work where we also preserve the N x C structure. I think we can go ahead with the current solution, but should throw a more meaningful error message here. I suggest to throw a ValueError
, which should also indicate that using masking_method=multiply
will fix the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -32,6 +32,50 @@ def test_masking(self): | |||
# print((grad[~mask] == 0).sum()) | |||
self.assertTrue((grad[~mask] == 0).all()) | |||
|
|||
def test_ApplyMask_output_shape_crop(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test where the number of channels in p
and t
is > 1, but the mask only has a single channel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
e242f70
to
66e1ae9
Compare
current state of the gradient masking with cropping.
TODO: allow for different mask in each channel
TODO: make it more efficient