Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect batch sizes in KFoldIterator #6810

Merged
merged 16 commits into from Dec 14, 2018

Conversation

@printomi
Copy link

commented Dec 6, 2018

What changes were proposed in this pull request?

This PR will fix #6786
KFoldIterator has been modified following the scikit-learn implementation.

N number of samples are split into k batches. The first (N%k) batches contain (N/k)+1 samples, while the remaining batches contain (N/k) samples.
In case the number of samples (N) in the dataset is a multiple of k, all batches will contain (N/k) samples.

Visibility of private fields and methods in KFoldIterator has been changed to protected, so that this class can be extended in the future. (Stratified k-fold and leave-k-out cross-validation iterators would be nice.)

An unneccessary copy of the DataSet has been eliminated from KFoldIterator constructor. However, the next() method will maintain a copy in the fields train and test.

Unit tests has been modified too: enhancements, improved documentation comments, and an added new test case.

How was this patch tested?

Manual test run in the directory deeplearning4j/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native with the following command:

mvn -e test -Dtest=KFoldIteratorTest -P testresources
printomi added 2 commits Dec 6, 2018
fix incorrect batch sizes in KFoldIterator, by following the scikit-l…
…earn implementation of sklearn.model_selection.KFold
fix comments in KFoldIterator, and make it extendableby changing visi…
…bility of private fields and methods to protected

@RobAltena RobAltena self-assigned this Dec 6, 2018

@RobAltena RobAltena added Bug Java labels Dec 6, 2018

@RobAltena

This comment has been minimized.

Copy link
Contributor

commented Dec 6, 2018

This needs to be in the source code comments too.

I modified KFoldIterator following the scikit-learn implementation.

The 2nd file we should modify in this PR is the KFoldIteratorTest.java.

btw: I literally just did the KFoldIterator Example which was an open ticket since february.

@printomi

This comment has been minimized.

Copy link
Author

commented Dec 7, 2018

Thanks for the comment, I will include a comment about this major change of implementation in the author field of the KFoldIterator JavaDoc.
I will also use your minor maintenance changes in the unit test update.

printomi added 3 commits Dec 7, 2018
Change naming of KFoldIteratorTest.RandomDataSet to KFoldIteratorTest…
….KBatchRandomDataSet, and it's variables from fold to batch. Improve documentation.
@printomi

This comment has been minimized.

Copy link
Author

commented Dec 7, 2018

Finally, I updated KFoldIteratorTest, and simplified KFoldIterator so that the main logic be more straightforward.
I see that I made a conflicting change: it is nothing breaking, but a function that is placed above an inner class. I will try to resolve it by placing back this function to it's original place in the file.

printomi added 2 commits Dec 7, 2018

@printomi printomi changed the title [WIP] Fix incorrect batch sizes in KFoldIterator Fix incorrect batch sizes in KFoldIterator Dec 7, 2018

printomi added 3 commits Dec 7, 2018
Merge remote-tracking branch 'upstream/master'
Merged upstream, because the previous code contained compilation errors
(missint symbols in java files)
@printomi

This comment has been minimized.

Copy link
Author

commented Dec 8, 2018

I have just removed an unneccessary copy of the DataSet from KFoldIterator.
I think that I can't improve any more this PR, so a review would be great now.
@RobAltena , could you check if the tests pass?

printomi added 2 commits Dec 10, 2018
@printomi

This comment has been minimized.

Copy link
Author

commented Dec 10, 2018

Manual test is in progress. I have just merged upstream branch. Rebuilding libnd4j, then recompiling and running tests.

@printomi

This comment has been minimized.

Copy link
Author

commented Dec 10, 2018

All tests passed with latest code. I used the following command in the directory deeplearning4j/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native to run the tests:

mvn -e test -Dtest=KFoldIteratorTest -P testresources
printomi added 3 commits Dec 11, 2018
@RobAltena
Copy link
Contributor

left a comment

I have just the one tiny issue with the PR. But am struggling with my local dev environment.

I could build nd4j locally and make sure the tests pass. When I then fetch in this pull request and run the tests I get an error:

java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 42 must be less than its size: 42
    at org.nd4j.linalg.dataset.KFoldIteratorTest.checkTestFoldContent(KFoldIteratorTest.java:64)

That must have run the tests against the old version. I tried this again and this time the old tests passed on the old version and the new tests passed on the new version. So we just have the one typo in the comments to fix.

* Create a k-fold cross-validation iterator given the dataset and k=10 train-test splits.
* N number of samples are split into k batches. The first (N%k) batches contain (N/k)+1 samples, while the remaining batches contain (N/k) samples.
* In case the number of samples (N) in the dataset is a multiple of k, all batches will contain (N/k) samples.
* @param k number of folds (optional, defaults to 10)

This comment has been minimized.

Copy link
@RobAltena

RobAltena Dec 13, 2018

Contributor

There is no parameter k in this constructor.

@RobAltena
Copy link
Contributor

left a comment

See the comment above for one tiny correction in the comments.
I was able to build locally and pass the tests.

@agibsonccc agibsonccc merged commit 4670eee into eclipse:master Dec 14, 2018

1 of 2 checks passed

codeclimate Code Climate encountered an error attempting to analyze this pull request.
Details
Codacy/PR Quality Review Up to standards. A positive pull request.
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.