Skip to content

Commit

Permalink
#6366 RecordReaderMultiDataSetIterator: better validation/errors; 5D …
Browse files Browse the repository at this point in the history
…(3D CNN) NDArrayWritable support
  • Loading branch information
AlexDBlack committed Sep 6, 2018
1 parent 813e409 commit 6ee913b
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 2 deletions.
Expand Up @@ -21,16 +21,21 @@
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
Expand All @@ -55,6 +60,7 @@

import static org.junit.Assert.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;

public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
Expand Down Expand Up @@ -792,8 +798,131 @@ public void testExcludeStringColCSV() throws Exception {

assertEquals(expFeatures, mds.getFeatures(0));
assertEquals(expLabels, mds.getLabels(0));
}


private static final int nX = 32;
private static final int nY = 32;
private static final int nZ = 28;


@Test
public void testRRMDSI5D() {
int batchSize = 5;

CustomRecordReader recordReader = new CustomRecordReader();
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize,
1, /* Index of label in records */
2 /* number of different labels */);

int count = 0;
while(dataIter.hasNext()){
DataSet ds = dataIter.next();

int offset = 5*count;
for( int i=0; i<5; i++ ){
INDArray act = ds.getFeatures().get(interval(i,i,true), all(), all(), all(), all());
INDArray exp = Nd4j.valueArrayOf(new int[]{1, 1, nZ, nX, nY}, i + offset );
assertEquals(exp, act);
}
count++;
}

assertEquals(2, count);
}


static class CustomRecordReader extends BaseRecordReader {

int n = 0;

CustomRecordReader() { }

@Override
public boolean batchesSupported() {
return false;
}

@Override
public List<List<Writable>> next(int num) {
throw new RuntimeException("Not implemented");
}

@Override
public List<Writable> next() {
INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] {1, 1, nZ, nY, nX }, 'C').assign(n);
final List<Writable>res = RecordConverter.toRecord(nd);
res.add(new IntWritable(0));
n++;
return res;
}

@Override
public boolean hasNext() {
return n<10;
}

final static ArrayList<String> labels = new ArrayList<>(2);
static {
labels.add("lbl0");
labels.add("lbl1");
}
@Override
public List<String> getLabels() {
return labels;
}

@Override
public void reset() {
n = 0;
}

@Override
public boolean resetSupported() {
return true;
}

@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) {
return next();
}

@Override
public Record nextRecord() {
List<Writable> r = next();
return new org.datavec.api.records.impl.Record(r, null);
}

@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
throw new RuntimeException("Not implemented");
}

@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) {
throw new RuntimeException("Not implemented");
}

@Override
public void close() {
}

@Override
public void setConf(Configuration conf) {
}

@Override
public Configuration getConf() {
return null;
}

@Override
public void initialize(InputSplit split) {
n = 0;
}
@Override
public void initialize(Configuration conf, InputSplit split) {
n = 0;
}
}
}
Expand Up @@ -33,6 +33,7 @@
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
Expand Down Expand Up @@ -540,6 +541,15 @@ private INDArray convertWritablesHelper(List<List<Writable>> list, int minValues
}

private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
Preconditions.checkState(singleExample.size(0) == 1 && singleExample.rank() == arr.rank(), "Cannot put array: array should have leading dimension of 1 " +
"and equal rank to output array. Attempting to put array of shape %s into output array of shape %s", singleExample.shape(), arr.shape());

long[] arrShape = arr.shape();
long[] singleShape = singleExample.shape();
for( int i=1; i<arr.rank(); i++ ){
Preconditions.checkState(arrShape[i] == singleShape[i], "Single example array and output arrays differ at position %s:" +
"single example shape %s, output array shape %s", i, singleShape, arrShape);
}
switch (arr.rank()) {
case 2:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all()}, singleExample);
Expand All @@ -552,8 +562,12 @@ private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()}, singleExample);
break;
case 5:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.all()}, singleExample);
break;
default:
throw new RuntimeException("Unexpected rank: " + arr.rank());
throw new RuntimeException("Unexpected array rank: " + arr.rank() + " with shape " + Arrays.toString(arr.shape()) + " input arrays should be rank 2 to 5 inclusive");
}
}

Expand Down
Expand Up @@ -24,6 +24,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

/**
* Import Keras Tokenizer
Expand All @@ -44,7 +45,7 @@ public void importTest() throws IOException, InvalidKerasConfigurationException
KerasTokenizer tokenizer = KerasTokenizer.fromJson(configResource.getFile().getAbsolutePath());

assertEquals(100, tokenizer.getNumWords().intValue());
assertEquals(tokenizer.isLower();
assertTrue(tokenizer.isLower());
assertEquals(" ", tokenizer.getSplit());
assertFalse(tokenizer.isCharLevel());
assertEquals(0, tokenizer.getDocumentCount().intValue());
Expand Down

0 comments on commit 6ee913b

Please sign in to comment.