Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KFoldIterator - Splits wrongly, last fold is usually to small or even zero #5974

Closed
boschhd opened this issue Jul 26, 2018 · 2 comments
Closed

Comments

@boschhd
Copy link

boschhd commented Jul 26, 2018

Expected behavior: KFoldIterator should split the dataset into k folds as evenly as possible.

Observed behavior: The last batch is often very small, in the range of 0..(k-1), which could explain the high variance of results in issue #5343

Explanation:
If the dataset size (n) is divisible by the desired number of folds (k) without remainder, it splits the dataset evenly.
However, if there is a remainder, it divides only into (k-1) folds and assigns the remainder to the last fold. This creates a last fold of only up to (modulo k-1) elements! This means that that the test set of the last fold will be extremely small and probably create NaN results in evaluation due to missing classes.

In the case that n is divisible by k and k-1 at the same time, this even creates an empty fold which will cause Exceptions later on.

Related lines in the file

Thanks to @RajaniVM for noticing the problem!

@AlexDBlack
Copy link
Contributor

Thanks for the issue (and @RajaniVM for flagging) - easy to confirm with size 99, 10 splits:

    @Test
    public void test5974(){
        DataSet ds = new DataSet(Nd4j.linspace(1,99,99).transpose(), Nd4j.linspace(1,99,99).transpose());

        KFoldIterator iter = new KFoldIterator(10, ds);

        while(iter.hasNext()){
            DataSet fold = iter.next();
            System.out.println(fold);
        }
    }
java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 99 must be less than its size: 99
	at org.nd4j.linalg.indexing.NDArrayIndex.validate(NDArrayIndex.java:440)
	at org.nd4j.linalg.indexing.NDArrayIndex.resolve(NDArrayIndex.java:345)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:5083)
	at org.nd4j.linalg.dataset.DataSet.getRange(DataSet.java:236)
	at org.nd4j.linalg.dataset.api.iterator.KFoldIterator.nextFold(KFoldIterator.java:193)
	at org.nd4j.linalg.dataset.api.iterator.KFoldIterator.next(KFoldIterator.java:163)
	at org.nd4j.linalg.dataset.KFoldIteratorTest.test5974(KFoldIteratorTest.java:182)

@lock
Copy link

lock bot commented Sep 21, 2018

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators Sep 21, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants