Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wasserstein loss function #7406

Merged
merged 11 commits into from Apr 2, 2019

Conversation

Projects
None yet
2 participants
@rnett
Copy link
Contributor

commented Mar 30, 2019

What changes were proposed in this pull request?

Added the Wasserstein loss function.

Forward pass: mean(labels * output, dims=1)
Backwards pass: labels/labels.shape(1)

How was this patch tested?

None. Should I add tests? I didn't see tests for any other losses except for crossentropy.

Note that the gradient doesn't match SameDiff's b/c of #7405.

rnett added some commits Mar 29, 2019

@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Mar 30, 2019

Note that the improved Wasserstein loss is better, but requires the 2nd degree gradient, which I don't think we can calculate? Specifically, it adds a gradient - 1 term to the loss.

rnett added some commits Mar 30, 2019

@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Apr 1, 2019

Tests are in and the gradients pass.

The tests fail the serialization test because of an extra space at the end of the expected serialized string here:

	at org.junit.Assert.failNotEquals(Assert.java:834)
	at org.junit.Assert.assertEquals(Assert.java:118)
	at org.junit.Assert.assertEquals(Assert.java:144)
	at org.deeplearning4j.TestUtils.testModelSerialization(TestUtils.java:63)
	at org.deeplearning4j.gradientcheck.LossFunctionGradientCheck.lossFunctionGradientCheckLossLayer(LossFunctionGradientCheck.java:373)
@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Apr 2, 2019

The tests fail the serialization test because of an extra space at the end of the expected serialized string

I haven't seen that before in any of the tests :/
I'll check out and build locally

@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Apr 2, 2019

OK, so the problem here is the equals method in DifferentialFunction.
When the equality check fails, it does toString which returns network JSON - that extra space isn't relevant here.
Let's just switch to @EqualsAndHashCode(callSuper = false) - superclass fields aren't relevant/used, so we can ignore them. Test passes with that.

@AlexDBlack AlexDBlack merged commit 66c8a66 into deeplearning4j:master Apr 2, 2019

1 check was pending

continuous-integration/jenkins/pr-head This commit is being built
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.