Skip to content

Commit

Permalink
[WIP] Misc DL4J/ND4J/DataVec Issues (#7340)
Browse files Browse the repository at this point in the history
* Add FirstDigitTransform (Benfords law) + tests

* Javadoc, polish

* #7325 Fix SameDiff.asFlatPrint

* Refactor DataVec readers to remove hard-coded use of Files, in favor of streams

* Add StreamInputSplit (partly complete)

* More tests, fixes

* Fixes for model import test failures

* DataVec fixes after earlier changes

* Another DataVec fix

* #7355 SameDiff array reuse fix

* #7343 SameDiff method for Pad op

* #7305 Fix getColumn on row vector (returning scalar, not view)

* #7168 Empty arrays - create only once

* #7002 Remove newFormat arg/field

* #7352 MultiLayerNetwork.output(DataSetIterator) validation

* Fixes

* Small fixes

* SameDiff variables: Switch to LinkedHashMap for consitent iteration order

* Fix validation NPE for LogFileWriter

* Reduce3 fixes

* Small test fix

* Small test fix

* Fix bad test

* Small test threshold tweak

* OpProfiler fix: null x array (random ops etc)

* Fix issue with array order not matching flattening order when Nd4j.ordering() == f - Nd4j.createFromArray
  • Loading branch information
AlexDBlack committed Mar 28, 2019
1 parent 1c64031 commit 633f9c7
Show file tree
Hide file tree
Showing 80 changed files with 1,558 additions and 618 deletions.
Expand Up @@ -17,8 +17,18 @@
package org.datavec.api.records.reader;

import org.datavec.api.records.listener.RecordListener;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.StreamInputSplit;
import org.datavec.api.split.streams.FileStreamCreatorFunction;
import org.datavec.api.writable.Writable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.function.Function;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -31,7 +41,9 @@
*/
public abstract class BaseRecordReader implements RecordReader {

protected InputSplit inputSplit;
protected List<RecordListener> listeners = new ArrayList<>();
protected Function<URI,InputStream> streamCreatorFn = new FileStreamCreatorFunction();

/** Invokes {@link RecordListener#recordRead(RecordReader, Object)} on all listeners. */
protected void invokeListeners(Object record) {
Expand All @@ -40,6 +52,17 @@ protected void invokeListeners(Object record) {
}
}

@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
this.inputSplit = split;
if(split instanceof StreamInputSplit){
StreamInputSplit s = (StreamInputSplit)split;
if(s.getStreamCreatorFn() != null){
this.streamCreatorFn = s.getStreamCreatorFn();
}
}
}

@Override
public List<RecordListener> getListeners() {
return listeners;
Expand Down
Expand Up @@ -16,7 +16,8 @@

package org.datavec.api.records.reader.impl;

import org.apache.commons.io.FileUtils;
import lombok.Getter;
import lombok.Setter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
Expand All @@ -29,10 +30,9 @@

import java.io.*;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
* File reader/writer
Expand All @@ -41,20 +41,20 @@
*/
public class FileRecordReader extends BaseRecordReader {

protected Iterator<File> iter;
protected Iterator<String> locationsIterator;
protected Iterator<URI> locationsIterator;
protected Configuration conf;
protected File currentFile;
protected URI currentUri;
protected List<String> labels;
protected boolean appendLabel = false;
protected InputSplit inputSplit;
@Getter @Setter
protected String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable

public FileRecordReader() {}

@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
super.initialize(split);
doInitialize(split);
this.inputSplit = split;
}


Expand All @@ -63,17 +63,16 @@ protected void doInitialize(InputSplit split) {
if (labels == null && appendLabel) {
URI[] locations = split.locations();
if (locations.length > 0) {
//root dir relative to example where the label is the parent directory and the root directory is
//recursively the parent of that
File parent = new File(locations[0]).getParentFile().getParentFile();
//calculate the labels relative to the parent file
labels = new ArrayList<>();

for (File labelDir : parent.listFiles())
labels.add(labelDir.getName());
Set<String> labels = new HashSet<>();
for(URI u : locations){
String[] pathSplit = u.toString().split("[/\\\\]");
labels.add(pathSplit[pathSplit.length-2]);
}
this.labels = new ArrayList<>(labels);
Collections.sort(this.labels);
}
}
locationsIterator = split.locationsPathIterator();
locationsIterator = split.locationsIterator();
}

@Override
Expand All @@ -89,14 +88,20 @@ public List<Writable> next() {
return nextRecord().getRecord();
}

private List<Writable> loadFromFile(File next) {
private List<Writable> loadFromStream(URI uri, InputStream next, Charset charset) {
List<Writable> ret = new ArrayList<>();
try {
ret.add(new Text(FileUtils.readFileToString(next)));
if (appendLabel)
ret.add(new IntWritable(labels.indexOf(next.getParentFile().getName())));
if(!(next instanceof BufferedInputStream)){
next = new BufferedInputStream(next);
}
String s = org.apache.commons.io.IOUtils.toString(next, charset);
ret.add(new Text(s));
if (appendLabel) {
int idx = getLabel(uri);
ret.add(new IntWritable(idx));
}
} catch (IOException e) {
e.printStackTrace();
throw new IllegalStateException("Error reading from input stream: " + uri);
}
return ret;
}
Expand All @@ -108,7 +113,16 @@ private List<Writable> loadFromFile(File next) {
* @return The index of the current file's parent directory
*/
public int getCurrentLabel() {
return labels.indexOf(currentFile.getParentFile().getName());
return getLabel(currentUri);
}

public int getLabel(URI uri){
String s = uri.toString();
int lastIdx = Math.max(s.lastIndexOf('/'), s.lastIndexOf('\\')); //Note: if neither are found, -1 is fine here
String sub = s.substring(0, lastIdx);
int secondLastIdx = Math.max(sub.lastIndexOf('/'), sub.lastIndexOf('\\'));
String name = s.substring(secondLastIdx+1, lastIdx);
return labels.indexOf(name);
}

public List<String> getLabels() {
Expand All @@ -121,15 +135,7 @@ public void setLabels(List<String> labels) {

@Override
public boolean hasNext() {
if (iter != null && iter.hasNext()) {
return true;
}
if (!locationsIterator.hasNext()) {
return false;
}
// iter is exhausted, set to iterate of the next location
this.advanceToNextLocation();
return iter != null && iter.hasNext();
return locationsIterator.hasNext();
}

@Override
Expand Down Expand Up @@ -191,41 +197,17 @@ public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IO

@Override
public Record nextRecord() {
if (iter == null || !iter.hasNext()) {
this.advanceToNextLocation();
}
File next = iter.next();
this.currentFile = next;
URI next = locationsIterator.next();
invokeListeners(next);
List<Writable> ret = loadFromFile(next);

return new org.datavec.api.records.impl.Record(ret,
new RecordMetaDataURI(next.toURI(), FileRecordReader.class));
}

protected File nextFile() {
if (iter == null || !iter.hasNext()) {
this.advanceToNextLocation();
List<Writable> ret;
try(InputStream s = streamCreatorFn.apply(next)) {
ret = loadFromStream(next, s, Charset.forName(charset));
} catch (IOException e){
throw new RuntimeException("Error reading from stream for URI: " + next);
}
File next = iter.next();
this.currentFile = next;
return next;
}

protected void advanceToNextLocation () {
//File file;
String path = locationsIterator.next(); // should always have file:// preceding
if(!path.startsWith("file:")){
path = "file:///" + path;
}
if(path.contains("\\")){
path = path.replaceAll("\\\\","/");
}
File file = new File(URI.create(path));
if (file.isDirectory())
iter = FileUtils.iterateFiles(file, null, true);
else
iter = Collections.singletonList(file).iterator();
return new org.datavec.api.records.impl.Record(ret,new RecordMetaDataURI(next, FileRecordReader.class));
}

@Override
Expand All @@ -240,8 +222,13 @@ public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throw
for (RecordMetaData meta : recordMetaDatas) {
URI uri = meta.getURI();

File f = new File(uri);
List<Writable> list = loadFromFile(f);
List<Writable> list;
try(InputStream s = streamCreatorFn.apply(uri)) {
list = loadFromStream(uri, s, Charset.forName(charset));
} catch (IOException e){
throw new RuntimeException("Error reading from stream for URI: " + uri);
}

out.add(new org.datavec.api.records.impl.Record(list, meta));
}

Expand Down
Expand Up @@ -48,12 +48,11 @@ public class LineRecordReader extends BaseRecordReader {
protected int splitIndex = 0;
protected int lineIndex = 0; //Line index within the current split
protected Configuration conf;
protected InputSplit inputSplit;
protected boolean initialized;

@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
this.inputSplit = split;
super.initialize(split);
this.iter = getIterator(0);
this.initialized = true;
}
Expand Down Expand Up @@ -82,7 +81,8 @@ public List<Writable> next() {
lineIndex = 0; //New split opened -> reset line index
try {
close();
iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream()));
// iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream()));
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
Expand Down Expand Up @@ -113,7 +113,7 @@ public boolean hasNext() {
lineIndex = 0; //New split -> reset line count
try {
close();
iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream()));
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
Expand Down Expand Up @@ -201,14 +201,9 @@ protected Iterator<String> getIterator(int location) {
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());

this.locations = uris.toArray(new URI[0]);
this.locations = uris.toArray(new URI[uris.size()]);
if (locations.length > 0) {
InputStream inputStream;
try {
inputStream = locations[location].toURL().openStream();
} catch (IOException e) {
throw new RuntimeException(e);
}
InputStream inputStream = streamCreatorFn.apply(locations[location]);
iterator = IOUtils.lineIterator(new InputStreamReader(inputStream));
}
}
Expand Down
Expand Up @@ -72,17 +72,12 @@ public SequenceRecord nextSequence() {
if(!hasNext()){
throw new NoSuchElementException("No next element");
}
File next = iter.next();
invokeListeners(next);

List<List<Writable>> out;
try {
out = loadAndClose(new FileInputStream(next));
} catch (IOException e) {
throw new RuntimeException(e);
}
URI next = locationsIterator.next();
invokeListeners(next);

return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next.toURI()));
List<List<Writable>> out = loadAndClose(streamCreatorFn.apply(next));
return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next));
}

private List<List<Writable>> loadAndClose(InputStream inputStream) {
Expand Down
Expand Up @@ -74,14 +74,9 @@ public SequenceRecord nextSequence() {
throw new NoSuchElementException("No next element");
}

File next = iter.next();
List<List<Writable>> out;
try {
out = loadAndClose(new FileInputStream(next));
} catch (IOException e){
throw new RuntimeException(e);
}
return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next.toURI()));
URI next = locationsIterator.next();
List<List<Writable>> out = loadAndClose(streamCreatorFn.apply(next));
return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next));
}

