diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index 59a206ae286d..6e190634cef8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -19,10 +19,13 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.dropout.GaussianNoise; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.*; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; +import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; @@ -31,6 +34,7 @@ import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; +import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Sgd; @@ -449,4 +453,35 @@ public void testReconstructionErrorSimple() { } } } + + + @Test + public void testVaeWeightNoise(){ + + for(boolean ws : new boolean[]{false, true}) { + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(12345L) + .trainingWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) + .inferenceWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) + .weightNoise(new WeightNoise(new org.deeplearning4j.nn.conf.distribution.NormalDistribution(0.1, 0.3))) + .list().layer(0, + new VariationalAutoencoder.Builder().nIn(10).nOut(3) + .encoderLayerSizes(5).decoderLayerSizes(6) + .pzxActivationFunction(Activation.TANH) + .reconstructionDistribution(new GaussianReconstructionDistribution()) + .activation(new ActivationTanH()) + .build()) + .pretrain(true).backprop(false).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray arr = Nd4j.rand(3, 10); + net.pretrainLayer(0, arr); + + } + + + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java index ac56bb98a285..23b7f3d062de 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java @@ -155,10 +155,10 @@ public double score() { return score; } - protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr){ + protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr){ INDArray p; if(layerConf().getWeightNoise() != null){ - if(training && weightNoiseParams.size() > 0 ){ + if(training && weightNoiseParams.size() > 0 && weightNoiseParams.containsKey(param) ){ //Re-use these weights for both forward pass and backprop - don't want to use 2 different params here //These should be cleared during backprop return weightNoiseParams.get(param);