Skip to content

Commit

Permalink
Fix issue with VAE + weight noise (#6289)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Aug 27, 2018
1 parent 57ccb33 commit 9fde8d1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
Expand Up @@ -19,10 +19,13 @@
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.dropout.GaussianNoise;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.*;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
Expand All @@ -31,6 +34,7 @@
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Sgd;
Expand Down Expand Up @@ -449,4 +453,35 @@ public void testReconstructionErrorSimple() {
}
}
}


@Test
public void testVaeWeightNoise(){

for(boolean ws : new boolean[]{false, true}) {

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345L)
.trainingWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.inferenceWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.weightNoise(new WeightNoise(new org.deeplearning4j.nn.conf.distribution.NormalDistribution(0.1, 0.3)))
.list().layer(0,
new VariationalAutoencoder.Builder().nIn(10).nOut(3)
.encoderLayerSizes(5).decoderLayerSizes(6)
.pzxActivationFunction(Activation.TANH)
.reconstructionDistribution(new GaussianReconstructionDistribution())
.activation(new ActivationTanH())
.build())
.pretrain(true).backprop(false).build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray arr = Nd4j.rand(3, 10);
net.pretrainLayer(0, arr);

}


}
}
Expand Up @@ -155,10 +155,10 @@ public double score() {
return score;
}

protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr){
protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr){
INDArray p;
if(layerConf().getWeightNoise() != null){
if(training && weightNoiseParams.size() > 0 ){
if(training && weightNoiseParams.size() > 0 && weightNoiseParams.containsKey(param) ){
//Re-use these weights for both forward pass and backprop - don't want to use 2 different params here
//These should be cleared during backprop
return weightNoiseParams.get(param);
Expand Down

0 comments on commit 9fde8d1

Please sign in to comment.