Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recurrent weight init configuration option for all RNN layers #4579

Merged
merged 3 commits into from Feb 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,78 @@
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;

import static org.junit.Assert.assertTrue;

public class TestRecurrentWeightInit {

@Test
public void testRWInit() {

for (boolean rwInit : new boolean[]{false, true}) {
for (int i = 0; i < 3; i++) {

NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
.weightInit(new UniformDistribution(0, 1))
.list();

if(rwInit) {
switch (i) {
case 0:
b.layer(new LSTM.Builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3))
.build());
break;
case 1:
b.layer(new GravesLSTM.Builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3))
.build());
break;
case 2:
b.layer(new SimpleRnn.Builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3)).build());
break;
default:
throw new RuntimeException();
}
} else {
switch (i) {
case 0:
b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
break;
case 1:
b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
break;
case 2:
b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
break;
default:
throw new RuntimeException();
}
}

MultiLayerNetwork net = new MultiLayerNetwork(b.build());
net.init();

INDArray rw = net.getParam("0_RW");
double min = rw.minNumber().doubleValue();
double max = rw.maxNumber().doubleValue();
if(rwInit){
assertTrue(String.valueOf(min), min >= 2.0);
assertTrue(String.valueOf(max), max <= 3.0);
} else {
assertTrue(String.valueOf(min), min >= 0.0);
assertTrue(String.valueOf(max), max <= 1.0);
}
}
}
}

}
Expand Up @@ -20,6 +20,7 @@

import lombok.*;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
Expand Down
Expand Up @@ -6,7 +6,9 @@
import lombok.ToString;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.weights.WeightInit;

import java.util.Arrays;
import java.util.List;
Expand All @@ -17,8 +19,13 @@
@EqualsAndHashCode(callSuper = true)
public abstract class BaseRecurrentLayer extends FeedForwardLayer {

protected WeightInit weightInitRecurrent;
protected Distribution distRecurrent;

protected BaseRecurrentLayer(Builder builder) {
super(builder);
this.weightInitRecurrent = builder.weightInitRecurrent;
this.distRecurrent = builder.distRecurrent;
}

@Override
Expand Down Expand Up @@ -57,6 +64,8 @@ public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
public static abstract class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
protected List<LayerConstraint> recurrentConstraints;
protected List<LayerConstraint> inputWeightConstraints;
protected WeightInit weightInitRecurrent;
protected Distribution distRecurrent;

/**
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no constraints.<br>
Expand All @@ -81,6 +90,29 @@ public T constrainInputWeights(LayerConstraint... constraints) {
this.inputWeightConstraints = Arrays.asList(constraints);
return (T) this;
}
}

/**
* Set the weight initialization for the recurrent weights. Not that if this is not set explicitly, the same
* weight initialization as the layer input weights is also used for the recurrent weights.
*
* @param weightInit Weight initialization for the recurrent weights only.
*/
public T weightInitRecurrent(WeightInit weightInit){
this.weightInitRecurrent = weightInit;
return (T) this;
}

/**
* Set the weight initialization for the recurrent weights, based on the specified distribution. Not that if this
* is not set explicitly, the same weight initialization as the layer input weights is also used for the recurrent
* weights.
*
* @param dist Distribution to use for initializing the recurrent weights
*/
public T weightInitRecurrent(Distribution dist){
this.weightInitRecurrent = WeightInit.DISTRIBUTION;
this.distRecurrent = dist;
return (T) this;
}
}
}
Expand Up @@ -22,6 +22,7 @@
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
Expand Down Expand Up @@ -128,10 +129,20 @@ public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsVi
int[] inputWShape = new int[] {nLast, 4 * nL};
int[] recurrentWShape = new int[] {nL, 4 * nL + 3};

WeightInit rwInit;
Distribution rwDist;
if(layerConf.getWeightInitRecurrent() != null){
rwInit = layerConf.getWeightInitRecurrent();
rwDist = Distributions.createDistribution(layerConf.getDistRecurrent());
} else {
rwInit = layerConf.getWeightInit();
rwDist = dist;
}