@Override
Expand Down
Expand Up @@ -16,7 +16,10 @@

package org.datavec.api.records.reader.impl.jackson;

import lombok.Getter;
import lombok.Setter;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.records.Record;
Expand All @@ -32,6 +35,8 @@

import java.io.*;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
Expand Down Expand Up @@ -69,6 +74,8 @@ public class JacksonRecordReader extends BaseRecordReader {
private int labelPosition;
private InputSplit is;
private Random r;
@Getter @Setter
private String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable

private URI[] uris;
private int cursor = 0;
Expand Down Expand Up @@ -102,6 +109,7 @@ public JacksonRecordReader(FieldSelection selection, ObjectMapper mapper, boolea
public void initialize(InputSplit split) throws IOException, InterruptedException {
if (split instanceof FileSplit)
throw new UnsupportedOperationException("Cannot use JacksonRecordReader with FileSplit");
super.initialize(inputSplit);
this.uris = split.locations();
if (shuffle) {
List<URI> list = Arrays.asList(uris);
Expand All @@ -125,8 +133,8 @@ public List<Writable> next() {
URI uri = uris[cursor++];
invokeListeners(uri);
String fileAsString;
try {
fileAsString = FileUtils.readFileToString(new File(uri.toURL().getFile()));
try (InputStream s = streamCreatorFn.apply(uri)){
fileAsString = IOUtils.toString(s, charset);
} catch (IOException e) {
throw new RuntimeException("Error reading URI file", e);
}
Expand Down

0 comments on commit 633f9c7

Please sign in to comment.