-
Notifications
You must be signed in to change notification settings - Fork 5
/
ExistingMiniBatchMultiDataSetIterator.java
105 lines (88 loc) · 2.47 KB
/
ExistingMiniBatchMultiDataSetIterator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package org.genericsystem.cv.nn;
import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
// Adaptation of ExistingMiniBatchDataSetIterator for MultiDataSets.
public class ExistingMiniBatchMultiDataSetIterator implements MultiDataSetIterator {
public static final String DEFAULT_PATTERN = "dataset-%d.bin";
private int currIdx;
private File rootDir;
private int totalBatches = -1;
private MultiDataSetPreProcessor multiDataSetPreProcessor;
private final String pattern;
/**
* Create with the given root directory, using the default filename pattern {@link #DEFAULT_PATTERN}
*
* @param rootDir
* the root directory to use
*/
public ExistingMiniBatchMultiDataSetIterator(File rootDir) {
this(rootDir, DEFAULT_PATTERN);
}
/**
*
* @param rootDir
* The root directory to use
* @param pattern
* The filename pattern to use. Used with {@code String.format(pattern,idx)}, where idx is an integer, starting
* at 0.
*/
public ExistingMiniBatchMultiDataSetIterator(File rootDir, String pattern) {
this.rootDir = rootDir;
rootDir.mkdirs();
totalBatches = rootDir.list().length;
this.pattern = pattern;
}
@Override
public MultiDataSet next(int num) {
throw new UnsupportedOperationException("Unable to load custom number of examples");
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
currIdx = 0;
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
this.multiDataSetPreProcessor = preProcessor;
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
return multiDataSetPreProcessor;
}
@Override
public boolean hasNext() {
return currIdx < totalBatches;
}
@Override
public void remove() {
// no opt;
}
@Override
public MultiDataSet next() {
try {
MultiDataSet ret = read(currIdx);
if (multiDataSetPreProcessor != null)
multiDataSetPreProcessor.preProcess(ret);
currIdx++;
return ret;
} catch (IOException e) {
throw new IllegalStateException("Unable to read dataset");
}
}
private MultiDataSet read(int idx) throws IOException {
File path = new File(rootDir, String.format(pattern, idx));
MultiDataSet d = new MultiDataSet();
d.load(path);
return d;
}
}