Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Capsnet layers #7391

Merged
merged 45 commits into from Apr 5, 2019
Merged
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0fce147
Merge pull request #1 from deeplearning4j/master
rnett Mar 29, 2019
e6dead4
Main capsule layer start, and utilities
rnett Mar 29, 2019
e7e8604
Config validation
rnett Mar 29, 2019
c1c8ccd
variable minibatch support
rnett Mar 29, 2019
fd2458c
bugfix
rnett Mar 29, 2019
83d32c9
capsule strength layer
rnett Mar 29, 2019
7d5f0e8
docstring
rnett Mar 29, 2019
f713125
faster dimension shrink
rnett Mar 29, 2019
b762744
add a keepDim option to point SDIndexes, which if used will only cont…
rnett Mar 29, 2019
7e634ab
keepDim test
rnett Mar 29, 2019
b910cf7
Merge branch 'master' into rnett-point-index-keep-dim
rnett Mar 29, 2019
46445f6
PrimaryCaps layer
rnett Mar 29, 2019
fe35614
fixes
rnett Mar 29, 2019
0a863fa
more small fixes
rnett Mar 29, 2019
39d4f3a
config and shape inference tests
rnett Mar 29, 2019
19b2a00
Merge branch 'master' into rnett-capsnet
rnett Mar 29, 2019
6acbd71
Changed default routings to 3 as per paper
rnett Mar 29, 2019
63ba472
better docstrings
rnett Mar 29, 2019
d70c97e
squash fixes
rnett Mar 29, 2019
151f6cc
better test matrix
rnett Mar 29, 2019
ef3817a
Merge remote-tracking branch 'origin/rnett-point-index-keep-dim' into…
rnett Mar 30, 2019
a6d199e
init weights to 1 (need to use param)
rnett Mar 30, 2019
e22c82c
Proper weight initialization
rnett Mar 30, 2019
6e73717
Undo changes to wrong files
rnett Mar 30, 2019
cabfda5
Single layer output tests
rnett Mar 30, 2019
16c5dc2
MNIST test (> 95% acc, p, r, f1)
rnett Mar 30, 2019
28abe35
need an updater...
rnett Mar 30, 2019
9783c05
cleanup
rnett Mar 30, 2019
6b8246e
Merge branch 'master' into rnett-capsnet
rnett Mar 30, 2019
f39ece0
added license to tests
rnett Mar 30, 2019
a3e7d7a
gradient check
rnett Mar 30, 2019
6214734
fixes
rnett Mar 31, 2019
f8853d3
optimize imports
rnett Mar 31, 2019
f967ed7
fixes
rnett Mar 31, 2019
e6bbab7
fixes
rnett Mar 31, 2019
a777d4c
fix CapsuleLayer output shape
rnett Apr 1, 2019
68fe093
fixes and test update
rnett Apr 1, 2019
4a34a45
typo fix
rnett Apr 1, 2019
285723c
test fix
rnett Apr 1, 2019
75b146d
shape comments
rnett Apr 1, 2019
146ea9f
variable description comments
rnett Apr 1, 2019
6faa787
Merge branch 'master' into rnett-capsnet
rnett Apr 1, 2019
fbd5a83
optimized imports
rnett Apr 1, 2019
3392604
better initialization
rnett Apr 4, 2019
78ddb17
Revert "better initialization"
rnett Apr 5, 2019
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

PrimaryCaps layer

  • Loading branch information...
rnett committed Mar 29, 2019
commit 46445f620d0f86bf30cc5e0fb9c50bfbedde1ffd
@@ -68,13 +68,15 @@ public CapsuleLayer(Builder builder){

if(capsules <= 0 || capsuleDimensions <= 0 || routings <= 0){
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
+ layerName + "\"): capsules, capsuleDimensions, and routings must be > 0. Got: "
+ layerName + "\"):"
+ " capsules, capsuleDimensions, and routings must be > 0. Got: "
+ capsules + ", " + capsuleDimensions + ", " + routings);
}

if(inputCapsules < 0 || inputCapsuleDimensions < 0){
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
+ layerName + "\"): inputCapsules and inputCapsuleDimensions must be >= 0. Got: "
+ layerName + "\"):"
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
+ inputCapsules + ", " + inputCapsuleDimensions);
}

@@ -104,7 +106,6 @@ public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVaria
SDVariable uHat = weights.times(tiled).sum(true, 3)
.reshape(-1, inputCapsules, capsules, capsuleDimensions, 1);

//TODO better way of getting rid of dim 3
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));

This comment has been minimized.

Copy link
@AlexDBlack

AlexDBlack Apr 4, 2019

Contributor

This seems off.
You create a zeros array of shape [mb, inputCaps, caps, capsDimensions, 1]
But then proceed to immediately get a [mb, inputCaps, caps, 1, 1] subset from it?
That's unnecessarily inefficient. Why not make a [mb, inputCaps, caps, 1, 1] in the first place?


//TODO convert to SameDiff.whileLoop?
@@ -144,12 +145,12 @@ public void defineParameters(SDLayerParams params) {

@Override
public void initializeParameters(Map<String, INDArray> params) {
//TODO
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (Map.Entry<String, INDArray> e : params.entrySet()) {
if (BIAS_PARAM.equals(e.getKey())) {
e.getValue().assign(0);
} else if(WEIGHT_PARAM.equals(e.getKey())){
//TODO use weightInit
e.getValue().assign(0);
}
}
@@ -33,6 +33,11 @@
* @author Ryan Nett
*/
public class CapsuleStrengthLayer extends SameDiffLambdaLayer {

public CapsuleStrengthLayer(Builder builder){
super();
}

@Override
public SDVariable defineLayer(SameDiff SD, SDVariable layerInput) {
return SD.norm2("caps_strength", layerInput, 2);
@@ -54,7 +59,7 @@ public InputType getOutputType(int layerIndex, InputType inputType) {

@Override
public <E extends Layer> E build() {
return (E) new CapsuleStrengthLayer();
return (E) new CapsuleStrengthLayer(this);
}
}

@@ -16,6 +16,324 @@

package org.deeplearning4j.nn.conf.layers;

public class PrimaryCapsules {
//TODO
import java.util.Map;
import lombok.AccessLevel;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.factory.Nd4j;

/**
* An implementation of the PrimaryCaps layer from Dynamic Routing Between Capsules
*
* Is a reshaped 2D convolution.
*
* From <a href="http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules.pdf">Dynamic Routing Between Capsules</a>
*
* @author Ryan Nett
*/
@Data
@EqualsAndHashCode(callSuper = true)
public class PrimaryCapsules extends SameDiffLayer {

private int[] kernelSize;
private int[] stride;
private int[] padding;
private int[] dilation;
private int inputChannels;
private int channels;

private boolean hasBias;

private int capsules;
private int capsuleDimensions;

private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;

private boolean useRelu = false;
private double leak = 0;

private static final String WEIGHT_PARAM = "weight";
private static final String BIAS_PARAM = "bias";

public PrimaryCapsules(Builder builder){
super(builder);

this.kernelSize = builder.kernelSize;
this.stride = builder.stride;
this.padding = builder.padding;
this.dilation = builder.dilation;
this.channels = builder.channels;
this.hasBias = builder.hasBias;
this.capsules = builder.capsules;
this.capsuleDimensions = builder.capsuleDimensions;
this.convolutionMode = builder.convolutionMode;
this.useRelu = builder.useRelu;
this.leak = builder.leak;

if(capsuleDimensions <= 0 || channels <= 0){
throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \""
+ layerName + "\"):"
+ " capsuleDimensions and channels must be > 0. Got: "
+ capsuleDimensions + ", " + channels);
}

if(capsules < 0){
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
+ layerName + "\"):"
+ " capsules must be >= 0 if set. Got: "
+ capsules);
}

}

@Override
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable) {

Conv2DConfig conf = Conv2DConfig.builder()
.kH(kernelSize[0]).kW(kernelSize[1])
.sH(stride[0]).sW(stride[1])
.pH(padding[0]).pW(padding[1])
.dH(dilation[0]).dW(dilation[1])
.isSameMode(convolutionMode == ConvolutionMode.Same)
.build();

SDVariable conved;

if(hasBias){
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
} else {
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
}

if(useRelu){
if(leak == 0) {
conved = SD.nn.relu(conved, 0);
} else {
conved = SD.nn.leakyRelu(conved, leak);
}
}

return conved.reshape(-1, capsules, capsuleDimensions);
}

@Override
public void defineParameters(SDLayerParams params) {
params.clear();
params.addWeightParam(WEIGHT_PARAM,
kernelSize[0], kernelSize[1], inputChannels, channels);

if(hasBias){
params.addBiasParam(BIAS_PARAM, channels);
}
}

@Override
public void initializeParameters(Map<String, INDArray> params) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (Map.Entry<String, INDArray> e : params.entrySet()) {
if (BIAS_PARAM.equals(e.getKey())) {
e.getValue().assign(0);
} else if(WEIGHT_PARAM.equals(e.getKey())){
double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
double fanOut = channels * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c',
e.getValue());
}
}
}
}

@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != Type.CNN) {
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
+ layerName + "\"): expect CNN input. Got: " + inputType);
}
return InputType.recurrent(capsules, capsuleDimensions);
}

@Override
public void setNIn(InputType inputType, boolean override) {
if (inputType == null || inputType.getType() != Type.CNN) {
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
+ layerName + "\"): expect CNN input. Got: " + inputType);
}

InputTypeConvolutional ci = (InputTypeConvolutional) inputType;

this.inputChannels = (int) ci.getChannels();

InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode,
channels, -1, getLayerName(), PrimaryCapsules.class);

this.capsules = (int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
}

@Getter
@Setter
class Builder extends SameDiffLayer.Builder<Builder>{

@Setter(AccessLevel.NONE)
private int[] kernelSize = new int[]{9, 9};

@Setter(AccessLevel.NONE)
private int[] stride = new int[]{2, 2};

@Setter(AccessLevel.NONE)
private int[] padding = new int[]{0, 0};

@Setter(AccessLevel.NONE)
private int[] dilation = new int[]{1, 1};

private int channels = 32;

private boolean hasBias = true;

private int capsules;
private int capsuleDimensions;

private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;

private boolean useRelu = false;
private double leak = 0;


public void setKernelSize(int... kernelSize){
this.kernelSize = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
}

public void setStride(int... stride){
this.stride = ValidationUtils.validate2NonNegative(stride, true, "stride");
}

public void setPadding(int... padding){
this.padding = ValidationUtils.validate2NonNegative(padding, true, "padding");
}

public void setDilation(int... dilation){
this.dilation = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
}


public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding, int[] dilation,
ConvolutionMode convolutionMode){
this.capsuleDimensions = capsuleDimensions;
this.channels = channels;
this.setKernelSize(kernelSize);
this.setStride(stride);
this.setPadding(padding);
this.setDilation(dilation);
this.convolutionMode = convolutionMode;
}

public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding, int[] dilation){
this(capsuleDimensions, channels, kernelSize, stride, padding, dilation, ConvolutionMode.Truncate);
}

public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding){
this(capsuleDimensions, channels, kernelSize, stride, padding, new int[]{1, 1}, ConvolutionMode.Truncate);
}

public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride){
this(capsuleDimensions, channels, kernelSize, stride, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}

public Builder(int capsuleDimensions, int channels,
int[] kernelSize){
this(capsuleDimensions, channels, kernelSize, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}

public Builder(int capsuleDimensions, int channels){
this(capsuleDimensions, channels, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}

public Builder kernelSize(int... kernelSize){
this.setKernelSize(kernelSize);
return this;
}

public Builder stride(int... stride){
this.setStride(stride);
return this;
}

public Builder padding(int... padding){
this.setPadding(padding);
return this;
}

public Builder dilation(int... dilation){
this.setDilation(dilation);
return this;
}

public Builder channels(int channels){
this.channels = channels;
return this;
}

public Builder nOut(int nOut){
return channels(nOut);
}

public Builder capsuleDimensions(int capsuleDimensions){
this.capsuleDimensions = capsuleDimensions;
return this;
}

public Builder capsules(int capsules){
this.capsules = capsules;
return this;
}

public Builder hasBias(boolean hasBias){
this.hasBias = hasBias;
return this;
}

public Builder convolutionMode(ConvolutionMode convolutionMode){
this.convolutionMode = convolutionMode;
return this;
}

public Builder useReLU(boolean useRelu){
this.useRelu = useRelu;
return this;
}

public Builder useReLU(){
return useReLU(true);
}

public Builder useLeakyReLU(double leak){
this.useRelu = true;
this.leak = leak;
return this;
}

@Override
public <E extends Layer> E build() {
return (E) new PrimaryCapsules(this);
}
}

}
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.