Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement gradient masking with cropping #90

Closed
wants to merge 6 commits into from

Conversation

JonasHell
Copy link
Contributor

current state of the gradient masking with cropping.

TODO: allow for different mask in each channel
TODO: make it more efficient

torch_em/loss/wrapper.py Show resolved Hide resolved
n_channels = prediction.shape[channel_dim]
# for now only support same mask for all channels
# to be able to stack/concat afterwards
# TODO: make this more efficient, and for different
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case where we have a mask of lower dimensionality this can be done much simpler (I am using numpy here for simplicity, but the same works with torch, I just checked):

x = np.random.rand(128, 128, 3)
y = np.random.rand(128, 128) > 0.5
z = x[y]
print(z.shape)
(8186, 3)  # 8186 is the number of elems in mask, i.e. it's not always the same number

So you just need to remove the singleton channel dimension in the mask, swap the channel dimension to the back in the prediction / target, apply the mask and bring the channel dimension back to the original pos.

I still need to think about the case where we have a full mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I implement it. Thank you :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Owner

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, we should improve the error that is being thrown when the mask is not a singleton. Also, could you add a test for ApplyMask here https://github.com/constantinpape/torch-em/blob/main/test/loss/test_loss_wrapper.py that checks that the shape after calling it is correct? (N x C x Flattened Spatial).

mask = mask.type(torch.bool)
# TODO: make this work for different
# masks per channel
assert mask.shape[channel_dim] == 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where we have a full mask is complicated and I can't see an option right now to make it work where we also preserve the N x C structure. I think we can go ahead with the current solution, but should throw a more meaningful error message here. I suggest to throw a ValueError, which should also indicate that using masking_method=multiply will fix the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -32,6 +32,50 @@ def test_masking(self):
# print((grad[~mask] == 0).sum())
self.assertTrue((grad[~mask] == 0).all())

def test_ApplyMask_output_shape_crop(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where the number of channels in p and t is > 1, but the mask only has a single channel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@JonasHell JonasHell deleted the branch constantinpape:main December 9, 2022 15:50
@JonasHell JonasHell closed this Dec 9, 2022
@JonasHell JonasHell deleted the main branch December 9, 2022 15:50
@JonasHell JonasHell restored the main branch December 9, 2022 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants