Skip to content

Commit

Permalink
#5351 Fix EvaluativeListener workspace issue
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed May 28, 2018
1 parent a589d2c commit df681f0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
Expand Up @@ -43,6 +43,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
Expand Down Expand Up @@ -1249,6 +1250,40 @@ public void testConfusionMatrixString(){

System.out.println("\n\n\n\n");
System.out.println(e.stats(false, true));
}



@Test
public void testEvaluativeListenerSimple(){
//Sanity check: https://github.com/deeplearning4j/deeplearning4j/issues/5351

// Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();

// Instantiate model
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

// Train-test split
DataSetIterator iter = new IrisDataSetIterator(30, 150);
DataSetIterator iterTest = new IrisDataSetIterator(30, 150);

net.setListeners(new EvaluativeListener(iterTest, 3));

for( int i=0; i<10; i++ ){
net.fit(iter);
}


}
}
Expand Up @@ -152,7 +152,8 @@ public EvaluativeListener(@NonNull MultiDataSet multiDataSet, int frequency, @No
*/
@Override
public void iterationDone(Model model, int iteration, int epoch) {
// no-op
if (invocationType == InvocationType.ITERATION_END)
invokeListener(model);
}

@Override
Expand All @@ -167,27 +168,6 @@ public void onEpochEnd(Model model) {
invokeListener(model);
}

@Override
public void onForwardPass(Model model, List<INDArray> activations) {
// no-op
}

@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
// no-op
}

@Override
public void onGradientCalculation(Model model) {
// no-op
}

@Override
public void onBackwardPass(Model model) {
if (invocationType == InvocationType.ITERATION_END)
invokeListener(model);
}

protected void invokeListener(Model model) {
if (iterationCount.get() == null)
iterationCount.set(new AtomicLong(0));
Expand Down

0 comments on commit df681f0

Please sign in to comment.