From a5372376ca34de2be6df71ea6e9b723826ff64d8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 31 Jan 2018 18:34:00 +1100 Subject: [PATCH 1/3] Add recurrent weight init option for RNNs --- .../nn/conf/layers/AbstractLSTM.java | 1 + .../nn/conf/layers/BaseRecurrentLayer.java | 32 ++++++++++++++++++- .../nn/params/GravesLSTMParamInitializer.java | 13 +++++++- .../nn/params/LSTMParamInitializer.java | 25 ++++++++++++--- .../nn/params/SimpleRnnParamInitializer.java | 26 ++++++++++----- 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java index d433bb8698a8..1a48d68f69a4 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.params.LSTMParamInitializer; +import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index 232867441e1a..ddf9c4397a85 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -6,7 +6,9 @@ import lombok.ToString; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.weights.WeightInit; import java.util.Arrays; import java.util.List; @@ -17,6 +19,9 @@ @EqualsAndHashCode(callSuper = true) public abstract class BaseRecurrentLayer extends FeedForwardLayer { + protected WeightInit weightInitRecurrent; + protected Distribution distRecurrent; + protected BaseRecurrentLayer(Builder builder) { super(builder); } @@ -57,6 +62,8 @@ public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public static abstract class Builder> extends FeedForwardLayer.Builder { protected List recurrentConstraints; protected List inputWeightConstraints; + protected WeightInit weightInitRecurrent; + protected Distribution distRecurrent; /** * Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no constraints.
@@ -81,6 +88,29 @@ public T constrainInputWeights(LayerConstraint... constraints) { this.inputWeightConstraints = Arrays.asList(constraints); return (T) this; } - } + /** + * Set the weight initialization for the recurrent weights. Not that if this is not set explicitly, the same + * weight initialization as the layer input weights is also used for the recurrent weights. + * + * @param weightInit Weight initialization for the recurrent weights only. + */ + public T weightInitRecurrent(WeightInit weightInit){ + this.weightInitRecurrent = weightInit; + return (T) this; + } + + /** + * Set the weight initialization for the recurrent weights, based on the specified distribution. Not that if this + * is not set explicitly, the same weight initialization as the layer input weights is also used for the recurrent + * weights. + * + * @param dist Distribution to use for initializing the recurrent weights + */ + public T weightInitRecurrent(Distribution dist){ + this.weightInitRecurrent = WeightInit.DISTRIBUTION; + this.distRecurrent = dist; + return (T) this; + } + } } diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java index b4393f877e85..0215a794691b 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; @@ -128,10 +129,20 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi int[] inputWShape = new int[] {nLast, 4 * nL}; int[] recurrentWShape = new int[] {nL, 4 * nL + 3}; + WeightInit rwInit; + Distribution rwDist; + if(layerConf.getWeightInitRecurrent() != null){ + rwInit = layerConf.getWeightInitRecurrent(); + rwDist = Distributions.createDistribution(layerConf.getDistRecurrent()); + } else { + rwInit = layerConf.getWeightInit(); + rwDist = dist; + } + params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, - layerConf.getWeightInit(), dist, recurrentWeightView)); + rwInit, rwDist, recurrentWeightView)); biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.valueArrayOf(1, nL, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG} /*The above line initializes the forget gate biases to specified value. diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java index c0b222f66bc9..289463e3ea82 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; @@ -49,6 +50,12 @@ public static LSTMParamInitializer getInstance() { public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; public final static String INPUT_WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY; + private static final List LAYER_PARAM_KEYS = Collections.unmodifiableList( + Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY)); + private static final List WEIGHT_KEYS = Collections.unmodifiableList( + Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY)); + private static final List BIAS_KEYS = Collections.unmodifiableList(Collections.singletonList(BIAS_KEY)); + @Override public int numParams(NeuralNetConfiguration conf) { return numParams(conf.getLayer()); @@ -70,17 +77,17 @@ public int numParams(Layer l) { @Override public List paramKeys(Layer layer) { - return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY); + return LAYER_PARAM_KEYS; } @Override public List weightKeys(Layer layer) { - return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY); + return WEIGHT_KEYS; } @Override public List biasKeys(Layer layer) { - return Collections.singletonList(BIAS_KEY); + return BIAS_KEYS; } @Override @@ -128,10 +135,20 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi int[] inputWShape = new int[] {nLast, 4 * nL}; int[] recurrentWShape = new int[] {nL, 4 * nL}; + WeightInit rwInit; + Distribution rwDist; + if(layerConf.getWeightInitRecurrent() != null){ + rwInit = layerConf.getWeightInitRecurrent(); + rwDist = Distributions.createDistribution(layerConf.getDistRecurrent()); + } else { + rwInit = layerConf.getWeightInit(); + rwDist = dist; + } + params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape, layerConf.getWeightInit(), dist, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape, - layerConf.getWeightInit(), dist, recurrentWeightView)); + rwInit, rwDist, recurrentWeightView)); biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.valueArrayOf(1, nL, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG} /*The above line initializes the forget gate biases to specified value. diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java index dd4e6746f2ca..5b7b4fc6bcb7 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java @@ -5,6 +5,7 @@ import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; @@ -77,18 +78,27 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi Map m; - if(initializeParams){ + if (initializeParams) { Distribution dist = Distributions.createDistribution(c.getDist()); - m = getSubsets(paramsView, nIn, nOut, false); - INDArray w = WeightInitUtil.initWeights(nIn, nOut, new int[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY)); - m.put(WEIGHT_KEY, w); - - INDArray rw = WeightInitUtil.initWeights(nOut, nOut, new int[]{nOut, nOut}, c.getWeightInit(), dist, 'f', m.get(RECURRENT_WEIGHT_KEY)); + m = getSubsets(paramsView, nIn, nOut, false); + INDArray w = WeightInitUtil.initWeights(nIn, nOut, new int[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY)); + m.put(WEIGHT_KEY, w); + + WeightInit rwInit; + Distribution rwDist; + if (c.getWeightInitRecurrent() != null) { + rwInit = c.getWeightInitRecurrent(); + rwDist = Distributions.createDistribution(c.getDistRecurrent()); + } else { + rwInit = c.getWeightInit(); + rwDist = dist; + } + + INDArray rw = WeightInitUtil.initWeights(nOut, nOut, new int[]{nOut, nOut}, rwInit, rwDist, 'f', m.get(RECURRENT_WEIGHT_KEY)); m.put(RECURRENT_WEIGHT_KEY, rw); - } else { - m = getSubsets(paramsView, nIn, nOut, true); + m = getSubsets(paramsView, nIn, nOut, true); } conf.addVariable(WEIGHT_KEY); From e4469cb4a21772e4c653e7d82ce1ac4478425712 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 31 Jan 2018 19:04:40 +1100 Subject: [PATCH 2/3] Add test for RNN recurrent weight init --- .../recurrent/TestRecurrentWeightInit.java | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java diff --git a/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java b/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java new file mode 100644 index 000000000000..8b0147699e75 --- /dev/null +++ b/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java @@ -0,0 +1,62 @@ +package org.deeplearning4j.nn.layers.recurrent; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.UniformDistribution; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; + +import static org.junit.Assert.assertTrue; + +public class TestRecurrentWeightInit { + + @Test + public void testRWInit() { + + for (boolean rwInit : new boolean[]{false, true}) { + for (int i = 0; i < 3; i++) { + + NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() + .weightInit(new UniformDistribution(0, 1)) + .list(); + + switch (i) { + case 0: + b.layer(new LSTM.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2,3)) + .build()); + break; + case 1: + b.layer(new GravesLSTM.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2,3)) + .build()); + break; + case 2: + b.layer(new SimpleRnn.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2,3)).build()); + break; + default: + throw new RuntimeException(); + } + + MultiLayerNetwork net = new MultiLayerNetwork(b.build()); + net.init(); + + INDArray rw = net.getParam("0_RW"); + double min = rw.minNumber().doubleValue(); + double max = rw.maxNumber().doubleValue(); + if(rwInit){ + assertTrue(String.valueOf(min), min >= 0.0); + assertTrue(String.valueOf(max), max <= 1.0); + } else { + assertTrue(String.valueOf(min), min >= 2.0); + assertTrue(String.valueOf(max), max <= 3.0); + } + } + } + } + +} From 02211b689da9a2e2dd9cc7c206dc0868a48377f5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 31 Jan 2018 19:15:26 +1100 Subject: [PATCH 3/3] Fix implementation and test for recurrent weight init config --- .../recurrent/TestRecurrentWeightInit.java | 56 ++++++++++++------- .../nn/conf/layers/BaseRecurrentLayer.java | 2 + 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java b/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java index 8b0147699e75..6678ec38d92b 100644 --- a/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java +++ b/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java @@ -23,23 +23,39 @@ public void testRWInit() { .weightInit(new UniformDistribution(0, 1)) .list(); - switch (i) { - case 0: - b.layer(new LSTM.Builder().nIn(10).nOut(10) - .weightInitRecurrent(new UniformDistribution(2,3)) - .build()); - break; - case 1: - b.layer(new GravesLSTM.Builder().nIn(10).nOut(10) - .weightInitRecurrent(new UniformDistribution(2,3)) - .build()); - break; - case 2: - b.layer(new SimpleRnn.Builder().nIn(10).nOut(10) - .weightInitRecurrent(new UniformDistribution(2,3)).build()); - break; - default: - throw new RuntimeException(); + if(rwInit) { + switch (i) { + case 0: + b.layer(new LSTM.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2, 3)) + .build()); + break; + case 1: + b.layer(new GravesLSTM.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2, 3)) + .build()); + break; + case 2: + b.layer(new SimpleRnn.Builder().nIn(10).nOut(10) + .weightInitRecurrent(new UniformDistribution(2, 3)).build()); + break; + default: + throw new RuntimeException(); + } + } else { + switch (i) { + case 0: + b.layer(new LSTM.Builder().nIn(10).nOut(10).build()); + break; + case 1: + b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build()); + break; + case 2: + b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build()); + break; + default: + throw new RuntimeException(); + } } MultiLayerNetwork net = new MultiLayerNetwork(b.build()); @@ -49,11 +65,11 @@ public void testRWInit() { double min = rw.minNumber().doubleValue(); double max = rw.maxNumber().doubleValue(); if(rwInit){ - assertTrue(String.valueOf(min), min >= 0.0); - assertTrue(String.valueOf(max), max <= 1.0); - } else { assertTrue(String.valueOf(min), min >= 2.0); assertTrue(String.valueOf(max), max <= 3.0); + } else { + assertTrue(String.valueOf(min), min >= 0.0); + assertTrue(String.valueOf(max), max <= 1.0); } } } diff --git a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index ddf9c4397a85..c0f0692ce562 100644 --- a/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -24,6 +24,8 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { protected BaseRecurrentLayer(Builder builder) { super(builder); + this.weightInitRecurrent = builder.weightInitRecurrent; + this.distRecurrent = builder.distRecurrent; } @Override