CrossValidation

davidpicard edited this page Jul 4, 2012 · 4 revisions

There is a common CrossValidation interface for all implemented methods. This interface defines mainly 4 methods as follows:

public interface CrossValidation {

    /**
     * perform learning and evaluations
     */
    public void run();

    /**
     * Tells the average score of the test
     * @return the average score
     */
    public double getAverageScore();

    /**
     * Tells the standard deviation of the test
     * @return the standard deviation
     */
    public double getStdDevScore();

    /**
     * Tells the scores of the tests, in order of evaluation
     * @return an array with the scores in order
     */
    public double[] getScores(); 

}

Currently, 3 crossvalidation techniques are available, namely RandomSplitCrossValidation, LeaveOneOutCrossValidation and NFoldCrossValidation. RandomSplitCrossValidation performs several evaluation of the classifier using a random split of the provided sample set. LeaveOneOutCrossValidation is the implementation of the well known leave one out protocole. NFoldCrossValidation splits the data in n subset, using (n-1) for training and the last one for testing. CrossValidation is agnostic regarding data type (as is the whole library) but also regarding the metric used (please refer to the Evaluator classes for this point).

Suppose you have a Classifier<double[]> c and a List<TrainingSample<double[]>> l, you can initialize a CrossValidation as follows:

Evaluator<double[]> eval = new AccuracyEvaluator<double[]>();
RandomSplitCrossValidation<double[]> cv = new RandomSplitCrossValidation<double[]>(c, l, eval);
cv.setTrainPercent(0.80);
cv.setNbTest(10);

In this case, we use the accuracy metric computed by the AccuracyEvaluator. We take 80% of the sample for training and the remaining for evaluation. The test will be done 10 times, using random split each time.

To run the test, we just call the run() method:

cv.run();

Results can be obtained calling the getAverageScore() and getStdDevScore() methods:

debug.println(1,"Accuracy: " + cv.getAverageScore() + " +/- "
        + cv.getStdDevScore());