params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape,
layerConf.getWeightInit(), dist, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape,
layerConf.getWeightInit(), dist, recurrentWeightView));
rwInit, rwDist, recurrentWeightView));
biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)},
Nd4j.valueArrayOf(1, nL, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG}
/*The above line initializes the forget gate biases to specified value.
Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
Expand All @@ -49,6 +50,12 @@ public static LSTMParamInitializer getInstance() {
public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY;
public final static String INPUT_WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY;

private static final List<String> LAYER_PARAM_KEYS = Collections.unmodifiableList(
Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY));
private static final List<String> WEIGHT_KEYS = Collections.unmodifiableList(
Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY));
private static final List<String> BIAS_KEYS = Collections.unmodifiableList(Collections.singletonList(BIAS_KEY));

@Override
public int numParams(NeuralNetConfiguration conf) {
return numParams(conf.getLayer());
Expand All @@ -70,17 +77,17 @@ public int numParams(Layer l) {

@Override
public List<String> paramKeys(Layer layer) {
return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY);
return LAYER_PARAM_KEYS;
}

@Override
public List<String> weightKeys(Layer layer) {
return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY);
return WEIGHT_KEYS;
}

@Override
public List<String> biasKeys(Layer layer) {
return Collections.singletonList(BIAS_KEY);
return BIAS_KEYS;
}

@Override
Expand Down Expand Up @@ -128,10 +135,20 @@ public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsVi
int[] inputWShape = new int[] {nLast, 4 * nL};
int[] recurrentWShape = new int[] {nL, 4 * nL};

WeightInit rwInit;
Distribution rwDist;
if(layerConf.getWeightInitRecurrent() != null){
rwInit = layerConf.getWeightInitRecurrent();
rwDist = Distributions.createDistribution(layerConf.getDistRecurrent());
} else {
rwInit = layerConf.getWeightInit();
rwDist = dist;
}

params.put(INPUT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, inputWShape,
layerConf.getWeightInit(), dist, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(fanIn, fanOut, recurrentWShape,
layerConf.getWeightInit(), dist, recurrentWeightView));
rwInit, rwDist, recurrentWeightView));
biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)},
Nd4j.valueArrayOf(1, nL, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG}
/*The above line initializes the forget gate biases to specified value.
Expand Down
Expand Up @@ -5,6 +5,7 @@
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
Expand Down Expand Up @@ -77,18 +78,27 @@ public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsVi

Map<String,INDArray> m;

if(initializeParams){
if (initializeParams) {
Distribution dist = Distributions.createDistribution(c.getDist());

m = getSubsets(paramsView, nIn, nOut, false);
INDArray w = WeightInitUtil.initWeights(nIn, nOut, new int[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
m.put(WEIGHT_KEY, w);

INDArray rw = WeightInitUtil.initWeights(nOut, nOut, new int[]{nOut, nOut}, c.getWeightInit(), dist, 'f', m.get(RECURRENT_WEIGHT_KEY));
m = getSubsets(paramsView, nIn, nOut, false);
INDArray w = WeightInitUtil.initWeights(nIn, nOut, new int[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
m.put(WEIGHT_KEY, w);

WeightInit rwInit;
Distribution rwDist;
if (c.getWeightInitRecurrent() != null) {
rwInit = c.getWeightInitRecurrent();
rwDist = Distributions.createDistribution(c.getDistRecurrent());
} else {
rwInit = c.getWeightInit();
rwDist = dist;
}

INDArray rw = WeightInitUtil.initWeights(nOut, nOut, new int[]{nOut, nOut}, rwInit, rwDist, 'f', m.get(RECURRENT_WEIGHT_KEY));
m.put(RECURRENT_WEIGHT_KEY, rw);

} else {
m = getSubsets(paramsView, nIn, nOut, true);
m = getSubsets(paramsView, nIn, nOut, true);
}

conf.addVariable(WEIGHT_KEY);
Expand Down