Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion Matrix in CNNExample does not contain any information #72

Closed
Nikoschenk opened this issue Jan 22, 2016 · 2 comments
Closed

Confusion Matrix in CNNExample does not contain any information #72

Nikoschenk opened this issue Jan 22, 2016 · 2 comments

Comments

@Nikoschenk
Copy link

Hi all,

I've modified CNNIrisExample.java in the package
org.deeplearning4j.examples.convolution
and try to get the confusion matrix.

However a call to

System.out.println(eval.getConfusionMatrix().toCSV());

shows only zeros and contains no information.

My minimal example:

package org.deeplearning4j.examples.convolution;

import java.io.File;
import java.io.IOException;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Random;
import org.canova.api.conf.Configuration;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.impl.CSVRecordReader;
import org.canova.api.records.reader.impl.LibSvmRecordReader;
import org.canova.api.split.FileSplit;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.springframework.core.io.ClassPathResource;

/**
 * @author sonali
 */
public class CNNMy2Example {

    private static Logger log = LoggerFactory.getLogger(CNNIrisExample.class);

    public static void main(String[] args) throws IOException, InterruptedException {

        final int numRows = 2;
        final int numColumns = 2;
        int nChannels = 1;
        int iterations = 10;
        int seed = 123;
        int listenerFreq = 1;

        /**
         * Set a neural network configuration with multiple layers
         */
        log.info("Load data....");

        // 1. Get TRAINING data.
        RecordReader recordReaderTrain = new CSVRecordReader(0, ",");
        recordReaderTrain.initialize(new FileSplit(new File("myCSVinputTRAIN.txt")));
        DataSetIterator iteratorTrain = new RecordReaderDataSetIterator(recordReaderTrain, 12, 4, 3);
        DataSet irisTrain = iteratorTrain.next();
        irisTrain.normalizeZeroMeanZeroUnitVariance();
        System.out.println("Loaded " + irisTrain.labelCounts() + " training set.");

        // 2. Get TEST data.
        RecordReader recordReaderTest = new CSVRecordReader(0, ",");
        recordReaderTest.initialize(new FileSplit(new File("myCSVinputTEST.txt")));
        DataSetIterator iteratorTest = new RecordReaderDataSetIterator(recordReaderTest, 6, 4, 3);
        DataSet irisTest = iteratorTest.next();
        irisTest.normalizeZeroMeanZeroUnitVariance();
        System.out.println("Loaded " + irisTest.labelCounts() + " test set.");

        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .list(2)
                .layer(0, new ConvolutionLayer.Builder(new int[]{1, 1})
                        .nIn(nChannels)
                        .nOut(1000) // # nodes in hidden layer.
                        .activation("relu")
                        .weightInit(WeightInit.RELU)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .nOut(3) // # output classes.
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .build())
                .backprop(true).pretrain(false);

        new ConvolutionLayerSetup(builder, numRows, numColumns, nChannels);

        MultiLayerConfiguration conf = builder.build();

        log.info("Build model....");
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));

        log.info("Train model....");
        System.out.println("Training on " + irisTrain.labelCounts());
        model.fit(irisTrain);

        log.info("Evaluate weights....");
        for (org.deeplearning4j.nn.api.Layer layer : model.getLayers()) {
            INDArray w = layer.getParam(DefaultParamInitializer.WEIGHT_KEY);
            //log.info("Weights: " + w);
        }

        Evaluation eval = new Evaluation(3);
        log.info("Evaluate model....");
        System.out.println("Testing on " + irisTest.labelCounts());

        INDArray output = model.output(irisTest.getFeatureMatrix());

        int[] predictedLabels = model.predict(irisTest.getFeatureMatrix());
        //for(int predictedLabel : predictedLabels) System.out.println(predictedLabel);

        // CONFUSION MATRIX DOES CONTAIN ONLY ZERO VALUES !!!
        System.out.println(eval.getConfusionMatrix().toCSV());
        // SAME HERE !
        System.out.println(eval.getConfusionMatrix());


        eval.eval(irisTest.getLabels(), output);
        log.info(eval.stats());

        log.info("\n");
        log.info("****************Example finished********************");
    }
}

myCSVinputTEST.txt
myCSVinputTRAIN.txt

@eraly
Copy link
Contributor

eraly commented Mar 14, 2016

@Nikoschenk You have to run the .eval method before you look at the confusion matrix. Simply move your print lines to below the eval.eval like so:

` int[] predictedLabels = model.predict(irisTest.getFeatureMatrix());

    // CONFUSION MATRIX DOES CONTAIN ONLY ZERO VALUES !!!
    //System.out.println(eval.getConfusionMatrix().toCSV());
    // SAME HERE !
    //System.out.println(eval.getConfusionMatrix());


    eval.eval(irisTest.getLabels(), output);
    log.info(eval.stats());
    // CONFUSION MATRIX DOES CONTAIN ONLY ZERO VALUES !!!
    System.out.println(eval.getConfusionMatrix().toCSV());
    // SAME HERE !
    System.out.println(eval.getConfusionMatrix());

    log.info("\n");
    log.info("****************Example finished********************");`

@nyghtowl
Copy link
Contributor

Addressed by @eraly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants