Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Capsnet layers #7391

Merged
merged 45 commits into from Apr 5, 2019

Conversation

@rnett
Copy link
Contributor

commented Mar 29, 2019

What changes were proposed in this pull request?

From #7386

Adds DigiCaps, PrimaryCaps, and a capsule strength layer from Dynamic Routing Between Capsules, as CapsuleLayer, PrimaryCapsules, and CapsuleStrengthLayer, respectively, as well as a utility class.

All are implemented using SameDiff.

How was this patch tested?

None yet. Will need tests for all 3 layers, as well as bad configuration tests.

TODO

  • CapsuleLayer
  • PrimaryCapsules
  • CapsuleStrengthLayer
  • Docstrings
  • Tests
  • Arbiter?
rnett added 24 commits Mar 29, 2019
add a keepDim option to point SDIndexes, which if used will only cont…
…ract the dimension to 1 instead of removing it
@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Mar 30, 2019

Planed tests:

  • CapsuleLayer
    • outputType
    • input type inference
    • config
    • output shape of single layer network
  • PrimaryCapsules
    • outputType
    • input type inference
    • config
    • output shape of single layer network
  • CapsuleStrength
    • outputType
    • config
    • output shape of single layer network
  • MNIST 2 epoch, >95% acc
  • Gradient Checks
@treo

This comment has been minimized.

Copy link

commented Mar 30, 2019

Maybe also add gradient checks for your layers here. As it is still early days for samediff, a more sophisticated use of it, like your layers, may surface some bugs.

Gradient checks for layers are pretty simple to do, for an example see https://github.com/deeplearning4j/deeplearning4j/pull/7311/files#diff-956c3d9bc48af5c418208bce04cd0027

@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Mar 30, 2019

I will do that. I ran into a couple of bugs and reported them, but what I have currently seems to be working, I can get >95% accuracy on MNIST with 2 epochs and a very scaled-down network compared to the original CapsNet. Of course, its possible there is a issue somewhere in there that just isn't big enough to be detected, but it looks fairly good.

rnett added 9 commits Mar 30, 2019
@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Apr 1, 2019

I'm getting errors in the gradient check. However, the reported gradients are relatively close, so I'm not sure if its my code or compounding SameDiff errors.

Report:

nParams, layer 0: 1312
nParams, layer 1: 50240
nParams, layer 2: 10240
nParams, layer 3: 0
minibatch=64, PrimaryCaps: 8 channels, 8 dimensions, Capsules: 10 capsules with 16 dimensions and 3 routings
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 0: ConvolutionLayer - params [b, W]
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 1: SameDiffLayer - params [weight, bias]          // PrimaryCapsules, these are parameters for the convolution
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 2: SameDiffLayer - params [weight]                // CapsuleLayer
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 3: SameDiffLayer - params []                      // CapsuleStrengthLayer
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 4: ActivationLayer - params []
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Layer 5: LossLayer - params []
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 0 (0_b) FAILED: grad= -2.379222777046915E-4, numericalGrad= -2.300193369109138E-4, relError= 0.016888732583165585, scorePlus=2.3062806561712503, scoreMinus= 2.306280656631289, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 1 (0_b) passed: grad= 5.929871790586867E-4, numericalGrad= 5.91871440747127E-4, relError= 9.416636659507323E-4
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 2 (0_b) FAILED: grad= -3.268174111421275E-4, numericalGrad= -3.2263214322370004E-4, relError= 0.006444331034323755, scorePlus=2.3062806560786377, scoreMinus= 2.306280656723902, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 3 (0_b) FAILED: grad= 1.9390171313378663E-4, numericalGrad= 1.9152812669176456E-4, relError= 0.006158284068240222, scorePlus=2.306280656592798, scoreMinus= 2.3062806562097418, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 4 (0_b) FAILED: grad= 4.883863696451556E-4, numericalGrad= 4.786830931635677E-4, relError= 0.01003369132699733, scorePlus=2.306280656879953, scoreMinus= 2.306280655922587, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 5 (0_b) FAILED: grad= -5.9985731824785914E-5, numericalGrad= -6.629985449535525E-5, relError= 0.049998759593693244, scorePlus=2.3062806563349705, scoreMinus= 2.30628065646757, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 6 (0_b) FAILED: grad= 2.681832847767592E-4, numericalGrad= 2.6690916143934373E-4, relError= 0.0023811274975481696, scorePlus=2.3062806566681795, scoreMinus= 2.306280656134361, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 7 (0_b) FAILED: grad= 7.39580096531447E-4, numericalGrad= 7.302829452271453E-4, relError= 0.006325182034088197, scorePlus=2.306280657131553, scoreMinus= 2.306280655670987, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 8 (0_b) FAILED: grad= -6.130691658394211E-6, numericalGrad= -8.705702825295702E-6, relError= 0.17356044082895455, scorePlus=2.3062806563925644, scoreMinus= 2.306280656409976, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 9 (0_b) FAILED: grad= 4.509005322153342E-4, numericalGrad= 4.4923442743538544E-4, relError= 0.0018509499737631159, scorePlus=2.3062806568505048, scoreMinus= 2.306280655952036, paramValue = 0.0
[main] INFO org.deeplearning4j.gradientcheck.GradientCheckUtil - Param 10 (0_b) FAILED: grad= 1.9728276810339123E-4, numericalGrad= 1.9635360004599534E-4, relError= 0.0023604730979615053, scorePlus=2.3062806565976235, scoreMinus= 2.3062806562049163, paramValue = 0.0
rnett added 2 commits Apr 1, 2019
@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Apr 1, 2019

Alright, so the tests pass if you set routings to 1. Routings is the loop limit for the internal SameDiff loop, so either this is a SameDiff compounding errors issue, or there is something in the post-short circut operations that causes this.

Tests should now all be passing, I'm just adding a few comments.

@treo

This comment has been minimized.

Copy link

commented Apr 1, 2019

Make sure that you've got the latest master merged into this and run the gradient checks again. There was an actual bug in samdiff were recurrent gradients under some circumstances were not added into the total. #7393 you should have these changes on your branch for the fix

rnett added 2 commits Apr 1, 2019
@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Apr 1, 2019

It is in there. No change in the test.

@AlexDBlack
Copy link
Contributor

left a comment

Overall this looks great. I can't see any obvious errors here in the math or implementation (though I'm not as familiar with capsnet as other layers/architectures).

The fact that it always fails with 2 routings is a little concerning, but I agree it's probably just numerical precision issues... there's no systematic deviations from the expected gradient values. We've got a whole lot of l2 norms and softmax in there (I've seen similar sort of behaviour with l2 norm and gradient checks before). And the whole point of the squashing function is to push some vectors towards 0 length/norm, which will definitely compound gradient check issues.

I played around with the scale of the inputs/weights and values like eps, I could get it closer to passing but not much better.
If we want to keep the routings=2 gradient check tests, only thing I can think to do is split up the architecture to compare the individual layers - so test primary caps, then capsule layer etc separately to the others. Fewer layers in one architecture (hence fewer norm2/softmax ops) means less numeric stability issues, in theory.


// b is the logits of the routing procedure
// [mb, inputCapsules, capsules, 1, 1]
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));

This comment has been minimized.

Copy link
@AlexDBlack

AlexDBlack Apr 4, 2019

Contributor

This seems off.
You create a zeros array of shape [mb, inputCaps, caps, capsDimensions, 1]
But then proceed to immediately get a [mb, inputCaps, caps, 1, 1] subset from it?
That's unnecessarily inefficient. Why not make a [mb, inputCaps, caps, 1, 1] in the first place?

public static SDVariable squash(SameDiff SD, SDVariable x, int dim){
SDVariable squaredNorm = SD.math.square(x).sum(true, dim);
SDVariable scale = SD.math.sqrt(squaredNorm.plus(1e-5));
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));

This comment has been minimized.

Copy link
@AlexDBlack

AlexDBlack Apr 4, 2019

Contributor

Why addition of 1e-5? Normally this is for numerical precision, but I don't believe we risk underflowing here even if l2 norm is 0? (We get 0/1 in that case)

This comment has been minimized.

Copy link
@rnett

rnett Apr 4, 2019

Author Contributor

That was because I saw it in every implementation I looked at and wasn't sure why, and so didn't want to remove it.

From here, it seems to be to keep the gradient stable if the squaredNorm is zero. Is that an issue SameDiff will have?

This comment has been minimized.

Copy link
@AlexDBlack

AlexDBlack Apr 5, 2019

Contributor

Ah, makes sense, let's leave it then.

SDVariable v = CapsuleUtils.squash(SD, s, 3);

if(i == routings - 1){
return SD.squeeze(SD.squeeze(v, 1), 3);

This comment has been minimized.

Copy link
@AlexDBlack

AlexDBlack Apr 4, 2019

Contributor

Not a big deal, but a single reshape op would be more efficient...

@rnett

This comment has been minimized.

Copy link
Contributor Author

commented Apr 4, 2019

For issues 1 and 3, part of the reason I did it that way is I had no way to get the minibatch size, because of variable sized minibatches. Is there a way around that? E.g. for 3, I couldn't reshape because I don't know what the minibatch size would be, and I can't use -1 there.

@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented Apr 5, 2019

Re: no minibatch size - it's possible, but on reflection it's not easy... OK, let's leave it for now until we have something better: #7445

As for the SD.zero("b_var", 1, inputCapsules, capsules, 1, 1); - that'll make it a broadcast softmax (same value for all examples in minibatch) later in first iteration SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5); which is not what we want... Maybe let's change this back, then this should be good to go (we'll make a note on that issue I just opened for the inefficient zeros/subset thing, and fix it in same PR)

Revert "better initialization"
This reverts commit 3392604.

@AlexDBlack AlexDBlack merged commit ee614a0 into eclipse:master Apr 5, 2019

1 check failed

Codacy/PR Quality Review Not up to standards. This pull request quality could be better.
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.