Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Capsnet layers #7391

Merged
merged 45 commits into from Apr 5, 2019
Merged
Changes from 43 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0fce147
Merge pull request #1 from deeplearning4j/master
rnett Mar 29, 2019
e6dead4
Main capsule layer start, and utilities
rnett Mar 29, 2019
e7e8604
Config validation
rnett Mar 29, 2019
c1c8ccd
variable minibatch support
rnett Mar 29, 2019
fd2458c
bugfix
rnett Mar 29, 2019
83d32c9
capsule strength layer
rnett Mar 29, 2019
7d5f0e8
docstring
rnett Mar 29, 2019
f713125
faster dimension shrink
rnett Mar 29, 2019
b762744
add a keepDim option to point SDIndexes, which if used will only cont…
rnett Mar 29, 2019
7e634ab
keepDim test
rnett Mar 29, 2019
b910cf7
Merge branch 'master' into rnett-point-index-keep-dim
rnett Mar 29, 2019
46445f6
PrimaryCaps layer
rnett Mar 29, 2019
fe35614
fixes
rnett Mar 29, 2019
0a863fa
more small fixes
rnett Mar 29, 2019
39d4f3a
config and shape inference tests
rnett Mar 29, 2019
19b2a00
Merge branch 'master' into rnett-capsnet
rnett Mar 29, 2019
6acbd71
Changed default routings to 3 as per paper
rnett Mar 29, 2019
63ba472
better docstrings
rnett Mar 29, 2019
d70c97e
squash fixes
rnett Mar 29, 2019
151f6cc
better test matrix
rnett Mar 29, 2019
ef3817a
Merge remote-tracking branch 'origin/rnett-point-index-keep-dim' into…
rnett Mar 30, 2019
a6d199e
init weights to 1 (need to use param)
rnett Mar 30, 2019
e22c82c
Proper weight initialization
rnett Mar 30, 2019
6e73717
Undo changes to wrong files
rnett Mar 30, 2019
cabfda5
Single layer output tests
rnett Mar 30, 2019
16c5dc2
MNIST test (> 95% acc, p, r, f1)
rnett Mar 30, 2019
28abe35
need an updater...
rnett Mar 30, 2019
9783c05
cleanup
rnett Mar 30, 2019
6b8246e
Merge branch 'master' into rnett-capsnet
rnett Mar 30, 2019
f39ece0
added license to tests
rnett Mar 30, 2019
a3e7d7a
gradient check
rnett Mar 30, 2019
6214734
fixes
rnett Mar 31, 2019
f8853d3
optimize imports
rnett Mar 31, 2019
f967ed7
fixes
rnett Mar 31, 2019
e6bbab7
fixes
rnett Mar 31, 2019
a777d4c
fix CapsuleLayer output shape
rnett Apr 1, 2019
68fe093
fixes and test update
rnett Apr 1, 2019
4a34a45
typo fix
rnett Apr 1, 2019
285723c
test fix
rnett Apr 1, 2019
75b146d
shape comments
rnett Apr 1, 2019
146ea9f
variable description comments
rnett Apr 1, 2019
6faa787
Merge branch 'master' into rnett-capsnet
rnett Apr 1, 2019
fbd5a83
optimized imports
rnett Apr 1, 2019
3392604
better initialization
rnett Apr 4, 2019
78ddb17
Revert "better initialization"
rnett Apr 5, 2019
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -0,0 +1,123 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.gradientcheck;

import static org.junit.Assert.assertTrue;

import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;

public class CapsnetGradientCheckTest extends BaseDL4JTest {

private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;

@Test
public void testCapsNet() {

int[] minibatchSizes = {8, 16};

int width = 6;
int height = 6;
int inputDepth = 4;

int[] primaryCapsDims = {2, 4};
int[] primaryCapsChannels = {8};
int[] capsules = {5};
int[] capsuleDims = {4, 8};
int[] routings = {1};

Nd4j.getRandom().setSeed(12345);

for (int routing : routings) {
for (int primaryCapsDim : primaryCapsDims) {
for (int primarpCapsChannel : primaryCapsChannels) {
for (int capsule : capsules) {
for (int capsuleDim : capsuleDims) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, inputDepth * height * width).mul(10)
.reshape(-1, inputDepth, height, width);
INDArray labels = Nd4j.zeros(minibatchSize, capsule);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % capsule}, 1.0);
}

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new NoOp())
.weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6)))
.list()
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
.kernelSize(3, 3)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

for (int i = 0; i < 4; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}

String msg = "minibatch=" + minibatchSize +
", PrimaryCaps: " + primarpCapsChannel +
" channels, " + primaryCapsDim + " dimensions, Capsules: " + capsule +
" capsules with " + capsuleDim + " dimensions and " + routing + " routings";
System.out.println(msg);

boolean gradOK = GradientCheckUtil
.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input,
labels);

assertTrue(msg, gradOK);

TestUtils.testModelSerialization(net);
}
}
}
}
}
}
}
}
@@ -0,0 +1,86 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.nn.layers.capsule;

import static org.junit.Assert.assertTrue;

import java.io.IOException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;

public class CapsNetMNISTTest extends BaseDL4JTest {
@Test
public void testCapsNetOnMNIST(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam())
.list()
.layer(new ConvolutionLayer.Builder()
.nOut(16)
.kernelSize(9, 9)
.stride(3, 3)
.build())
.layer(new PrimaryCapsules.Builder(8, 8)
.kernelSize(7, 7)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

int rngSeed = 12345;
try {
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, rngSeed);
MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, rngSeed);

for (int i = 0; i < 2; i++) {
model.fit(mnistTrain);
}

Evaluation eval = model.evaluate(mnistTest);

assertTrue("Accuracy not over 95%", eval.accuracy() > 0.95);
assertTrue("Precision not over 95%", eval.precision() > 0.95);
assertTrue("Recall not over 95%", eval.recall() > 0.95);
assertTrue("F1-score not over 95%", eval.f1() > 0.95);

} catch (IOException e){
System.out.println("Could not load MNIST.");
}
}
}
@@ -0,0 +1,96 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.nn.layers.capsule;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class CapsuleLayerTest extends BaseDL4JTest {

@Override
public DataType getDataType(){
return DataType.FLOAT;
}

@Test
public void testOutputType(){
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();

InputType in1 = InputType.recurrent(5, 8);

assertEquals(InputType.recurrent(10, 16), layer.getOutputType(0, in1));
}

@Test
public void testInputType(){
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();

InputType in1 = InputType.recurrent(5, 8);

layer.setNIn(in1, true);

assertEquals(5, layer.getInputCapsules());
assertEquals(8, layer.getInputCapsuleDimensions());
}

@Test
public void testConfig(){
CapsuleLayer layer1 = new CapsuleLayer.Builder(10, 16, 5).build();

assertEquals(10, layer1.getCapsules());
assertEquals(16, layer1.getCapsuleDimensions());
assertEquals(5, layer1.getRoutings());
assertFalse(layer1.isHasBias());

CapsuleLayer layer2 = new CapsuleLayer.Builder(10, 16, 5).hasBias(true).build();

assertTrue(layer2.isHasBias());

}

@Test
public void testLayer(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.list()
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.setInputType(InputType.recurrent(10, 8))
.build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

INDArray emptyFeatures = Nd4j.zeros(64, 10, 8);

long[] shape = model.output(emptyFeatures).shape();

assertArrayEquals(new long[]{64, 10, 16}, shape);
}
}
@@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.nn.layers.capsule;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class CapsuleStrengthLayerTest extends BaseDL4JTest {

@Override
public DataType getDataType(){
return DataType.FLOAT;
}

@Test
public void testOutputType(){
CapsuleStrengthLayer layer = new CapsuleStrengthLayer.Builder().build();

InputType in1 = InputType.recurrent(5, 8);

assertEquals(InputType.feedForward(5), layer.getOutputType(0, in1));
}

@Test
public void testLayer(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.list()
.layer(new CapsuleStrengthLayer.Builder().build())
.setInputType(InputType.recurrent(5, 8))
.build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

INDArray emptyFeatures = Nd4j.zeros(64, 5, 10);

long[] shape = model.output(emptyFeatures).shape();

assertArrayEquals(new long[]{64, 5}, shape);
}
}
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.