-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 2D Probabilistic UNet #119
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I just had a brief initial look, and found one obvious thing that needs to be change. I will have a closer look at the rest in the next few days.
torch_em/model/probabilistic_unet.py
Outdated
from torch_em.model import UNet2d | ||
from torch_em.loss.dice import DiceLossWithLogits | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't define this globally. Instead pass this as an argument to the functions or arguments that need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a bit closer look and left a few more comments. In addition to addressing those we should also add a test, see for example https://github.com/constantinpape/torch-em/blob/main/test/model/test_unet.py.
torch_em/model/probabilistic_unet.py
Outdated
num_classes=None | ||
): | ||
|
||
super(Encoder, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This style of writing the super
calls is outdated. It's better to just write super().__init__()
instead.
torch_em/model/probabilistic_unet.py
Outdated
num_classes=None | ||
): | ||
|
||
super(AxisAlignedConvGaussian, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
torch_em/model/probabilistic_unet.py
Outdated
use_tile=True | ||
): | ||
|
||
super(Fcomb, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
torch_em/model/probabilistic_unet.py
Outdated
rl_swap=False | ||
): | ||
|
||
super(ProbabilisticUnet, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
torch_em/model/probabilistic_unet.py
Outdated
""" | ||
|
||
def __init__( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor thing, but the usual style for these inits is to have only one indent:
def __init__(
self,
aa,
bb,
cc,
dd,
...
):
...
(also relevant in several places below).
torch_em/model/probabilistic_unet.py
Outdated
return self.last_layer(output) | ||
|
||
|
||
class ProbabilisticUnet(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it ProbabilisticUNet
to be consistent with the other UNet naming.
Thanks for the feedback. The discussions for naming consistency (#119 (comment)), updating the indent (#119 (comment)) and style of (I'll work on adding a test for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more minor comment, besides this the changes look good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now!
torch_em.model.UNet2d
backbone