Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KFoldIterator calculates incorrect batch sizes #6786

Closed
printomi opened this issue Nov 28, 2018 · 9 comments · Fixed by #6810

Comments

@printomi
Copy link

commented Nov 28, 2018

KfoldIterator calculates incorrect batch and lastBatch sizes. The later can be non-positive (lastBatch<=0). E.g.: Number of examples in allData N = 13, number of folds k = 6, then the current implementation calculates that batch = 3 and lastBatch = -2, which is obviously not what we wanted.

Current constructor looks like this:

public KFoldIterator(int k, DataSet allData) {
	this.k = k;
	this.allData = allData.copy();
	if (k <= 1)
		throw new IllegalArgumentException();
	if (allData.numExamples() % k != 0) {
		this.batch = (int)Math.ceil(allData.numExamples() / (double)k);
		this.lastBatch = allData.numExamples() - (k-1) * this.batch;
	} else {
		this.batch = allData.numExamples() / k;
		this.lastBatch = allData.numExamples() / k;
	}
}

My suggested implementation is the following:

public KFoldIterator(int k, DataSet allData) {
	this.k = k;
	this.allData = allData.copy();
	if (k <= 1) {
		throw new IllegalArgumentException();
	}
	this.batch = allData.numExamples() / k;
	this.lastBatch = batch + allData.numExamples() % k;
}

My implementation would give batch = 2 and lastBatch = 3 in the above example.

I have not tried your test (/pull/5993), but I have made a simple test class that passes with my implementation, and fails with the current one. My test tries every possible k numbers from 2 to N, and sums number of test examples in the folds.

import static org.junit.Assert.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import org.nd4j.linalg.dataset.api.iterator.KFoldIterator;

public class KFoldIteratorTest {

	private static final int NUM_EXAMPLES = 13;
	private static final int NUM_FEATURES = 3;
	
	@Test
	public void testTestFold() {
		INDArray features = Nd4j.rand(new int[] {NUM_EXAMPLES, NUM_FEATURES});
		INDArray labels = Nd4j.ones(NUM_EXAMPLES, 1);
		DataSet dataSet = new DataSet(features, labels);
		
		for (int k = 2; k <= NUM_EXAMPLES; k++) {
			KFoldIterator kFoldIterator = new KFoldIterator(k, dataSet);
			int numTestExamples = 0;
			for (int i = 0; i < k; i++) {
				kFoldIterator.next();
				numTestExamples += kFoldIterator.testFold().numExamples();
			}
			assertEquals(NUM_EXAMPLES, numTestExamples);
		}
	}
	
}

The above test of KFoldIterator (version 1.0.0-beta3) fails with the following exception stack trace:

java.lang.IllegalArgumentException: Beginning index (15) in range must be less than or equal to end (13)
	at org.nd4j.base.Preconditions.throwEx(Preconditions.java:636)
	at org.nd4j.base.Preconditions.checkArgument(Preconditions.java:110)
	at org.nd4j.linalg.indexing.NDArrayIndex.interval(NDArrayIndex.java:666)
	at org.nd4j.linalg.indexing.NDArrayIndex.interval(NDArrayIndex.java:698)
	at org.nd4j.linalg.dataset.DataSet.getRange(DataSet.java:237)
	at org.nd4j.linalg.dataset.api.iterator.KFoldIterator.nextFold(KFoldIterator.java:188)
	at org.nd4j.linalg.dataset.api.iterator.KFoldIterator.next(KFoldIterator.java:158)
	at hu.printnet.anne.node.task.impl.KFoldIteratorTest.testTestFold(KFoldIteratorTest.java:28)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50)
	at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
	at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47)
	at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
	at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325)
	at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78)
	at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57)
	at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290)
	at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71)
	at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288)
	at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58)
	at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268)
	at org.junit.runners.ParentRunner.run(ParentRunner.java:363)
	at org.eclipse.jdt.internal.junit4.runner.JUnit4TestReference.run(JUnit4TestReference.java:86)
	at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:38)
	at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:459)
	at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:678)
	at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:382)
	at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:192)

@printomi printomi changed the title KfoldIterator calculates incorrect batch sizes KFoldIterator calculates incorrect batch sizes Nov 28, 2018

@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented Nov 29, 2018

@printomi What wersion of DL4J/ND4J are you using?
This looks like an issue that was fixed a few releases ago.
Maybe try upgrading to 1.0.0-beta3 and try again?

@printomi

This comment has been minimized.

Copy link
Author

commented Nov 29, 2018

@AlexDBlack I use 1.0.0-beta3 version of ND4J and DL4J. The KFoldIterator constructor that I quoted above is from here

The problem with the current implementation is that in case of allData.numExamples() % k != 0 i.e. (N mod k) > 0, the size of the first k-1 batch is larger by one, and the last batch size is calculated from this, which can result in 0 or negative batch size in the last batch.

The correct approach would be to use always the integer part (floor) of N / k for the first k-1 batch, and use N / k + (N mod k) for the last batch. This way, we can eliminate an if clause from the current implementation.

@printomi

This comment has been minimized.

Copy link
Author

commented Nov 30, 2018

The following equations may make my approach more easy to understood:

N = allData.numExamples()
N = k*(N / k) + (N % k)
N = (k-1)*(N / k) + (N / k) + (N % k)
batch=N / k
lastBatch=(N / k) + (N % k)

where N / k is the quotient, and N % k is the remainder (modulo) of the division

@RobAltena RobAltena self-assigned this Dec 2, 2018

@RobAltena

This comment has been minimized.

Copy link
Contributor

commented Dec 4, 2018

Hello @printomi , Thanks for creating the issue, you clearly put a lot of effort in it which we appreciate. I was able to reproduce the failing test. The new implementation passes this test so would be an improvement.

However all the existing tests fail with the new implementation.

Please have a look at those. Then we can work on doing a pull request to improve the KFoldIterator.

@printomi

This comment has been minimized.

Copy link
Author

commented Dec 5, 2018

I reviewed the existing tests, and found some problems:

  1. The test checkFolds() has a problem with the following line: https://github.com/deeplearning4j/deeplearning4j/blob/d9f8836dd64e3346906a50ff0d4cdfd21a30f37d/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java#L42
    On this line, a RandomDataSet is created, which will consist of preset sized (3, 3, 3, 2) folds, resulting in
    3+3+3+2=11 examples in the next line: https://github.com/deeplearning4j/deeplearning4j/blob/d9f8836dd64e3346906a50ff0d4cdfd21a30f37d/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java#L43
    Then, it creates a KFoldIterator with k = 4: https://github.com/deeplearning4j/deeplearning4j/blob/d9f8836dd64e3346906a50ff0d4cdfd21a30f37d/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java#L44
    Here is the difference, which causes this test to fail:

  2. For the same reason, I would replace the following line of checkCornerCaseA() https://github.com/deeplearning4j/deeplearning4j/blob/d9f8836dd64e3346906a50ff0d4cdfd21a30f37d/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java#L89 with the following code:
    RandomDataSet randomDS = new RandomDataSet(new int[] {2, 3}, new int[] {1, 2});

  3. For the same reason, test5974() will have to be modified to the following:

    @Test
    public void test5974(){
        int N = 99;
        int k = 10;
        DataSet ds = new DataSet(Nd4j.linspace(1,N,N).transpose(), Nd4j.linspace(1,N,N).transpose());

        KFoldIterator iter = new KFoldIterator(k, ds);

        int count = 0;
        while(iter.hasNext()){
            DataSet fold = iter.next();
            INDArray testFold;
            int countTrain;
            int rem = N % k;
            if(count < k-1){
                //Folds 0 to 8: should have 9 examples for test (N=99,k=10)
                testFold = Nd4j.linspace(rem*count+1, rem*count+rem, k).transpose();
                countTrain = N - rem;
            } else {
                //Fold 9 should have 18 examples for test (N=99,k=10)
                int lastBatch = N-(k-1)*rem;
                testFold = Nd4j.linspace(rem*count+1, rem*count+lastBatch, lastBatch).transpose();
                countTrain = N - lastBatch;
            }
            String s = String.valueOf(count);
            DataSet test = iter.testFold();
            assertEquals(s, testFold, test.getFeatures());
            assertEquals(s, testFold, test.getLabels());
            assertEquals(s, countTrain, fold.getFeatures().length());
            assertEquals(s, countTrain, fold.getLabels().length());
            count++;
        }
    }

I created an Excel worksheet camparing the batch sizes of two implementations: KFoldIterator_batch_sizes.xlsx

TL-DR: the current implementation's batch sizes are hard-coded in the current unit tests (by @eraly ). That's why they all fail with my proposed implementation of KFoldIterator.

@raver119 raver119 added Bug Java labels Dec 5, 2018

@RobAltena

This comment has been minimized.

Copy link
Contributor

commented Dec 5, 2018

Well, that is an impressive effort.Thanks again. Would you submit a Pull Request for KFoldIterator.java and
KFoldIteratorTest.java ? You can ping us here if you need any help.

@printomi

This comment has been minimized.

Copy link
Author

commented Dec 6, 2018

I will be happy to create a Pull Request. However, I was not satisfied with my previously proposed implementation, because of the large differences of batch sizes in case of 1 < (N % k) < k. For example, with N = 11 and k = 6, batch = 1 and lastBatch = 6.
I found a better implementation in scikit-learn, which sets the size of every batch to N / k , and then increments the size of the first N % k batches by 1. This way, the difference between batch sizes will be at most 1, which is a good news, if we expect that the cross-validation would consist of similar sized training and test sets.
We can implement this with this KFoldIterator constructor and nextFold() method:

// new fields
private int N;
private int baseBatchSize;
private int numIncrementedBatches;
/* ... */
// modified constructor
public KFoldIterator(int k, DataSet allData) {
    if (k <= 1) {
        throw new IllegalArgumentException();
    }
    this.k = k;
    this.N = allData.numExamples();
    this.baseBatchSize = N / k;
    this.numIncrementedBatches = N % k;
    this.allData = allData.copy();
}
/* ... */
private void nextFold() {
    int left;
    int right;
    if (kCursor < numIncrementedBatches) {
        left = kCursor * (baseBatchSize + 1);
        right = left + (baseBatchSize + 1);
    } else {
        left = rem * (baseBatchSize + 1) + (kCursor - numIncrementedBatches) * baseBatchSize;
        right = left + rem;
    }

    List<DataSet> kMinusOneFoldList = new ArrayList<DataSet>();
    if (right < totalExamples()) {
        if (left > 0) {
            kMinusOneFoldList.add((DataSet) allData.getRange(0, left));
        }
        kMinusOneFoldList.add((DataSet) allData.getRange(right, totalExamples()));
        train = DataSet.merge(kMinusOneFoldList);
    } else {
        train = (DataSet) allData.getRange(0, left);
    }
    test = (DataSet) allData.getRange(left, right);

    kCursor++;

}

We will then have to modify the tests also.

@RobAltena

This comment has been minimized.

Copy link
Contributor

commented Dec 6, 2018

Looks good. Lets take the conversation to the PR. No point copying and pasting code here. I am not closing this issue yet as I am not sure if WIP Pull Requests are the way to go. We will find out quickly if we are not supposed to.

@lock

This comment has been minimized.

Copy link

commented Jan 13, 2019

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators Jan 13, 2019

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
4 participants
You can’t perform that action at this time.