-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
isda loss segmentation code #10
Comments
Yes. This code transfers the label "255" to "19". |
` labels = ((1 - label_mask).mul(labels) + label_mask * C).long() onehot = torch.zeros(N, C).cuda() NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 255 to 19? this can not make error? this make the label range [0, 19], ti says our num_class is 20.so make one_hot will make error |
Thank you for the comment, but we have to point out that, this does not make an error! Since the label 255 should be ignored in segmentation (and this is automatically realized by the CrossEntropyLoss() ), here we simply assume that it is 20th class, such that the off-the-shelf code (e.g. EstimateCV) can be directly used. Otherwise, the indexing function will report an error if we have both [0,18] and 255 in labels. I agree that there may still be some space to improve the implementation. However, for one thing, current code is correct. For another, we just observe a minimal additional time consumption (~5%) currently. |
Another important point is that, this transfer is just used in ISDAloss, and does not exist when computing the cross-entropy loss. |
Thank you for your replay,I understand what you say.But i dont not think your current code is correct, As I said above, index out of range.
|
Thank you for the discussion. C is 20, actually. Please see this code:
I believe that, the best way to check if there is a bug is to run the code in practice. Maybe you can try it. |
thanks, i try to run your code,but environmen pytorch-segmentation-toolbox is failed,thank your replay again |
https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/master/Semantic%20segmentation%20on%20Cityscapes/train_isda.py#L180
labels = ((1 - label_mask).mul(labels) + label_mask * 19).long()
19 is num_class?
The text was updated successfully, but these errors were encountered: