diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java index 89112b56ce06..149736055189 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java @@ -20,9 +20,15 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.junit.Test; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.*; public class DataSetSplitterTests extends BaseDL4JTest { @Test @@ -39,7 +45,7 @@ public void testSplitter_1() throws Exception { int gcntTest = 0; int global = 0; // emulating epochs here - for (int e = 0; e < numEpochs; e++){ + for (int e = 0; e < numEpochs; e++) { int cnt = 0; while (train.hasNext()) { val data = train.next().getFeatures(); @@ -79,7 +85,7 @@ public void testSplitter_2() throws Exception { int gcntTest = 0; int global = 0; // emulating epochs here - for (int e = 0; e < numEpochs; e++){ + for (int e = 0; e < numEpochs; e++) { int cnt = 0; while (train.hasNext()) { val data = train.next().getFeatures(); @@ -117,7 +123,7 @@ public void testSplitter_3() throws Exception { int gcntTest = 0; int global = 0; // emulating epochs here - for (int e = 0; e < numEpochs; e++){ + for (int e = 0; e < numEpochs; e++) { int cnt = 0; while (train.hasNext()) { val data = train.next().getFeatures(); @@ -144,4 +150,245 @@ public void testSplitter_3() throws Exception { assertEquals(1000 * numEpochs, global); } + + @Test + public void testSplitter_4() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new DataSetIteratorSplitter(back, 1000, new double[]{0.5, 0.3, 0.2}); + List iteratorList = splitter.getIterators(); + val numEpochs = 10; + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + int iterNo = 0; + int perEpoch = 0; + for (val partIterator : iteratorList) { + int cnt = 0; + partIterator.reset(); + while (partIterator.hasNext()) { + val data = partIterator.next().getFeatures(); + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, + (float) perEpoch, data.getFloat(0), 1e-5); + //gcntTrain++; + global++; + cnt++; + ++perEpoch; + } + ++iterNo; + } + } + + assertEquals(1000* numEpochs, global); + } + + @Test + public void testSplitter_5() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new DataSetIteratorSplitter(back, new int[]{900, 100}); + + List iteratorList = splitter.getIterators(); + val numEpochs = 10; + + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + int iterNo = 0; + int perEpoch = 0; + for (val partIterator : iteratorList) { + partIterator.reset(); + while (partIterator.hasNext()) { + int cnt = 0; + val data = partIterator.next().getFeatures(); + + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, + (float) perEpoch, data.getFloat(0), 1e-5); + //gcntTrain++; + global++; + cnt++; + ++perEpoch; + } + ++iterNo; + } + } + + assertEquals(1000 * numEpochs, global); + } + + @Test + public void testSplitter_6() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + // we're going to mimic train+test+validation split + val splitter = new DataSetIteratorSplitter(back, new int[]{800, 100, 100}); + + assertEquals(3, splitter.getIterators().size()); + + val trainIter = splitter.getIterators().get(0); + val testIter = splitter.getIterators().get(1); + val validationIter = splitter.getIterators().get(2); + + // we're going to have multiple epochs + int numEpochs = 10; + for (int e = 0; e < numEpochs; e++) { + int globalIter = 0; + trainIter.reset(); + testIter.reset(); + validationIter.reset(); + + boolean trained = false; + while (trainIter.hasNext()) { + trained = true; + val ds = trainIter.next(); + assertNotNull(ds); + + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", trained); + assertEquals(800, globalIter); + + + // test set is used every epoch + boolean tested = false; + //testIter.reset(); + while (testIter.hasNext()) { + tested = true; + val ds = testIter.next(); + assertNotNull(ds); + + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", tested); + assertEquals(900, globalIter); + + // validation set is used every 5 epochs + if (e % 5 == 0) { + boolean validated = false; + //validationIter.reset(); + while (validationIter.hasNext()) { + validated = true; + val ds = validationIter.next(); + assertNotNull(ds); + + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", validated); + } + + // all 3 iterators have exactly 1000 elements combined + if (e % 5 == 0) + assertEquals(1000, globalIter); + else + assertEquals(900, globalIter); + trainIter.reset(); + } + } + + @Test + public void testUnorderedSplitter_1() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new DataSetIteratorSplitter(back, new int[]{500, 500}); + + List iteratorList = splitter.getIterators(); + val numEpochs = 10; + + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + + // Get data from second part, then rewind for the first one. + int cnt = 0; + int partNumber = 1; + while (iteratorList.get(partNumber).hasNext()) { + int farCnt = (1000 / 2) * (partNumber) + cnt; + val data = iteratorList.get(partNumber).next().getFeatures(); + + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5); + cnt++; + global++; + } + iteratorList.get(partNumber).reset(); + partNumber = 0; + cnt = 0; + while (iteratorList.get(0).hasNext()) { + val data = iteratorList.get(0).next().getFeatures(); + + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + global++; + } + } + } + + @Test + public void testUnorderedSplitter_2() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new DataSetIteratorSplitter(back, new int[]{2}); + + List iteratorList = splitter.getIterators(); + + for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) { + int cnt = 0; + while (iteratorList.get(partNumber).hasNext()) { + val data = iteratorList.get(partNumber).next().getFeatures(); + + assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); + cnt++; + } + } + } + + @Test + public void testUnorderedSplitter_3() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new DataSetIteratorSplitter(back, new int[]{10}); + + List iteratorList = splitter.getIterators(); + Random random = new Random(); + int[] indexes = new int[iteratorList.size()]; + for (int i = 0; i < indexes.length; ++i) { + indexes[i] = random.nextInt(iteratorList.size()); + } + + for (int partNumber : indexes) { + int cnt = 0; + while (iteratorList.get(partNumber).hasNext()) { + val data = iteratorList.get(partNumber).next().getFeatures(); + + assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); + cnt++; + } + } + } + + @Test + public void testUnorderedSplitter_4() { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + // we're going to mimic train+test+validation split + val splitter = new DataSetIteratorSplitter(back, new int[]{80, 10, 5}); + + assertEquals(3, splitter.getIterators().size()); + + val trainIter = splitter.getIterators().get(0); // 0..79 + val testIter = splitter.getIterators().get(1); // 80 ..89 + val validationIter = splitter.getIterators().get(2); // 90..94 + + // we're skipping train/test and go for validation first. we're that crazy, right. + int valCnt = 0; + while (validationIter.hasNext()) { + val ds = validationIter.next(); + assertNotNull(ds); + + assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5); + valCnt++; + } + assertEquals(5, valCnt); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java index 6f624ecfda75..2e2853133a19 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java @@ -18,11 +18,17 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator; import org.junit.Test; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; -import static org.junit.Assert.assertEquals; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.*; /** * @@ -150,4 +156,309 @@ public void testSplitter_3() throws Exception { assertEquals(1000 * numEpochs, global); } + + @Test + public void testMultiSplitter_1() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + // we're going to mimic train+test+validation split + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100}); + + assertEquals(3, splitter.getIterators().size()); + + val trainIter = splitter.getIterators().get(0); + val testIter = splitter.getIterators().get(1); + val validationIter = splitter.getIterators().get(2); + + // we're going to have multiple epochs + int numEpochs = 10; + for (int e = 0; e < numEpochs; e++) { + int globalIter = 0; + trainIter.reset(); + testIter.reset(); + validationIter.reset(); + + boolean trained = false; + while (trainIter.hasNext()) { + trained = true; + val ds = trainIter.next(); + assertNotNull(ds); + + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", trained); + assertEquals(800, globalIter); + + + // test set is used every epoch + boolean tested = false; + //testIter.reset(); + while (testIter.hasNext()) { + tested = true; + val ds = testIter.next(); + assertNotNull(ds); + + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", tested); + assertEquals(900, globalIter); + + // validation set is used every 5 epochs + if (e % 5 == 0) { + boolean validated = false; + //validationIter.reset(); + while (validationIter.hasNext()) { + validated = true; + val ds = validationIter.next(); + assertNotNull(ds); + + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", validated); + } + + // all 3 iterators have exactly 1000 elements combined + if (e % 5 == 0) + assertEquals(1000, globalIter); + else + assertEquals(900, globalIter); + trainIter.reset(); + } + } + + @Test + public void testSplitter_5() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{900, 100}); + + List iteratorList = splitter.getIterators(); + val numEpochs = 10; + + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + int iterNo = 0; + int perEpoch = 0; + for (val partIterator : iteratorList) { + partIterator.reset(); + while (partIterator.hasNext()) { + int cnt = 0; + val data = partIterator.next().getFeatures(); + + for (int i = 0; i < data.length; ++i) { + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, + (float) perEpoch, data[i].getFloat(0), 1e-5); + } + //gcntTrain++; + global++; + cnt++; + ++perEpoch; + } + ++iterNo; + } + } + + assertEquals(1000 * numEpochs, global); + } + + @Test + public void testSplitter_6() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + // we're going to mimic train+test+validation split + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100}); + + assertEquals(3, splitter.getIterators().size()); + + val trainIter = splitter.getIterators().get(0); + val testIter = splitter.getIterators().get(1); + val validationIter = splitter.getIterators().get(2); + + // we're going to have multiple epochs + int numEpochs = 10; + for (int e = 0; e < numEpochs; e++) { + int globalIter = 0; + trainIter.reset(); + testIter.reset(); + validationIter.reset(); + + boolean trained = false; + while (trainIter.hasNext()) { + trained = true; + val ds = trainIter.next(); + assertNotNull(ds); + + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, + ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", trained); + assertEquals(800, globalIter); + + + // test set is used every epoch + boolean tested = false; + //testIter.reset(); + while (testIter.hasNext()) { + tested = true; + val ds = testIter.next(); + assertNotNull(ds); + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", tested); + assertEquals(900, globalIter); + + // validation set is used every 5 epochs + if (e % 5 == 0) { + boolean validated = false; + //validationIter.reset(); + while (validationIter.hasNext()) { + validated = true; + val ds = validationIter.next(); + assertNotNull(ds); + + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, + ds.getFeatures()[i].getDouble(0), 1e-5f); + } + globalIter++; + } + assertTrue("Failed at epoch [" + e + "]", validated); + } + + // all 3 iterators have exactly 1000 elements combined + if (e % 5 == 0) + assertEquals(1000, globalIter); + else + assertEquals(900, globalIter); + trainIter.reset(); + } + } + + @Test + public void testUnorderedSplitter_1() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{500, 500}); + + List iteratorList = splitter.getIterators(); + val numEpochs = 10; + + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + + // Get data from second part, then rewind for the first one. + int cnt = 0; + int partNumber = 1; + while (iteratorList.get(partNumber).hasNext()) { + int farCnt = (1000 / 2) * (partNumber) + cnt; + val data = iteratorList.get(partNumber).next().getFeatures(); + for (int i = 0; i < data.length; ++i) { + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5); + } + cnt++; + global++; + } + iteratorList.get(partNumber).reset(); + partNumber = 0; + cnt = 0; + while (iteratorList.get(0).hasNext()) { + val data = iteratorList.get(0).next().getFeatures(); + for (int i = 0; i < data.length; ++i) { + assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, + data[i].getFloat(0), 1e-5); + } + global++; + } + } + } + + @Test + public void testUnorderedSplitter_2() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{2}); + + List iteratorList = splitter.getIterators(); + + for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) { + int cnt = 0; + while (iteratorList.get(partNumber).hasNext()) { + val data = iteratorList.get(partNumber).next().getFeatures(); + for (int i = 0; i < data.length; ++i) { + assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5); + } + cnt++; + } + } + } + + @Test + public void testUnorderedSplitter_3() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10}); + + List iteratorList = splitter.getIterators(); + Random random = new Random(); + int[] indexes = new int[iteratorList.size()]; + for (int i = 0; i < indexes.length; ++i) { + indexes[i] = random.nextInt(iteratorList.size()); + } + + for (int partNumber : indexes) { + int cnt = 0; + while (iteratorList.get(partNumber).hasNext()) { + val data = iteratorList.get(partNumber).next().getFeatures(); + for (int i = 0; i < data.length; ++i) { + assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), + data[i].getFloat(0), 1e-5); + } + cnt++; + } + } + } + + @Test + public void testUnorderedSplitter_4() { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + + // we're going to mimic train+test+validation split + val splitter = new MultiDataSetIteratorSplitter(back, new int[]{80, 10, 5}); + + assertEquals(3, splitter.getIterators().size()); + + val trainIter = splitter.getIterators().get(0); // 0..79 + val testIter = splitter.getIterators().get(1); // 80 ..89 + val validationIter = splitter.getIterators().get(2); // 90..94 + + // we're skipping train/test and go for validation first. we're that crazy, right. + int valCnt = 0; + while (validationIter.hasNext()) { + val ds = validationIter.next(); + assertNotNull(ds); + for (int i = 0; i < ds.getFeatures().length; ++i) { + assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, + ds.getFeatures()[i].getFloat(0), 1e-5); + } + valCnt++; + } + assertEquals(5, valCnt); + } } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java index 6248aa4a19a6..ac03d5cecc7b 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -42,14 +43,20 @@ public class DataSetIteratorSplitter { protected DataSetIterator backedIterator; protected final long totalExamples; protected final double ratio; + protected final double[] ratios; protected final long numTrain; protected final long numTest; + protected final long numArbitrarySets; + protected final int[] splits; + protected AtomicLong counter = new AtomicLong(0); protected AtomicBoolean resetPending = new AtomicBoolean(false); protected DataSet firstTrain = null; + protected int partNumber = 0; + /** * The only constructor * @@ -71,17 +78,94 @@ public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long total this.backedIterator = baseIterator; this.totalExamples = totalBatches; this.ratio = ratio; + this.ratios = null; this.numTrain = (long) (totalExamples * ratio); this.numTest = totalExamples - numTrain; + this.numArbitrarySets = 2; + this.splits = null; log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); } + public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double[] ratios) { + for (double ratio : ratios) { + if (!(ratio > 0.0 && ratio < 1.0)) + throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0"); + } + + if (totalBatches < 0) + throw new ND4JIllegalStateException("totalExamples number should be positive value"); + + if (!baseIterator.resetSupported()) + throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split"); + + + this.backedIterator = baseIterator; + this.totalExamples = totalBatches; + this.ratio = 0.0; + this.ratios = ratios; + this.numTrain = 0; //(long) (totalExamples * ratio); + this.numTest = 0; //totalExamples - numTrain; + this.numArbitrarySets = ratios.length; + + this.splits = new int[this.ratios.length]; + for (int i = 0; i < this.splits.length; ++i) { + this.splits[i] = (int)(totalExamples * ratios[i]); + } + + log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); + } + + public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, int[] splits) { + + /*if (!(simpleRatio > 0.0 && simpleRatio < 1.0)) + throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");*/ + + int totalBatches = 0; + for (val v:splits) + totalBatches += v; + + if (totalBatches < 0) + throw new ND4JIllegalStateException("totalExamples number should be positive value"); + + if (!baseIterator.resetSupported()) + throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split"); + + + this.backedIterator = baseIterator; + this.totalExamples = totalBatches; + this.ratio = 0.0; + this.ratios = null; + + this.numTrain = 0; //(long) (totalExamples * ratio); + this.numTest = 0; //totalExamples - numTrain; + this.splits = splits; + this.numArbitrarySets = splits.length; + + log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); + } + + public List getIterators() { + List retVal = new ArrayList<>(); + int partN = 0; + int bottom = 0; + for (final int split : splits) { + ScrollableDataSetIterator partIterator = + new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain, + new int[]{bottom,split}); + bottom += split; + retVal.add(partIterator); + } + return retVal; + } + + /** * This method returns train iterator instance * * @return */ + @Deprecated public DataSetIterator getTrainIterator() { return new DataSetIterator() { @Override @@ -184,6 +268,7 @@ public DataSet next() { * * @return */ + @Deprecated public DataSetIterator getTestIterator() { return new DataSetIterator() { @Override diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java index b233faeac4e5..effa77f05745 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java @@ -21,9 +21,12 @@ import lombok.val; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -43,6 +46,9 @@ public class MultiDataSetIteratorSplitter { protected final double ratio; protected final long numTrain; protected final long numTest; + protected final double[] ratios; + protected final long numArbitrarySets; + protected final int[] splits; protected AtomicLong counter = new AtomicLong(0); @@ -71,15 +77,87 @@ public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, this.ratio = ratio; this.numTrain = (long) (totalExamples * ratio); this.numTest = totalExamples - numTrain; + this.ratios = null; + this.numArbitrarySets = 0; + this.splits = null; log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); } + public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) { + for (double ratio : ratios) { + if (!(ratio > 0.0 && ratio < 1.0)) + throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0"); + } + + if (totalBatches < 0) + throw new ND4JIllegalStateException("totalExamples number should be positive value"); + + if (!baseIterator.resetSupported()) + throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split"); + + + this.backedIterator = baseIterator; + this.totalExamples = totalBatches; + this.ratio = 0.0; + this.numTrain = (long) (totalExamples * ratio); + this.numTest = totalExamples - numTrain; + this.ratios = null; + this.numArbitrarySets = ratios.length; + + this.splits = new int[this.ratios.length]; + for (int i = 0; i < this.splits.length; ++i) { + this.splits[i] = (int)(totalExamples * ratios[i]); + } + + log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); + } + + public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) { + + int totalBatches = 0; + for (val v:splits) + totalBatches += v; + + if (totalBatches < 0) + throw new ND4JIllegalStateException("totalExamples number should be positive value"); + + if (!baseIterator.resetSupported()) + throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split"); + + + this.backedIterator = baseIterator; + this.totalExamples = totalBatches; + this.ratio = 0.0; + this.numTrain = (long) (totalExamples * ratio); + this.numTest = totalExamples - numTrain; + this.ratios = null; + this.numArbitrarySets = splits.length; + this.splits = splits; + + log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); + } + + public List getIterators() { + List retVal = new ArrayList<>(); + int partN = 0; + int bottom = 0; + for (final int split : splits) { + ScrollableMultiDataSetIterator partIterator = + new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain, + new int[]{bottom,split}); + bottom += split; + retVal.add(partIterator); + } + return retVal; + } + /** * This method returns train iterator instance * * @return */ + @Deprecated public MultiDataSetIterator getTrainIterator() { return new MultiDataSetIterator() { @Override @@ -162,6 +240,7 @@ public void remove() { * * @return */ + @Deprecated public MultiDataSetIterator getTestIterator() { return new MultiDataSetIterator() { @Override diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java new file mode 100644 index 000000000000..40039f09e032 --- /dev/null +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java @@ -0,0 +1,158 @@ +package org.deeplearning4j.datasets.iterator; + +import lombok.val; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +public class ScrollableDataSetIterator implements DataSetIterator { + private int thisPart = 0; + private int top = 0; + private int bottom = 0; + protected DataSetIterator backedIterator; + protected AtomicLong counter = new AtomicLong(0); + + protected AtomicBoolean resetPending = new AtomicBoolean(false); + protected DataSet firstTrain = null; + protected MultiDataSet firstMultiTrain = null; + private double ratio; + private long totalExamples; + private long itemsPerPart; + private long current; + + + public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter, + AtomicBoolean resetPending, DataSet firstTrain, double ratio, + int totalExamples) { + this.thisPart = num; + this.backedIterator = backedIterator; + this.counter = counter; + this.resetPending = resetPending; + this.firstTrain = firstTrain; + this.ratio = ratio; + this.totalExamples = totalExamples; + this.itemsPerPart = (long)(totalExamples * ratio); + this.current = 0; + } + + public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter, + AtomicBoolean resetPending, DataSet firstTrain, + int[] itemsPerPart) { + this.thisPart = num; + this.bottom = itemsPerPart[0]; + this.top = bottom + itemsPerPart[1]; + this.itemsPerPart = top; + + this.backedIterator = backedIterator; + this.counter = counter; + //this.resetPending = resetPending; + this.firstTrain = firstTrain; + //this.totalExamples = totalExamples; + this.current = 0; + } + + @Override + public DataSet next(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getLabels() { + return backedIterator.getLabels(); + } + + @Override + public int inputColumns() { + return backedIterator.inputColumns(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public int totalOutcomes() { + return backedIterator.totalOutcomes(); + } + + @Override + public boolean resetSupported() { + return backedIterator.resetSupported(); + } + + @Override + public boolean asyncSupported() { + return backedIterator.asyncSupported(); + } + + @Override + public void reset() { + resetPending.set(true); + } + + @Override + public int batch() { + return backedIterator.batch(); + } + + @Override + public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) { + backedIterator.setPreProcessor(dataSetPreProcessor); + } + + @Override + public DataSetPreProcessor getPreProcessor() { + + return backedIterator.getPreProcessor(); + } + + + @Override + public boolean hasNext() { + if (resetPending.get()) { + if (resetSupported()) { + backedIterator.reset(); + counter.set(0); + current = 0; + resetPending.set(false); + } else + throw new UnsupportedOperationException("Reset isn't supported by underlying iterator"); + } + + boolean state = false; + if (current >= top) + return false; + state = backedIterator.hasNext(); + if (!state) + return false; + if (state && counter.get() < itemsPerPart) + return true; + else + return false; + + } + + @Override + public DataSet next() { + counter.incrementAndGet(); + if ((current == 0) && (bottom != 0)) { + backedIterator.reset(); + long cnt = current; + for (; cnt < bottom; ++cnt) { + if (backedIterator.hasNext()) + backedIterator.next(); + } + current = cnt+1; + } + else current++; + val p = backedIterator.next(); + return p; + } +} diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java new file mode 100644 index 000000000000..4bd851c86eb1 --- /dev/null +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java @@ -0,0 +1,121 @@ +package org.deeplearning4j.datasets.iterator; + +import lombok.val; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +import javax.naming.OperationNotSupportedException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +public class ScrollableMultiDataSetIterator implements MultiDataSetIterator { + private int thisPart = 0; + private int top = 0; + private int bottom = 0; + protected MultiDataSetIterator backedIterator; + protected AtomicLong counter = new AtomicLong(0); + + protected AtomicBoolean resetPending = new AtomicBoolean(false); + protected DataSet firstTrain = null; + protected MultiDataSet firstMultiTrain = null; + private double ratio; + private long totalExamples; + private long itemsPerPart; + private long current; + + public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter, + MultiDataSet firstTrain, int[] itemsPerPart) { + this.thisPart = num; + this.bottom = itemsPerPart[0]; + this.top = bottom + itemsPerPart[1]; + this.itemsPerPart = top; + + this.counter = counter; + //this.resetPending = resetPending; + this.firstTrain = null; + this.firstMultiTrain = firstTrain; + //this.totalExamples = totalExamples; + this.current = 0; + this.backedIterator = backedIterator; + this.resetPending = resetPending; + } + + @Override + public boolean resetSupported() { + return backedIterator.resetSupported(); + } + + @Override + public boolean asyncSupported() { + return backedIterator.asyncSupported(); + } + + @Override + public void reset() { + resetPending.set(true); + } + + @Override + public void setPreProcessor(MultiDataSetPreProcessor dataSetPreProcessor) { + backedIterator.setPreProcessor(dataSetPreProcessor); + } + + @Override + public MultiDataSetPreProcessor getPreProcessor() { + + throw new UnsupportedOperationException(); + } + + + @Override + public boolean hasNext() { + if (resetPending.get()) { + if (resetSupported()) { + backedIterator.reset(); + counter.set(0); + current = 0; + resetPending.set(false); + } else + throw new UnsupportedOperationException("Reset isn't supported by underlying iterator"); + } + + boolean state = false; + if (current >= top) + return false; + state = backedIterator.hasNext(); + if (!state) + return false; + if (state && counter.get() < itemsPerPart) + return true; + else + return false; + + } + + @Override + public MultiDataSet next() { + counter.incrementAndGet(); + if ((current == 0) && (bottom != 0)) { + backedIterator.reset(); + long cnt = current; + for (; cnt < bottom; ++cnt) { + if (backedIterator.hasNext()) + backedIterator.next(); + } + current = cnt+1; + } + else current++; + val p = backedIterator.next(); + return p; + } + + @Override + public MultiDataSet next(int i) { + throw new UnsupportedOperationException(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 55d4e0ccb4ad..c57f02d314bd 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3828,12 +3828,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint /** * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array */ - public native NDArray permute(@StdVector IntPointer dimensions); - public native NDArray permute(@StdVector IntBuffer dimensions); - public native NDArray permute(@StdVector int[] dimensions); - public native NDArray permute(@Const IntPointer dimensions, int rank); - public native NDArray permute(@Const IntBuffer dimensions, int rank); - public native NDArray permute(@Const int[] dimensions, int rank); + public native @ByVal NDArray permute(@StdVector IntPointer dimensions); + public native @ByVal NDArray permute(@StdVector IntBuffer dimensions); + public native @ByVal NDArray permute(@StdVector int[] dimensions); + public native @ByVal NDArray permute(@Const IntPointer dimensions, int rank); + public native @ByVal NDArray permute(@Const IntBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Const int[] dimensions, int rank); public native void permute(@Const IntPointer dimensions, int rank, @ByRef NDArray target); public native void permute(@Const IntBuffer dimensions, int rank, @ByRef NDArray target); @@ -3841,12 +3841,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native void permute(@StdVector IntPointer dimensions, @ByRef NDArray target); public native void permute(@StdVector IntBuffer dimensions, @ByRef NDArray target); public native void permute(@StdVector int[] dimensions, @ByRef NDArray target); - public native NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); - public native NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); - public native NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); - public native NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); - public native NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); - public native NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); public native void permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank, @ByRef NDArray target); public native void permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank, @ByRef NDArray target); @@ -3940,8 +3940,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint /** * apply transpose operation to the copy of this array, that is this array remains unaffected */ - public native NDArray transpose(); - public native @ByVal NDArray transp(); + public native @ByVal NDArray transpose(); /** * perform transpose operation and store result in target, this array remains unaffected @@ -4066,9 +4065,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * * if permute have been applied before or there are weird strides, then new buffer is allocated for new array */ - public native NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); /** * calculate strides and set given order