From 633f9c79ebf1345b81d8157491643f4fb95d36b2 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 28 Mar 2019 23:08:35 +1100 Subject: [PATCH] [WIP] Misc DL4J/ND4J/DataVec Issues (#7340) * Add FirstDigitTransform (Benfords law) + tests * Javadoc, polish * #7325 Fix SameDiff.asFlatPrint * Refactor DataVec readers to remove hard-coded use of Files, in favor of streams * Add StreamInputSplit (partly complete) * More tests, fixes * Fixes for model import test failures * DataVec fixes after earlier changes * Another DataVec fix * #7355 SameDiff array reuse fix * #7343 SameDiff method for Pad op * #7305 Fix getColumn on row vector (returning scalar, not view) * #7168 Empty arrays - create only once * #7002 Remove newFormat arg/field * #7352 MultiLayerNetwork.output(DataSetIterator) validation * Fixes * Small fixes * SameDiff variables: Switch to LinkedHashMap for consitent iteration order * Fix validation NPE for LogFileWriter * Reduce3 fixes * Small test fix * Small test fix * Fix bad test * Small test threshold tweak * OpProfiler fix: null x array (random ops etc) * Fix issue with array order not matching flattening order when Nd4j.ordering() == f - Nd4j.createFromArray --- .../api/records/reader/BaseRecordReader.java | 23 ++ .../records/reader/impl/FileRecordReader.java | 121 +++++------ .../records/reader/impl/LineRecordReader.java | 17 +- .../impl/csv/CSVSequenceRecordReader.java | 13 +- .../JacksonLineSequenceRecordReader.java | 11 +- .../impl/jackson/JacksonRecordReader.java | 12 +- .../impl/regex/RegexSequenceRecordReader.java | 13 +- .../api/split/CollectionInputSplit.java | 3 +- .../api/split/NumberedFileInputSplit.java | 12 +- .../datavec/api/split/StreamInputSplit.java | 154 +++++++++++++ .../streams/FileStreamCreatorFunction.java | 38 ++++ .../api/transform/TransformProcess.java | 42 +++- .../categorical/FirstDigitTransform.java | 202 ++++++++++++++++++ .../api/util/files/UriFromPathIterator.java | 10 +- .../records/reader/impl/LineReaderTest.java | 15 +- .../org/datavec/api/split/FileSplitTest.java | 103 --------- .../api/split/TestStreamInputSplit.java | 201 +++++++++++++++++ .../api/transform/transform/TestJsonYaml.java | 2 + .../transform/transform/TestTransforms.java | 49 ++++- .../codec/reader/BaseCodecRecordReader.java | 27 +-- .../datavec/poi/excel/ExcelRecordReader.java | 48 +++-- .../spark/transform/ExecutionTest.java | 51 +++++ .../nn/mkldnn/ValidateMKLDNN.java | 1 + .../KerasAtrousConvolution1D.java | 2 +- .../convolutional/KerasConvolution1D.java | 2 +- .../nn/conf/layers/util/MaskZeroLayer.java | 4 +- .../nn/multilayer/MultiLayerNetwork.java | 26 ++- .../DifferentialFunctionFactory.java | 4 + .../org/nd4j/autodiff/samediff/SameDiff.java | 80 ++++++- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 80 +++++++ .../samediff/serde/FlatBuffersMapper.java | 2 - .../java/org/nd4j/graph/ui/LogFileWriter.java | 5 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 5 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 +- .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 14 +- .../linalg/api/ops/BaseReduceFloatOp.java | 10 +- .../nd4j/linalg/api/ops/BaseReduceLongOp.java | 6 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 14 +- .../nd4j/linalg/api/ops/BaseReduceSameOp.java | 6 +- .../linalg/api/ops/IndexAccumulation.java | 6 - .../org/nd4j/linalg/api/ops/ReduceOp.java | 6 - .../api/ops/impl/reduce/BaseReduction.java | 104 --------- .../api/ops/impl/reduce/bool/IsInf.java | 2 +- .../api/ops/impl/reduce/bool/IsNaN.java | 2 +- .../api/ops/impl/reduce/floating/Mean.java | 6 +- .../linalg/api/ops/impl/reduce/same/Max.java | 4 +- .../linalg/api/ops/impl/reduce/same/Min.java | 4 +- .../linalg/api/ops/impl/reduce/same/Prod.java | 4 +- .../linalg/api/ops/impl/reduce/same/Sum.java | 4 +- .../api/ops/impl/reduce3/BaseReduce3Op.java | 14 +- .../api/ops/impl/reduce3/CosineDistance.java | 4 +- .../ops/impl/reduce3/CosineSimilarity.java | 4 +- .../nd4j/linalg/api/ops/impl/reduce3/Dot.java | 2 +- .../api/ops/impl/reduce3/EqualsWithEps.java | 2 +- .../ops/impl/reduce3/EuclideanDistance.java | 7 +- .../api/ops/impl/reduce3/HammingDistance.java | 6 +- .../api/ops/impl/reduce3/JaccardDistance.java | 6 +- .../ops/impl/reduce3/ManhattanDistance.java | 9 +- .../impl/summarystats/StandardDeviation.java | 2 +- .../api/ops/impl/summarystats/Variance.java | 8 +- .../linalg/api/ops/impl/transforms/Pad.java | 4 +- .../java/org/nd4j/linalg/api/shape/Shape.java | 4 + .../linalg/dimensionalityreduction/PCA.java | 5 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 100 ++++++--- .../org/nd4j/linalg/profiler/OpProfiler.java | 12 +- .../ops/executioner/CudaExecutioner.java | 7 +- .../nativecpu/ops/NativeOpExecutioner.java | 10 +- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 181 ++++++++++++---- .../opvalidation/ReductionOpValidation.java | 15 +- .../opvalidation/TransformOpValidation.java | 21 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 21 ++ .../nd4j/autodiff/ui/FileReadWriteTests.java | 18 ++ .../evaluation/EvaluationCalibrationTest.java | 6 +- .../nd4j/imports/TFGraphs/BERTGraphTest.java | 3 + .../test/java/org/nd4j/linalg/LoneTest.java | 4 +- .../org/nd4j/linalg/NDArrayTestsFortran.java | 4 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 65 ++++++ .../org/nd4j/linalg/dataset/DataSetTest.java | 5 +- .../nd4j/linalg/ops/OpExecutionerTests.java | 41 +++- .../linalg/api/buffer/BaseDataBuffer.java | 10 +- 80 files changed, 1558 insertions(+), 618 deletions(-) create mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java create mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/split/streams/FileStreamCreatorFunction.java create mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java delete mode 100644 datavec/datavec-api/src/test/java/org/datavec/api/split/FileSplitTest.java create mode 100644 datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/BaseReduction.java diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java index f0bfa496e1a7..986984a37f08 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java @@ -17,8 +17,18 @@ package org.datavec.api.records.reader; import org.datavec.api.records.listener.RecordListener; +import org.datavec.api.split.InputSplit; +import org.datavec.api.split.StreamInputSplit; +import org.datavec.api.split.streams.FileStreamCreatorFunction; import org.datavec.api.writable.Writable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.function.Function; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -31,7 +41,9 @@ */ public abstract class BaseRecordReader implements RecordReader { + protected InputSplit inputSplit; protected List listeners = new ArrayList<>(); + protected Function streamCreatorFn = new FileStreamCreatorFunction(); /** Invokes {@link RecordListener#recordRead(RecordReader, Object)} on all listeners. */ protected void invokeListeners(Object record) { @@ -40,6 +52,17 @@ protected void invokeListeners(Object record) { } } + @Override + public void initialize(InputSplit split) throws IOException, InterruptedException { + this.inputSplit = split; + if(split instanceof StreamInputSplit){ + StreamInputSplit s = (StreamInputSplit)split; + if(s.getStreamCreatorFn() != null){ + this.streamCreatorFn = s.getStreamCreatorFn(); + } + } + } + @Override public List getListeners() { return listeners; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java index d35ecf75c78c..3497c14e6ea9 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java @@ -16,7 +16,8 @@ package org.datavec.api.records.reader.impl; -import org.apache.commons.io.FileUtils; +import lombok.Getter; +import lombok.Setter; import org.datavec.api.conf.Configuration; import org.datavec.api.records.Record; import org.datavec.api.records.metadata.RecordMetaData; @@ -29,10 +30,9 @@ import java.io.*; import java.net.URI; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.*; /** * File reader/writer @@ -41,20 +41,20 @@ */ public class FileRecordReader extends BaseRecordReader { - protected Iterator iter; - protected Iterator locationsIterator; + protected Iterator locationsIterator; protected Configuration conf; - protected File currentFile; + protected URI currentUri; protected List labels; protected boolean appendLabel = false; - protected InputSplit inputSplit; + @Getter @Setter + protected String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable public FileRecordReader() {} @Override public void initialize(InputSplit split) throws IOException, InterruptedException { + super.initialize(split); doInitialize(split); - this.inputSplit = split; } @@ -63,17 +63,16 @@ protected void doInitialize(InputSplit split) { if (labels == null && appendLabel) { URI[] locations = split.locations(); if (locations.length > 0) { - //root dir relative to example where the label is the parent directory and the root directory is - //recursively the parent of that - File parent = new File(locations[0]).getParentFile().getParentFile(); - //calculate the labels relative to the parent file - labels = new ArrayList<>(); - - for (File labelDir : parent.listFiles()) - labels.add(labelDir.getName()); + Set labels = new HashSet<>(); + for(URI u : locations){ + String[] pathSplit = u.toString().split("[/\\\\]"); + labels.add(pathSplit[pathSplit.length-2]); + } + this.labels = new ArrayList<>(labels); + Collections.sort(this.labels); } } - locationsIterator = split.locationsPathIterator(); + locationsIterator = split.locationsIterator(); } @Override @@ -89,14 +88,20 @@ public List next() { return nextRecord().getRecord(); } - private List loadFromFile(File next) { + private List loadFromStream(URI uri, InputStream next, Charset charset) { List ret = new ArrayList<>(); try { - ret.add(new Text(FileUtils.readFileToString(next))); - if (appendLabel) - ret.add(new IntWritable(labels.indexOf(next.getParentFile().getName()))); + if(!(next instanceof BufferedInputStream)){ + next = new BufferedInputStream(next); + } + String s = org.apache.commons.io.IOUtils.toString(next, charset); + ret.add(new Text(s)); + if (appendLabel) { + int idx = getLabel(uri); + ret.add(new IntWritable(idx)); + } } catch (IOException e) { - e.printStackTrace(); + throw new IllegalStateException("Error reading from input stream: " + uri); } return ret; } @@ -108,7 +113,16 @@ private List loadFromFile(File next) { * @return The index of the current file's parent directory */ public int getCurrentLabel() { - return labels.indexOf(currentFile.getParentFile().getName()); + return getLabel(currentUri); + } + + public int getLabel(URI uri){ + String s = uri.toString(); + int lastIdx = Math.max(s.lastIndexOf('/'), s.lastIndexOf('\\')); //Note: if neither are found, -1 is fine here + String sub = s.substring(0, lastIdx); + int secondLastIdx = Math.max(sub.lastIndexOf('/'), sub.lastIndexOf('\\')); + String name = s.substring(secondLastIdx+1, lastIdx); + return labels.indexOf(name); } public List getLabels() { @@ -121,15 +135,7 @@ public void setLabels(List labels) { @Override public boolean hasNext() { - if (iter != null && iter.hasNext()) { - return true; - } - if (!locationsIterator.hasNext()) { - return false; - } - // iter is exhausted, set to iterate of the next location - this.advanceToNextLocation(); - return iter != null && iter.hasNext(); + return locationsIterator.hasNext(); } @Override @@ -191,41 +197,17 @@ public List record(URI uri, DataInputStream dataInputStream) throws IO @Override public Record nextRecord() { - if (iter == null || !iter.hasNext()) { - this.advanceToNextLocation(); - } - File next = iter.next(); - this.currentFile = next; + URI next = locationsIterator.next(); invokeListeners(next); - List ret = loadFromFile(next); - return new org.datavec.api.records.impl.Record(ret, - new RecordMetaDataURI(next.toURI(), FileRecordReader.class)); - } - - protected File nextFile() { - if (iter == null || !iter.hasNext()) { - this.advanceToNextLocation(); + List ret; + try(InputStream s = streamCreatorFn.apply(next)) { + ret = loadFromStream(next, s, Charset.forName(charset)); + } catch (IOException e){ + throw new RuntimeException("Error reading from stream for URI: " + next); } - File next = iter.next(); - this.currentFile = next; - return next; - } - protected void advanceToNextLocation () { - //File file; - String path = locationsIterator.next(); // should always have file:// preceding - if(!path.startsWith("file:")){ - path = "file:///" + path; - } - if(path.contains("\\")){ - path = path.replaceAll("\\\\","/"); - } - File file = new File(URI.create(path)); - if (file.isDirectory()) - iter = FileUtils.iterateFiles(file, null, true); - else - iter = Collections.singletonList(file).iterator(); + return new org.datavec.api.records.impl.Record(ret,new RecordMetaDataURI(next, FileRecordReader.class)); } @Override @@ -240,8 +222,13 @@ public List loadFromMetaData(List recordMetaDatas) throw for (RecordMetaData meta : recordMetaDatas) { URI uri = meta.getURI(); - File f = new File(uri); - List list = loadFromFile(f); + List list; + try(InputStream s = streamCreatorFn.apply(uri)) { + list = loadFromStream(uri, s, Charset.forName(charset)); + } catch (IOException e){ + throw new RuntimeException("Error reading from stream for URI: " + uri); + } + out.add(new org.datavec.api.records.impl.Record(list, meta)); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java index 1180efdbabcc..858810a35a66 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java @@ -48,12 +48,11 @@ public class LineRecordReader extends BaseRecordReader { protected int splitIndex = 0; protected int lineIndex = 0; //Line index within the current split protected Configuration conf; - protected InputSplit inputSplit; protected boolean initialized; @Override public void initialize(InputSplit split) throws IOException, InterruptedException { - this.inputSplit = split; + super.initialize(split); this.iter = getIterator(0); this.initialized = true; } @@ -82,7 +81,8 @@ public List next() { lineIndex = 0; //New split opened -> reset line index try { close(); - iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream())); +// iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream())); + iter = getIterator(splitIndex); onLocationOpen(locations[splitIndex]); } catch (IOException e) { e.printStackTrace(); @@ -113,7 +113,7 @@ public boolean hasNext() { lineIndex = 0; //New split -> reset line count try { close(); - iter = IOUtils.lineIterator(new InputStreamReader(locations[splitIndex].toURL().openStream())); + iter = getIterator(splitIndex); onLocationOpen(locations[splitIndex]); } catch (IOException e) { e.printStackTrace(); @@ -201,14 +201,9 @@ protected Iterator getIterator(int location) { final Iterator uriIterator = inputSplit.locationsIterator(); while(uriIterator.hasNext()) uris.add(uriIterator.next()); - this.locations = uris.toArray(new URI[0]); + this.locations = uris.toArray(new URI[uris.size()]); if (locations.length > 0) { - InputStream inputStream; - try { - inputStream = locations[location].toURL().openStream(); - } catch (IOException e) { - throw new RuntimeException(e); - } + InputStream inputStream = streamCreatorFn.apply(locations[location]); iterator = IOUtils.lineIterator(new InputStreamReader(inputStream)); } } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java index 68bf24ad145e..a3808edaa82e 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java @@ -72,17 +72,12 @@ public SequenceRecord nextSequence() { if(!hasNext()){ throw new NoSuchElementException("No next element"); } - File next = iter.next(); - invokeListeners(next); - List> out; - try { - out = loadAndClose(new FileInputStream(next)); - } catch (IOException e) { - throw new RuntimeException(e); - } + URI next = locationsIterator.next(); + invokeListeners(next); - return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next.toURI())); + List> out = loadAndClose(streamCreatorFn.apply(next)); + return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next)); } private List> loadAndClose(InputStream inputStream) { diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java index cbd61f6f26bb..8352f98fc2b7 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java @@ -74,14 +74,9 @@ public SequenceRecord nextSequence() { throw new NoSuchElementException("No next element"); } - File next = iter.next(); - List> out; - try { - out = loadAndClose(new FileInputStream(next)); - } catch (IOException e){ - throw new RuntimeException(e); - } - return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next.toURI())); + URI next = locationsIterator.next(); + List> out = loadAndClose(streamCreatorFn.apply(next)); + return new org.datavec.api.records.impl.SequenceRecord(out, new RecordMetaDataURI(next)); } @Override diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java index 5334f123707f..cdae52c4f779 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java @@ -16,7 +16,10 @@ package org.datavec.api.records.reader.impl.jackson; +import lombok.Getter; +import lombok.Setter; import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; import org.datavec.api.conf.Configuration; import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.records.Record; @@ -32,6 +35,8 @@ import java.io.*; import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.*; /** @@ -69,6 +74,8 @@ public class JacksonRecordReader extends BaseRecordReader { private int labelPosition; private InputSplit is; private Random r; + @Getter @Setter + private String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable private URI[] uris; private int cursor = 0; @@ -102,6 +109,7 @@ public JacksonRecordReader(FieldSelection selection, ObjectMapper mapper, boolea public void initialize(InputSplit split) throws IOException, InterruptedException { if (split instanceof FileSplit) throw new UnsupportedOperationException("Cannot use JacksonRecordReader with FileSplit"); + super.initialize(inputSplit); this.uris = split.locations(); if (shuffle) { List list = Arrays.asList(uris); @@ -125,8 +133,8 @@ public List next() { URI uri = uris[cursor++]; invokeListeners(uri); String fileAsString; - try { - fileAsString = FileUtils.readFileToString(new File(uri.toURL().getFile())); + try (InputStream s = streamCreatorFn.apply(uri)){ + fileAsString = IOUtils.toString(s, charset); } catch (IOException e) { throw new RuntimeException("Error reading URI file", e); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java index fb7201d4717a..410e5b9fe994 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java @@ -27,6 +27,7 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; +import org.nd4j.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -156,17 +157,17 @@ public void reset() { @Override public SequenceRecord nextSequence() { - File next = this.nextFile(); + Preconditions.checkState(hasNext(), "No next element available"); + URI next = locationsIterator.next(); String fileContents; - try { - fileContents = FileUtils.readFileToString(next, charset.name()); + try (InputStream s = streamCreatorFn.apply(next)){ + fileContents = IOUtils.toString(s, charset); } catch (IOException e) { throw new RuntimeException(e); } - List> sequence = loadSequence(fileContents, next.toURI()); - return new org.datavec.api.records.impl.SequenceRecord(sequence, - new RecordMetaDataURI(next.toURI(), RegexSequenceRecordReader.class)); + List> sequence = loadSequence(fileContents, next); + return new org.datavec.api.records.impl.SequenceRecord(sequence, new RecordMetaDataURI(next, RegexSequenceRecordReader.class)); } @Override diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java index 4c3ed6f52f39..0f2c7f3f888c 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java @@ -18,6 +18,7 @@ import java.io.*; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.LinkedList; @@ -34,7 +35,7 @@ public CollectionInputSplit(URI[] array){ } public CollectionInputSplit(Collection list) { - uriStrings = new LinkedList<>(); + uriStrings = new ArrayList<>(list.size()); for (URI uri : list) { uriStrings.add(uri.toString()); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java index e53587abedd6..8c1e368aa6b5 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java @@ -122,8 +122,16 @@ public long length() { public URI[] locations() { URI[] uris = new URI[(int) length()]; int x = 0; - for (int i = minIdx; i <= maxIdx; i++) { - uris[x++] = Paths.get(String.format(baseString, i)).toUri(); + if(baseString.matches(".*:/.*")){ + //URI (has scheme) + for (int i = minIdx; i <= maxIdx; i++) { + uris[x++] = URI.create(String.format(baseString, i)); + } + } else { + //File, no scheme + for (int i = minIdx; i <= maxIdx; i++) { + uris[x++] = new File(String.format(baseString, i)).toURI(); + } } return uris; } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java new file mode 100644 index 000000000000..24737dbab692 --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java @@ -0,0 +1,154 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.api.split; + +import lombok.Data; +import lombok.NonNull; +import org.datavec.api.util.files.ShuffledListIterator; +import org.nd4j.linalg.function.Function; +import org.nd4j.linalg.util.MathUtils; + +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.util.*; + +/** + * StreamInputSplit is a way of specifying input as a bunch of URIs, as well as the way those URIs should be opened. + * For example, if data was stored remotely (HDFS, S3, etc) you could use StreamInputSplit to load them, doing two things: + * (a) providing the URIs of the remote data to the constructor
+ * (b) providing a {@code Function} that opens an InputStream for the given URI.
+ *
+ * Note: supports optional randomization (shuffling of order in which streams are read) via {@link #StreamInputSplit(List, Function, Random)} + * by providing a {@link Random} instance. If no Random instance is provided, the order will always be according to the + * order in the provided list of URIs. + * + * @author Alex Black + */ +@Data +public class StreamInputSplit implements InputSplit { + + protected List uris; + protected Function streamCreatorFn; + protected Random rng; + protected int[] order; + + /** + * Create a StreamInputSplit with no randomization + * + * @param uris The list of URIs to load + * @param streamCreatorFn The function to be used to create InputStream objects for a given URI. + */ + public StreamInputSplit(@NonNull List uris, @NonNull Function streamCreatorFn) { + this(uris, streamCreatorFn, null); + } + + /** + * Create a StreamInputSplit with optional randomization + * + * @param uris The list of URIs to load + * @param streamCreatorFn The function to be used to create InputStream objects for a given URI + * @param rng Random number generator instance. If non-null: streams will be iterated over in a random + * order. If null: no randomization (iteration order is according to the URIs list) + */ + public StreamInputSplit(@NonNull List uris, @NonNull Function streamCreatorFn, Random rng){ + this.uris = uris; + this.streamCreatorFn = streamCreatorFn; + this.rng = rng; + } + + @Override + public boolean canWriteToLocation(URI location) { + throw new UnsupportedOperationException(); + } + + @Override + public String addNewLocation() { + throw new UnsupportedOperationException(); + } + + @Override + public String addNewLocation(String location) { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSplitLocations(boolean reset) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean needsBootstrapForWrite() { + throw new UnsupportedOperationException(); + } + + @Override + public void bootStrapForWrite() { + throw new UnsupportedOperationException(); + } + + @Override + public OutputStream openOutputStreamFor(String location) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream openInputStreamFor(String location) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public long length() { + return uris.size(); + } + + @Override + public URI[] locations() { + return uris.toArray(new URI[uris.size()]); + } + + @Override + public Iterator locationsIterator() { + if(rng == null){ + return uris.iterator(); + } else { + if(order == null){ + order = new int[uris.size()]; + for( int i=0; i(uris, order); + } + } + + @Override + public Iterator locationsPathIterator() { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + //No op + } + + @Override + public boolean resetSupported() { + return true; + } +} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/split/streams/FileStreamCreatorFunction.java b/datavec/datavec-api/src/main/java/org/datavec/api/split/streams/FileStreamCreatorFunction.java new file mode 100644 index 000000000000..53b8a415ff03 --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/split/streams/FileStreamCreatorFunction.java @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.api.split.streams; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.function.Function; + +import java.io.*; +import java.net.URI; + +public class FileStreamCreatorFunction implements Function, Serializable { + + @Override + public InputStream apply(URI uri) { + Preconditions.checkState(uri.getScheme() == null || uri.getScheme().equalsIgnoreCase("file"), + "Attempting to open URI that is not a File URI; for other stream types, you must use an appropriate stream loader function. URI: %s", uri); + try { + return new FileInputStream(new File(uri)); + } catch (IOException e){ + throw new RuntimeException("Error loading stream for file: " + uri, e); + } + } + +} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index 126c312633cb..85905f15df28 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -38,10 +38,7 @@ import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; import org.datavec.api.transform.sequence.window.WindowFunction; import org.datavec.api.transform.serde.JsonMappers; -import org.datavec.api.transform.transform.categorical.CategoricalToIntegerTransform; -import org.datavec.api.transform.transform.categorical.CategoricalToOneHotTransform; -import org.datavec.api.transform.transform.categorical.IntegerToCategoricalTransform; -import org.datavec.api.transform.transform.categorical.StringToCategoricalTransform; +import org.datavec.api.transform.transform.categorical.*; import org.datavec.api.transform.transform.column.*; import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform; import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform; @@ -1409,6 +1406,43 @@ public Builder ndArrayDistanceTransform(String newColumnName, Distance distance, return transform(new NDArrayDistanceTransform(newColumnName, distance, firstCol, secondCol)); } + /** + * FirstDigitTransform converts a column to a categorical column, with values being the first digit of the number.
+ * For example, "3.1415" becomes "3" and "2.0" becomes "2".
+ * Negative numbers ignore the sign: "-7.123" becomes "7".
+ * Note that two {@link FirstDigitTransform.Mode}s are supported, which determines how non-numerical entries should be handled:
+ * EXCEPTION_ON_INVALID: output has 10 category values ("0", ..., "9"), and any non-numerical values result in an exception
+ * INCLUDE_OTHER_CATEGORY: output has 11 category values ("0", ..., "9", "Other"), all non-numerical values are mapped to "Other"
+ *
+ * FirstDigitTransform is useful (combined with {@link CategoricalToOneHotTransform} and Reductions) to implement + * Benford's law. + * + * @param inputColumn Input column name + * @param outputColumn Output column name. If same as input, input column is replaced + */ + public Builder firstDigitTransform(String inputColumn, String outputColumn){ + return firstDigitTransform(inputColumn, outputColumn, FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY); + } + + /** + * FirstDigitTransform converts a column to a categorical column, with values being the first digit of the number.
+ * For example, "3.1415" becomes "3" and "2.0" becomes "2".
+ * Negative numbers ignore the sign: "-7.123" becomes "7".
+ * Note that two {@link FirstDigitTransform.Mode}s are supported, which determines how non-numerical entries should be handled:
+ * EXCEPTION_ON_INVALID: output has 10 category values ("0", ..., "9"), and any non-numerical values result in an exception
+ * INCLUDE_OTHER_CATEGORY: output has 11 category values ("0", ..., "9", "Other"), all non-numerical values are mapped to "Other"
+ *
+ * FirstDigitTransform is useful (combined with {@link CategoricalToOneHotTransform} and Reductions) to implement + * Benford's law. + * + * @param inputColumn Input column name + * @param outputColumn Output column name. If same as input, input column is replaced + * @param mode See {@link FirstDigitTransform.Mode} + */ + public Builder firstDigitTransform(String inputColumn, String outputColumn, FirstDigitTransform.Mode mode){ + return transform(new FirstDigitTransform(inputColumn, outputColumn, mode)); + } + /** * Create the TransformProcess object */ diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java new file mode 100644 index 000000000000..f244a71ad2cf --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java @@ -0,0 +1,202 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.api.transform.transform.categorical; + +import org.datavec.api.transform.metadata.CategoricalMetaData; +import org.datavec.api.transform.metadata.ColumnMetaData; +import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.transform.BaseTransform; +import org.datavec.api.writable.Text; +import org.datavec.api.writable.Writable; +import org.nd4j.base.Preconditions; +import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * FirstDigitTransform converts a column to a categorical column, with values being the first digit of the number.
+ * For example, "3.1415" becomes "3" and "2.0" becomes "2".
+ * Negative numbers ignore the sign: "-7.123" becomes "7".
+ * Note that two {@link Mode}s are supported, which determines how non-numerical entries should be handled:
+ * EXCEPTION_ON_INVALID: output has 10 category values ("0", ..., "9"), and any non-numerical values result in an exception
+ * INCLUDE_OTHER_CATEGORY: output has 11 category values ("0", ..., "9", "Other"), all non-numerical values are mapped to "Other"
+ *
+ * FirstDigitTransform is useful (combined with {@link CategoricalToOneHotTransform} and Reductions) to implement + * Benford's law. + + * + * @author Alex Black + */ +@JsonIgnoreProperties({"inputSchema", "columnIdx"}) +public class FirstDigitTransform extends BaseTransform { + public static final String OTHER_CATEGORY = "Other"; + + /** + * Mode determines how non-numerical entries should be handled:
+ * EXCEPTION_ON_INVALID: output has 10 category values ("0", ..., "9"), and any non-numerical values result in an exception
+ * INCLUDE_OTHER_CATEGORY: output has 11 category values ("0", ..., "9", "Other"), all non-numerical values are mapped to "Other"
+ */ + public enum Mode { + EXCEPTION_ON_INVALID, + INCLUDE_OTHER_CATEGORY + } + + protected String inputColumn; + protected String outputColumn; + protected Mode mode; + private int columnIdx = -1; + + /** + * @param inputColumn Input column name + * @param outputColumn Output column name. If same as input, input column is replaced + * @param mode See {@link FirstDigitTransform.Mode} + */ + public FirstDigitTransform(@JsonProperty("inputColumn") String inputColumn, @JsonProperty("outputColumn") String outputColumn, + @JsonProperty("mode") Mode mode){ + this.inputColumn = inputColumn; + this.outputColumn = outputColumn; + this.mode = mode; + } + + @Override + public List map(List writables) { + List out = new ArrayList<>(); + int i=0; + boolean inplace = inputColumn.equals(outputColumn); + for(Writable w : writables){ + if(i++ == columnIdx) { + if(!inplace){ + out.add(w); + } + + String s = w.toString(); + if (s.isEmpty()) { + if (mode == Mode.INCLUDE_OTHER_CATEGORY) { + out.add(new Text(OTHER_CATEGORY)); + } else { + throw new IllegalStateException("Encountered empty string in FirstDigitTransform that is set to Mode.EXCEPTION_ON_INVALID." + + " Either data contains an invalid (non-numerical) entry, or set FirstDigitTransform to Mode.INCLUDE_OTHER_CATEGORY"); + } + } else { + char first = s.charAt(0); + if (first == '-' && s.length() > 1) { + //Handle negatives + first = s.charAt(1); + } + if (first >= '0' && first <= '9') { + out.add(new Text(String.valueOf(first))); + } else { + if (mode == Mode.INCLUDE_OTHER_CATEGORY) { + out.add(new Text(OTHER_CATEGORY)); + } else { + String s2 = s; + if (s.length() > 100) { + s2 = s2.substring(0, 100) + "..."; + } + throw new IllegalStateException("Encountered string \"" + s2 + "\" with non-numerical first character in " + + "FirstDigitTransform that is set to Mode.EXCEPTION_ON_INVALID." + + " Either data contains an invalid (non-numerical) entry, or set FirstDigitTransform to Mode.INCLUDE_OTHER_CATEGORY"); + } + } + } + } else { + out.add(w); + } + } + return out; + } + + @Override + public Object map(Object input) { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + public Object mapSequence(Object sequence) { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + public String toString() { + return "FirstDigitTransform(input=\"" + inputColumn + "\",output=\"" + outputColumn + "\",mode=" + mode + ")"; + } + + @Override + public Schema transform(Schema inputSchema) { + List origNames = inputSchema.getColumnNames(); + List origMeta = inputSchema.getColumnMetaData(); + + Preconditions.checkState(origNames.contains(inputColumn), "Input column with name \"%s\" not found in schema", inputColumn); + Preconditions.checkState(inputColumn.equals(outputColumn) || !origNames.contains(outputColumn), + "Output column with name \"%s\" already exists in schema (only allowable if input column == output column)", outputColumn); + + List outMeta = new ArrayList<>(origNames.size()+1); + for( int i=0; i l = Collections.unmodifiableList( + mode == Mode.INCLUDE_OTHER_CATEGORY ? + Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", OTHER_CATEGORY) : + Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")); + + CategoricalMetaData cm = new CategoricalMetaData(outputColumn, l); + + outMeta.add(cm); + } else { + outMeta.add(origMeta.get(i)); + } + } + + return inputSchema.newSchema(outMeta); + } + + @Override + public String outputColumnName() { + return outputColumn; + } + + @Override + public String[] outputColumnNames() { + return new String[]{outputColumn}; + } + + @Override + public String[] columnNames() { + return new String[]{inputColumn}; + } + + @Override + public String columnName() { + return inputColumn; + } + + @Override + public void setInputSchema(Schema schema){ + super.setInputSchema(schema); + + columnIdx = schema.getIndexOfColumn(inputColumn); + Preconditions.checkState(columnIdx >= 0, "Input column \"%s\" not found in schema", inputColumn); + } +} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java index b2de829e1b40..ffbd36fe57b9 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java @@ -18,6 +18,7 @@ import lombok.AllArgsConstructor; +import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.util.Iterator; @@ -45,7 +46,14 @@ public URI next() { throw new NoSuchElementException("No next element"); } try { - return new URI(paths.next()); + String s = paths.next(); + if(!s.matches(".*:/.*")){ + //No scheme - assume file for backward compatibility + return new File(s).toURI(); + } else { + return new URI(s); + } + } catch (URISyntaxException e) { throw new RuntimeException(e); } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index e2f01ba99611..a209abb91fa3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -26,7 +26,9 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.Writable; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +50,8 @@ */ public class LineReaderTest { - private static Logger log = LoggerFactory.getLogger(LineReaderTest.class); + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); @Test public void testLineReader() throws Exception { @@ -91,11 +94,7 @@ public void testLineReader() throws Exception { @Test public void testLineReaderMetaData() throws Exception { - String tempDir = System.getProperty("java.io.tmpdir"); - File tmpdir = new File(tempDir, "tmpdir-testLineReader"); - if (tmpdir.exists()) - tmpdir.delete(); - tmpdir.mkdir(); + File tmpdir = testDir.newFolder(); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); @@ -157,9 +156,7 @@ public void testLineReaderMetaData() throws Exception { @Test public void testLineReaderWithInputStreamInputSplit() throws Exception { - String tempDir = System.getProperty("java.io.tmpdir"); - File tmpdir = new File(tempDir, "tmpdir"); - tmpdir.mkdir(); + File tmpdir = testDir.newFolder(); File tmp1 = new File(tmpdir, "tmp1.txt.gz"); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/FileSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/FileSplitTest.java deleted file mode 100644 index 23df52097e5c..000000000000 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/FileSplitTest.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.split; - -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.io.File; -import java.util.UUID; - - -/** - * Created by nyghtowl on 11/8/15. - */ -public class FileSplitTest { - - protected File file, file1, file2, file3, file4, file5, file6, newPath; - protected String[] allForms = {"jpg", "jpeg", "JPG", "JPEG"}; - private static String localPath = System.getProperty("java.io.tmpdir") + File.separator; - private static String testPath = localPath + "test" + File.separator; - - // These cannot run on TravisCI - uncomment when checking locally - - @Test(expected = IllegalArgumentException.class) - public void testEmptySplit() { - new FileSplit(new File("THE_TEST_WILL_PASS_UNLESS_YOU_CAN_PREDICT_THIS_UUID: " + UUID.randomUUID())); - } - - @Rule - public TemporaryFolder mainFolder = new TemporaryFolder(); - - // - // @Before - // public void doBefore() throws IOException { - // file = mainFolder.newFile("myfile.txt"); - // - // newPath = new File(testPath); - // - // newPath.mkdir(); - // - // file1 = File.createTempFile("myfile_1", ".jpg", newPath); - // file2 = File.createTempFile("myfile_2", ".txt", newPath); - // file3 = File.createTempFile("myfile_3", ".jpg", newPath); - // file4 = File.createTempFile("treehouse_4", ".csv", newPath); - // file5 = File.createTempFile("treehouse_5", ".csv", newPath); - // file6 = File.createTempFile("treehouse_6", ".jpg", newPath); - // - // } - // - // @Test - // public void testInitializeLoadSingleFile(){ - // InputSplit split = new FileSplit(file, allForms); - // assertEquals(split.locations()[0], file.toURI()); - // - // } - // - // @Test - // public void testInitializeLoadMulFiles() throws IOException{ - // InputSplit split = new FileSplit(newPath, allForms, true); - // assertEquals(3, split.locations().length); - // assertEquals(file1.toURI(), split.locations()[0]); - // assertEquals(file3.toURI(), split.locations()[1]); - // } - // - // @Test - // public void testInitializeMulFilesShuffle() throws IOException{ - // InputSplit split = new FileSplit(newPath, new Random(123)); - // InputSplit split2 = new FileSplit(newPath, new Random(123)); - // assertEquals(6, split.locations().length); - // assertEquals(6, split2.locations().length); - // assertEquals(split.locations()[3], split2.locations()[3]); - // } - // - // @After - // public void doAfter(){ - // mainFolder.delete(); - // file.delete(); - // file1.delete(); - // file2.delete(); - // file3.delete(); - // file4.delete(); - // file5.delete(); - // file6.delete(); - // newPath.delete(); - // - // } - -} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java new file mode 100644 index 000000000000..c618c625d1b5 --- /dev/null +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -0,0 +1,201 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.api.split; + +import org.apache.commons.io.FileUtils; +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; +import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; +import org.datavec.api.writable.Text; +import org.datavec.api.writable.Writable; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.function.Function; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +public class TestStreamInputSplit { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + @Test + public void testCsvSimple() throws Exception { + File dir = testDir.newFolder(); + File f1 = new File(dir, "file1.txt"); + File f2 = new File(dir, "file2.txt"); + + FileUtils.writeStringToFile(f1, "a,b,c\nd,e,f", StandardCharsets.UTF_8); + FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8); + + List uris = Arrays.asList(f1.toURI(), f2.toURI()); + + CSVRecordReader rr = new CSVRecordReader(); + + TestStreamFunction fn = new TestStreamFunction(); + InputSplit is = new StreamInputSplit(uris, fn); + rr.initialize(is); + + List> exp = new ArrayList<>(); + exp.add(Arrays.asList(new Text("a"), new Text("b"), new Text("c"))); + exp.add(Arrays.asList(new Text("d"), new Text("e"), new Text("f"))); + exp.add(Arrays.asList(new Text("1"), new Text("2"), new Text("3"))); + + List> act = new ArrayList<>(); + while(rr.hasNext()){ + act.add(rr.next()); + } + + assertEquals(exp, act); + + //Check that the specified stream loading function was used, not the default: + assertEquals(uris, fn.calledWithUris); + + rr.reset(); + int count = 0; + while(rr.hasNext()) { + count++; + rr.next(); + } + assertEquals(3, count); + } + + + @Test + public void testCsvSequenceSimple() throws Exception { + + File dir = testDir.newFolder(); + File f1 = new File(dir, "file1.txt"); + File f2 = new File(dir, "file2.txt"); + + FileUtils.writeStringToFile(f1, "a,b,c\nd,e,f", StandardCharsets.UTF_8); + FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8); + + List uris = Arrays.asList(f1.toURI(), f2.toURI()); + + CSVSequenceRecordReader rr = new CSVSequenceRecordReader(); + + TestStreamFunction fn = new TestStreamFunction(); + InputSplit is = new StreamInputSplit(uris, fn); + rr.initialize(is); + + List>> exp = new ArrayList<>(); + exp.add(Arrays.asList( + Arrays.asList(new Text("a"), new Text("b"), new Text("c")), + Arrays.asList(new Text("d"), new Text("e"), new Text("f")))); + exp.add(Arrays.asList( + Arrays.asList(new Text("1"), new Text("2"), new Text("3")))); + + List>> act = new ArrayList<>(); + while (rr.hasNext()) { + act.add(rr.sequenceRecord()); + } + + assertEquals(exp, act); + + //Check that the specified stream loading function was used, not the default: + assertEquals(uris, fn.calledWithUris); + + rr.reset(); + int count = 0; + while(rr.hasNext()) { + count++; + rr.sequenceRecord(); + } + assertEquals(2, count); + } + + @Test + public void testShuffle() throws Exception { + File dir = testDir.newFolder(); + File f1 = new File(dir, "file1.txt"); + File f2 = new File(dir, "file2.txt"); + File f3 = new File(dir, "file3.txt"); + + FileUtils.writeStringToFile(f1, "a,b,c", StandardCharsets.UTF_8); + FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8); + FileUtils.writeStringToFile(f3, "x,y,z", StandardCharsets.UTF_8); + + List uris = Arrays.asList(f1.toURI(), f2.toURI(), f3.toURI()); + + CSVSequenceRecordReader rr = new CSVSequenceRecordReader(); + + TestStreamFunction fn = new TestStreamFunction(); + InputSplit is = new StreamInputSplit(uris, fn, new Random(12345)); + rr.initialize(is); + + List>> act = new ArrayList<>(); + while (rr.hasNext()) { + act.add(rr.sequenceRecord()); + } + + rr.reset(); + List>> act2 = new ArrayList<>(); + while (rr.hasNext()) { + act2.add(rr.sequenceRecord()); + } + + rr.reset(); + List>> act3 = new ArrayList<>(); + while (rr.hasNext()) { + act3.add(rr.sequenceRecord()); + } + + assertEquals(3, act.size()); + assertEquals(3, act2.size()); + assertEquals(3, act3.size()); + + /* + System.out.println(act); + System.out.println("---------"); + System.out.println(act2); + System.out.println("---------"); + System.out.println(act3); + */ + + //Check not the same. With this RNG seed, results are different for first 3 resets + assertNotEquals(act, act2); + assertNotEquals(act2, act3); + assertNotEquals(act, act3); + } + + + public static class TestStreamFunction implements Function { + public List calledWithUris = new ArrayList<>(); + @Override + public InputStream apply(URI uri) { + calledWithUris.add(uri); //Just for testing to ensure function is used + try { + return new FileInputStream(new File(uri)); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + } +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java index 8e70a6c50a3c..9f647d365ba3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java @@ -146,6 +146,8 @@ public void testToFromJsonYaml() { .addConstantColumn("testColSeq", ColumnType.Integer, new DoubleWritable(0)) .offsetSequence(Collections.singletonList("testColSeq"), 1, SequenceOffsetTransform.OperationType.InPlace) .addConstantColumn("someTextCol", ColumnType.String, new Text("some values")) + .addConstantColumn("testFirstDigit", ColumnType.Double, new DoubleWritable(0)) + .firstDigitTransform("testFirstDigit", "testFirstDigitOut") .build(); String asJson = tp.toJson(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 5210a4b697c7..bcc8d7abeb5a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -57,8 +57,6 @@ import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.jackson.core.JsonFactory; -import org.nd4j.shade.jackson.databind.ObjectMapper; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -1559,4 +1557,51 @@ public void testTextToTermIndexSequenceTransform(){ TransformProcess tp2 = TransformProcess.fromJson(json); assertEquals(tp, tp2); } + + + @Test + public void testFirstDigitTransform(){ + Schema s = new Schema.Builder() + .addColumnString("data") + .addColumnDouble("double") + .addColumnString("stringNumber") + .build(); + + TransformProcess tp = new TransformProcess.Builder(s) + .firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID) + .firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY) + .build(); + + Schema s2 = tp.getFinalSchema(); + assertEquals(Arrays.asList("data","double", "fdDouble", "stringNumber"), s2.getColumnNames()); + + assertEquals(Arrays.asList(ColumnType.String, ColumnType.Double, ColumnType.Categorical, ColumnType.Categorical), s2.getColumnTypes()); + + List> in = Arrays.asList( + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); + + List> expected = Arrays.asList( + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("3"), new Text("8")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("2"), new Text("7")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("1"), new Text("6")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("2"), new Text("Other"))); + + List> out = new ArrayList<>(); + for(List i : in){ + out.add(tp.execute(i)); + } + assertEquals(expected, out); + + //Test Benfords law use case: + TransformProcess tp2 = new TransformProcess.Builder(s) + .firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID) + .firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY) + .removeColumns("data", "double") + .categoricalToOneHot("fdDouble", "stringNumber") + .reduce(new Reducer.Builder(ReduceOp.Sum).build()) + .build(); + } } diff --git a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java index 38d58863a07c..e2d136474c5b 100644 --- a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java +++ b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java @@ -60,13 +60,10 @@ public abstract class BaseCodecRecordReader extends FileRecordReader implements @Override public List> sequenceRecord() { - if (iter == null || !iter.hasNext()) { - this.advanceToNextLocation(); - } - File next = iter.next(); + URI next = locationsIterator.next(); - try { - return loadData(next, null); + try (InputStream s = streamCreatorFn.apply(next)){ + return loadData(null, s); } catch (IOException e) { throw new RuntimeException(e); } @@ -116,16 +113,16 @@ public Configuration getConf() { @Override public SequenceRecord nextSequence() { - File next = this.nextFile(); + URI next = locationsIterator.next(); List> list; - try { - list = loadData(next, null); + try (InputStream s = streamCreatorFn.apply(next)){ + list = loadData(null, s); } catch (IOException e) { throw new RuntimeException(e); } return new org.datavec.api.records.impl.SequenceRecord(list, - new RecordMetaDataURI(next.toURI(), CodecRecordReader.class)); + new RecordMetaDataURI(next, CodecRecordReader.class)); } @Override @@ -137,14 +134,12 @@ public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) th public List loadSequenceFromMetaData(List recordMetaDatas) throws IOException { List out = new ArrayList<>(); for (RecordMetaData meta : recordMetaDatas) { - File f = new File(meta.getURI()); - - List> list = loadData(f, null); - out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta)); + try (InputStream s = streamCreatorFn.apply(meta.getURI())){ + List> list = loadData(null, s); + out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta)); + } } return out; } - - } diff --git a/datavec/datavec-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java b/datavec/datavec-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java index 559ef799e480..13184a9b14f5 100644 --- a/datavec/datavec-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java +++ b/datavec/datavec-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java @@ -31,6 +31,7 @@ import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -120,7 +121,7 @@ public Record nextRecord(){ Record record = new org.datavec.api.records.impl.Record(ret, new RecordMetaDataIndex( currRow.getRowNum(), - super.currentFile.toURI(), + super.currentUri, ExcelRecordReader.class)); return record; } @@ -132,7 +133,7 @@ else if(sheetIterator != null && sheetIterator.hasNext()) { Record record = new org.datavec.api.records.impl.Record(rowToRecord(currRow), new RecordMetaDataIndex( currRow.getRowNum(), - super.currentFile.toURI(), + super.currentUri, ExcelRecordReader.class)); return record; @@ -140,27 +141,30 @@ else if(sheetIterator != null && sheetIterator.hasNext()) { //finally extract workbooks from files and iterate over those starting again at top - File nextFile = super.nextFile(); - // Creating a Workbook from an Excel file (.xls or .xlsx) - try { - if(currWorkBook != null) { - currWorkBook.close(); - } - - this.currWorkBook = WorkbookFactory.create(nextFile); - this.sheetIterator = currWorkBook.sheetIterator(); - Sheet sheet = sheetIterator.next(); - rows = sheet.rowIterator(); - Row currRow = rows.next(); - Record record = new org.datavec.api.records.impl.Record(rowToRecord(currRow), - new RecordMetaDataIndex( - currRow.getRowNum(), - super.currentFile.toURI(), - ExcelRecordReader.class)); - return record; + try(InputStream is = streamCreatorFn.apply(super.locationsIterator.next())) { + // Creating a Workbook from an Excel file (.xls or .xlsx) + try { + if (currWorkBook != null) { + currWorkBook.close(); + } - } catch (Exception e) { - throw new IllegalStateException("Error processing row",e); + this.currWorkBook = WorkbookFactory.create(is); + this.sheetIterator = currWorkBook.sheetIterator(); + Sheet sheet = sheetIterator.next(); + rows = sheet.rowIterator(); + Row currRow = rows.next(); + Record record = new org.datavec.api.records.impl.Record(rowToRecord(currRow), + new RecordMetaDataIndex( + currRow.getRowNum(), + super.currentUri, + ExcelRecordReader.class)); + return record; + + } catch (Exception e) { + throw new IllegalStateException("Error processing row", e); + } + } catch (IOException e){ + throw new RuntimeException("Error reading from stream", e); } } diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java index 95da226b2b33..c6f5f28b62db 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java @@ -23,6 +23,7 @@ import org.datavec.api.transform.reduce.Reducer; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; +import org.datavec.api.transform.transform.categorical.FirstDigitTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; @@ -231,4 +232,54 @@ public void testUniqueMultiCol(){ assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2"))); } + + @Test + public void testFirstDigitTransformBenfordsLaw(){ + Schema s = new Schema.Builder() + .addColumnString("data") + .addColumnDouble("double") + .addColumnString("stringNumber") + .build(); + + List> in = Arrays.asList( + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), + Arrays.asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); + + //Test Benfords law use case: + TransformProcess tp = new TransformProcess.Builder(s) + .firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID) + .firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY) + .removeAllColumnsExceptFor("stringNumber") + .categoricalToOneHot("stringNumber") + .reduce(new Reducer.Builder(ReduceOp.Sum).build()) + .build(); + + JavaRDD> rdd = sc.parallelize(in); + + + List> out = SparkTransformExecutor.execute(rdd, tp).collect(); + assertEquals(1, out.size()); + + List l = out.get(0); + List exp = Arrays.asList( + new IntWritable(0), //0 + new IntWritable(0), //1 + new IntWritable(3), //2 + new IntWritable(0), //3 + new IntWritable(0), //4 + new IntWritable(0), //5 + new IntWritable(1), //6 + new IntWritable(2), //7 + new IntWritable(1), //8 + new IntWritable(0), //9 + new IntWritable(1)); //Other + assertEquals(exp, l); + } + } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index 0910cfaa1c10..5a0d158b6a9d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -168,6 +168,7 @@ public void validateBatchNorm() { .features(f) .labels(l) .data(new SingletonDataSetIterator(new DataSet(f, l))) + .maxRelError(1e-4) .build(); LayerHelperValidationUtil.validateMLN(netWith, tc); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java index 0f80b34d39ba..b7fa269f7a62 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasAtrousConvolution1D.java @@ -92,7 +92,7 @@ public KerasAtrousConvolution1D(Map layerConfig, boolean enforce .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) .weightInit(weightInit.getWeightInitFunction(distribution)) - .dilation(getDilationRate(layerConfig, 1, conf, true)) + .dilation(getDilationRate(layerConfig, 1, conf, true)[0]) .l1(this.weightL1Regularization).l2(this.weightL2Regularization) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index 0b79f8a4ac96..351dca5fd3f0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -114,7 +114,7 @@ public KerasConvolution1D(Map layerConfig, boolean enforceTraini if (padding != null) builder.padding(padding[0]); if (dilationRate != null) - builder.dilation(dilationRate); + builder.dilation(dilationRate[0]); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java index e36277e82791..b66885a75f84 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java @@ -110,12 +110,12 @@ public static class Builder extends Layer.Builder { private double maskValue; public Builder setUnderlying(Layer underlying) { - this.setUnderlying(underlying); + this.underlying = underlying; return this; } public Builder setMaskValue(double maskValue) { - this.setMaskValue(maskValue); + this.maskValue = maskValue; return this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index cf7f68a657a3..2df173d49c01 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -2355,13 +2355,19 @@ public INDArray output(INDArray input) { /** * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array. * See {@link #output(INDArray)}
- * NOTE: The output array can require a considerable amount of memory for iterators with a large number of examples + * NOTE 1: The output array can require a considerable amount of memory for iterators with a large number of examples
+ * NOTE 2: This method cannot be used for variable length time series outputs, as this would require padding arrays + * for some outputs, or returning a mask array (which cannot be done with this method). For variable length time + * series applications, use one of the other output methods. This method also cannot be used with fully convolutional + * networks with different output sizes (for example, segmentation on different input image sizes). + * * * @param iterator Data to pass through the network * @return output for all examples in the iterator, concatenated into a */ public INDArray output(DataSetIterator iterator, boolean train) { List outList = new ArrayList<>(); + long[] firstOutputShape = null; while (iterator.hasNext()) { DataSet next = iterator.next(); INDArray features = next.getFeatures(); @@ -2371,7 +2377,23 @@ public INDArray output(DataSetIterator iterator, boolean train) { INDArray fMask = next.getFeaturesMaskArray(); INDArray lMask = next.getLabelsMaskArray(); - outList.add(this.output(features, train, fMask, lMask)); + INDArray output = this.output(features, train, fMask, lMask); + outList.add(output); + if(firstOutputShape == null){ + firstOutputShape = output.shape(); + } else { + //Validate that shapes are the same (may not be, for some RNN variable length time series applications) + long[] currShape = output.shape(); + Preconditions.checkState(firstOutputShape.length == currShape.length, "Error during forward pass:" + + "different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", firstOutputShape, currShape); + for( int i=1; i variables = new HashMap<>(); //TODO concurrent maps required? Or lock? + private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde @Getter private final Map ops = new HashMap<>(); @Getter @@ -826,6 +826,25 @@ public void associateArrayWithVariable(INDArray arr, SDVariable variable) { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); } + boolean duped = false; + if(arr.isAttached()) { + arr = arr.detach(); + duped = true; + } + if(arr.isView()) { + arr = arr.dup(); + duped = true; + } + + if(!duped && variable.getVariableType() == VariableType.VARIABLE) { + for (DeviceLocalNDArray otherArr : variablesArrays.values()) { + if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) + arr = arr.dup(); + break; + } + } + } + switch(variable.getVariableType()){ case VARIABLE: variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); @@ -881,12 +900,12 @@ public void putSubFunction(String name, SameDiff nameSpace) { /** - * Return the internal variable map + * Return a copy of the internal variable map * * @return Map of variables by name */ public Map variableMap() { - Map ret = new HashMap<>(); + Map ret = new LinkedHashMap<>(); for(Variable v : variables.values()){ ret.put(v.getName(), v.getVariable()); } @@ -2134,8 +2153,25 @@ public SDVariable var(String name, INDArray arr) { if (arr == null) throw new IllegalArgumentException("Array for " + name + " must not be null"); - if(arr.isAttached()) + boolean duped = false; + if(arr.isAttached()) { arr = arr.detach(); + duped = true; + } + if(arr.isView()) { + arr = arr.dup(); + duped = true; + } + + if(!duped) { + for (DeviceLocalNDArray otherArr : variablesArrays.values()) { + if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) + arr = arr.dup(); + break; + } + } + } + SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType(), new NDArraySupplierInitScheme(arr)); associateArrayWithVariable(arr, ret); @@ -4121,14 +4157,42 @@ public String asFlatPrint() { sb.append("\nExternal variables:\n\n"); for (int e = 0; e < graph.variablesLength(); e++) { val var = graph.variables(e); - INDArray ndarray; + INDArray ndarray = null; try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - ndarray = Nd4j.createFromFlatArray(var.ndarray()); + FlatArray fa = var.ndarray(); + if(fa != null) { + ndarray = Nd4j.createFromFlatArray(fa); + } } sb.append(var.id().first()) - .append(":<").append(var.name()).append("> ") - .append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: ").append(Arrays.toString(ndarray.data().asFloat())).append(";\n"); + .append(":<").append(var.name()).append("> "); + if(ndarray == null){ + sb.append("").append("; Values: ").append("").append(";\n"); + } else { + sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: "); + if(ndarray.data() == null){ + //Empty array + sb.append(""); + } else if(ndarray.dataType() == DataType.UTF8) { + sb.append(""); + } else { + if(ndarray.length() < 50){ + sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ","")); + } else { + //Array is too long - only tak. last few values... + sb.append("["); + for( int i=0; i<50; i++ ){ + if(i > 0) + sb.append(","); + sb.append(ndarray.data().getFloat(i)); + } + sb.append("]"); + } + } + sb.append(";\n"); + } + } val map = Nd4j.getExecutioner().getCustomOperations(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7920bde531cc..ae8c2811687f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -2,6 +2,8 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.impl.transforms.Pad; +import org.nd4j.linalg.factory.Nd4j; /** * SameDiff general neural network operations
@@ -701,4 +703,82 @@ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int. SDVariable result = f().layerNorm(input, gain, dimensions); return updateVariableNameAndReference(result, name); } + + /** + * See {@link #pad(SDVariable, SDVariable, double)} + */ + public SDVariable pad(SDVariable input, int[][] padding, double constant){ + return pad(input, sd.constant(Nd4j.createFromArray(padding)), constant); + } + + /** + * Perform padding on the given array, where padded values are the specified constant.
+ * Example:
+ * Input array:
+ * [1, 2]
+ * [3, 4]
+ * Padding array:
+ * [2, 0]
+ * [1, 1]
+ * Contant = 0
+ * Result:
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * [0, 1, 2, 0]
+ * [0, 3, 4, 0]
+ *
+ * + * + * @param input Input array to pad + * @param padding Padding array + * @param constant Constant to use for padded values + * @return Padded array + */ + public SDVariable pad(SDVariable input, SDVariable padding, double constant){ + return pad(null, input, padding, Pad.Mode.CONSTANT, constant); + } + + /** + * As per {@link #pad(SDVariable, SDVariable, double)} but also supports multiple {@link Pad.Mode} modes.
+ * Example: + * Input array:
+ * [1, 2]
+ * [3, 4]
+ * [5, 6]
+ * Padding array:
+ * [2, 0]
+ * [1, 1]
+ * Contant = 0
+ * Result: CONSTANT mode
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * [0, 1, 2, 0]
+ * [0, 3, 4, 0]
+ * [0, 5, 6, 0]
+ *
+ * Result: SYMMETRIC mode
+ * [3, 3, 4, 4]
+ * [1, 1, 2, 2]
+ * [1, 1, 2, 2]
+ * [3, 3, 4, 4]
+ * [5, 5, 6, 6]
+ *
+ * Result: REFLECT:
+ * [6, 5, 6, 0]
+ * [2, 3, 4, 3]
+ * [2, 1, 2, 1]
+ * [4, 3, 4, 3]
+ * [6, 5, 6, 5]
+ *
+ * @param outputName + * @param input + * @param padding + * @param mode + * @param constant + * @return + */ + public SDVariable pad(String outputName, SDVariable input, SDVariable padding, Pad.Mode mode, double constant){ + SDVariable out = f().pad(input, padding, mode, constant); + return updateVariableNameAndReference(out, outputName); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 6e5be36ff828..6903e889f982 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -385,12 +385,10 @@ public static DifferentialFunction fromFlatNode(FlatNode fn){ val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations ba.setDimensions(dimensions); ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions)); - ba.setNewFormat(true); //Always "new" format (i.e., rank 0 scalars, not rank 2) for SameDiff-based exec } else if(opType == Op.Type.INDEXREDUCE){ BaseIndexAccumulation bia = (BaseIndexAccumulation)op; bia.setDimensions(dimensions); bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions)); - bia.setNewFormat(true); //Always "new" format (i.e., rank 0 scalars, not rank 2) for SameDiff-based exec } /* Op types that don't need any extra/special mapping: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java index 6b4547026f5d..58eb3c882439 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java @@ -4,6 +4,7 @@ import com.google.flatbuffers.Table; import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SameDiff; @@ -289,9 +290,9 @@ public long writeScalarEvent(String name, long time, int iteration, int epoch, N return append(fbb, fbb2); } - public long writeHistogramEventDiscrete(String name, long time, int iteration, int epoch, List binLabels, INDArray y) throws IOException { + public long writeHistogramEventDiscrete(@NonNull String name, long time, int iteration, int epoch, List binLabels, @NonNull INDArray y) throws IOException { Preconditions.checkState(binLabels == null || binLabels.size() == y.length(), "Number of bin labels (if present) must " + - "be same as Y array length - got %s bins, array shape %ndShape", binLabels.size(), y.length()); + "be same as Y array length - got %s bins, array shape %ndShape", (binLabels == null ? 0L : binLabels.size()), y.length()); Preconditions.checkState(y.rank() == 1, "Y array must be rank 1, got Y array with shape %ndShape", y); //TODO add support for plugin, variable and frame/iter diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 45aeed61bf34..c09f101b5258 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -4689,7 +4689,7 @@ public INDArray sum(int... dimension) { @Override public INDArray sum(boolean keepDim, int... dimension) { validateNumericalArray("sum", false); - return Nd4j.getExecutioner().exec(new Sum(this, null, true, keepDim, dimension)); + return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); } @@ -4901,9 +4901,6 @@ public INDArray getColumn(long c) { return this; else if (isColumnVector() && c > 0) throw new IllegalArgumentException("Illegal index for row"); - else if(isRowVector()) { - return Nd4j.scalar(getDouble(c)); - } return get(NDArrayIndex.all(), NDArrayIndex.point(c)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index a738545d0a8e..75cb5dded6ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -43,8 +43,6 @@ @Data public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation { protected boolean keepDims = false; - @Deprecated - protected boolean newFormat = false; public BaseIndexAccumulation(SameDiff sameDiff, SDVariable i_v, @@ -64,7 +62,6 @@ public BaseIndexAccumulation(SameDiff sameDiff, throw new IllegalArgumentException("Input not null variable."); } this.keepDims = keepDims; - this.newFormat = true; defineDimensions(dimensions); } @@ -93,7 +90,6 @@ public BaseIndexAccumulation(SameDiff sameDiff, throw new IllegalArgumentException("Input not null variable."); } this.keepDims = keepDims; - this.newFormat = true; defineDimensions(dimensions); } @@ -116,7 +112,7 @@ public List calculateOutputShape() { if(x == null) return Collections.emptyList(); - long[] reducedShape = Shape.getReducedShape(x.shape(), dimensions, keepDims, newFormat); + long[] reducedShape = Shape.getReducedShape(x.shape(), dimensions, keepDims); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.LONG)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index ae4a2e0563ee..b0a75d0994f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.api.ops; -import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -24,10 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -44,16 +40,16 @@ protected BaseReduceBoolOp(SameDiff sameDiff, SDVariable input, int... dimension super(sameDiff, input, dimensions); } - public BaseReduceBoolOp(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int[] dimensions) { - super(x, null, z, newFormat, keepDims, dimensions); + public BaseReduceBoolOp(INDArray x, INDArray z, boolean keepDims, int[] dimensions) { + super(x, null, z, keepDims, dimensions); } public BaseReduceBoolOp(INDArray x, int... dimensions) { - this(x, null, true, false, dimensions); + this(x, null, false, dimensions); } public BaseReduceBoolOp(INDArray x, INDArray z, int... dimensions) { - this(x, z, true, false, dimensions); + this(x, z, false, dimensions); } protected BaseReduceBoolOp() { @@ -93,7 +89,7 @@ public List calculateOutputShape() { return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims(), newFormat); + long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.BOOL)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java index b706720907ec..5788fe178f87 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java @@ -34,8 +34,8 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFloatOp { - public BaseReduceFloatOp(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public BaseReduceFloatOp(INDArray x, INDArray y, INDArray z, boolean keepDims, int... dimensions){ + super(x, y, z, keepDims, dimensions); } protected BaseReduceFloatOp(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { @@ -54,10 +54,6 @@ protected BaseReduceFloatOp(SameDiff sameDiff, SDVariable input, int... dimensio super(sameDiff, input, dimensions); } - public BaseReduceFloatOp(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int[] dimensions) { - super(x, null, z, newFormat, keepDims, dimensions); - } - public BaseReduceFloatOp(INDArray input, INDArray output, boolean keepDims, int... dimensions){ super(input, null, output, dimensions); this.keepDims = keepDims; @@ -118,7 +114,7 @@ public List calculateOutputShape() { return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims(), newFormat); + long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); DataType retType = arg().dataType(); if(!retType.isFPType()) retType = Nd4j.defaultFloatingPointType(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java index 54b1c22844ef..c37bf99863f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java @@ -49,10 +49,6 @@ public BaseReduceLongOp(INDArray x, int... dimensions) { super(x, dimensions); } - public BaseReduceLongOp(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int[] dimensions) { - super(x, null, z, newFormat, keepDims, dimensions); - } - public BaseReduceLongOp(INDArray x, INDArray z, int... dimensions) { super(x, z, dimensions); } @@ -94,7 +90,7 @@ public List calculateOutputShape() { return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims(), newFormat); + long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.LONG)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 89e87a6f3b48..2f3736395991 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -50,13 +50,8 @@ */ @Slf4j public abstract class BaseReduceOp extends BaseOp implements ReduceOp { - protected Number finalResult; @Setter @Getter protected boolean keepDims = false; - - // flag for tf imported ops, shows that there's probably one more value appended in axis - @Setter @Getter @Deprecated - protected boolean newFormat = false; protected boolean isComplex = false; @@ -81,7 +76,6 @@ public BaseReduceOp(SameDiff sameDiff, throw new IllegalArgumentException("Input not null variable."); } - this.newFormat = true; defineDimensions(dimensions); } @@ -107,7 +101,6 @@ public BaseReduceOp(SameDiff sameDiff, throw new IllegalArgumentException("Input not null variable."); } - this.newFormat = true; defineDimensions(dimensions); } @@ -137,9 +130,8 @@ public BaseReduceOp(SameDiff sameDiff, public BaseReduceOp() {} - public BaseReduceOp(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int[] dimensions) { + public BaseReduceOp(INDArray x, INDArray y, INDArray z, boolean keepDims, int[] dimensions) { super(x, y, z); - this.newFormat = newFormat; this.keepDims = keepDims; this.dimensions = dimensions; defineDimensions(dimensions); @@ -154,7 +146,7 @@ public BaseReduceOp(INDArray x, INDArray y, int... dimensions) { } public BaseReduceOp(INDArray x, INDArray y, INDArray z, int... dimensions) { - this(x, y, z, true, false, dimensions); + this(x, y, z, false, dimensions); } public BaseReduceOp(SameDiff sameDiff) { @@ -180,8 +172,6 @@ public boolean isKeepDims() { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - newFormat = true; - if (!attributesForNode.containsKey("axis") && !hasReductionIndices(nodeDef)) { this.dimensions = new int[] { Integer.MAX_VALUE }; } //Otherwise: dimensions are dynamically set during execution in InferenceSession diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java index 0df8cd4a9f2a..5fd71bf94baa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java @@ -46,8 +46,8 @@ protected BaseReduceSameOp(SameDiff sameDiff, SDVariable input, int... dimension super(sameDiff, input, dimensions); } - public BaseReduceSameOp(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int[] dimensions) { - super(x, null, z, newFormat, keepDims, dimensions); + public BaseReduceSameOp(INDArray x, INDArray z, boolean keepDims, int[] dimensions) { + super(x, null, z, keepDims, dimensions); } public BaseReduceSameOp(INDArray x, INDArray y, INDArray z, int... dimensions) { @@ -96,7 +96,7 @@ public List calculateOutputShape() { return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims(), newFormat); + long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType())); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/IndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/IndexAccumulation.java index a096603546c5..0671e9a24160 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/IndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/IndexAccumulation.java @@ -52,12 +52,6 @@ public interface IndexAccumulation extends Op { */ boolean isKeepDims(); - /** - * This method returns true if scalar is 0D, false otherwise - * @return - */ - boolean isNewFormat(); - /** * This method returns dimensions for this op * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java index 282872826350..8f1814dfd116 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java @@ -80,12 +80,6 @@ public interface ReduceOp extends Op { */ boolean isKeepDims(); - /** - * This method returns true if scalar is 0D, false otherwise - * @return - */ - boolean isNewFormat(); - /** * This method returns datatype for result array wrt given inputs * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/BaseReduction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/BaseReduction.java deleted file mode 100644 index d5684bc72038..000000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/BaseReduction.java +++ /dev/null @@ -1,104 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.reduce; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.DynamicCustomOp; - - -/** - * @author Alex Black - */ - -public abstract class BaseReduction extends DynamicCustomOp { - - protected boolean keepDims; - protected int[] dimensions; - - /** - * - * @param input Input to be reduced - * @param keepDims If true: reduction dimensions were kept - * @param dimensions Dimensions to reduce. May be null - */ - public BaseReduction(SameDiff sameDiff, SDVariable input, boolean keepDims, int... dimensions) { - super(null, sameDiff, new SDVariable[]{input}, false); - this.keepDims = keepDims; - this.dimensions = dimensions; - addArgs(); - } - - /** - * - * @param input1 input 1 - * @param input2 input 2 - * @param keepDims If true: reduction dimensions were kept - * @param dimensions Dimensions to reduce. May be null - */ - public BaseReduction(SameDiff sameDiff, SDVariable input1, SDVariable input2, boolean keepDims, int... dimensions) { - super(null, sameDiff, new SDVariable[]{input1, input2}, false); - this.keepDims = keepDims; - this.dimensions = dimensions; - addArgs(); - } - - /** - * - * @param input input - * @param output Output array - i.e., gradient at the input to the reduction function - * @param keepDims If true: reduction dimensions were kept - * @param dimensions Dimensions to reduce. May be null - */ - public BaseReduction(INDArray input, INDArray output, boolean keepDims, int... dimensions){ - super(null, new INDArray[]{input}, (output == null ? null : new INDArray[]{output})); - this.keepDims = keepDims; - this.dimensions = dimensions; - addArgs(); - } - - /** - * - * @param input1 Pre-reduced input1 - * @param input2 Pre-reduced input2 - * @param output Output array - i.e., gradient at the input to the reduction function - * @param keepDims If true: reduction dimensions were kept - * @param dimensions Dimensions to reduce. May be null - */ - public BaseReduction(INDArray input1, INDArray input2, INDArray output, boolean keepDims, int... dimensions){ - super(null, new INDArray[]{input1, input2}, (output == null ? null : new INDArray[]{output})); - this.keepDims = keepDims; - this.dimensions = dimensions; - addArgs(); - } - - public BaseReduction(){} - - protected void addArgs(){ - addTArgument(keepDims ? 1 : 0); - if(dimensions != null && dimensions.length > 0){ - if(dimensions.length != 1 || dimensions[0] != Integer.MAX_VALUE ){ - //Integer.MAX_VALUE means "full array" but here no dimension args == full array - addIArgument(dimensions); - } - } - } - - public abstract String opName(); - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index 9e69aa0a7786..6ce3c1eef078 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -42,7 +42,7 @@ public IsInf(SameDiff sameDiff, SDVariable i_v, int[] dims, boolean keepDims) { public IsInf() {} public IsInf(INDArray x, INDArray z) { - super(x, z, true, false, null); + super(x, z, false, null); } public IsInf(INDArray x) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index 46a0078b248e..2d25391a9b4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -42,7 +42,7 @@ public IsNaN(SameDiff sameDiff, SDVariable i_v, int[] dims, boolean keepDims) { public IsNaN() {} public IsNaN(INDArray x, INDArray z) { - super(x, z, true, false, null); + super(x, z, false, null); } public IsNaN(INDArray x) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index c450048a9370..853ccae24086 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -45,8 +45,8 @@ public Mean(INDArray x, int... dimensions) { super(x, dimensions); } - public Mean(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, z, newFormat, keepDims, dimensions); + public Mean(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(x, z, keepDims, dimensions); } @Override @@ -61,8 +61,6 @@ public String opName() { @Override public List doDiff(List i_v1) { - if(!newFormat) - throw new IllegalStateException("Cannot doDiff with newFormat == false"); //If out = mean(in), then dL/dIn = 1/N * dL/dOut (broadcast to appropriate shape) //Note that N differs for "along dimension" vs. "whole array" reduce cases return Collections.singletonList(f().meanBp(arg(), i_v1.get(0), keepDims, dimensions)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java index 0627ce0d2065..509a1fa4e8cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java @@ -59,8 +59,8 @@ public Max(INDArray x, INDArray z, int... axis) { super(x, null, z, axis); } - public Max(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, z, newFormat, keepDims, dimensions); + public Max(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(x, z, keepDims, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java index 0cfe13eef3ec..8169b04383ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java @@ -45,8 +45,8 @@ public Min(INDArray x, INDArray z, int... dimensions) { super(x, null, z, dimensions); } - public Min(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, z, newFormat, keepDims, dimensions); + public Min(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(x, z, keepDims, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java index 6932b7aaf075..b735b15cf1ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java @@ -49,8 +49,8 @@ public Prod(INDArray x, INDArray z, int... dimensions) { super(x, null, z, dimensions); } - public Prod(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, z, newFormat, keepDims, dimensions); + public Prod(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(x, z, keepDims, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index 9103e0de5ac9..ae70e44d9361 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -52,8 +52,8 @@ public Sum(INDArray x, INDArray z, int... dimensions) { super(x, null, z, dimensions); } - public Sum(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, z, newFormat, keepDims, dimensions); + public Sum(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(x, z, keepDims, dimensions); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java index 1445bd7dd6e8..5be2e644efef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java @@ -57,21 +57,21 @@ public BaseReduce3Op(INDArray x, INDArray y, boolean allDistances, int... dimens } public BaseReduce3Op(INDArray x, INDArray y, INDArray z) { - this(x, y, z, false, null); + this(x, y, z, false, false, (int[])null); } - public BaseReduce3Op(INDArray x, INDArray y, INDArray z, boolean allDistances, int... dimensions) { - this(x, y, z, true, false, dimensions); - this.isComplex = allDistances; + public BaseReduce3Op(INDArray x, INDArray y, INDArray z, boolean keepDims, int... dimensions){ + this(x,y,z,keepDims, false); } - public BaseReduce3Op(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public BaseReduce3Op(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, dimensions); + this.isComplex = allDistances; extraArgs = new Object[]{0.0f, 0.0f}; } public BaseReduce3Op(INDArray x, INDArray y, INDArray z, int... dimensions) { - super(x, y, z, true, false, dimensions); + super(x, y, z, false, dimensions); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java index 34f343e24980..29c7eb878e26 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java @@ -67,8 +67,8 @@ public CosineDistance(INDArray x, INDArray y, boolean allDistances, int... dimen this(x, y, null, allDistances, dimension); } - public CosineDistance(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public CosineDistance(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java index d2a313b9354f..acba7a6049be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java @@ -75,8 +75,8 @@ public CosineSimilarity(INDArray x, INDArray y, boolean allDistances, int... dim this(x, y, null, allDistances, dimension); } - public CosineSimilarity(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public CosineSimilarity(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java index cc0b31c5d032..7c00c7933448 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java @@ -52,7 +52,7 @@ public Dot(INDArray x, INDArray y, INDArray z) { } public Dot(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + super(x, y, z, keepDims, false, dimensions); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java index c902ec8af874..3bf3105f8972 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java @@ -48,7 +48,7 @@ public EqualsWithEps(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] d public EqualsWithEps() {} public EqualsWithEps(INDArray x, INDArray y, INDArray z, double eps, int... dimensions) { - super(x, y, z,true, false, dimensions); + super(x, y, z, false, dimensions); this.extraArgs = new Object[] {eps}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java index 8a73679fe787..c0f7caa67f3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java @@ -66,12 +66,11 @@ public EuclideanDistance(INDArray x, INDArray y, boolean allDistances, int... di } public EuclideanDistance(INDArray x, INDArray y, INDArray z, boolean allDistances, int... dimensions) { - this(x, y, z, true, false, dimensions); - this.isComplex = allDistances; + this(x, y, z, false, allDistances, dimensions); } - public EuclideanDistance(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public EuclideanDistance(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/HammingDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/HammingDistance.java index 580d23725392..aa59416655a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/HammingDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/HammingDistance.java @@ -48,7 +48,7 @@ public HammingDistance(INDArray x, INDArray y, int... dimensions) { } public HammingDistance(INDArray x, INDArray y, INDArray z, boolean allDistances, int... dimensions) { - this(x, y, z, true, false, dimensions); + this(x, y, z, false, allDistances, dimensions); this.isComplex = allDistances; } @@ -60,8 +60,8 @@ public HammingDistance(INDArray x, INDArray y, INDArray z) { this(x, y, z, false, null); } - public HammingDistance(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public HammingDistance(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java index 5e08980d04fd..59fe5aca7ede 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java @@ -49,7 +49,7 @@ public JaccardDistance(INDArray x, INDArray y, int... dimensions) { } public JaccardDistance(INDArray x, INDArray y, INDArray z, boolean allDistances, int... dimensions) { - this(x, y, z, true, false, dimensions); + this(x, y, z, false, allDistances, dimensions); this.isComplex = allDistances; } @@ -62,8 +62,8 @@ public JaccardDistance(INDArray x, INDArray y, boolean allDistances) { this.isComplex = allDistances; } - public JaccardDistance(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public JaccardDistance(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java index 9efaf4312f54..3071b30ece72 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java @@ -52,7 +52,7 @@ public ManhattanDistance(INDArray x, INDArray y, int... dimensions) { } public ManhattanDistance(INDArray x, INDArray y, boolean allDistances, int... dimensions) { - this(x, y, null, true, false, dimensions); + this(x, y, null, false, allDistances, dimensions); this.isComplex = allDistances; } @@ -61,12 +61,11 @@ public ManhattanDistance(INDArray x, INDArray y, INDArray z) { } public ManhattanDistance(INDArray x, INDArray y, INDArray z, boolean allDistances, int... dimensions) { - this(x, y, z, true, false, dimensions); - this.isComplex = allDistances; + this(x, y, z, false, allDistances, dimensions); } - public ManhattanDistance(INDArray x, INDArray y, INDArray z, boolean newFormat, boolean keepDims, int... dimensions){ - super(x, y, z, newFormat, keepDims, dimensions); + public ManhattanDistance(INDArray x, INDArray y, INDArray z, boolean keepDims, boolean allDistances, int... dimensions){ + super(x, y, z, keepDims, allDistances, dimensions); extraArgs = new Object[]{0.0f, 0.0f}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java index 31d72c149186..86985fcc0bf8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java @@ -112,7 +112,7 @@ public List calculateOutputShape() { long[] inputShape = (argShape == null ? x().shape() : argShape); val ret = new ArrayList(1); - val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims(), newFormat); + val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType())); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index ee58a85bcf8d..9f74d083b468 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -75,9 +75,9 @@ public Variance(INDArray x, boolean biasCorrected, int... dimensions) { defineDimensions(dimensions); } - public Variance(INDArray x, INDArray z, boolean newFormat, boolean keepDims, int... dimensions) { - super(x, null, z, newFormat, keepDims, dimensions); - this.biasCorrected = true; + public Variance(INDArray x, INDArray z, boolean biasCorrected, boolean keepDims, int... dimensions) { + super(x, null, z, keepDims, dimensions); + this.biasCorrected = biasCorrected; defineDimensions(dimensions); } @@ -165,7 +165,7 @@ public List calculateOutputShape() { long[] inputShape = (argShape == null ? x().shape() : argShape); val ret = new ArrayList(1); - val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims(), newFormat); + val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType())); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index c9d953ba8dbc..55c011df825b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -45,10 +45,12 @@ public enum Mode {CONSTANT, REFLECT, SYMMETRIC} public Pad(){ } - public Pad(SameDiff sd, SDVariable in, SDVariable padding, Mode mode){ + public Pad(SameDiff sd, SDVariable in, SDVariable padding, Mode mode, double padValue) { super(sd, new SDVariable[]{in, padding}, false); + Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); this.mode = mode; addIArgument(mode.ordinal()); + addTArgument(padValue); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index cf94aa2305e0..d94436fd2044 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -445,6 +445,10 @@ else if (dimensions.length == 1 && wholeShape.length == 2) { return result; } + public static long[] getReducedShape(long[] wholeShape, int[] dimensions, boolean keepDims) { + return getReducedShape(wholeShape, dimensions, keepDims, true); + } + public static long[] getReducedShape(long[] wholeShape, int[] dimensions, boolean keepDims, boolean newFormat) { // we need to normalize dimensions, in case they have negative values or unsorted, or whatever dimensions = Shape.normalizeAxis(wholeShape.length, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java index a8c44b80616c..1babd29de063 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java @@ -331,11 +331,8 @@ public static INDArray[] covarianceMatrix(INDArray in) { long dlength = in.rows(); long vlength = in.columns(); - INDArray sum = Nd4j.create(vlength); INDArray product = Nd4j.create(vlength, vlength); - - for (int i = 0; i < vlength; i++) - sum.getColumn(i).assign(in.getColumn(i).sumNumber().doubleValue() / dlength); + INDArray sum = in.sum(0).divi(dlength); for (int i = 0; i < dlength; i++) { INDArray dx1 = in.getRow(i).sub(sum); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index a00255486ab6..2f73bc5e4dab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -200,6 +200,8 @@ public class Nd4j { private final static Logger logger = Logger.getLogger(Nd4j.class.getName()); + protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length]; + static { fallbackMode = new AtomicBoolean(false); Nd4j nd4j = new Nd4j(); @@ -3793,9 +3795,14 @@ public static INDArray empty() { * @return Empty INDArray */ public static INDArray empty(DataType type) { - val ret = INSTANCE.empty(type); - logCreationIfNecessary(ret); - return ret; + if(EMPTY_ARRAYS[type.ordinal()] == null){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ + val ret = INSTANCE.empty(type); + EMPTY_ARRAYS[type.ordinal()] = ret; + logCreationIfNecessary(ret); + } + } + return EMPTY_ARRAYS[type.ordinal()]; } /** @@ -6766,7 +6773,8 @@ public static INDArray createFromArray(double... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.DOUBLE); - return create(array, new long[]{array.length}, DataType.DOUBLE); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.DOUBLE); } /** @@ -6778,7 +6786,8 @@ public static INDArray createFromArray(float... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.FLOAT); - return create(array, new long[]{array.length}, DataType.FLOAT); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.FLOAT); } /** @@ -6790,7 +6799,8 @@ public static INDArray createFromArray(int... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.INT); - return create(array, new long[]{array.length}, DataType.INT); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.INT); } /** @@ -6802,7 +6812,8 @@ public static INDArray createFromArray(short... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.SHORT); - return create(array, new long[]{array.length}, DataType.SHORT); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.SHORT); } /** @@ -6814,7 +6825,8 @@ public static INDArray createFromArray(byte... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.BYTE); - return create(array, new long[]{array.length}, DataType.BYTE); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.BYTE); } /** @@ -6826,7 +6838,8 @@ public static INDArray createFromArray(long... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.LONG); - return create(array, new long[]{array.length}, DataType.LONG); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.LONG); } /** @@ -6838,7 +6851,8 @@ public static INDArray createFromArray(boolean... array) { Preconditions.checkNotNull(array, "Cannot create INDArray from null Java array"); if(array.length == 0) return Nd4j.empty(DataType.BOOL); - return create(array, new long[]{array.length}, DataType.BOOL); + long[] shape = new long[]{array.length}; + return create(array, shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL); } /////////////////// @@ -6853,7 +6867,8 @@ public static INDArray createFromArray(double[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.DOUBLE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.DOUBLE); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.DOUBLE); } /** @@ -6866,7 +6881,8 @@ public static INDArray createFromArray(float[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.FLOAT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.FLOAT); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.FLOAT); } /** @@ -6879,7 +6895,8 @@ public static INDArray createFromArray(long[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.LONG); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.LONG); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.LONG); } /** @@ -6892,7 +6909,8 @@ public static INDArray createFromArray(int[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.INT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.INT); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.INT); } /** @@ -6905,7 +6923,8 @@ public static INDArray createFromArray(short[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.SHORT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.SHORT); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.SHORT); } /** @@ -6918,7 +6937,8 @@ public static INDArray createFromArray(byte[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.BYTE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.BYTE); + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BYTE); } /** @@ -6931,7 +6951,9 @@ public static INDArray createFromArray(boolean[][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0) return Nd4j.empty(DataType.BOOL); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length}, DataType.BOOL); + + long[] shape = new long[]{array.length, array[0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL); } /////////////////// @@ -6946,7 +6968,8 @@ public static INDArray createFromArray(double[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.DOUBLE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.DOUBLE); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.DOUBLE); } /** @@ -6959,7 +6982,8 @@ public static INDArray createFromArray(float[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.FLOAT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.FLOAT); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.FLOAT); } /** @@ -6972,7 +6996,9 @@ public static INDArray createFromArray(long[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.LONG); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.LONG); + + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.LONG); } /** @@ -6985,7 +7011,9 @@ public static INDArray createFromArray(int[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.INT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.INT); + + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.INT); } /** @@ -6998,7 +7026,8 @@ public static INDArray createFromArray(short[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.SHORT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.SHORT); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.SHORT); } /** @@ -7011,7 +7040,8 @@ public static INDArray createFromArray(byte[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.BYTE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.BYTE); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BYTE); } /** @@ -7024,7 +7054,8 @@ public static INDArray createFromArray(boolean[][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0) return Nd4j.empty(DataType.BOOL); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length}, DataType.BOOL); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL); } /////////////////// @@ -7039,7 +7070,8 @@ public static INDArray createFromArray(double[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.DOUBLE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.DOUBLE); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.DOUBLE); } /** @@ -7052,7 +7084,8 @@ public static INDArray createFromArray(float[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.FLOAT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.FLOAT); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.FLOAT); } /** @@ -7065,7 +7098,8 @@ public static INDArray createFromArray(long[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.LONG); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.LONG); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.LONG); } /** @@ -7078,7 +7112,8 @@ public static INDArray createFromArray(int[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.INT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.INT); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.INT); } /** @@ -7091,7 +7126,8 @@ public static INDArray createFromArray(short[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.SHORT); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.SHORT); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.SHORT); } /** @@ -7104,7 +7140,8 @@ public static INDArray createFromArray(byte[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.BYTE); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.BYTE); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BYTE); } /** @@ -7117,7 +7154,8 @@ public static INDArray createFromArray(boolean[][][][] array) { ArrayUtil.assertNotRagged(array); if(array.length == 0 || array[0].length == 0 || array[0][0].length == 0 || array[0][0][0].length == 0) return Nd4j.empty(DataType.BOOL); - return create(ArrayUtil.flatten(array), new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}, DataType.BOOL); + long[] shape = new long[]{array.length, array[0].length, array[0][0].length, array[0][0][0].length}; + return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL); } /////////////////// diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java index c5368b0ad225..f7eb698ffef1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java @@ -237,7 +237,7 @@ public void processOpCall(Op op) { String opClass = getOpClass(op); classCounter.incrementCount(opClass); - if(op.x() == null || op.x().data().address() == lastZ && op.z() == op.x() && op.y() == null) { + if(op.x() == null || (op.x() != null && op.x().data().address() == lastZ && op.z() == op.x() && op.y() == null)) { // we have possible shift here matchingCounter.incrementCount(prevOpMatching + " -> " + opClass); matchingCounterDetailed.incrementCount(prevOpMatchingDetailed + " -> " + opClass + " " + op.opName()); @@ -592,20 +592,20 @@ public void processBlasCall(boolean isGemm, INDArray... operands) { public PenaltyCause[] processOperands(INDArray x, INDArray y) { List penalties = new ArrayList<>(); - if (x.ordering() != y.ordering()) { + if (x != null && x.ordering() != y.ordering()) { penalties.add(PenaltyCause.MIXED_ORDER); } - if (x.elementWiseStride() < 1) { + if (x != null && x.elementWiseStride() < 1) { penalties.add(PenaltyCause.NON_EWS_ACCESS); - } else if (y.elementWiseStride() < 1) { + } else if (y != null && y.elementWiseStride() < 1) { penalties.add(PenaltyCause.NON_EWS_ACCESS); } - if (x.elementWiseStride() > 1) { + if (x != null && x.elementWiseStride() > 1) { penalties.add(PenaltyCause.STRIDED_ACCESS); - } else if (y.elementWiseStride() > 1) { + } else if (y != null && y.elementWiseStride() > 1) { penalties.add(PenaltyCause.STRIDED_ACCESS); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 2289ce714d74..02abbab3c51b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -475,10 +475,7 @@ public INDArray exec(ReduceOp op) { long[] retShape; val wholeDims = Shape.wholeArrayDimension(dimension) || op.x().rank() == dimension.length || dimension.length == 0; if (wholeDims) - if (op.isNewFormat()) - retShape = new long[0]; - else - retShape = new long[] {1, 1}; + retShape = new long[0]; else retShape = ArrayUtil.removeIndex(maxShape, dimension); @@ -540,7 +537,7 @@ public INDArray exec(IndexAccumulation op) { val dimension = op.dimensions().toIntVector(); val wholeArray = Shape.wholeArrayDimension(dimension) || dimension.length == 0; if (op.z() == null) { - long[] retShape = wholeArray ? (op.isNewFormat() ? new long[]{} : new long[]{1, 1}) : ArrayUtil.removeIndex(op.x().shape(), dimension); + long[] retShape = wholeArray ? new long[]{} : ArrayUtil.removeIndex(op.x().shape(), dimension); //ensure vector is proper shape if (retShape.length == 1) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 366e108f462c..f6ed505fd20b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -156,15 +156,12 @@ public INDArray exec(IndexAccumulation op) { val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); boolean keepDims; - boolean newFormat; if(op instanceof BaseIndexAccumulation) { keepDims = ((BaseIndexAccumulation) op).isKeepDims(); - newFormat = ((BaseIndexAccumulation) op).isNewFormat(); } else { keepDims = false; - newFormat = false; } - long[] retShape = reductionShape(op.x(), dimension, newFormat, keepDims); + long[] retShape = reductionShape(op.x(), dimension, true, keepDims); if(op.z() == null || op.x() == op.z()) { val ret = Nd4j.createUninitialized(DataType.LONG, retShape); @@ -238,16 +235,13 @@ public INDArray exec(ReduceOp op) { long[] maxShape = Shape.getMaxShape(op.x(),op.y()); boolean keepDims; - boolean newFormat; if(op instanceof BaseReduceOp) { keepDims = op.isKeepDims(); - newFormat = ((BaseReduceOp) op).isNewFormat(); } else { keepDims = false; - newFormat = true; } - long[] retShape = reductionShape(op.x(), dimension, newFormat, keepDims); + long[] retShape = reductionShape(op.x(), dimension, true, keepDims); if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null) return op.noOp(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 1172b0f8255b..113adba4f61e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -2889,6 +2889,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native @Cast("bool") boolean isAll(); public native @Cast("bool") boolean isPoint(); + public native @Cast("bool") boolean isInterval(); public native @Cast("Nd4jLong*") @StdVector LongPointer getIndices(); public native @Cast("Nd4jLong") long stride(); @@ -2912,6 +2913,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public NDIndexAll() { super((Pointer)null); allocate(); } private native void allocate(); + public native @Cast("bool") boolean isInterval(); } @@ -2922,6 +2924,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public NDIndexPoint(@Cast("Nd4jLong") long point) { super((Pointer)null); allocate(point); } private native void allocate(@Cast("Nd4jLong") long point); + public native @Cast("bool") boolean isInterval(); } @Namespace("nd4j") public static class NDIndexInterval extends NDIndex { @@ -2933,6 +2936,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end) { super((Pointer)null); allocate(start, end); } private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); + public native @Cast("bool") boolean isInterval(); } @@ -4865,7 +4869,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint - @@ -4901,6 +4904,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #endif + // Parsed from graph/Variable.h /******************************************************************************* @@ -6520,9 +6524,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") long[] oldShape, int newRank, @Cast("Nd4jLong*") long[] newShape, @Cast("bool") boolean isFOrder); - @Namespace("shape") public static native @Cast("bool") boolean reshapeCF(int oldRank, @Cast("const Nd4jLong*") LongPointer oldShapeInfo, int newRank, @Cast("const Nd4jLong*") LongPointer newShape, @Cast("const bool") boolean isFOrder, @Cast("Nd4jLong*") LongPointer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeCF(int oldRank, @Cast("const Nd4jLong*") LongBuffer oldShapeInfo, int newRank, @Cast("const Nd4jLong*") LongBuffer newShape, @Cast("const bool") boolean isFOrder, @Cast("Nd4jLong*") LongBuffer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeCF(int oldRank, @Cast("const Nd4jLong*") long[] oldShapeInfo, int newRank, @Cast("const Nd4jLong*") long[] newShape, @Cast("const bool") boolean isFOrder, @Cast("Nd4jLong*") long[] newShapeInfo); + @Namespace("shape") public static native @Cast("bool") boolean reshapeC(int oldRank, @Cast("const Nd4jLong*") LongPointer oldShapeInfo, int newRank, @Cast("const Nd4jLong*") LongPointer newShape, @Cast("Nd4jLong*") LongPointer newShapeInfo); + @Namespace("shape") public static native @Cast("bool") boolean reshapeC(int oldRank, @Cast("const Nd4jLong*") LongBuffer oldShapeInfo, int newRank, @Cast("const Nd4jLong*") LongBuffer newShape, @Cast("Nd4jLong*") LongBuffer newShapeInfo); + @Namespace("shape") public static native @Cast("bool") boolean reshapeC(int oldRank, @Cast("const Nd4jLong*") long[] oldShapeInfo, int newRank, @Cast("const Nd4jLong*") long[] newShape, @Cast("Nd4jLong*") long[] newShapeInfo); /** * Get the shape info buffer @@ -6944,6 +6948,11 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint @Namespace("shape") public static native int rank(@Const IntBuffer buffer); @Namespace("shape") public static native int rank(@Const int[] buffer); + // returns pointer on elementWiseStride + @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo); + /** * Converts a raw int buffer of the layout: * rank @@ -6962,9 +6971,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * Returns the stride portion of an information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("const Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("const Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("const Nd4jLong*") long[] buffer); + @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("Nd4jLong*") LongPointer buffer); + @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("Nd4jLong*") LongBuffer buffer); + @Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("Nd4jLong*") long[] buffer); /** * Compute the length of the given shape @@ -6996,6 +7005,13 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongBuffer buffer); @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") long[] buffer); +/** + * Returns the type + */ + @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") long[] shapeInfo); + /** * Returns the element wise stride for this information * buffer @@ -7575,27 +7591,30 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides); - - @Namespace("shape") public static native void printIntArray(@Cast("Nd4jLong*") LongPointer arr,int length); - @Namespace("shape") public static native void printIntArray(@Cast("Nd4jLong*") LongBuffer arr,int length); - @Namespace("shape") public static native void printIntArray(@Cast("Nd4jLong*") long[] arr,int length); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") long[] shapeInfo); + + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") long[] shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") long[] shapeInfo); + + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); + @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); + + @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongPointer arr, int length); + @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongBuffer arr, int length); + @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") long[] arr, int length); + @Namespace("shape") public static native void printIntArray(@Const IntPointer arr, int length); + @Namespace("shape") public static native void printIntArray(@Const IntBuffer arr, int length); + @Namespace("shape") public static native void printIntArray(@Const int[] arr, int length); @Namespace("shape") public static native void printArray(FloatPointer arr,int length); @Namespace("shape") public static native void printArray(FloatBuffer arr,int length); @@ -7616,30 +7635,31 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // max array is outer for min array, min array is sub-array of max array // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); + // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank + @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array @@ -7670,6 +7690,14 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint @Namespace("shape") public static native void shapeOldScalar(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*const") LongBuffer buffer, byte order); @Namespace("shape") public static native void shapeOldScalar(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*const") long[] buffer, byte order); + // calculate element-wise stride + // if array is scalar or unit length vector then ews = 1 + // if array is common vector then ews = stride of non-unity dimension + // if strides are normal set ews = 1, otherwise ews = 0 + @Namespace("shape") public static native void calcEws(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long len); + @Namespace("shape") public static native void calcEws(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long len); + @Namespace("shape") public static native void calcEws(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long len); + @@ -8081,6 +8109,10 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran * for this shape information buffer */ +/** + * Returns type + */ + /** * Returns the element wise stride for this information * buffer @@ -8382,7 +8414,7 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran ////////////////////////////////////////////////////////////////////////// // copy-past from java hasDefaultStridesForShape function -// INLINEDEF _CUDA_H bool reshapeCF(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { // int oldnd; // Nd4jLong* olddims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); // Nd4jLong* oldstrides = shape::copyOf(oldRank, shape::stride(oldShape)); @@ -8521,12 +8553,77 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran // return true; // } +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, const bool isFOrder, Nd4jLong* newShapeInfo) { + +// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements +// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo + +// const int newOrder = isFOrder ? 102 : 99; +// const int oldOrder = oldShapeInfo[2 * oldRank + 3]; + +// newShapeInfo[0] = newRank; +// memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); + +// Nd4jLong* newStrides = shape::stride(newShapeInfo); +// const Nd4jLong* oldShape = shape::shapeOf(const_cast(oldShapeInfo)); +// const Nd4jLong* oldStrides = shape::stride(const_cast(oldShapeInfo)); +// int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; + + +// while (newStart < newRank && oldStart < oldRank) { + +// newDim = newShape[newStart]; +// oldDim = oldShape[oldStart]; + +// while (newDim != oldDim) +// if (newDim < oldDim) newDim *= newShape[newStop++]; +// else oldDim *= oldShape[oldStop++]; + +// // ------ Check whether the original axes can be combined ------ // +// for (int i = oldStart; i < oldStop - 1; i++) { + +// if(oldShape[i] == 1) { // ignore strides like {...,1,1,...} +// if(oldOrder == 102) ++oldStart; +// continue; +// } + +// if(oldOrder == 102 && oldStrides[i + 1] != oldShape[i] * oldStrides[i]) +// return false; // not contiguous enough +// if(oldOrder == 99 && oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1]) +// return false; // not contiguous enough +// } + +// // ------ Calculate new strides for all axes currently worked with ------ // +// if(isFOrder) { +// newStrides[newStart] = oldStrides[oldStart]; +// for (int i = newStart + 1; i < newStop; ++i) +// newStrides[i] = newStrides[i - 1] * newShape[i - 1]; +// } +// else { +// newStrides[newStop - 1] = oldStrides[oldStop - 1]; +// for (int i = newStop - 1; i > newStart; --i) +// newStrides[i - 1] = newStrides[i] * newShape[i]; +// } + +// newStart = newStop++; +// oldStart = oldStop++; +// } + +// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order +// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews +// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type + +// return true; +// } + +////////////////////////////////////////////////////////////////////// + // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) // also it sorts input array of dimensions, this operation is also necessary for creating TAD object // max array is outer for min array, min array is sub-array of max array - // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) +// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) ////////////////////////////////////////////////////////////////////// @@ -8538,12 +8635,16 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + + // #endif /* SHAPE_H_ */ + // Parsed from helpers/OpArgsHolder.h /******************************************************************************* @@ -10851,6 +10952,8 @@ INLINEDEF _CUDA_HD void doPermuteShapeBuffer(Nd4jLong *shapeBuffer, int *rearran // #define PARAMETRIC_XYZ() [&] (Parameters &p, ResultSet &x, ResultSet &y, ResultSet &z) // #define PARAMETRIC_XZ() [&] (Parameters &p, ResultSet &x, ResultSet &z) +// #define PARAMETRIC_D() [&] (Parameters &p) -> Context* + // #endif diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 10e0ad58ca4f..7ed22bf33a4f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -919,40 +919,43 @@ public void testReduce3_2() { INDArray expOut; SDVariable reduced; String name; + System.out.println(i); switch (i) { case 0: reduced = sd.math().manhattanDistance(in, in2, reduceDims); name = "manhattan"; - expOut = Nd4j.getExecutioner().exec(new ManhattanDistance(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new ManhattanDistance(a, b, null, false, reduceDims)); break; case 1: reduced = sd.math().euclideanDistance(in, in2, reduceDims); name = "euclidean"; - expOut = Nd4j.getExecutioner().exec(new EuclideanDistance(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new EuclideanDistance(a, b, null, false, reduceDims)); break; case 2: reduced = sd.math().cosineSimilarity(in, in2, reduceDims); name = "cosine"; - expOut = Nd4j.getExecutioner().exec(new CosineSimilarity(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new CosineSimilarity(a, b, null, false, reduceDims)); break; case 3: reduced = sd.math().jaccardDistance(in, in2, reduceDims); name = "jaccard"; - expOut = Nd4j.getExecutioner().exec(new JaccardDistance(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new JaccardDistance(a, b, null, false, reduceDims)); break; case 4: reduced = sd.math().hammingDistance(in, in2, reduceDims); name = "hamming"; - expOut = Nd4j.getExecutioner().exec(new HammingDistance(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new HammingDistance(a, b, null, false, reduceDims)); break; case 5: reduced = sd.math().cosineDistance(in, in2, reduceDims); name = "reduced"; - expOut = Nd4j.getExecutioner().exec(new CosineDistance(a, b, null, true, false, reduceDims)); + expOut = Nd4j.getExecutioner().exec(new CosineDistance(a, b, null, false, reduceDims)); break; default: throw new RuntimeException(); } + System.out.println(i + " - end"); + long[] expShape; if (Arrays.equals(new int[]{0}, reduceDims)) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 443f964149a6..21c9234a43e8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.shape.Cross; +import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; @@ -1463,15 +1464,20 @@ public void testPad(){ .addIntegerArguments(0) //0 = CONSTANT .build(); + INDArray exp = Nd4j.create(new double[]{10, 1, 1, 1, 1, 1, 10}); OpValidation.validate(new OpTestCase(op) - .expectedOutput(0, Nd4j.create(new double[]{10, 1, 1, 1, 1, 1, 10}))); + .expectedOutput(0, exp)); + + SameDiff sd = SameDiff.create(); + SDVariable s = sd.var("in", in); + SDVariable padded = sd.nn().pad(s, sd.constant(pad), 10.0); + String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); + assertNull(err2); } @Test public void testMirrorPad(){ -// OpValidationSuite.ignoreFailing(); - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); @@ -1494,12 +1500,17 @@ public void testMirrorPad(){ .expectedOutput(0, exp)); assertNull(err); + + + SameDiff sd = SameDiff.create(); + SDVariable s = sd.var("in", in); + SDVariable padded = sd.nn().pad("pad", s, sd.constant(Nd4j.createFromArray(new int[][]{{1,1},{2,2}})), Pad.Mode.REFLECT, 0.0); + String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); + assertNull(err2); } @Test public void testMirrorPad2(){ -// OpValidationSuite.ignoreFailing(); - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index fd83d095e720..247bd97f80a0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -25,6 +25,8 @@ import org.junit.rules.TemporaryFolder; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; +import org.nd4j.autodiff.validation.OpValidation; +import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; @@ -2757,4 +2759,23 @@ public void testConvertToVariable(){ //Sanity check on training: sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr,null).toMultiDataSet()), 1); } + + @Test + public void testDoubleUseOfArray(){ + //If array is reused, gradient check will fail + INDArray a = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4}); + SameDiff sd = SameDiff.create(); + SDVariable a1 = sd.var("a", a); + SDVariable a2 = sd.var("b", a); + a1.add(a2).norm2("out"); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + a1.setArray(a); + a2.setArray(a); + err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java index e4063d05f453..ffef35b98726 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java @@ -144,4 +144,22 @@ public void testSimple() throws IOException { assertEquals(exp, arr); } } + + @Test + public void testNullBinLabels() throws Exception{ + File dir = testDir.newFolder(); + File f = new File(dir, "temp.bin"); + LogFileWriter w = new LogFileWriter(f); + + SameDiff sd = SameDiff.create(); + SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); + SDVariable sum = v.sum(); + + w.writeGraphStructure(sd); + w.writeFinishStaticMarker(); + + w.registerEventName("name"); + INDArray arr = Nd4j.create(1); + w.writeHistogramEventDiscrete("name", System.currentTimeMillis(), 0, 0, null, arr); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 33a833771cc8..99001a17ecc1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -130,7 +130,7 @@ public void testLabelAndPredictionCounts() { int[] expPredictionCount = new int[(int) labels.size(1)]; INDArray argmax = Nd4j.argMax(arr, 1); for (int i = 0; i < argmax.length(); i++) { - expPredictionCount[argmax.getInt(i, 0)]++; + expPredictionCount[argmax.getInt(i)]++; } assertArrayEquals(expLabelCounts, ec.getLabelCountsEachClass()); @@ -163,7 +163,7 @@ public void testResidualPlots() { double binSize = 1.0 / numBins; for (int i = 0; i < minibatch; i++) { - int actualClassIdx = argmaxLabels.getInt(i, 0); + int actualClassIdx = argmaxLabels.getInt(i); for (int j = 0; j < nClasses; j++) { double labelSubProb = absLabelSubProb.getDouble(i, j); for (int k = 0; k < numBins; k++) { @@ -205,7 +205,7 @@ public void testResidualPlots() { int[] probCountsAllClasses = new int[numBins]; int[][] probCountsByClass = new int[nClasses][numBins]; //Histogram count of |label[x] - p(x)|; rows x are over classes for (int i = 0; i < minibatch; i++) { - int actualClassIdx = argmaxLabels.getInt(i, 0); + int actualClassIdx = argmaxLabels.getInt(i); for (int j = 0; j < nClasses; j++) { double prob = arr.getDouble(i, j); for (int k = 0; k < numBins; k++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index 5a7fd1ebbd43..cd4c854397a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -160,6 +160,9 @@ public List processSubgraph(SameDiff sd, SubGraph subGraph) { } }); + //Small test / sanity check for asFlatPrint(): + sd.asFlatPrint(); + /* Output during inference: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index cfaa57d585ce..f79879c15cec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -195,12 +195,12 @@ public void opsNotAllowed() { public void testArgMax() { int max = 63; INDArray A = Nd4j.linspace(1, max, max).reshape(1, max); - int currentArgMax = Nd4j.argMax(A).getInt(0, 0); + int currentArgMax = Nd4j.argMax(A).getInt(0); assertEquals(max - 1, currentArgMax); max = 64; A = Nd4j.linspace(1, max, max).reshape(1, max); - currentArgMax = Nd4j.argMax(A).getInt(0, 0); + currentArgMax = Nd4j.argMax(A).getInt(0); System.out.println("Returned argMax is " + currentArgMax); assertEquals(max - 1, currentArgMax); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java index 5898ff8fe74f..1e457b10cba8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java @@ -1062,10 +1062,10 @@ public void testNegativeShape() { @Test public void testGetColumnGetRow() { - INDArray row = Nd4j.ones(5).reshape(1, -1); + INDArray row = Nd4j.ones(1, 5); for (int i = 0; i < 5; i++) { INDArray col = row.getColumn(i); - assertArrayEquals(col.shape(), new long[] {}); + assertArrayEquals(col.shape(), new long[] {1,1}); } INDArray col = Nd4j.ones(5, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index c2ea3e36da8b..191757b8ceeb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -7436,6 +7436,71 @@ public void testMeshgridDtypes() { Nd4j.meshgrid(Nd4j.createFromArray(1, 2, 3), Nd4j.createFromArray(4, 5, 6)); } + @Test + public void testGetColumnRowVector(){ + INDArray arr = Nd4j.create(1,4); + INDArray col = arr.getColumn(0); + System.out.println(Arrays.toString(col.shape())); + assertArrayEquals(new long[]{1,1}, col.shape()); + } + + + @Test + public void testEmptyArrayReuse(){ + //Empty arrays are immutable - no point creating them multiple times + INDArray ef1 = Nd4j.empty(DataType.FLOAT); + INDArray ef2 = Nd4j.empty(DataType.FLOAT); + assertTrue(ef1 == ef2); //Should be exact same object + + INDArray el1 = Nd4j.empty(DataType.LONG); + INDArray el2 = Nd4j.empty(DataType.LONG); + assertTrue(el1 == el2); //Should be exact same object + } + + @Test + public void testMaxViewF(){ + INDArray arr = Nd4j.create(DataType.DOUBLE, new long[]{8,2}, 'f').assign(999); + + INDArray view = arr.get(NDArrayIndex.interval(3,5), NDArrayIndex.all()); + view.assign(Nd4j.createFromArray(new double[][]{{1,2},{3,4}})); + + assertEquals(Nd4j.create(new double[]{3,4}), view.max(0)); + assertEquals(Nd4j.create(new double[]{2,4}), view.max(1)); + } + + @Test + public void testCreateF(){ + char origOrder = Nd4j.order(); + try { + Nd4j.factory().setOrder('f'); + + + INDArray arr = Nd4j.createFromArray(new double[][]{{1, 2, 3}, {4, 5, 6}}); + INDArray arr2 = Nd4j.createFromArray(new float[][]{{1, 2, 3}, {4, 5, 6}}); + INDArray arr3 = Nd4j.createFromArray(new int[][]{{1, 2, 3}, {4, 5, 6}}); + INDArray arr4 = Nd4j.createFromArray(new long[][]{{1, 2, 3}, {4, 5, 6}}); + INDArray arr5 = Nd4j.createFromArray(new short[][]{{1, 2, 3}, {4, 5, 6}}); + INDArray arr6 = Nd4j.createFromArray(new byte[][]{{1, 2, 3}, {4, 5, 6}}); + + INDArray exp = Nd4j.create(2, 3); + exp.putScalar(0, 0, 1.0); + exp.putScalar(0, 1, 2.0); + exp.putScalar(0, 2, 3.0); + exp.putScalar(1, 0, 4.0); + exp.putScalar(1, 1, 5.0); + exp.putScalar(1, 2, 6.0); + + assertEquals(exp, arr); + assertEquals(exp.castTo(DataType.FLOAT), arr2); + assertEquals(exp.castTo(DataType.INT), arr3); + assertEquals(exp.castTo(DataType.LONG), arr4); + assertEquals(exp.castTo(DataType.SHORT), arr5); + assertEquals(exp.castTo(DataType.BYTE), arr6); + } finally { + Nd4j.factory().setOrder(origOrder); + } + } + /////////////////////////////////////////////////////// protected static void fillJvmArray3D(float[][][] arr) { int cnt = 1; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index 937240674040..148770881b1c 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -881,8 +881,9 @@ public void testGetRangeMask() { //The feature mask does not have to be equal to the label mask, just in this ex it should be assertEquals(newDs.getLabelsMaskArray(), newDs.getFeaturesMaskArray()); //System.out.println(newDs); - assertEquals(Nd4j.linspace(numExamples + from, numExamples + to - 1, to - from, DataType.DOUBLE), - newDs.getLabelsMaskArray().sum(1)); + INDArray exp = Nd4j.linspace(numExamples + from, numExamples + to - 1, to - from, DataType.DOUBLE); + INDArray act = newDs.getLabelsMaskArray().sum(1); + assertEquals(exp, act); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index e339fdd906e3..4cfe91e6b733 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -636,16 +636,37 @@ public void tescodtSum6d() { @Test public void testSum6d2() { - INDArray arr6 = Nd4j.linspace(1, 256, 256, DataType.DOUBLE).reshape(1, 1, 4, 4, 4, 4); - INDArray arr6s = arr6.sum(2, 3); - - assertEquals(136, arr6s.getDouble(0), 1e-1); - assertEquals(1160, arr6s.getDouble(1), 1e-1); - assertEquals(2184, arr6s.getDouble(2), 1e-1); - assertEquals(3208, arr6s.getDouble(3), 1e-1); - assertEquals(392, arr6s.getDouble(4), 1e-1); - assertEquals(1416, arr6s.getDouble(5), 1e-1); - assertEquals(2440, arr6s.getDouble(6), 1e-1); + char origOrder = Nd4j.order(); + try { + for (char order : new char[]{'c', 'f'}) { + Nd4j.factory().setOrder(order); + + INDArray arr6 = Nd4j.linspace(1, 256, 256, DataType.DOUBLE).reshape(1, 1, 4, 4, 4, 4); + INDArray arr6s = arr6.sum(2, 3); + + INDArray exp = Nd4j.create(DataType.DOUBLE, 1, 1, 4, 4); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + double sum = 0; + for (int x = 0; x < 4; x++) { + for (int y = 0; y < 4; y++) { + sum += arr6.getDouble(0, 0, x, y, i, j); + } + } + + exp.putScalar(0, 0, i, j, sum); + } + } + assertEquals(exp, arr6s); + + System.out.println("ORDER: " + order); + for (int i = 0; i < 6; i++) { + System.out.println(arr6s.getDouble(i)); + } + } + } finally { + Nd4j.factory().setOrder(origOrder); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 426cd256b581..9f2369fa19bb 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1242,7 +1242,7 @@ public double getDouble(long i) { case UBYTE: return ((UByteIndexer) indexer).get(offset() + i); default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Cannot get double value from buffer of type " + dataType()); } } @@ -1268,7 +1268,7 @@ public long getLong(long i) { case BOOL: return ((BooleanIndexer) indexer).get(offset() + i) ? 1L : 0L; default: - throw new UnsupportedOperationException("Unsupported data type: " + dataType()); + throw new UnsupportedOperationException("Cannot get long value from buffer of type " + dataType()); } } @@ -1296,7 +1296,7 @@ protected short getShort(long i) { case FLOAT: return (short) ((FloatIndexer) indexer).get(offset() + i); default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Cannot get short value from buffer of type " + dataType()); } } @@ -1331,7 +1331,7 @@ public float getFloat(long i) { case FLOAT: return ((FloatIndexer) indexer).get(offset() + i); default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Cannot get float value from buffer of type " + dataType()); } } @@ -1357,7 +1357,7 @@ public int getInt(long i) { case FLOAT: return (int) ((FloatIndexer) indexer).get(offset() + i); default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Cannot get integer value from buffer of type " + dataType()); } }