Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mis-inplementation of JS divergence #27

Closed
greatwallet opened this issue Sep 9, 2021 · 2 comments
Closed

Mis-inplementation of JS divergence #27

greatwallet opened this issue Sep 9, 2021 · 2 comments

Comments

@greatwallet
Copy link

greatwallet commented Sep 9, 2021

Hi, according to definition of JS divergence (as mentioned in your supp file), JS divergence is calculated as the difference of
entropy of average probabilities and average of entropies.

image

However in your code, the first term of JS, aka the difference of entropy of average probabilities is implemented as:

full_entropy = Categorical(logits=mean_seg).entropy()

where mean_seg is defined as average segmentation map of 10 outputs of ensembled pixel_classifiers.

Specifically, I have traced the implementation of mean_seg -->

mean_seg = mean_seg / len(all_seg)

-->
if mean_seg is None:
mean_seg = img_seg
else:
mean_seg += img_seg

--> img_seg
img_seg = classifier(affine_layers)
img_seg = img_seg.squeeze()

In fact, img_seg are all unnormalized probabilities, aka logits defined in pytorch distribution's argument. I think in the code you attempted to do average upon logits instead of probabilies (since you have commented out Sigmoid in pixel_classifier)

class pixel_classifier(nn.Module):
def __init__(self, numpy_class, dim):
super(pixel_classifier, self).__init__()
if numpy_class < 32:
self.layers = nn.Sequential(
nn.Linear(dim, 128),
nn.ReLU(),
nn.BatchNorm1d(num_features=128),
nn.Linear(128, 32),
nn.ReLU(),
nn.BatchNorm1d(num_features=32),
nn.Linear(32, numpy_class),
# nn.Sigmoid()
)
else:
self.layers = nn.Sequential(
nn.Linear(dim, 256),
nn.ReLU(),
nn.BatchNorm1d(num_features=256),
nn.Linear(256, 128),
nn.ReLU(),
nn.BatchNorm1d(num_features=128),
nn.Linear(128, numpy_class),
# nn.Sigmoid()
)

TL; DR

The unlawful commutation of softmax and linear operation leads to mis-implementation of JS divergence.
image

@arieling
Copy link
Collaborator

arieling commented Sep 9, 2021

Thank you a lot for pointing this out. We are looking into it.

@arieling
Copy link
Collaborator

@greatwallet Thank you again for pointing this bug out!
We have fixed the bug in the commit of
d9564d4.
The number is also updated in README.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants