diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/conf/Configuration.java b/datavec/datavec-api/src/main/java/org/datavec/api/conf/Configuration.java index 2c76103bc665..db736c11f571 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/conf/Configuration.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/conf/Configuration.java @@ -771,6 +771,8 @@ public IntegerRanges getRange(String name, String defaultValue) { */ public Collection getStringCollection(String name) { String valueString = get(name); + if(valueString == null) + return null; return Arrays.asList(StringUtils.split(valueString, ",")); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java index 4331511bd6e7..a62af332013e 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java @@ -158,7 +158,7 @@ else if (idx != null) } protected INDArray makeBOWNDArray(Collection indices) { - INDArray counts = Nd4j.zeros(vocabulary.size()); + INDArray counts = Nd4j.zeros(1, vocabulary.size()); for (Integer idx : indices) counts.putScalar(idx, counts.getDouble(idx) + 1); Nd4j.getExecutioner().commit(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java index 73ef321be277..3a4ca58ba714 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java @@ -56,7 +56,7 @@ public StringListToIndicesNDArrayTransform(@JsonProperty("columnName") String co @Override protected INDArray makeBOWNDArray(Collection indices) { - INDArray counts = Nd4j.zeros(indices.size()); + INDArray counts = Nd4j.zeros(1, indices.size()); List indicesSorted = new ArrayList<>(indices); Collections.sort(indicesSorted); for (int i = 0; i < indicesSorted.size(); i++) diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index 7730d2d2aebf..e3ba797c8d0a 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -306,8 +306,8 @@ private static List> getClassificationWritableMatrix(DataSet data List> writableMatrix = new ArrayList<>(); for (int i = 0; i < dataSet.numExamples(); i++) { - List writables = toRecord(dataSet.getFeatures().getRow(i)); - writables.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i), 1).getInt(0))); + List writables = toRecord(dataSet.getFeatures().getRow(i, true)); + writables.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i)).getInt(0))); writableMatrix.add(writables); } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index bcc8d7abeb5a..5aca0400096b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -56,6 +56,7 @@ import org.joda.time.DateTimeZone; import org.junit.Assert; import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.ByteArrayInputStream; @@ -1438,7 +1439,8 @@ public void testStringListToCountsNDArrayTransform() throws Exception { List out = t.map(l); - assertEquals(Collections.singletonList(new NDArrayWritable(Nd4j.create(new double[]{2,3,0}, new long[]{1,3}, Nd4j.dataType()))), out); + INDArray exp = Nd4j.create(new double[]{2,3,0}, new long[]{1,3}, Nd4j.dataType()); + assertEquals(Collections.singletonList(new NDArrayWritable(exp)), out); String json = JsonMappers.getMapper().writeValueAsString(t); Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class); diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index a0c81ce56bab..36b4e59c4b2e 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -180,11 +180,13 @@ public INDArray asRowVector(Frame image) throws IOException { } public INDArray asRowVector(Mat image) throws IOException { - return asMatrix(image).ravel(); + INDArray arr = asMatrix(image); + return arr.reshape('c', 1, arr.length()); } public INDArray asRowVector(org.opencv.core.Mat image) throws IOException { - return asMatrix(image).ravel(); + INDArray arr = asMatrix(image); + return arr.reshape('c', 1, arr.length()); } static Mat convert(PIX pix) { diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java index a75e2a9143be..70debdc71e32 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java @@ -173,7 +173,7 @@ record = rr.nextRecord(); assertEquals(42, transform.getCurrentImage().getHeight()); INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); } ImageTransform transform2 = new ResizeImageTransform(1024, 2048); @@ -186,7 +186,7 @@ record = rr.nextRecord(); assertEquals(2048, transform2.getCurrentImage().getHeight()); INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); } //Make sure image flip does not break labels and are correct for new image size dimensions: @@ -201,7 +201,7 @@ record = rr.nextRecord(); List next = rrTransform3.next(); INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); } //Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception: @@ -217,7 +217,7 @@ record = rr.nextRecord(); INDArray labelArray = ((NDArrayWritable)next.get(1)).get(); BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0)); - assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0)); + assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0)); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java index 46cee66d079b..cec8b650f776 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java @@ -19,6 +19,7 @@ import org.apache.spark.api.java.function.Function; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -53,13 +54,13 @@ public INDArray call(List c) throws Exception { } } - INDArray arr = Nd4j.zeros(length); + INDArray arr = Nd4j.zeros(DataType.FLOAT, 1, length); int idx = 0; for (Writable w : c) { if (w instanceof NDArrayWritable) { INDArray subArr = ((NDArrayWritable) w).get(); int subLength = subArr.columns(); - arr.get(NDArrayIndex.interval(idx, idx + subLength)).assign(subArr); + arr.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, idx + subLength)).assign(subArr); idx += subLength; } else { arr.putScalar(idx++, w.toDouble()); diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java index 4d16404d7e33..b59c8b32cfb8 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.spark.BaseSparkTest; import org.datavec.python.PythonTransform; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 65995bb550ff..12c248104877 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -56,13 +56,13 @@ public void testAnalysis() throws Exception { List> data = new ArrayList<>(); data.add(Arrays.asList((Writable) new IntWritable(0), new DoubleWritable(1.0), new LongWritable(1000), - new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 100.0)))); + new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 100.0)))); data.add(Arrays.asList((Writable) new IntWritable(5), new DoubleWritable(0.0), new LongWritable(2000), - new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 200.0)))); + new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 200.0)))); data.add(Arrays.asList((Writable) new IntWritable(3), new DoubleWritable(10.0), new LongWritable(3000), - new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 300.0)))); + new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 300.0)))); data.add(Arrays.asList((Writable) new IntWritable(-1), new DoubleWritable(-1.0), new LongWritable(20000), - new Text("B"), new NDArrayWritable(Nd4j.valueArrayOf(10, 400.0)))); + new Text("B"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 400.0)))); JavaRDD> rdd = sc.parallelize(data); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 65eb487b08c8..aff85479fe51 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -1204,8 +1204,8 @@ public void testRecordReaderDataSetIteratorConcat() { DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3); DataSet ds = iter.next(); - INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); - INDArray expL = Nd4j.create(new float[] {0, 1, 0}); + INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9}); + INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3}); assertEquals(expF, ds.getFeatures()); assertEquals(expL, ds.getLabels()); @@ -1222,7 +1222,7 @@ public void testRecordReaderDataSetIteratorConcat2() { DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1); DataSet ds = iter.next(); - INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,10}); assertEquals(expF, ds.getFeatures()); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 54485bd57e65..9ad7770f0c35 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -212,9 +212,9 @@ public void testSplittingCSV() throws Exception { assertNotNull(lmds[i]); //Get the subsets of the original iris data - INDArray expIn1 = fds.get(all(), point(0)); - INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(1, 2, true)); - INDArray expOut1 = fds.get(all(), point(3)); + INDArray expIn1 = fds.get(all(), interval(0,0,true)); + INDArray expIn2 = fds.get(all(), interval(1, 2, true)); + INDArray expOut1 = fds.get(all(), interval(3,3,true)); INDArray expOut2 = lds; assertEquals(expIn1, fmds[0]); @@ -693,14 +693,14 @@ public void testTimeSeriesRandomOffset() { INDArray f = mds.getFeatures(0); INDArray l = mds.getLabels(0); - INDArray expF1 = Nd4j.create(new double[] {1.0}); - INDArray expL1 = Nd4j.create(new double[] {2.0}); + INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1}); + INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1}); - INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}); - INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}); + INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3}); + INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3}); - INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}); - INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}); + INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5}); + INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5}); assertEquals(expF1, f.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java index 0975022ee164..9ddb6abe88f9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java @@ -44,7 +44,7 @@ public void testDSI(){ assertArrayEquals(new long[]{3,5}, ds.getLabels().shape()); assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0); - assertEquals(Nd4j.ones(3,1), ds.getLabels().sum(1)); + assertEquals(Nd4j.ones(3), ds.getLabels().sum(1)); } assertEquals(5, count); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index 9ce343ab96d6..e467bcda4bd8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -168,7 +168,7 @@ public void testTimeTermination() { EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() .epochTerminationConditions(new MaxEpochsTerminationCondition(10000)) - .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(10, TimeUnit.SECONDS), + .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(5, TimeUnit.SECONDS), new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8 .scoreCalculator(new DataSetLossCalculator(irisIter, true)) .modelSaver(saver).build(); @@ -184,7 +184,7 @@ public void testTimeTermination() { assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason()); - String expDetails = new MaxTimeIterationTerminationCondition(10, TimeUnit.SECONDS).toString(); + String expDetails = new MaxTimeIterationTerminationCondition(5, TimeUnit.SECONDS).toString(); assertEquals(expDetails, result.getTerminationDetails()); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java index 955b82f33f1c..4f9bb11f5a71 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java @@ -73,7 +73,7 @@ public void testSerde() { evalLabel.putScalar(i, i % 3, 1.0); } INDArray evalProb = Nd4j.rand(10, 3); - evalProb.diviColumnVector(evalProb.sum(1)); + evalProb.diviColumnVector(evalProb.sum(true,1)); evaluation.eval(evalLabel, evalProb); roc3.eval(evalLabel, evalProb); ec.eval(evalLabel, evalProb); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index 1e37d5c5db56..29ee66c46fc5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -14,6 +14,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; @@ -70,6 +71,7 @@ public void testSelfAttentionLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -134,6 +136,7 @@ public void testLearnedSelfAttentionLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -167,6 +170,7 @@ public void testRecurrentAttentionLayer_differingTimeSteps(){ int layerSize = 8; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.IDENTITY) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -233,6 +237,7 @@ public void testRecurrentAttentionLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.IDENTITY) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -294,6 +299,7 @@ public void testAttentionVertex() { ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -360,6 +366,7 @@ public void testAttentionVertexSameInput() { ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 35f30d793044..11b3ee0d0cfd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -80,6 +80,7 @@ public void testGradient2dSimple() { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 1)).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3) @@ -125,6 +126,7 @@ public void testGradientCnnSimple() { for(boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) @@ -201,6 +203,7 @@ public void testGradientBNWithCNNandSubsamplingcCnfigurableProfiler() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .l2(l2vals[j]) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new NoOp()) @@ -310,6 +313,7 @@ public void testGradientBNWithCNNandSubsampling() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .l2(l2vals[j]) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new NoOp()) @@ -419,6 +423,7 @@ public void testGradientDense() { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .l2(l2vals[j]) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) .updater(new NoOp()) @@ -495,6 +500,7 @@ public void testGradient2dFixedGammaBeta() { for(boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 1)).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()) @@ -540,6 +546,7 @@ public void testGradientCnnFixedGammaBeta() { for(boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) @@ -584,6 +591,7 @@ public void testBatchNormCompGraphSimple() { for(boolean useLogStd : new boolean[]{true, false}) { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()) + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .setInputTypes(InputType.convolutional(height, width, channels)) .addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in") @@ -655,6 +663,7 @@ public void testGradientBNWithCNNandSubsamplingCompGraph() { Activation outputActivation = outputActivations[i]; ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new NoOp()) .dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index f375b3684522..64748f9329fd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -81,6 +81,7 @@ public void testCnn1DWithLocallyConnected1D() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) @@ -159,6 +160,7 @@ public void testCnn1DWithCropping1D() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) @@ -239,6 +241,7 @@ public void testCnn1DWithZeroPadding1D() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) @@ -317,6 +320,7 @@ public void testCnn1DWithSubsampling1D() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() .layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) @@ -379,6 +383,7 @@ public void testCnn1dWithMasking(){ Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) .dist(new NormalDistribution(0, 1)).convolutionMode(cm) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 72b0cd6be567..f23e965a156d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -109,6 +109,7 @@ public void testCnn3DPlain() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) .list() @@ -212,6 +213,7 @@ public void testCnn3DZeroPadding() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) .list() @@ -307,6 +309,7 @@ public void testCnn3DPooling() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .dist(new NormalDistribution(0, 1)) @@ -395,6 +398,7 @@ public void testCnn3DUpsampling() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) .list() @@ -489,6 +493,7 @@ public void testCnn3DCropping() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) .list() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index cbaca38be48b..c13923c3fa57 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -86,6 +86,7 @@ public void testGradientCNNMLN() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .weightInit(WeightInit.XAVIER).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn).build()) @@ -169,6 +170,7 @@ public void testGradientCNNL1L2MLN() { double l1 = l1vals[k]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .l2(l2).l1(l1).l2Bias(biasL2[k]).l1Bias(biasL1[k]) .optimizationAlgo( OptimizationAlgorithm.CONJUGATE_GRADIENT) @@ -255,6 +257,7 @@ public void testCnnWithSpaceToDepth() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).hasBias(false) @@ -317,6 +320,7 @@ public void testCnnWithSpaceToBatch() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth) .nOut(3).build())//output: (5-2+0)/1+1 = 4 @@ -381,6 +385,7 @@ public void testCnnWithUpsampling() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel, @@ -449,7 +454,7 @@ public void testCnnWithSubsampling() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, new ConvolutionLayer.Builder(kernel, @@ -519,6 +524,7 @@ public void testCnnWithSubsamplingV2() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, new ConvolutionLayer.Builder(kernel, @@ -577,6 +583,7 @@ public void testCnnLocallyConnected2D() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) + .dataType(DataType.DOUBLE) .activation(afn) .list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) @@ -643,6 +650,7 @@ public void testCnnMultiLayer() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) + .dataType(DataType.DOUBLE) .activation(afn) .list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) @@ -707,6 +715,7 @@ public void testCnnSamePaddingMode() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(Same).list() .layer(0, new ConvolutionLayer.Builder().name("layer 0").kernelSize(k, k) @@ -776,6 +785,7 @@ public void testCnnSamePaddingModeStrided() { .stride(stride, stride).padding(0, 0).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(Same).list() .layer(0, convFirst ? convLayer : poolLayer) @@ -839,6 +849,7 @@ public void testCnnZeroPaddingLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).list() .layer(0, new ConvolutionLayer.Builder(kernel, stride, padding) .nIn(inputDepth).nOut(3).build())//output: (6-2+0)/1+1 = 5 @@ -916,6 +927,7 @@ public void testDeconvolution2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(act) .list() @@ -990,6 +1002,7 @@ public void testSeparableConv2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) .convolutionMode(cm) @@ -1062,6 +1075,7 @@ public void testCnnDilated() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(cm).list() .layer(new ConvolutionLayer.Builder().name("layer 0") @@ -1136,6 +1150,7 @@ public void testCropping2DLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .weightInit(new NormalDistribution(0, 1)).list() @@ -1211,6 +1226,7 @@ public void testDepthwiseConv2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) .convolutionMode(cm) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java index 62fd64878cb6..d692920282eb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.junit.Test; import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; @@ -77,6 +78,7 @@ public void testCapsNet() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(123) .updater(new NoOp()) .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index de205ff200a8..b511d36176d5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -91,7 +91,7 @@ public void testDropoutGradient() { } NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0,1)) .convolutionMode(ConvolutionMode.Same) .dropOut(dropout) @@ -148,7 +148,7 @@ public void testCompGraphMultiInput(){ int mb = 3; ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0,1)) .convolutionMode(ConvolutionMode.Same) .dropOut(new GaussianDropout(0.1)) //0.33 stdev. Gaussian dropout: out = in * N(1,stdev) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index e7d82fc6a8d7..a107b450add8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -70,6 +70,7 @@ public void testLSTMGlobalPoolingBasicMultiLayer() { for (PoolingType pt : poolingTypes) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -133,6 +134,7 @@ public void testCnnGlobalPoolingBasicMultiLayer() { for (PoolingType pt : poolingTypes) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth) @@ -188,6 +190,7 @@ public void testLSTMWithMasking() { for (PoolingType pt : poolingTypes) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -273,6 +276,7 @@ public void testCnnGlobalPoolingMasking() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).convolutionMode(ConvolutionMode.Same) .seed(12345L).list() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 1dbbd1d316d6..0f83e38566ad 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -72,6 +72,7 @@ public void testMinibatchApplication() { IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().miniBatch(false) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new NoOp()) .list() .layer(0, @@ -161,6 +162,7 @@ public void testGradientMLP2LayerIrisSimple() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) .list().layer(0, @@ -250,6 +252,7 @@ public void testGradientMLP2LayerIrisL1L2Simple() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + .dataType(DataType.DOUBLE) .l2Bias(biasL2[k]).l1Bias(biasL1[k]) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) .seed(12345L) @@ -320,6 +323,7 @@ public void testEmbeddingLayerPreluSimple() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .list().layer(new EmbeddingLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) .updater(new NoOp()).build()) @@ -357,6 +361,7 @@ public void testEmbeddingLayerSimple() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .list().layer(0, new EmbeddingLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) @@ -423,6 +428,7 @@ public void testAutoEncoder() { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .l2(l2).l1(l1) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) @@ -483,6 +489,7 @@ public void elementWiseMultiplicationLayerTest(){ for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH}) { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) .weightInit(new UniformDistribution(0, 1)) @@ -551,6 +558,7 @@ public void testEmbeddingSequenceLayer(){ for(int inputRank : new int[]{2,3}) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) .weightInit(new NormalDistribution(0, 1)) @@ -657,6 +665,7 @@ public void testGradientWeightDecay() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + .dataType(DataType.DOUBLE) .l2Bias(biasL2[k]).l1Bias(biasL1[k]) .weightDecay(wdVals[k]).weightDecayBias(wdBias[k]) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) @@ -718,6 +727,7 @@ public void testGradientMLP2LayerIrisLayerNorm() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) .list().layer(0, diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index aa5a870eb794..d0c6094be735 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -71,6 +71,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testBasicIris() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) .graphBuilder().addInputs("input") @@ -117,6 +118,7 @@ public void testBasicIris() { public void testBasicIrisWithMerging() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) .graphBuilder().addInputs("input") @@ -174,6 +176,7 @@ public void testBasicIrisWithElementWiseNode() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input") @@ -232,6 +235,7 @@ public void testBasicIrisWithElementWiseNodeInputSizeGreaterThanTwo() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input") @@ -287,6 +291,7 @@ public void testCnnDepthMerge() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 0.1)) .updater(new NoOp()).graphBuilder().addInputs("input") @@ -334,6 +339,7 @@ public void testLSTMWithMerging() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new UniformDistribution(0.2, 0.6)) .updater(new NoOp()).graphBuilder().addInputs("input") @@ -394,6 +400,7 @@ public void testLSTMWithMerging() { public void testLSTMWithSubset() { Nd4j.getRandom().setSeed(1234); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(1234) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input").setOutputs("out") @@ -436,6 +443,7 @@ public void testLSTMWithLastTimeStepVertex() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input").setOutputs("out") @@ -489,6 +497,7 @@ public void testLSTMWithDuplicateToTimeSeries() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder() @@ -543,6 +552,7 @@ public void testLSTMWithReverseTimeSeriesVertex() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder() @@ -606,6 +616,7 @@ public void testMultipleInputsLayer() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).activation(Activation.TANH).graphBuilder().addInputs("i0", "i1", "i2") @@ -649,6 +660,7 @@ public void testMultipleInputsLayer() { public void testMultipleOutputsLayer() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).activation(Activation.TANH).graphBuilder().addInputs("i0") @@ -689,6 +701,7 @@ public void testMultipleOutputsLayer() { public void testMultipleOutputsMergeVertex() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).activation(Activation.TANH).graphBuilder().addInputs("i0", "i1", "i2") @@ -737,6 +750,7 @@ public void testMultipleOutputsMergeCnn() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).activation(Activation.TANH).graphBuilder().addInputs("input") @@ -787,6 +801,7 @@ public void testBasicIrisTripletStackingL2Loss() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder() @@ -861,6 +876,7 @@ public void testBasicCenterLoss() { for (double lambda : new double[] {0.0, 0.5, 2.0}) { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new GaussianDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input1") @@ -926,6 +942,7 @@ public void testCnnPoolCenterLoss() { for (double lambda : new double[] {0.0, 0.5, 2.0}) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(3).build()) @@ -979,6 +996,7 @@ public void testCnnPoolCenterLoss() { public void testBasicL2() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1031,6 +1049,7 @@ public void testBasicStackUnstack() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1086,6 +1105,7 @@ public void testBasicStackUnstackDebug() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1146,6 +1166,7 @@ public void testBasicStackUnstackVariableLengthTS() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1209,6 +1230,7 @@ public void testBasicTwoOutputs() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1265,6 +1287,7 @@ public void testL2NormalizeVertex2d() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1312,6 +1335,7 @@ public void testL2NormalizeVertex4d() { int dIn = 2; ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1364,6 +1388,7 @@ public void testGraphEmbeddingLayerSimple() { } ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .updater(new NoOp()).graphBuilder().addInputs("in") .addLayer("0", new EmbeddingLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index 1dd9095341ac..91737d6fe8b0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -43,8 +43,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.nd4j.linalg.indexing.NDArrayIndex.all; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; +import static org.nd4j.linalg.indexing.NDArrayIndex.*; /**Gradient checking tests with masking (i.e., variable length time series inputs, one-to-many and many-to-one etc) */ @@ -128,6 +127,7 @@ public void gradientCheckMaskingOutputSimple() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new NormalDistribution(0, 1)) @@ -171,6 +171,7 @@ public void testBidirectionalLSTMMasking() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize) .activation(Activation.TANH).build()) @@ -258,6 +259,7 @@ public void testPerOutputMaskingMLP() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).seed(12345) .list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -353,6 +355,7 @@ public void testPerOutputMaskingRnn() { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).seed(12345) .list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -384,6 +387,7 @@ public void testPerOutputMaskingRnn() { //Check the equivalent compgraph: Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration cg = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 2)).seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) @@ -414,6 +418,7 @@ public void testOutputLayerMasking(){ int mb = 10; int tsLength = 5; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(new NormalDistribution(0,2)) .updater(new NoOp()) .list() @@ -444,10 +449,10 @@ public void testOutputLayerMasking(){ continue; } - INDArray fView = f.get(point(i), all(),all()); + INDArray fView = f.get(interval(i,i,true), all(),all()); fView.assign(Nd4j.rand(fView.shape())); - INDArray lView = l.get(point(i), all()); + INDArray lView = l.get(interval(i,i,true), all()); lView.assign(TestUtils.randomOneHot(1, lView.size(1))); double score2 = net.score(new DataSet(f,l,null,lm)); @@ -464,6 +469,7 @@ public void testOutputLayerMaskingCG(){ int mb = 10; int tsLength = 5; ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(new NormalDistribution(0,2)) .updater(new NoOp()) .graphBuilder() @@ -497,10 +503,10 @@ public void testOutputLayerMaskingCG(){ continue; } - INDArray fView = f.get(point(i), all(),all()); + INDArray fView = f.get(interval(i,i,true), all(),all()); fView.assign(Nd4j.rand(fView.shape())); - INDArray lView = l.get(point(i), all()); + INDArray lView = l.get(interval(i,i,true), all()); lView.assign(TestUtils.randomOneHot(1, lView.size(1))); double score2 = net.score(new DataSet(f,l,null,lm)); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 0cad498c919e..18fbcce453d1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -69,6 +69,7 @@ public void testGradientLRNSimple() { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new ConvolutionLayer.Builder().nOut(6).kernelSize(2, 2).stride(1, 1) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index 45e6e646a3f9..f51ca0a67045 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -88,7 +88,9 @@ public void testLSTMBasicMultiLayer() { } MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L).list() + new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) + .list() .layer(0, l0).layer(1, l1) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT) @@ -178,6 +180,7 @@ public void testGradientLSTMFull() { NeuralNetConfiguration.Builder conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 1)).updater(new NoOp()); @@ -266,6 +269,7 @@ public void testGradientLSTMEdgeCases() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).list().layer(0, layer) .layer(1, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX) @@ -343,6 +347,7 @@ public void testGradientGravesBidirectionalLSTMFull() { conf.l1Bias(biasL1[k]); MultiLayerConfiguration mlc = conf.seed(12345L) + .dataType(DataType.DOUBLE) .list().layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize) @@ -411,6 +416,7 @@ public void testGradientGravesBidirectionalLSTMEdgeCases() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .list() .layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize) @@ -459,6 +465,7 @@ public void testGradientCnnFfRnn() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).seed(12345) + .dataType(DataType.DOUBLE) .dist(new UniformDistribution(-2, 2)).list() .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(3).nOut(5).stride(1, 1) .activation(Activation.TANH).build()) //Out: (10-5)/1+1 = 6 -> 6x6x5 diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 3f95ec7af5e7..bf06551a17cc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -167,6 +167,7 @@ public void lossFunctionGradientCheck() { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) .dist(new UniformDistribution(-2, 2)).list() @@ -328,6 +329,7 @@ public void lossFunctionGradientCheckLossLayer() { } Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) .dist(new UniformDistribution(-2, 2)).list() @@ -615,6 +617,7 @@ public void lossFunctionWeightedGradientCheck() { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) // .dist(new UniformDistribution(-3, 3)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index 318b1419862f..f28eb0c366f7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -66,6 +66,7 @@ public void testGradientNoBiasDenseOutput() { for (boolean outHasBias : new boolean[]{true, false}) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) .list() @@ -136,6 +137,7 @@ public void testGradientNoBiasRnnOutput() { for (boolean rnnOutHasBias : new boolean[]{true, false}) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) .list() @@ -196,6 +198,7 @@ public void testGradientNoBiasEmbedding() { for (boolean embeddingHasBias : new boolean[]{true, false}) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) .list() @@ -262,7 +265,7 @@ public void testCnnWithSubsamplingNoBias() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list() .layer(new ConvolutionLayer.Builder(kernel, diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index dea0aad7a34e..2552b6072b7d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -105,6 +105,7 @@ public void testRnnLossLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .list() .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -210,6 +211,7 @@ public void testCnnLossLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .list() @@ -356,6 +358,7 @@ public void testCnn3dLossLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .list() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index ae5cdaea8749..9fcc0fe5649b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -99,6 +99,7 @@ public void testBidirectionalWrapper() { System.out.println("Starting test: " + name); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .list() @@ -173,6 +174,7 @@ public void testSimpleRnn() { System.out.println("Starting test: " + name); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) @@ -245,6 +247,7 @@ public void testLastTimeStepLayer(){ } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index 6d6e21ec26ab..fc62be1fa465 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -167,7 +167,7 @@ public void testMaskLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) .activation(Activation.TANH) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0,2)) .list() .layer(l1) @@ -199,6 +199,7 @@ public void testFrozenWithBackprop(){ for( int minibatch : new int[]{1,5}) { MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(Updater.NONE) .list() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index b5f3bc76013a..b4b679beb8e0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -95,6 +95,7 @@ public void testVaeAsMLP() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .l2Bias(biasL2[i]).l1Bias(biasL1[i]) .updater(new NoOp()).seed(12345L).list() @@ -174,6 +175,7 @@ public void testVaePretrain() { Activation pxzAfn = pxzAfns[i]; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(l2) + .dataType(DataType.DOUBLE) .l1(l1).l2Bias(biasL2[i]).l1Bias(biasL1[i]).updater(new NoOp()) .seed(12345L).weightInit(WeightInit.XAVIER).list() .layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3) @@ -262,6 +264,7 @@ public void testVaePretrainReconstructionDistributions() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L).dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -307,6 +310,7 @@ public void testVaePretrainMultipleSamples() { INDArray features = Nd4j.rand(minibatch, 4); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L).weightInit(WeightInit.XAVIER).list() .layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(5, 6) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 3fb58b56ded9..9ee04bb8b2df 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -103,6 +103,7 @@ public void testYoloOutputLayer() { INDArray labels = yoloLabels(mb, c, h, w); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(a) .l1(l1[i]).l2(l2[i]) @@ -209,6 +210,7 @@ public void yoloGradientCheckRealData() throws Exception { iter.setPreProcessor(new ImagePreProcessingScaler()); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .convolutionMode(ConvolutionMode.Same) .updater(new NoOp()) .dist(new GaussianDistribution(0,0.1)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java index 6596df0314d6..40c8e4d637ff 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java @@ -40,6 +40,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -336,7 +337,7 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { throw new UnsupportedOperationException("Not supported"); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java index 8cded1ffb4d1..0857ffded64d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java @@ -120,7 +120,7 @@ public void testRNG() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer model = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer model = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); @@ -131,7 +131,7 @@ public void testRNG() { long numParams2 = conf2.getLayer().initializer().numParams(conf); INDArray params2 = Nd4j.create(1, numParams); - Layer model2 = conf2.getLayer().instantiate(conf2, null, 0, params2, true); + Layer model2 = conf2.getLayer().instantiate(conf2, null, 0, params2, true, params.dataType()); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); assertEquals(modelWeights, modelWeights2); @@ -209,7 +209,7 @@ private static Layer getLayer(int nIn, int nOut, IWeightInit weightInit, boolean NeuralNetConfiguration conf = getConfig(nIn, nOut, weightInit, preTrain); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true); + return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java index 241b64716470..6cc891488b86 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java @@ -526,7 +526,7 @@ public void testSpatialDropoutValues3D(){ for( int j=0; j<8; j++ ){ double value = out.getDouble(i,j,0); assertTrue( value == 0 || value == 2.0); - INDArray exp = Nd4j.valueArrayOf(new int[]{1,12}, value); + INDArray exp = Nd4j.valueArrayOf(new int[]{12}, value); INDArray act = out.get(point(i), point(j), all()); assertEquals(exp, act); @@ -555,7 +555,7 @@ public void testSpatialDropoutValues3D(){ for( int j=0; j<8; j++ ){ double value = out.getDouble(m,j,0); assertTrue( value == 0 || value == 2.0); - INDArray exp = Nd4j.valueArrayOf(new int[]{1, 12}, value); + INDArray exp = Nd4j.valueArrayOf(new int[]{12}, value); INDArray act = out.get(point(m), point(j), all()); assertEquals(exp, act); @@ -577,7 +577,7 @@ public void testSpatialDropoutValues3D(){ for( int j=0; j<8; j++ ){ double value = out.getDouble(m,j,0); assertTrue( value == 0 || value == 10.0); - INDArray exp = Nd4j.valueArrayOf(new int[]{1,12}, value); + INDArray exp = Nd4j.valueArrayOf(new int[]{12}, value); INDArray act = out.get(point(m), point(j), all()); assertEquals(exp, act); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index 25667600221b..c66c8878e680 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -31,6 +31,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -192,6 +193,7 @@ public void testElementWiseVertexFullAdd() { int midsz = 13; int outputsz = 11; ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input1", "input2", "input3") @@ -367,6 +369,7 @@ public void testElementWiseVertexFullProduct() { int midsz = 13; int outputsz = 11; ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input1", "input2", "input3") @@ -541,6 +544,7 @@ public void testElementWiseVertexFullSubtract() { int midsz = 13; int outputsz = 11; ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input1", "input2") diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index 88b6d6b91e04..e2ac5542ecf5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -32,6 +32,7 @@ import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; @@ -138,6 +139,7 @@ public void testComprehensive() { {0.55, 0.60, 0.65, 0.70, 0.75}, {0.80, 0.85, 0.90, 0.95, 0.99}}); ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + .dataType(DataType.DOUBLE) .updater(new Sgd(0.01)) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input") @@ -176,7 +178,7 @@ public void testComprehensive() { // First things first, let's calculate the score. // FIXME: int cast int batchsz = (int) input.shape()[0]; - INDArray z = input.mmul(W).add(b.repmat(batchsz, 1)); + INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1)); INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!! INDArray q = a.mmul(V).add(c.repmat(batchsz, 1)); INDArray o = nullsafe(a2.getActivation(q.dup(), true)); @@ -196,8 +198,8 @@ public void testComprehensive() { * dq1/dv11 = a1 dq2/dV12 = a1 dq3/dV13 = a1 ... * dq1/dv21 = a2 dq2... */ - INDArray dEdo = Nd4j.zeros(target.shape()); - dEdo.addi(o).subi(target).muli(2); // This should be of size batchsz x outputsz + INDArray dEdo = target.like(); //Nd4j.zeros(target.shape()); + dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2); // This should be of size batchsz x outputsz dEdo.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. Pair derivs2 = a2.backprop(q, dEdo); @@ -246,7 +248,7 @@ public void testComprehensive() { } private static double sum_errors(INDArray a, INDArray b) { - INDArray o = a.sub(b); + INDArray o = a.sub(b.castTo(a.dataType())); return o.mul(o).sumNumber().doubleValue(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/misc/TestGraphVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/misc/TestGraphVertex.java index b9999e2468f0..a69ccd33ea84 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/misc/TestGraphVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/misc/TestGraphVertex.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @AllArgsConstructor @@ -58,7 +59,7 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { throw new UnsupportedOperationException("Not supported"); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index 471149346389..e1ce77cd36bc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -60,7 +60,7 @@ public void testRnnToFeedForwardPreProcessor() { long numParams = nnc.getLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true); + DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activations3dc = Nd4j.create(new int[] {miniBatchSize, layerSize, timeSeriesLength}, 'c'); @@ -87,8 +87,8 @@ public void testRnnToFeedForwardPreProcessor() { //(example=0,t=0), (example=0,t=1), (example=0,t=2), ..., (example=1,t=0), (example=1,t=1), ... int nRows = activations2dc.rows(); for (int i = 0; i < nRows; i++) { - INDArray rowc = activations2dc.getRow(i); - INDArray rowf = activations2df.getRow(i); + INDArray rowc = activations2dc.getRow(i, true); + INDArray rowf = activations2df.getRow(i, true); assertArrayEquals(rowc.shape(), new long[] {1, layerSize}); assertEquals(rowc, rowf); @@ -98,7 +98,7 @@ public void testRnnToFeedForwardPreProcessor() { //f order reshaping int time = i / miniBatchSize; int origExampleNum = i % miniBatchSize; - INDArray expectedRow = activations3dc.tensorAlongDimension(time, 1, 0).getRow(origExampleNum); + INDArray expectedRow = activations3dc.tensorAlongDimension(time, 1, 0).getRow(origExampleNum, true); assertEquals(expectedRow, rowc); assertEquals(expectedRow, rowf); } @@ -145,7 +145,7 @@ public void testFeedForwardToRnnPreProcessor() { val numParams = nnc.getLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true); + DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray rand = Nd4j.rand(miniBatchSize * timeSeriesLength, layerSize); @@ -170,9 +170,9 @@ public void testFeedForwardToRnnPreProcessor() { int time = i / miniBatchSize; int example = i % miniBatchSize; - INDArray row2d = activations2dc.getRow(i); - INDArray row3dc = activations3dc.tensorAlongDimension(time, 0, 1).getRow(example); - INDArray row3df = activations3df.tensorAlongDimension(time, 0, 1).getRow(example); + INDArray row2d = activations2dc.getRow(i, true); + INDArray row3dc = activations3dc.tensorAlongDimension(time, 0, 1).getRow(example, true); + INDArray row3df = activations3df.tensorAlongDimension(time, 0, 1).getRow(example, true); assertEquals(row2d, row3dc); assertEquals(row2d, row3df); @@ -232,7 +232,7 @@ public void testCnnToRnnPreProcessor() { val numParams = nnc.getLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true); + (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activationsCnn = Nd4j.rand(new int[] {miniBatchSize * timeSeriesLength, nChannels, @@ -314,7 +314,7 @@ public void testRnnToCnnPreProcessor() { val numParams = nnc.getLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true); + (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); val shape_rnn = new long[] {miniBatchSize, nChannels * inputHeight * inputWidth, diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index b4acd2a80000..0ebc598bc68b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -34,6 +34,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.dataset.DataSet; @@ -259,7 +260,7 @@ public void testDropConnectValues() { INDArray outTrain = d.getParameter(l, "W", 0, 0, true, LayerWorkspaceMgr.noWorkspaces()); assertNotEquals(l.getParam("W"), outTrain); - assertEquals(l.getParam("W"), Nd4j.ones(10, 10)); + assertEquals(l.getParam("W"), Nd4j.ones(DataType.FLOAT, 10, 10)); int countZeros = Nd4j.getExecutioner().exec(new MatchCondition(outTrain, Conditions.equals(0))).getInt(0); int countOnes = Nd4j.getExecutioner().exec(new MatchCondition(outTrain, Conditions.equals(1))).getInt(0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java new file mode 100644 index 000000000000..74e119990c78 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -0,0 +1,1259 @@ +package org.deeplearning4j.nn.dtypes; + +import com.google.common.collect.ImmutableSet; +import com.google.common.reflect.ClassPath; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.distribution.UniformDistribution; +import org.deeplearning4j.nn.conf.dropout.AlphaDropout; +import org.deeplearning4j.nn.conf.dropout.GaussianDropout; +import org.deeplearning4j.nn.conf.dropout.GaussianNoise; +import org.deeplearning4j.nn.conf.dropout.SpatialDropout; +import org.deeplearning4j.nn.conf.graph.*; +import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; +import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; +import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; +import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; +import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.util.MaskLayer; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; +import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.layers.util.IdentityLayer; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.WeightInitDistribution; +import org.junit.AfterClass; +import org.junit.Ignore; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.learning.config.NoOp; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; + +import java.io.IOException; +import java.lang.reflect.Modifier; +import java.util.*; + +import static org.junit.Assert.*; + +@Slf4j +public class DTypeTests extends BaseDL4JTest { + + protected static Set> seenLayers = new HashSet<>(); + protected static Set> seenPreprocs = new HashSet<>(); + protected static Set> seenVertices = new HashSet<>(); + + protected static Set> ignoreClasses = new HashSet<>(Arrays.>asList( + Pooling2D.class, //Alias for SubsamplingLayer + Convolution2D.class, //Alias for ConvolutionLayer + Pooling1D.class, //Alias for Subsampling1D + Convolution1D.class //Alias for Convolution1DLayer + )); + + @AfterClass + public static void after() { + ImmutableSet info; + try { + //Dependency note: this ClassPath class was added in Guava 14 + info = com.google.common.reflect.ClassPath.from(DTypeTests.class.getClassLoader()) + .getTopLevelClassesRecursive("org.deeplearning4j"); + } catch (IOException e) { + //Should never happen + throw new RuntimeException(e); + } + + Set> layerClasses = new HashSet<>(); + Set> preprocClasses = new HashSet<>(); + Set> vertexClasses = new HashSet<>(); + for (ClassPath.ClassInfo ci : info) { + Class clazz; + try { + clazz = Class.forName(ci.getName()); + } catch (ClassNotFoundException e) { + //Should never happen as this was found on the classpath + throw new RuntimeException(e); + } + + if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) { + continue; + } + + if(clazz.getName().toLowerCase().contains("custom") || clazz.getName().contains("samediff.testlayers") + || clazz.getName().toLowerCase().contains("test") || ignoreClasses.contains(clazz)){ + continue; + } + + if (Layer.class.isAssignableFrom(clazz)) { + layerClasses.add(clazz); + } else if (InputPreProcessor.class.isAssignableFrom(clazz)) { + preprocClasses.add(clazz); + } else if (GraphVertex.class.isAssignableFrom(clazz)) { + vertexClasses.add(clazz); + } + } + + boolean fail = false; + if (seenLayers.size() < layerClasses.size()) { + for (Class c : layerClasses) { + if (!seenLayers.contains(c) && !ignoreClasses.contains(c)) { + log.warn("Layer class not tested for global vs. network datatypes: {}", c); + fail = true; + } + } + } + if (seenPreprocs.size() < preprocClasses.size()) { + for (Class c : preprocClasses) { + if (!seenPreprocs.contains(c) && !ignoreClasses.contains(c)) { + log.warn("Preprocessor class not tested for global vs. network datatypes: {}", c); + fail = true; + } + } + } + if (seenVertices.size() < vertexClasses.size()) { + for (Class c : vertexClasses) { + if (!seenVertices.contains(c) && !ignoreClasses.contains(c)) { + log.warn("GraphVertex class not tested for global vs. network datatypes: {}", c); + fail = true; + } + } + } + + if (fail) { + fail("Tested " + seenLayers.size() + " of " + layerClasses.size() + " layers, " + seenPreprocs.size() + " of " + preprocClasses.size() + + " preprocessors, " + seenVertices.size() + " of " + vertexClasses.size() + " vertices"); + } + } + + public static void logUsedClasses(MultiLayerNetwork net) { + MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); + for (NeuralNetConfiguration nnc : conf.getConfs()) { + Layer l = nnc.getLayer(); + seenLayers.add(l.getClass()); + if (l instanceof BaseWrapperLayer) { + BaseWrapperLayer bwl = (BaseWrapperLayer) l; + seenLayers.add(bwl.getUnderlying().getClass()); + } else if (l instanceof Bidirectional) { + seenLayers.add(((Bidirectional) l).getFwd().getClass()); + } + } + + Map preprocs = conf.getInputPreProcessors(); + if (preprocs != null) { + for (InputPreProcessor ipp : preprocs.values()) { + seenPreprocs.add(ipp.getClass()); + } + } + } + + public static void logUsedClasses(ComputationGraph net) { + ComputationGraphConfiguration conf = net.getConfiguration(); + for(GraphVertex gv : conf.getVertices().values()){ + seenVertices.add(gv.getClass()); + if(gv instanceof LayerVertex){ + seenLayers.add(((LayerVertex) gv).getLayerConf().getLayer().getClass()); + InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor(); + if(ipp != null){ + seenPreprocs.add(ipp.getClass()); + } + } else if(gv instanceof PreprocessorVertex){ + seenPreprocs.add(((PreprocessorVertex) gv).getPreProcessor().getClass()); + } + } + + } + + @Test + public void testMultiLayerNetworkTypeConversion() { + + for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(dt, dt); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(12345) + .weightInit(WeightInit.XAVIER) + .updater(new Adam(0.01)) + .dataType(DataType.DOUBLE) + .list() + .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build()) + .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build()) + .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray inD = Nd4j.rand(DataType.DOUBLE, 1, 10); + INDArray lD = Nd4j.create(DataType.DOUBLE, 1, 10); + net.fit(inD, lD); + + INDArray outDouble = net.output(inD); + net.setInput(inD); + net.setLabels(lD); + net.computeGradientAndScore(); + double scoreDouble = net.score(); + INDArray grads = net.getFlattenedGradients(); + INDArray u = net.getUpdater().getStateViewArray(); + assertEquals(DataType.DOUBLE, net.params().dataType()); + assertEquals(DataType.DOUBLE, grads.dataType()); + assertEquals(DataType.DOUBLE, u.dataType()); + + + MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT); + netFloat.initGradientsView(); + assertEquals(DataType.FLOAT, netFloat.params().dataType()); + assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); + assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); + INDArray inF = inD.castTo(DataType.FLOAT); + INDArray lF = lD.castTo(DataType.FLOAT); + INDArray outFloat = netFloat.output(inF); + netFloat.setInput(inF); + netFloat.setLabels(lF); + netFloat.computeGradientAndScore(); + double scoreFloat = netFloat.score(); + INDArray gradsFloat = netFloat.getFlattenedGradients(); + INDArray uFloat = netFloat.getUpdater().getStateViewArray(); + + assertEquals(scoreDouble, scoreFloat, 1e-6); + assertEquals(outDouble.castTo(DataType.FLOAT), outFloat); + assertEquals(grads.castTo(DataType.FLOAT), gradsFloat); + INDArray uCast = u.castTo(DataType.FLOAT); + assertTrue(uCast.equalsWithEps(uFloat, 1e-4)); + + MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF); + netFP16.initGradientsView(); + assertEquals(DataType.HALF, netFP16.params().dataType()); + assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); + assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); + + INDArray inH = inD.castTo(DataType.HALF); + INDArray lH = lD.castTo(DataType.HALF); + INDArray outHalf = netFP16.output(inH); + netFP16.setInput(inH); + netFP16.setLabels(lH); + netFP16.computeGradientAndScore(); + double scoreHalf = netFP16.score(); + INDArray gradsHalf = netFP16.getFlattenedGradients(); + INDArray uHalf = netFP16.getUpdater().getStateViewArray(); + + assertEquals(scoreDouble, scoreHalf, 1e-4); + boolean outHalfEq = outDouble.castTo(DataType.HALF).equalsWithEps(outHalf, 1e-3); + assertTrue(outHalfEq); + boolean gradsHalfEq = grads.castTo(DataType.HALF).equalsWithEps(gradsHalf, 1e-3); + assertTrue(gradsHalfEq); + INDArray uHalfCast = u.castTo(DataType.HALF); + assertTrue(uHalfCast.equalsWithEps(uHalf, 1e-4)); + } + } + + @Test + public void testComputationGraphTypeConversion() { + + for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(dt, dt); + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(12345) + .weightInit(WeightInit.XAVIER) + .updater(new Adam(0.01)) + .dataType(DataType.DOUBLE) + .graphBuilder() + .addInputs("in") + .layer("l0", new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(), "in") + .layer("l1", new DenseLayer.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(), "l0") + .layer("out", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "l1") + .setOutputs("out") + .build(); + + ComputationGraph net = new ComputationGraph(conf); + net.init(); + + INDArray inD = Nd4j.rand(DataType.DOUBLE, 1, 10); + INDArray lD = Nd4j.create(DataType.DOUBLE, 1, 10); + net.fit(new DataSet(inD, lD)); + + INDArray outDouble = net.outputSingle(inD); + net.setInput(0, inD); + net.setLabels(lD); + net.computeGradientAndScore(); + double scoreDouble = net.score(); + INDArray grads = net.getFlattenedGradients(); + INDArray u = net.getUpdater().getStateViewArray(); + assertEquals(DataType.DOUBLE, net.params().dataType()); + assertEquals(DataType.DOUBLE, grads.dataType()); + assertEquals(DataType.DOUBLE, u.dataType()); + + + ComputationGraph netFloat = net.convertDataType(DataType.FLOAT); + netFloat.initGradientsView(); + assertEquals(DataType.FLOAT, netFloat.params().dataType()); + assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); + assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); + INDArray inF = inD.castTo(DataType.FLOAT); + INDArray lF = lD.castTo(DataType.FLOAT); + INDArray outFloat = netFloat.outputSingle(inF); + netFloat.setInput(0, inF); + netFloat.setLabels(lF); + netFloat.computeGradientAndScore(); + double scoreFloat = netFloat.score(); + INDArray gradsFloat = netFloat.getFlattenedGradients(); + INDArray uFloat = netFloat.getUpdater().getStateViewArray(); + + assertEquals(scoreDouble, scoreFloat, 1e-6); + assertEquals(outDouble.castTo(DataType.FLOAT), outFloat); + assertEquals(grads.castTo(DataType.FLOAT), gradsFloat); + INDArray uCast = u.castTo(DataType.FLOAT); + assertTrue(uCast.equalsWithEps(uFloat, 1e-4)); + + ComputationGraph netFP16 = net.convertDataType(DataType.HALF); + netFP16.initGradientsView(); + assertEquals(DataType.HALF, netFP16.params().dataType()); + assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); + assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); + + INDArray inH = inD.castTo(DataType.HALF); + INDArray lH = lD.castTo(DataType.HALF); + INDArray outHalf = netFP16.outputSingle(inH); + netFP16.setInput(0, inH); + netFP16.setLabels(lH); + netFP16.computeGradientAndScore(); + double scoreHalf = netFP16.score(); + INDArray gradsHalf = netFP16.getFlattenedGradients(); + INDArray uHalf = netFP16.getUpdater().getStateViewArray(); + + assertEquals(scoreDouble, scoreHalf, 1e-4); + boolean outHalfEq = outDouble.castTo(DataType.HALF).equalsWithEps(outHalf, 1e-3); + assertTrue(outHalfEq); + boolean gradsHalfEq = grads.castTo(DataType.HALF).equalsWithEps(gradsHalf, 1e-3); + assertTrue(gradsHalfEq); + INDArray uHalfCast = u.castTo(DataType.HALF); + assertTrue(uHalfCast.equalsWithEps(uHalf, 1e-4)); + } + } + + + @Test @Ignore //TODO JVM crash + public void testDtypesModelVsGlobalDtypeCnn() { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (int outputLayer = 0; outputLayer < 5; outputLayer++) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; + + Layer ol; + Layer secondLast; + switch (outputLayer) { + case 0: + ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new GlobalPoolingLayer(PoolingType.MAX); + break; + case 1: + ol = new LossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new FrozenLayerWithBackprop(new DenseLayer.Builder().nOut(10).activation(Activation.SIGMOID).build()); + break; + case 2: + ol = new CenterLossOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new VariationalAutoencoder.Builder().encoderLayerSizes(10).decoderLayerSizes(10).nOut(10).activation(Activation.SIGMOID).build(); + break; + case 3: + ol = new CnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build(); + break; + case 4: + ol = new Yolo2OutputLayer.Builder().boundingBoxPriors(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build(); + secondLast = new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build(); + break; + default: + throw new RuntimeException(); + } + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .convolutionMode(ConvolutionMode.Same) + .updater(new Adam(1e-2)) + .list() + .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build()) + .layer(new LocalResponseNormalization()) + .layer(new DropoutLayer(0.5)) + .layer(new DropoutLayer(new AlphaDropout(0.5))) + .layer(new DropoutLayer(new GaussianDropout(0.5))) + .layer(new DropoutLayer(new GaussianNoise(0.1))) + .layer(new DropoutLayer(new SpatialDropout(0.5))) + .layer(new SubsamplingLayer.Builder().poolingType(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(2, 2).build()) + .layer(new Pooling2D.Builder().poolingType(SubsamplingLayer.PoolingType.AVG).kernelSize(2, 2).stride(1, 1).build()) + .layer(new Deconvolution2D.Builder().kernelSize(2, 2).stride(2, 2).nOut(3).activation(Activation.TANH).build()) +// .layer(new LocallyConnected2D.Builder().nOut(3).kernelSize(2,2).stride(1,1).activation(Activation.SIGMOID).build()) //EXCEPTION + .layer(new ZeroPaddingLayer(1, 1)) + .layer(new Cropping2D(1, 1)) + .layer(new IdentityLayer()) + .layer(new Upsampling2D.Builder().size(2).build()) + .layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).build()) + .layer(new DepthwiseConvolution2D.Builder().nOut(3).activation(Activation.RELU).build()) + .layer(new SeparableConvolution2D.Builder().nOut(3).activation(Activation.HARDTANH).build()) + .layer(new MaskLayer()) + .layer(new BatchNormalization.Builder().build()) + .layer(new ActivationLayer(Activation.LEAKYRELU)) + .layer(secondLast) + .layer(ol) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.initGradientsView(); + assertEquals(msg, networkDtype, net.params().dataType()); + assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); + assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + + INDArray in = Nd4j.rand(networkDtype, 2, 28*28); + INDArray label; + if (outputLayer < 3) { + label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + } else if (outputLayer == 3) { + //CNN loss + label = Nd4j.rand(networkDtype, 2, 3, 28, 28); + } else if (outputLayer == 4) { + //YOLO + label = Nd4j.ones(networkDtype, 2, 6, 28, 28); + } else { + throw new IllegalStateException(); + } + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + assertEquals(s, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + log.info(msg + " - input/label type: " + inputLabelDtype); + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + } + + @Test @Ignore //TODO JVM CRASH + public void testDtypesModelVsGlobalDtypeCnn3d() { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (int outputLayer = 0; outputLayer < 3; outputLayer++) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; + + Layer ol; + Layer secondLast; + switch (outputLayer) { + case 0: + ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new GlobalPoolingLayer(PoolingType.AVG); + break; + case 1: + ol = new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new Convolution3D.Builder().nOut(3).activation(Activation.ELU).build(); + break; + case 2: + ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new Convolution3D.Builder().nOut(3).activation(Activation.ELU).build(); + break; + default: + throw new RuntimeException(); + } + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .convolutionMode(ConvolutionMode.Same) + .updater(new Nesterovs(1e-2, 0.9)) + .list() + .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) + .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) + .layer(new Subsampling3DLayer.Builder().poolingType(PoolingType.AVG).kernelSize(2, 2, 2).stride(2, 2, 2).build()) + .layer(new Cropping3D.Builder(1, 1, 1, 1, 1, 1).build()) + .layer(new ZeroPadding3DLayer.Builder(1, 1, 1, 1, 1, 1).build()) + .layer(new ActivationLayer(Activation.LEAKYRELU)) + .layer(new Upsampling3D.Builder().size(2).build()) + .layer(secondLast) + .layer(ol) + .setInputType(InputType.convolutional3D(Convolution3D.DataFormat.NCDHW, 28, 28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.initGradientsView(); + assertEquals(msg, networkDtype, net.params().dataType()); + assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); + assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + + INDArray in = Nd4j.rand(networkDtype, 2, 1, 28, 28, 28); + INDArray label; + if (outputLayer == 0) { + label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + } else if(outputLayer == 1){ + //CNN3D loss + label = Nd4j.rand(networkDtype, 2, 3, 28, 28, 28); + } else if(outputLayer == 2){ + label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + } else { + throw new RuntimeException(); + } + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + assertEquals(s, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + } + + @Test @Ignore //TODO TEMP - crashing + public void testDtypesModelVsGlobalDtypeCnn1d() { + //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); + + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (int outputLayer = 0; outputLayer < 3; outputLayer++) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; + + Layer ol; + Layer secondLast; + switch (outputLayer) { + case 0: + ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new GlobalPoolingLayer(PoolingType.MAX); + break; + case 1: + ol = new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(5).build(); + secondLast = new Convolution1D.Builder().kernelSize(2).nOut(5).build(); + break; + case 2: + ol = new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new Convolution1D.Builder().kernelSize(2).nOut(5).build(); + break; + default: + throw new RuntimeException(); + } + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .trainingWorkspaceMode(WorkspaceMode.NONE) + .inferenceWorkspaceMode(WorkspaceMode.NONE) + .dataType(networkDtype) + .convolutionMode(ConvolutionMode.Same) + .updater(new Adam(1e-2)) + .list() + .layer(new Convolution1D.Builder().kernelSize(2).stride(1).nOut(3).activation(Activation.TANH).build()) + .layer(new Subsampling1DLayer.Builder().poolingType(PoolingType.MAX).kernelSize(5).stride(1).build()) + .layer(new Cropping1D.Builder(1).build()) + .layer(new ZeroPadding1DLayer(1)) + .layer(new Upsampling1D.Builder(2).build()) + .layer(secondLast) + .layer(ol) + .setInputType(InputType.recurrent(5, 10)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.initGradientsView(); + assertEquals(msg, networkDtype, net.params().dataType()); + assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); + assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + + INDArray in = Nd4j.rand(networkDtype, 2, 5, 10); + INDArray label; + if (outputLayer == 0) { + //OutputLayer + label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + } else { + //RnnOutputLayer, RnnLossLayer + label = Nd4j.rand(networkDtype, 2, 5, 20); //Longer sequence due to upsampling + } + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + assertEquals(s, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + System.out.println(msg + " - " + inputLabelDtype); + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + } + + @Test + public void testDtypesModelVsGlobalDtypeMisc() { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype; + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .convolutionMode(ConvolutionMode.Same) + .updater(new Adam(1e-2)) + .list() + .layer(new SpaceToBatchLayer.Builder().blocks(1, 1).build()) + .layer(new SpaceToDepthLayer.Builder().blocks(2).build()) + .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.convolutional(28, 28, 5)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.initGradientsView(); + assertEquals(msg, networkDtype, net.params().dataType()); + assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); + assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + + INDArray in = Nd4j.rand(networkDtype, 2, 5, 28, 28); + INDArray label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + assertEquals(s, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + + @Test + public void testDtypesModelVsGlobalDtypeRnn() { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (int outputLayer = 0; outputLayer < 3; outputLayer++) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; + + Layer ol; + Layer secondLast; + switch (outputLayer) { + case 0: + ol = new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new LSTM.Builder().nOut(5).activation(Activation.TANH).build(); + break; + case 1: + ol = new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); + secondLast = new LSTM.Builder().nOut(5).activation(Activation.TANH).build(); + break; + case 2: + ol = new OutputLayer.Builder().nOut(5).build(); + secondLast = new LastTimeStep(new LSTM.Builder().nOut(5).activation(Activation.TANH).build()); + break; + default: + throw new RuntimeException(); + } + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .convolutionMode(ConvolutionMode.Same) + .updater(new Adam(1e-2)) + .list() + .layer(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) + .layer(new GravesLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) + .layer(new DenseLayer.Builder().nOut(5).build()) + .layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) + .layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())) + .layer(new SimpleRnn.Builder().nIn(10).nOut(5).build()) + .layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) + .layer(secondLast) + .layer(ol) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.initGradientsView(); + assertEquals(msg, networkDtype, net.params().dataType()); + assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); + assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + + INDArray in = Nd4j.rand(networkDtype, 2, 5, 4); + INDArray label; + if (outputLayer == 2) { + label = TestUtils.randomOneHot(2, 5).castTo(networkDtype); + } else { + label = TestUtils.randomOneHotTimeSeries(2, 5, 4).castTo(networkDtype); + } + + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + assertEquals(msg, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label, Nd4j.ones(networkDtype, 2, 4), outputLayer == 2 ? null :Nd4j.ones(networkDtype, 2, 4))); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + } + + @Test + public void testCapsNetDtypes(){ + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype; + + int primaryCapsDim = 2; + int primarpCapsChannel = 8; + int capsule = 5; + int minibatchSize = 8; + int routing = 1; + int capsuleDim = 4; + int height = 6; + int width = 6; + int inputDepth = 4; + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .seed(123) + .updater(new NoOp()) + .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) + .list() + .layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel) + .kernelSize(3, 3) + .stride(2, 2) + .build()) + .layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build()) + .layer(new CapsuleStrengthLayer.Builder().build()) + .layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()) + .layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()) + .setInputType(InputType.convolutional(height, width, inputDepth)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray in = Nd4j.rand(networkDtype, minibatchSize, inputDepth * height * width).mul(10) + .reshape(-1, inputDepth, height, width); + INDArray label = Nd4j.zeros(networkDtype, minibatchSize, capsule); + for (int i = 0; i < minibatchSize; i++) { + label.putScalar(new int[]{i, i % capsule}, 1.0); + } + + INDArray out = net.output(in); + assertEquals(msg, networkDtype, out.dataType()); + List ff = net.feedForward(in); + for (int i = 0; i < ff.size(); i++) { + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + assertEquals(s, networkDtype, ff.get(i).dataType()); + } + + net.setInput(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(in, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray in2 = in.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + + @Test + public void testEmbeddingDtypes(){ + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for(boolean frozen : new boolean[]{false, true}) { + for (int test = 0; test < 3; test++) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; + + ComputationGraphConfiguration.GraphBuilder conf = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .seed(123) + .updater(new NoOp()) + .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) + .graphBuilder() + .addInputs("in") + .setOutputs("out"); + + INDArray input; + if (test == 0) { + if(frozen) { + conf.layer("0", new FrozenLayer(new EmbeddingLayer.Builder().nIn(5).nOut(5).build()), "in"); + } else { + conf.layer("0", new EmbeddingLayer.Builder().nIn(5).nOut(5).build(), "in"); + } + input = Nd4j.rand(networkDtype, 10, 1).muli(5).castTo(DataType.INT); + conf.setInputTypes(InputType.feedForward(1)); + } else if(test == 1){ + if(frozen){ + conf.layer("0", new FrozenLayer(new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build()), "in"); + } else { + conf.layer("0", new EmbeddingSequenceLayer.Builder().nIn(5).nOut(5).build(), "in"); + } + conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.PNORM).pnorm(2).poolingDimensions(2).build(), "0"); + input = Nd4j.rand(networkDtype, 10, 1, 5).muli(5).castTo(DataType.INT); + conf.setInputTypes(InputType.recurrent(1)); + } else { + conf.layer("0", new RepeatVector.Builder().repetitionFactor(5).nOut(5).build(), "in"); + conf.layer("gp", new GlobalPoolingLayer.Builder(PoolingType.SUM).build(), "0"); + input = Nd4j.rand(networkDtype, 10, 5); + conf.setInputTypes(InputType.feedForward(5)); + } + + conf.appendLayer("el", new ElementWiseMultiplicationLayer.Builder().nOut(5).build()) + .appendLayer("ae", new AutoEncoder.Builder().nOut(5).build()) + .appendLayer("prelu", new PReLULayer.Builder().nOut(5).inputShape(5).build()) + .appendLayer("out", new OutputLayer.Builder().nOut(10).build()); + + ComputationGraph net = new ComputationGraph(conf.build()); + net.init(); + + INDArray label = Nd4j.zeros(networkDtype, 10, 10); + + INDArray out = net.outputSingle(input); + assertEquals(msg, networkDtype, out.dataType()); + Map ff = net.feedForward(input, false); + for (Map.Entry e : ff.entrySet()) { + if (e.getKey().equals("in")) + continue; + String s = msg + " - layer: " + e.getKey(); + assertEquals(s, networkDtype, e.getValue().dataType()); + } + + net.setInput(0, input); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new DataSet(input, label)); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray in2 = input.castTo(inputLabelDtype); + INDArray label2 = label.castTo(inputLabelDtype); + net.output(in2); + net.setInput(0, in2); + net.setLabels(label2); + net.computeGradientAndScore(); + + net.fit(new DataSet(in2, label2)); + } + } + } + } + } + } + + @Test + public void testVertexDtypes(){ + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + INDArray[] in = null; + for (int test = 0; test < 8; test++) { + String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; + + ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + .dataType(networkDtype) + .seed(123) + .updater(new NoOp()) + .weightInit(WeightInit.XAVIER) + .convolutionMode(ConvolutionMode.Same) + .graphBuilder(); + + switch (test){ + case 0: + b.addInputs("in") + .addLayer("l", new ConvolutionLayer.Builder().kernelSize(2,2).stride(1,1).nOut(1).build(), "in") + .addVertex("preproc", new PreprocessorVertex(new CnnToRnnPreProcessor(28, 28, 1)), "l") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "preproc") + .setInputTypes(InputType.convolutional(28, 28, 1)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 28, 28)}; + break; + case 1: + b.addInputs("in") + .addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in") + .addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2,2,2,2, true)), "l") + .addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0,2,3,4,1)), "preproc") + .addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2,2,2,2}, new long[]{16})), "preproc2") + .addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3") + .setInputTypes(InputType.feedForward(5)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5)}; + break; + case 2: + b.addInputs("in") + .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2,2).stride(1,1).nOut(1).build(), "in") + .addVertex("1a", new PoolHelperVertex(), "1") + .addVertex("2", new ShiftVertex(1), "1a") + .addVertex("3", new ScaleVertex(2), "2") + .addVertex("4", new ReshapeVertex(2, -1), "3") + .addVertex("5", new SubsetVertex(0, 99), "4") + .addVertex("6", new L2NormalizeVertex(), "5") + .addLayer("out", new OCNNOutputLayer.Builder().hiddenLayerSize(10).nIn(100).build(), "6") + .setInputTypes(InputType.convolutional(28, 28, 1)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 28, 28)}; + break; + case 3: + b.addInputs("in1", "in2", "in3") + .addVertex("1", new ElementWiseVertex(ElementWiseVertex.Op.Add), "in1", "in2") + .addVertex("2a", new UnstackVertex(0, 2), "1") + .addVertex("2b", new UnstackVertex(1, 2), "1") + .addVertex("3", new StackVertex(), "2a", "2b") + .addVertex("4", new DuplicateToTimeSeriesVertex("in3"), "3") + .addVertex("5", new ReverseTimeSeriesVertex(), "4") + .addLayer("6", new GlobalPoolingLayer(PoolingType.AVG), "5") + .addVertex("7", new LastTimeStepVertex("in3"), "in3") + .addVertex("8", new MergeVertex(), "6", "7") + .addVertex("9", new PreprocessorVertex(new ComposableInputPreProcessor()), "8") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "9") + .setInputTypes(InputType.feedForward(8), InputType.feedForward(8), InputType.recurrent(8)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 8), Nd4j.rand(networkDtype, 2, 8), Nd4j.rand(networkDtype, 2, 8, 5)}; + break; + case 4: + b.addInputs("in1", "in2") + .addLayer("1", new LSTM.Builder().nOut(8).build(), "in1") + .addVertex("preproc1", new PreprocessorVertex(new RnnToCnnPreProcessor(2,2,2)), "1") + .addVertex("preproc2", new PreprocessorVertex(new CnnToRnnPreProcessor(2,2,2)), "preproc1") + .addLayer("pool", new GlobalPoolingLayer(), "preproc2") + .addLayer("pool2", new GlobalPoolingLayer(), "in2") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "pool", "pool2") + .setInputTypes(InputType.recurrent(8), InputType.convolutional(28, 28, 1)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 8, 5), Nd4j.rand(networkDtype, 2, 1, 28, 28)}; + break; + case 5: + b.addInputs("in1", "in2") + .addVertex("fv", new FrozenVertex(new ScaleVertex(2.0)), "in1") + .addLayer("1", new DenseLayer.Builder().nOut(5).build(), "fv") + .addLayer("2", new DenseLayer.Builder().nOut(5).build(), "in2") + .addVertex("v", new L2Vertex(), "1", "2") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "v") + .setInputTypes(InputType.feedForward(5), InputType.feedForward(5)) + .setOutputs("out"); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5), Nd4j.rand(networkDtype, 2, 5)}; + break; + case 6: + b.addInputs("in") + .addLayer("1", new LSTM.Builder().nOut(5).build(), "in") + .addVertex("2", new PreprocessorVertex(new KerasFlattenRnnPreprocessor(5,4)), "1") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") + .setOutputs("out") + .setInputTypes(InputType.recurrent(5, 4)); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5, 4)}; + break; + case 7: + b.addInputs("in") + .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2,2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in") + .addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28,28,5)), "1") + .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") + .setOutputs("out") + .setInputTypes(InputType.convolutional(28, 28, 1)); + in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 28, 28)}; + break; + } + + ComputationGraph net = new ComputationGraph(b.build()); + net.init(); + + INDArray label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); + + INDArray out = net.outputSingle(in); + assertEquals(msg, networkDtype, out.dataType()); + Map ff = net.feedForward(in, false); + for (Map.Entry e : ff.entrySet()) { + if (e.getKey().equals("in")) + continue; + String s = msg + " - layer: " + e.getKey(); + assertEquals(s, networkDtype, e.getValue().dataType()); + } + + net.setInputs(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new MultiDataSet(in, new INDArray[]{label})); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray[] in2 = new INDArray[in.length]; + for( int i=0; i ff = net.feedForward(in, false); + for (Map.Entry e : ff.entrySet()) { + if (e.getKey().equals("in")) + continue; + String s = msg + " - layer: " + e.getKey(); + assertEquals(s, networkDtype, e.getValue().dataType()); + } + + net.setInputs(in); + net.setLabels(label); + net.computeGradientAndScore(); + + net.fit(new MultiDataSet(in, new INDArray[]{label})); + + logUsedClasses(net); + + //Now, test mismatched dtypes for input/labels: + for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + INDArray[] in2 = new INDArray[in.length]; + for( int i=0; i actTrainIntegrated = netIntegrated.feedForward(true); @@ -289,11 +289,13 @@ public void testDropoutLayerWithConvMnist() throws Exception { assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3)); + netIntegrated.setInput(next.getFeatures().dup()); + netSeparate.setInput(next.getFeatures().dup()); Nd4j.getRandom().setSeed(12345); List actTestIntegrated = netIntegrated.feedForward(false); Nd4j.getRandom().setSeed(12345); List actTestSeparate = netSeparate.feedForward(false); - assertEquals(actTestIntegrated.get(1), actTrainSeparate.get(1)); + assertEquals(actTestIntegrated.get(1), actTestSeparate.get(1)); assertEquals(actTestIntegrated.get(2), actTestSeparate.get(3)); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 2e490a18fb9e..2d746a1372ec 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -34,6 +34,7 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -66,7 +67,7 @@ public void testSetParams() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, - Collections.singletonList(new ScoreIterationListener(1)), 0, params, true); + Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); params = l.params(); l.setParams(params); assertEquals(params, l.params()); @@ -547,7 +548,7 @@ public void testCnnOutputLayerSoftmax(){ assertTrue(min >= 0 && max <= 1.0); INDArray sum = out.sum(1); - assertEquals(Nd4j.ones(2,4,5), sum); + assertEquals(Nd4j.ones(DataType.FLOAT,2,4,5), sum); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java index beeeef24fea6..2e37f728ec92 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java @@ -40,7 +40,7 @@ private Layer getRepeatVectorLayer() { NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) .layer(new RepeatVector.Builder(REPEAT).build()).build(); return conf.getLayer().instantiate(conf, null, 0, - null, false); + null, false, Nd4j.defaultFloatingPointType()); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index 49fd70f9046c..d56f167e5d18 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -50,7 +50,7 @@ public void testAutoEncoderSeed() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java index 75d3a5499132..d4be87a2a569 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java @@ -22,15 +22,12 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import java.util.Arrays; @@ -94,7 +91,7 @@ private Layer getConvolution3DLayer(ConvolutionMode mode) { .build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.ones(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true); + return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } public INDArray getData() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 8d267531f402..86278a79395c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -33,12 +33,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.convolution.Convolution; @@ -212,7 +209,7 @@ public void testCNNBiasInit() { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); assertEquals(1, layer.getParam("b").size(0)); } @@ -273,7 +270,7 @@ private static Layer getCNNConfig(int nIn, int nOut, int[] kernelSize, int[] str val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true); + return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } public Layer getMNISTConfig() { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java index bd430f3a4007..d0e937dbed3b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java @@ -61,7 +61,7 @@ private Layer getSpaceToDepthLayer() { NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true); + return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java index 134d8b4848d0..69f8c22dbd01 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java @@ -170,7 +170,7 @@ private Layer getSubsamplingLayer(SubsamplingLayer.PoolingType pooling) { .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true); + return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java index c486560164d9..abc6745fc136 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java @@ -108,7 +108,7 @@ private Layer getUpsampling1DLayer() { .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling1D.Builder(size).build()).build(); return conf.getLayer().instantiate(conf, null, 0, - null, true); + null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java index b34161659186..24bf3dd09ab6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java @@ -110,7 +110,7 @@ private Layer getUpsamplingLayer() { NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling2D.Builder(size).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true); + return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java index f4544cf0839b..333995706949 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -49,9 +50,9 @@ public CustomLayer(@JsonProperty("someCustomParameter") double someCustomParamet @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - CustomLayerImpl ret = new CustomLayerImpl(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java index e9a85b55587a..ffca8a5dc7ef 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java @@ -18,6 +18,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; +import org.nd4j.linalg.api.buffer.DataType; /** * @@ -26,8 +27,8 @@ * Created by Alex on 26/08/2016. */ public class CustomLayerImpl extends BaseLayer { - public CustomLayerImpl(NeuralNetConfiguration conf) { - super(conf); + public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java index 585e2bee641e..ff79adf0a3de 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -52,8 +53,8 @@ protected CustomOutputLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { - CustomOutputLayerImpl ret = new CustomOutputLayerImpl(conf); + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + CustomOutputLayerImpl ret = new CustomOutputLayerImpl(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java index 4d4df9dcb9a4..281eba7d82a6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java @@ -20,14 +20,15 @@ import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** * Created by Alex on 28/08/2016. */ public class CustomOutputLayerImpl extends BaseOutputLayer { - public CustomOutputLayerImpl(NeuralNetConfiguration conf) { - super(conf); + public CustomOutputLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index c6660202f240..1e9b7533a3aa 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -55,7 +55,7 @@ public void testDenseBiasInit() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); assertEquals(1, layer.getParam("b").size(0)); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index e0dfa0aa0942..2acb555a6f5c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -42,7 +42,6 @@ import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -106,7 +105,7 @@ public void testDnnForwardPass() { System.out.println(Arrays.toString(mean.data().asFloat())); assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f); - assertEquals(Nd4j.ones(1, nOut), stdev); + assertEquals(Nd4j.ones(nOut), stdev); //If we fix gamma/beta: expect different mean and variance... double gamma = 2.0; @@ -134,7 +133,7 @@ protected static Layer getLayer(int nOut, double epsilon, boolean lockGammaBeta, if (numParams > 0) { params = Nd4j.create(1, numParams); } - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); if (numParams > 0) { layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); } @@ -171,8 +170,8 @@ public void testDnnForwardBackward() { //Check backprop INDArray epsilon = Nd4j.rand(minibatch, nIn); //dL/dy - INDArray dldgammaExp = epsilon.mul(xHat).sum(0); - INDArray dldbetaExp = epsilon.sum(0); + INDArray dldgammaExp = epsilon.mul(xHat).sum(true, 0); + INDArray dldbetaExp = epsilon.sum(true, 0); INDArray dldxhat = epsilon.mulRowVector(gamma); INDArray dldvar = dldxhat.mul(input.subRowVector(mean)).mul(-0.5) @@ -316,7 +315,9 @@ public void testCnnForwardBackward() { int effectiveMinibatch = minibatch * hw * hw; INDArray dldgammaExp = epsilon.mul(xHat).sum(0, 2, 3); + dldgammaExp = dldgammaExp.reshape(1, dldgammaExp.length()); INDArray dldbetaExp = epsilon.sum(0, 2, 3); + dldbetaExp = dldbetaExp.reshape(1, dldbetaExp.length()); INDArray dldxhat = Nd4j.getExecutioner().exec(new BroadcastMulOp(epsilon, gamma, epsilon.dup(), 1)); //epsilon.mulRowVector(gamma); @@ -548,7 +549,7 @@ public void checkMeanVarianceEstimate() throws Exception { INDArray estVar; if(useLogStd){ INDArray log10std = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); - estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0); + estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); Transforms.pow(estVar, log10std, false); // stdev = 10^(log10(stdev)) estVar.muli(estVar); } else { @@ -613,7 +614,7 @@ public void checkMeanVarianceEstimateCNN() throws Exception { INDArray estVar; if(useLogStd){ INDArray log10std = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); - estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0); + estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); Transforms.pow(estVar, log10std, false); // stdev = 10^(log10(stdev)) estVar.muli(estVar); } else { @@ -675,7 +676,7 @@ public void checkMeanVarianceEstimateCNNCompareModes() throws Exception { INDArray globalVar = net.getParam("0_" + BatchNormalizationParamInitializer.GLOBAL_VAR); INDArray log10std = net2.getParam("0_" + BatchNormalizationParamInitializer.GLOBAL_LOG_STD); - INDArray globalVar2 = Nd4j.valueArrayOf(log10std.shape(), 10.0); + INDArray globalVar2 = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); Transforms.pow(globalVar2, log10std, false); // stdev = 10^(log10(stdev)) globalVar2.muli(globalVar2); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java index 4f484721abda..f88d999cb455 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java @@ -112,7 +112,7 @@ public void doBefore() { .layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()) .build(); - layer = new LocalResponseNormalization().instantiate(conf, null, 0, null, false); + layer = new LocalResponseNormalization().instantiate(conf, null, 0, null, false, Nd4j.defaultFloatingPointType()); activationsActual = layer.activate(x, false, LayerWorkspaceMgr.noWorkspaces()); } @@ -203,7 +203,7 @@ public void testLrnManual() { NeuralNetConfiguration nnc = new NeuralNetConfiguration.Builder().layer(lrn).build(); org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization layer = (org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization) lrn.instantiate(nnc, - null, 0, null, false); + null, 0, null, false, Nd4j.defaultFloatingPointType()); INDArray outAct = layer.activate(in, true, LayerWorkspaceMgr.noWorkspaces()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index 85da7f18d9ca..e700fada4744 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -189,6 +189,7 @@ private MultiLayerNetwork getSingleLayer() { public MultiLayerNetwork getGradientCheckNetwork(int numHidden) { MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(42).updater(new NoOp()).miniBatch(false) .list(new DenseLayer.Builder().activation(new ActivationIdentity()).nIn(4).nOut(4).build(), new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder().nIn(4) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index 887e125eab33..957d22a08b5b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -101,7 +101,7 @@ public void testMaskingRnn() { NDArrayIndex.interval(0, tsLength)); INDArray outSubset = net.output(inputSubset); - INDArray outputMaskedSubset = outputMasked.getRow(i); + INDArray outputMaskedSubset = outputMasked.getRow(i,true); assertEquals(outSubset, outputMaskedSubset); } @@ -286,7 +286,7 @@ public void testMaskingCnnDim3() { assertArrayEquals(new long[] {1, depthIn, height, width - i}, subset.shape()); INDArray outSubset = net.output(subset); - INDArray outMaskedSubset = outMasked.getRow(i); + INDArray outMaskedSubset = outMasked.getRow(i, true); assertEquals("minibatch: " + i, outSubset, outMaskedSubset); } @@ -345,7 +345,7 @@ public void testMaskingCnnDim2() { assertArrayEquals(new long[] {1, depthIn, height - i, width}, subset.shape()); INDArray outSubset = net.output(subset); - INDArray outMaskedSubset = outMasked.getRow(i); + INDArray outMaskedSubset = outMasked.getRow(i, true); assertEquals("minibatch: " + i, outSubset, outMaskedSubset); } @@ -410,7 +410,7 @@ public void testMaskingCnnDim23() { net.clear(); net.clearLayerMaskArrays(); INDArray outSubset = net.output(subset); - INDArray outMaskedSubset = outMasked.getRow(i); + INDArray outMaskedSubset = outMasked.getRow(i,true); assertEquals("minibatch: " + i + ", " + pt, outSubset, outMaskedSubset); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index 022f23d71494..f7a4c087f48c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -46,6 +46,7 @@ import org.deeplearning4j.util.TimeSeriesUtils; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.MultiDataSet; @@ -398,6 +399,7 @@ public void testSimpleBidirectional() { for (Bidirectional.Mode m : modes) { MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -411,6 +413,7 @@ public void testSimpleBidirectional() { net1.init(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Adam()) @@ -521,6 +524,7 @@ public void testSimpleBidirectionalCompGraph() { for (Bidirectional.Mode m : modes) { ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -536,6 +540,7 @@ public void testSimpleBidirectionalCompGraph() { net1.init(); ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Adam()) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index da3ef552253a..297067862c04 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -61,7 +61,7 @@ public void testBidirectionalLSTMGravesForwardBasic() { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM layer = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -108,7 +108,7 @@ private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHi long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); @@ -175,7 +175,7 @@ public void testGravesBidirectionalLSTMForwardPassHelper() throws Exception { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); final INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); @@ -234,7 +234,7 @@ public void testGetSetParmas() { long numParams = confBidirectional.getLayer().initializer().numParams(confBidirectional); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() - .instantiate(confBidirectional, null, 0, params, true); + .instantiate(confBidirectional, null, 0, params, true, params.dataType()); final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); @@ -280,9 +280,9 @@ public void testSimpleForwardsAndBackwardsActivation() { long numParamsBD = confBidirectional.getLayer().initializer().numParams(confBidirectional); INDArray paramsBD = Nd4j.create(1, numParamsBD); final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() - .instantiate(confBidirectional, null, 0, paramsBD, true); + .instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); final GravesLSTM forwardsLSTM = - (GravesLSTM) confForwards.getLayer().instantiate(confForwards, null, 0, params, true); + (GravesLSTM) confForwards.getLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); bidirectionalLSTM.setBackpropGradientsViewArray( Nd4j.create(1, confBidirectional.getLayer().initializer().numParams(confBidirectional))); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index 4b2e155ab425..a0fc0f99d13c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -62,7 +62,7 @@ public void testLSTMGravesForwardBasic() { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM layer = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + GravesLSTM layer = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -108,7 +108,7 @@ private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHi val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); @@ -159,7 +159,7 @@ public void testGravesLSTMForwardPassHelper() throws Exception { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true); + GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index b78c2e61c699..5ca613c590ba 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; @@ -161,7 +162,7 @@ public void testDropoutRecurrentLayers(){ assertEquals(s, net.params(), netD.params()); assertEquals(s, net.params(), netD2.params()); - INDArray f = Nd4j.rand(new int[]{3, 10, 10}); + INDArray f = Nd4j.rand(DataType.FLOAT, new int[]{3, 10, 10}); //Output: test mode -> no dropout INDArray out1 = net.output(f); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index ddc007b5b759..bf8b964b155e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; @@ -45,7 +46,7 @@ public void testSimpleRnn(){ int nIn = 5; int layerSize = 6; int tsLength = 7; - INDArray in = Nd4j.rand(new int[]{m, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); // in.get(all(), all(), interval(1,tsLength)).assign(0); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -108,6 +109,6 @@ public void testBiasInit(){ net.init(); INDArray bArr = net.getParam("0_b"); - assertEquals(Nd4j.valueArrayOf(new long[]{1,layerSize}, 100.0), bArr); + assertEquals(Nd4j.valueArrayOf(new long[]{1,layerSize}, 100.0f), bArr); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index 2736a6b798aa..a5ed47039ff7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; @@ -129,6 +130,7 @@ public void testSameDiffConvForward() { log.info("Starting test: " + msg); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .list() .layer(new SameDiffConv.Builder() @@ -161,6 +163,7 @@ public void testSameDiffConvForward() { assertNotNull(net.paramTable()); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER) .seed(12345) .list() @@ -260,6 +263,7 @@ public void testSameDiffConvGradient() { int outW = cm == ConvolutionMode.Same ? imgW : (imgW-2); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index 87966ad309e6..a0adf36fd031 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -390,6 +391,7 @@ public void gradientCheck() { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index a804cb8bc04b..4d7fca598dbb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -61,6 +62,7 @@ public void testSameDiffDenseVertex() { for (Activation a : afns) { log.info("Starting test - " + a + " - minibatch " + minibatch + ", workspaces: " + workspaces); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .updater(new Sgd(0.0)) @@ -77,6 +79,7 @@ public void testSameDiffDenseVertex() { netSD.init(); ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) // .updater(new Sgd(1.0)) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index 4b0c9a627fde..c96cf0ad8359 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaVertex; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -120,6 +121,7 @@ public void testSameDiffLamdaLayerBasic(){ public void testSameDiffLamdaVertexBasic(){ Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(new Adam(0.01)) .graphBuilder() @@ -134,6 +136,7 @@ public void testSameDiffLamdaVertexBasic(){ //Equavalent, not using SameDiff Lambda: ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .updater(new Adam(0.01)) .graphBuilder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java index 257a4cb949f5..af826b254cb5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java @@ -89,7 +89,7 @@ public void testOutputMSELossLayer(){ @Test - public void testMSEOutputLayer(){ + public void testMSEOutputLayer(){ //Faliing 2019/04/17 - https://github.com/deeplearning4j/deeplearning4j/issues/7560 Nd4j.getRandom().setSeed(12345); for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH, Activation.SOFTMAX}) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/iter/WSTestDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/iter/WSTestDataSetIterator.java index dd8f9dc2f9c9..cd2759f9ef39 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/iter/WSTestDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/iter/WSTestDataSetIterator.java @@ -50,7 +50,7 @@ public DataSet nextOne(){ return new DataSet( features, - vectors.getRow(7), + vectors.getRow(7, true), Nd4j.ones(1, 10), null ); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index d464ec2c6b68..95c04d154875 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; @@ -258,13 +259,13 @@ private static void testIrisMiniBatchGradients(int miniBatchSize, int[] hiddenLa INDArray[] layerActivations = new INDArray[nLayers]; for (int i = 0; i < nLayers; i++) { INDArray layerInput = (i == 0 ? x : layerActivations[i - 1]); - layerZs[i] = layerInput.mmul(layerWeights[i]).addiRowVector(layerBiases[i]); + layerZs[i] = layerInput.castTo(layerWeights[i].dataType()).mmul(layerWeights[i]).addiRowVector(layerBiases[i]); layerActivations[i] = (i == nLayers - 1 ? doSoftmax(layerZs[i].dup()) : doSigmoid(layerZs[i].dup())); } //Do backward pass: INDArray[] deltas = new INDArray[nLayers]; - deltas[nLayers - 1] = layerActivations[nLayers - 1].sub(y); //Out - labels; shape=[miniBatchSize,nOut]; + deltas[nLayers - 1] = layerActivations[nLayers - 1].sub(y.castTo(layerActivations[nLayers-1].dataType())); //Out - labels; shape=[miniBatchSize,nOut]; assertArrayEquals(deltas[nLayers - 1].shape(), new long[] {miniBatchSize, 3}); for (int i = nLayers - 2; i >= 0; i--) { INDArray sigmaPrimeOfZ; @@ -279,7 +280,7 @@ private static void testIrisMiniBatchGradients(int miniBatchSize, int[] hiddenLa for (int i = 0; i < nLayers; i++) { INDArray prevActivations = (i == 0 ? x : layerActivations[i - 1]); //Raw gradients, so not yet divided by mini-batch size (division is done in BaseUpdater) - dLdw[i] = deltas[i].transpose().mmul(prevActivations).transpose(); //Shape: [nIn, nOut] + dLdw[i] = deltas[i].transpose().castTo(prevActivations.dataType()).mmul(prevActivations).transpose(); //Shape: [nIn, nOut] dLdb[i] = deltas[i].sum(true, 0); //Shape: [1,nOut] int nIn = (i == 0 ? 4 : hiddenLayerSizes[i - 1]); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 53ba6b3be02a..a1512139fae9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -509,7 +509,7 @@ public void testScoreExamples() { assertArrayEquals(new long[] {3, 1}, scoresNoRegularization.shape()); for (int i = 0; i < 3; i++) { - DataSet singleEx = new DataSet(input.getRow(i), output.getRow(i)); + DataSet singleEx = new DataSet(input.getRow(i,true), output.getRow(i,true)); double score = net.score(singleEx); double scoreNoReg = netNoReg.score(singleEx); @@ -726,7 +726,7 @@ public void testApplyingPreTrainConfigAndParams() { assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer aePre.setParam("0_vb", Nd4j.ones(10)); params = aePre.getParam("0_vb"); - assertEquals(Nd4j.ones(10), params); // check set params for vb + assertEquals(Nd4j.ones(1,10), params); // check set params for vb // Test pretrain false, expect same for true because its not changed when applying update @@ -1136,7 +1136,7 @@ public void testExternalErrors() { Pair extErrorGrad = e.backpropGradient(olEpsilon, LayerWorkspaceMgr.noWorkspaces()); int nParamsDense = 10 * 10 + 10; - assertEquals(sGrad.gradient().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nParamsDense)), + assertEquals(sGrad.gradient().get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsDense)), extErrorGrad.getFirst().gradient()); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); @@ -1254,6 +1254,7 @@ public void testInputActivationGradient(){ Nd4j.setDataType(DataType.DOUBLE); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .activation(Activation.TANH) .list() @@ -1460,30 +1461,30 @@ public void testMLNUpdaterBlocks(){ //Initially updater view array is set out like: //[m0w, m0b, m1w, m1b, m2w, m2b][v0w, v0b, v1w, v1b, v2w, v2b] long soFar = 0; - INDArray m0w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+5*3)).assign(0); //m0w + INDArray m0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(0); //m0w soFar += 5*3; - INDArray m0b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+3)).assign(1); //m0b + INDArray m0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(1); //m0b soFar += 3; - INDArray m1w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+3*2)).assign(2); //m1w + INDArray m1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(2); //m1w soFar += 3*2; - INDArray m1b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b + INDArray m1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b soFar += 2; - INDArray m2w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+2*1)).assign(4); //m2w + INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(4); //m2w soFar += 2*1; - INDArray m2b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b + INDArray m2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b soFar += 1; - INDArray v0w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+5*3)).assign(6); //v0w + INDArray v0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(6); //v0w soFar += 5*3; - INDArray v0b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+3)).assign(7); //v0b + INDArray v0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(7); //v0b soFar += 3; - INDArray v1w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+3*2)).assign(8); //v1w + INDArray v1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(8); //v1w soFar += 3*2; - INDArray v1b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b + INDArray v1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b soFar += 2; - INDArray v2w = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+2*1)).assign(10); //v2w + INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(10); //v2w soFar += 2*1; - INDArray v2b = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b + INDArray v2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b soFar += 1; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 42c185cfb429..93e9bb9c7b0e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -314,7 +314,7 @@ public void testRnnTimeStepLayers() { INDArray expOutSubset; if (inLength == 1) { val sizes = new long[]{fullOutL3.size(0), fullOutL3.size(1), 1}; - expOutSubset = Nd4j.create(sizes); + expOutSubset = Nd4j.create(DataType.FLOAT, sizes); expOutSubset.tensorAlongDimension(0, 1, 0).assign(fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(startTimeRange))); } else { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java index 6fda7968fae1..2d4583a7b7fe 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java @@ -65,7 +65,7 @@ public void testSetParameters() { assertEquals(initParams, initParamsAfter); //Now, try the other way: get(set(random)) - INDArray randomParams = Nd4j.rand(initParams.shape()); + INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); net.setParams(randomParams.dup()); assertEquals(net.params(), randomParams); @@ -102,7 +102,7 @@ public void testSetParametersRNN() { assertEquals(initParams, initParamsAfter); //Now, try the other way: get(set(random)) - INDArray randomParams = Nd4j.rand(initParams.shape()); + INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); net.setParams(randomParams.dup()); assertEquals(net.params(), randomParams); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 09c13cd0b928..959ccbd2269b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -543,7 +543,7 @@ public void testMaskingLstmAndBidirectionalLstmGlobalPooling() { NDArrayIndex.interval(0, tsLength - i)}; INDArray inputSubset = input.get(idx); INDArray expExampleOut = net.output(inputSubset); - INDArray actualExampleOut = outMasked.getRow(i); + INDArray actualExampleOut = outMasked.getRow(i, true); // System.out.println(i); assertEquals(expExampleOut, actualExampleOut); } @@ -554,7 +554,7 @@ public void testMaskingLstmAndBidirectionalLstmGlobalPooling() { for (int i = 0; i < minibatch; i++) { INDArrayIndex[] idx = new INDArrayIndex[] {NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.interval(0, tsLength - i)}; - DataSet dsSingle = new DataSet(input.get(idx), labels.getRow(i)); + DataSet dsSingle = new DataSet(input.get(idx), labels.getRow(i,true)); INDArray exampleSingleScore = net.scoreExamples(dsSingle, false); double exp = exampleSingleScore.getDouble(0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java index d487f2b803b6..1cde13166c37 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java @@ -50,7 +50,7 @@ public void testRenormalizatonPerLayer() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -96,11 +96,11 @@ public void testRenormalizationPerParamType() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20); - INDArray biasGrad = Nd4j.rand(1, 10); + INDArray biasGrad = Nd4j.rand(1, 20); INDArray weightGradCopy = weightGrad.dup(); INDArray biasGradCopy = biasGrad.dup(); Gradient gradient = new DefaultGradient(); @@ -129,7 +129,7 @@ public void testAbsValueClippingPerElement() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -185,7 +185,7 @@ public void testL2ClippingPerLayer() { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(t == 0 ? 0.05 : 10).subi(t == 0 ? 0 : 5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = @@ -240,11 +240,11 @@ public void testL2ClippingPerParamType() { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20).muli(0.05); - INDArray biasGrad = Nd4j.rand(1, 10).muli(10); + INDArray biasGrad = Nd4j.rand(1, 20).muli(10); INDArray weightGradCopy = weightGrad.dup(); INDArray biasGradCopy = biasGrad.dup(); Gradient gradient = new DefaultGradient(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index cdbfda08d19b..f5049d21174d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -37,9 +37,9 @@ import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.*; import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -93,7 +93,7 @@ public void testAdaDeltaUpdate() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -159,7 +159,7 @@ public void testAdaGradUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -203,7 +203,7 @@ public void testAdamUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -266,7 +266,7 @@ public void testNadamUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -357,7 +357,7 @@ public void testAdaMaxUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -413,7 +413,7 @@ public void testNestorovsUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -460,7 +460,7 @@ public void testRMSPropUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -507,7 +507,7 @@ public void testSGDUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -541,7 +541,7 @@ public void testNoOpUpdater() { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true); + Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -637,8 +637,8 @@ public void testMultiLayerUpdater() throws Exception { for (int j = 0; j < net.getnLayers(); j++) { //Generate test gradient: - INDArray wGrad = Nd4j.rand(nIns[j], nOuts[j]); - INDArray bGrad = Nd4j.rand(1, nOuts[j]); + INDArray wGrad = Nd4j.rand(DataType.FLOAT, nIns[j], nOuts[j]); + INDArray bGrad = Nd4j.rand(DataType.FLOAT, 1, nOuts[j]); String wKey = j + "_" + DefaultParamInitializer.WEIGHT_KEY; String bKey = j + "_" + DefaultParamInitializer.BIAS_KEY; @@ -747,7 +747,7 @@ public void testPretrain() { .build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -792,7 +792,7 @@ public void testPretrain() { gradientCopyPreUpdate.setFlattenedGradient(g); params = Nd4j.create(1, numParams); - layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); updater = UpdaterCreator.getUpdater(layer); assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); @@ -1037,17 +1037,17 @@ public void testDivisionByMinibatch2(){ INDArray view = ((BaseMultiLayerUpdater) net.getUpdater()).getFlattenedGradientsView(); view.assign(Nd4j.linspace(1, view.length(), view.length(), Nd4j.dataType())); - INDArray expView1 = view.get(point(0), interval(0, 10*9 + 9 + 2*9)); + INDArray expView1 = view.get(interval(0,0,true), interval(0, 10*9 + 9 + 2*9)); assertEquals(expView1, l.get(0)); long start2 = (10*9 + 9 + 2*9) + 2*9; long length2 = 9*8 + 8 + 2*8; - INDArray expView2 = view.get(point(0), interval(start2, start2 + length2)); + INDArray expView2 = view.get(interval(0,0,true), interval(start2, start2 + length2)); assertEquals(expView2, l.get(1)); long start3 = start2 + length2 + 2*8; long length3 = 8*7 + 7; - INDArray expView3 = view.get(point(0), interval(start3, start3 + length3)); + INDArray expView3 = view.get(interval(0,0,true), interval(start3, start3 + length3)); assertEquals(expView3, l.get(2)); } @@ -1093,17 +1093,17 @@ public void testDivisionByMinibatch3() throws Exception{ INDArray view = ((BaseMultiLayerUpdater) net.getUpdater()).getFlattenedGradientsView(); view.assign(Nd4j.linspace(1, view.length(), view.length(), Nd4j.dataType())); - INDArray expView1 = view.get(point(0), interval(0, 2*6)); + INDArray expView1 = view.get(interval(0,0,true), interval(0, 2*6)); assertEquals(expView1, l.get(0)); long start2 = 2*6 + 2*6; long length2 = 6*5*2*2 + 5 + 2*5; - INDArray expView2 = view.get(point(0), interval(start2, start2 + length2)); + INDArray expView2 = view.get(interval(0,0,true), interval(start2, start2 + length2)); assertEquals(expView2, l.get(1)); long start3 = start2 + length2 + 2*5; long length3 = 5*4*2*2 + 4 + 2*4; - INDArray expView3 = view.get(point(0), interval(start3, start3 + length3)); + INDArray expView3 = view.get(interval(0,0,true), interval(start3, start3 + length3)); assertEquals(expView3, l.get(2)); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java index ee104b3d0e6a..44e5d96dd690 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java @@ -7,6 +7,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -24,7 +25,7 @@ public class WeightInitIdentityTest { */ @Test public void testIdConv1D() { - final INDArray input = Nd4j.randn(new long[] {1,5,7}); + final INDArray input = Nd4j.randn(DataType.FLOAT, 1,5,7); final String inputName = "input"; final String conv = "conv"; final String output = "output"; @@ -51,7 +52,7 @@ public void testIdConv1D() { */ @Test public void testIdConv2D() { - final INDArray input = Nd4j.randn(1,5,7,11); + final INDArray input = Nd4j.randn(DataType.FLOAT,1,5,7,11); final String inputName = "input"; final String conv = "conv"; final String output = "output"; @@ -78,7 +79,7 @@ public void testIdConv2D() { */ @Test public void testIdConv3D() { - final INDArray input = Nd4j.randn(1,5,7,11,13); + final INDArray input = Nd4j.randn(DataType.FLOAT, 1,5,7,11,13); final String inputName = "input"; final String conv = "conv"; final String output = "output"; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java index 7de4cd3a5bb2..6975de250b34 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java @@ -40,7 +40,6 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Collections; @@ -173,7 +172,7 @@ private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunct val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return (OutputLayer) conf.getLayer().instantiate(conf, null, 0, params, true); + return (OutputLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } /////////////////////////////////////////////////////////////////////////// diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index dfbe6e2c0a44..3f4c20575aa2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -89,9 +89,9 @@ public void regressionTestMLP1() throws Exception { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new Nesterovs().stateSize(net.numParams()); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -129,9 +129,9 @@ public void regressionTestMLP2() throws Exception { assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1)); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -174,8 +174,8 @@ public void regressionTestCNN1() throws Exception { assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index 0a8ef698c8c3..cbb72d0a9290 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -94,9 +94,9 @@ public void regressionTestMLP1() throws Exception { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new Nesterovs().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -138,9 +138,9 @@ public void regressionTestMLP2() throws Exception { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -185,9 +185,9 @@ public void regressionTestCNN1() throws Exception { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index f77bc54948aa..35999283e682 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -94,9 +94,9 @@ public void regressionTestMLP1() throws Exception { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); long numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); int updaterSize = (int) new Nesterovs().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -138,9 +138,9 @@ public void regressionTestMLP2() throws Exception { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); long numParams = net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -185,9 +185,9 @@ public void regressionTestCNN1() throws Exception { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); long numParams = net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index abbcd0d9f8d1..9d14d7e199f3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -99,9 +99,9 @@ public void regressionTestMLP1() throws Exception { assertEquals(0.15, n.getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new Nesterovs().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -149,9 +149,9 @@ public void regressionTestMLP2() throws Exception { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test @@ -201,9 +201,9 @@ public void regressionTestCNN1() throws Exception { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); int updaterSize = (int) new RmsProp().stateSize(numParams); - assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()), net.getUpdater().getStateViewArray()); + assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index c021db43d849..d5cb6e59bc92 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -16,6 +16,7 @@ package org.deeplearning4j.regressiontest; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; @@ -47,6 +48,7 @@ import static org.junit.Assert.*; +@Slf4j public class RegressionTest100a extends BaseDL4JTest { @Override @@ -248,9 +250,14 @@ public void testYoloHouseNumber() throws Exception { } } - INDArray outAct = net.outputSingle(in); + INDArray outAct = net.outputSingle(in).castTo(outExp.dataType()); - assertEquals(outExp, outAct.castTo(outExp.dataType())); + boolean eq = outExp.equalsWithEps(outAct, 1e-4); + if(!eq){ + log.info("Expected: {}", outExp); + log.info("Actual: {}", outAct); + } + assertTrue("Output not equal", eq); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 94a9fae7e68f..7627b3091811 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -18,14 +18,14 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; @@ -46,7 +46,8 @@ import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; public class RegressionTest100b3 extends BaseDL4JTest { @@ -104,6 +105,8 @@ public void testCustomLayer() throws Exception { List activations = net.feedForward(in); + assertEquals(dt, net.getLayerWiseConfigurations().getDataType()); + assertEquals(dt, net.params().dataType()); assertEquals(dtype, outExp, outAct); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java index 5ee60fd211dc..20ccac1fb3c7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -67,12 +68,12 @@ public void setSecondActivationFunction(IActivation secondActivationFunction) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { //The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class // (i.e., a CustomLayerImpl instance) //For the most part, it's the same for each type of layer - CustomLayerImpl myCustomLayer = new CustomLayerImpl(conf); + CustomLayerImpl myCustomLayer = new CustomLayerImpl(conf, networkDataType); myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any myCustomLayer.setIndex(layerIndex); //Integer index of the layer diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java index 195641c60545..e6c4260fd543 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -35,8 +36,8 @@ */ public class CustomLayerImpl extends BaseLayer { //Generic parameter here: the configuration class type - public CustomLayerImpl(NeuralNetConfiguration conf) { - super(conf); + public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index b7bea40836ab..b961153dd5ca 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -150,6 +150,7 @@ public void testCompareMlpTrainingIris(){ //Create equivalent DL4J net MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(12345) .l1(l1Val).l2(l2Val) .l1Bias(l1Val).l2Bias(l2Val) @@ -213,6 +214,7 @@ public void testCompareMlpTrainingIris(){ //Check training with updater mlc = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(12345) .l1(l1Val).l2(l2Val) .l1Bias(l1Val).l2Bias(l2Val) diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java index a23224e570f0..25f26a69cd7c 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/BaseCudnnHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.*; import org.nd4j.jita.allocator.impl.AtomicAllocator; @@ -179,15 +180,24 @@ public TensorArray(TensorArray a) { protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; - protected int dataType = Nd4j.dataType() == DataType.DOUBLE ? CUDNN_DATA_DOUBLE - : Nd4j.dataType() == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF; - protected int dataTypeSize = - Nd4j.dataType() == DataType.DOUBLE ? 8 : Nd4j.dataType() == DataType.FLOAT ? 4 : 2; + protected final DataType nd4jDataType; + protected final int dataType; + protected final int dataTypeSize; // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta - protected Pointer alpha = dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f); - protected Pointer beta = dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f); + protected final Pointer alpha; + protected final Pointer beta; protected SizeTPointer sizeInBytes = new SizeTPointer(1); + public BaseCudnnHelper(@NonNull DataType dataType){ + this.nd4jDataType = dataType; + this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE + : dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF; + this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2; + // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta + this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f); + this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f); + } + public static int toCudnnDataType(DataType type){ switch (type){ case DOUBLE: diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java index be502474a369..fd5fc373a464 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/CudnnConvolutionHelper.java @@ -37,6 +37,7 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.shape.Shape; @@ -67,6 +68,10 @@ @Slf4j public class CudnnConvolutionHelper extends BaseCudnnHelper implements ConvolutionHelper { + public CudnnConvolutionHelper(DataType dataType) { + super(dataType); + } + private static class CudnnConvolutionContext extends CudnnContext { private static class Deallocator extends CudnnConvolutionContext implements Pointer.Deallocator { @@ -252,7 +257,7 @@ public Pair backpropGradient(INDArray input, INDArray weight log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", mode, fa, da); } - INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c'); + INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c'); val dstStride = epsNext.stride(); @@ -363,7 +368,7 @@ public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); - INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]}); + INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]}); code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]); @@ -626,12 +631,12 @@ public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, input.size(3) + (manualPadRight ? 1 : 0)}; INDArray newInput; if(poolingType == null || poolingType != PoolingType.MAX){ - newInput = Nd4j.create(newShape); + newInput = Nd4j.create(input.dataType(), newShape); } else { //For max pooling, we don't want to include the padding in the maximum values. But, CuDNN doesn't knowm // that these values are padding and hence should be excluded. Instead: We'll use -infinity so that, // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value - newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY); + newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType()); } newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), interval(0, input.size(3))}, input); diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java index 6c5727a3bdbb..aacf19f680b2 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java @@ -28,6 +28,7 @@ import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.shape.Shape; @@ -56,6 +57,10 @@ @Slf4j public class CudnnSubsamplingHelper extends BaseCudnnHelper implements SubsamplingHelper { + public CudnnSubsamplingHelper(DataType dataType) { + super(dataType); + } + private static class CudnnSubsamplingContext extends CudnnContext { private static class Deallocator extends CudnnSubsamplingContext implements Pointer.Deallocator { @@ -167,7 +172,7 @@ public Pair backpropGradient(INDArray input, INDArray epsilo checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], kernel[1], pad[0], pad[1], strides[0], strides[1])); - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); val dstStride = outEpsilon.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, @@ -243,7 +248,7 @@ public INDArray activate(INDArray input, boolean training, int[] kernel, int[] s checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); - INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c'); + INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c'); val dstStride = reduced.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW, diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/dropout/CudnnDropoutHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/dropout/CudnnDropoutHelper.java index 0bfb9e683977..1fc258afceb3 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/dropout/CudnnDropoutHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/dropout/CudnnDropoutHelper.java @@ -25,6 +25,7 @@ import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.context.CudaContext; @@ -110,6 +111,10 @@ protected void destroyHandles() { private SizeTPointer reserveSizeBytesPtr; private float lastInitializedP; + public CudnnDropoutHelper(DataType dataType){ + super(dataType); + } + @Override public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) { float p = (float)(1.0 - dropoutInputRetainProb); //CuDNN uses p = probability of setting to 0. We use p = probability of retaining diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java index 405306a3fbad..4298da6d7e5d 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java @@ -16,12 +16,9 @@ package org.deeplearning4j.nn.layers.normalization; -import lombok.Getter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.indexer.DoubleBufferIndexer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseCudnnHelper; @@ -31,26 +28,22 @@ import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.factory.NDArrayFactory; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.jcublas.JCublasNDArray; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.bytedeco.cuda.cudart.*; import org.bytedeco.cuda.cudnn.*; -import static org.bytedeco.cuda.global.cudart.*; + import static org.bytedeco.cuda.global.cudnn.*; /** @@ -61,6 +54,10 @@ @Slf4j public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper { + public CudnnBatchNormalizationHelper(DataType dataType) { + super(dataType); + } + private static class CudnnBatchNormalizationContext extends CudnnContext { private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator { @@ -134,7 +131,7 @@ public Pair backpropGradient(INDArray input, INDArray epsilo val inH = (int) input.size(2); val inW = (int) input.size(3); - final boolean isHalf = (Nd4j.dataType() == DataType.HALF); + final boolean isHalf = (input.dataType() == DataType.HALF); INDArray gammaOrig = null; INDArray dGammaViewOrig = null; INDArray dBetaViewOrig = null; @@ -171,7 +168,7 @@ public Pair backpropGradient(INDArray input, INDArray epsilo checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); - INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); + INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, @@ -220,7 +217,7 @@ public Pair backpropGradient(INDArray input, INDArray epsilo public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { this.eps = eps; - final boolean isHalf = (Nd4j.dataType() == DataType.HALF); + final boolean isHalf = (x.dataType() == DataType.HALF); INDArray origGamma = gamma; INDArray origBeta = beta; INDArray origMean = mean; @@ -249,7 +246,7 @@ public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray ga checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); - INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[] {miniBatch, inDepth, inH, inW}, 'c'); + INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); val dstStride = ArrayUtil.toInts(activations.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, @@ -274,16 +271,20 @@ public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray ga checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); if (training) { if(meanCache == null || meanCache.length() < mean.length()){ - meanCache = Nd4j.createUninitializedDetached((int)mean.length()); - if(Nd4j.dataType() == DataType.HALF){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + meanCache = Nd4j.createUninitialized(x.dataType(), mean.length()); + } + if(x.dataType() == DataType.HALF){ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { meanCache = meanCache.castTo(DataType.FLOAT); } } } if(varCache == null || varCache.length() < mean.length()){ - varCache = Nd4j.createUninitializedDetached((int)mean.length()); - if(Nd4j.dataType() == DataType.HALF){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + varCache = Nd4j.createUninitialized(x.dataType(), mean.length()); + } + if(nd4jDataType == DataType.HALF){ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { varCache = varCache.castTo(DataType.FLOAT); } @@ -325,8 +326,8 @@ public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray ga } @Override - public INDArray getMeanCache() { - if(Nd4j.dataType() == DataType.HALF){ + public INDArray getMeanCache(DataType dataType) { + if(dataType == DataType.HALF){ //Buffer is FP32 return meanCache.castTo(DataType.HALF); } @@ -334,15 +335,15 @@ public INDArray getMeanCache() { } @Override - public INDArray getVarCache() { + public INDArray getVarCache(DataType dataType) { INDArray ret; - if(Nd4j.dataType() == DataType.HALF){ + if(dataType == DataType.HALF){ INDArray vc = varCache.castTo(DataType.HALF); ret = vc.mul(vc).rdivi(1.0).subi(eps); } else { ret = varCache.mul(varCache).rdivi(1.0).subi(eps); } - if(Nd4j.dataType() == DataType.HALF){ + if(dataType == DataType.HALF){ //Buffer is FP32 return ret.castTo(DataType.HALF); } diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper.java index bd700abd00e0..7c47dbd2b402 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper.java @@ -26,6 +26,7 @@ import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.shape.Shape; @@ -52,6 +53,10 @@ @Slf4j public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper { + public CudnnLocalResponseNormalizationHelper(DataType dataType) { + super(dataType); + } + private static class CudnnLocalResponseNormalizationContext extends CudnnContext { private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator { @@ -152,7 +157,7 @@ public Pair backpropGradient(INDArray input, INDArray epsilo deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3])); checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k)); - INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[] {miniBatch, depth, inH, inW}, 'c'); + INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, @@ -195,7 +200,7 @@ public INDArray activate(INDArray input, boolean training, double k, double n, d checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); - activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[] {miniBatch, inDepth, inH, inW}, 'c'); + activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); val dstStride = ArrayUtil.toInts(activations.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper.java index 84131d7a8d18..d3f42e0691ef 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -55,6 +56,10 @@ @Slf4j public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper { + public CudnnLSTMHelper(DataType dataType) { + super(dataType); + } + private static class CudnnLSTMContext extends CudnnContext { private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator { @@ -218,7 +223,7 @@ public Pair backpropGradient(final NeuralNetConfiguration co INDArray x = toCOrder(input.permute(2, 0, 1)); INDArray dy = toCOrder(epsilon.permute(2, 0, 1)); - INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c'); + INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c'); INDArray iwGradientsOut = gradientViews.get(inputWeightKey); INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G} @@ -394,10 +399,10 @@ public FwdPassReturn activate(final Layer layer, final NeuralNetConfiguration co INDArray prevMemCell = toCOrder(prevMemCellState); INDArray outputActivations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, - new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c'); - INDArray finalMemCellState = Nd4j.createUninitialized( + inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c'); + INDArray finalMemCellState = Nd4j.createUninitialized( inputWeights.dataType(), new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c'); - INDArray finalStepActivations = Nd4j.createUninitialized( + INDArray finalStepActivations = Nd4j.createUninitialized( inputWeights.dataType(), new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c'); FwdPassReturn toReturn = new FwdPassReturn(); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestDataTypes.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestDataTypes.java index da2ac9cce0aa..693260e5fab7 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestDataTypes.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestDataTypes.java @@ -79,56 +79,60 @@ public void testDataTypesSimple() throws Exception { Map outMapTrain = new HashMap<>(); Map outMapTest = new HashMap<>(); - for(DataType type : new DataType[]{DataType.HALF, DataType.FLOAT, DataType.DOUBLE}) { - log.info("Starting test: {}", type); - Nd4j.setDataType(type); - assertEquals(type, Nd4j.dataType()); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .convolutionMode(ConvolutionMode.Same) - .activation(Activation.TANH) - .seed(12345) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nOut(3).build()) - .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).build()) - .layer(new BatchNormalization()) - .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nOut(3).build()) - .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - - Field f1 = org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.class.getDeclaredField("helper"); - f1.setAccessible(true); - - Field f2 = org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer.class.getDeclaredField("helper"); - f2.setAccessible(true); - - Field f3 = org.deeplearning4j.nn.layers.normalization.BatchNormalization.class.getDeclaredField("helper"); - f3.setAccessible(true); - - assertNotNull(f1.get(net.getLayer(0))); - assertNotNull(f2.get(net.getLayer(1))); - assertNotNull(f3.get(net.getLayer(2))); - assertNotNull(f1.get(net.getLayer(3))); - - DataSet ds = new MnistDataSetIterator(32, true, 12345).next(); - - //Simple sanity checks: - //System.out.println("STARTING FIT"); - net.fit(ds); - net.fit(ds); - - //System.out.println("STARTING OUTPUT"); - INDArray outTrain = net.output(ds.getFeatures(), false); - INDArray outTest = net.output(ds.getFeatures(), true); - - outMapTrain.put(type, outTrain.castTo(DataType.DOUBLE)); - outMapTest.put(type, outTest.castTo(DataType.DOUBLE)); + for(DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + for(DataType netDType : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + log.info("Starting test: global dtype = {}, net dtype = {}", globalDtype, netDType); + assertEquals(globalDtype, Nd4j.dataType()); + assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(netDType) + .convolutionMode(ConvolutionMode.Same) + .activation(Activation.TANH) + .seed(12345) + .weightInit(WeightInit.XAVIER) + .list() + .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nOut(3).build()) + .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).build()) + .layer(new BatchNormalization.Builder().eps(1e-3).build()) + .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nOut(3).build()) + .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + + Field f1 = org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.class.getDeclaredField("helper"); + f1.setAccessible(true); + + Field f2 = org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer.class.getDeclaredField("helper"); + f2.setAccessible(true); + + Field f3 = org.deeplearning4j.nn.layers.normalization.BatchNormalization.class.getDeclaredField("helper"); + f3.setAccessible(true); + + assertNotNull(f1.get(net.getLayer(0))); + assertNotNull(f2.get(net.getLayer(1))); + assertNotNull(f3.get(net.getLayer(2))); + assertNotNull(f1.get(net.getLayer(3))); + + DataSet ds = new MnistDataSetIterator(32, true, 12345).next(); + + //Simple sanity checks: + //System.out.println("STARTING FIT"); + net.fit(ds); + net.fit(ds); + + //System.out.println("STARTING OUTPUT"); + INDArray outTrain = net.output(ds.getFeatures(), false); + INDArray outTest = net.output(ds.getFeatures(), true); + + outMapTrain.put(netDType, outTrain.castTo(DataType.DOUBLE)); + outMapTest.put(netDType, outTest.castTo(DataType.DOUBLE)); + } } Nd4j.setDataType(DataType.DOUBLE); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java index 81987017f7fd..ce1dfc9e09f2 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java @@ -27,11 +27,13 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.CuDNNValidationUtil; +import org.junit.BeforeClass; import org.junit.Test; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.dataset.DataSet; @@ -60,9 +62,10 @@ public void validateConvLayers() { int channels = 3; IActivation activation = new ActivationIdentity(); MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(42) .activation(new ActivationELU()) - .updater(new Nesterovs(1e-2, 0.9)) + .updater(new Nesterovs(1e-3, 0.9)) .list( new Convolution2D.Builder().nOut(96) .kernelSize(11, 11).biasInit(0.0) @@ -132,6 +135,7 @@ public void validateConvLayersSimpleBN() { int channels = 3; IActivation activation = new ActivationIdentity(); MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(42) .activation(new ActivationELU()) .updater(Nesterovs.builder() @@ -187,6 +191,7 @@ public void validateConvLayersLRN() { int channels = 3; IActivation activation = new ActivationIdentity(); MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(42) .activation(new ActivationELU()) .updater(Nesterovs.builder() diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java index dc96d33e3f94..7195b6d5aa34 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java @@ -43,6 +43,7 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -198,8 +199,9 @@ public void validateXceptionImport() throws Exception { int inSize = 256; ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false); + model = model.convertDataType(DataType.DOUBLE); - INDArray in = Nd4j.rand(new int[]{1, 3, inSize, inSize}); + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize}); CuDNNTestUtils.assertHelpersPresent(model.getLayers()); Map withCudnn = model.feedForward(in, false); @@ -261,7 +263,7 @@ public void testGradientNorm() throws Exception { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0.0, 0.01)) .activation(Activation.RELU) .updater(new Adam(5e-3)) diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index b4afadb36afa..ede6bb276fa0 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -86,6 +86,7 @@ public void testGradientCNNMLN() { Activation outputActivation = outputActivations[i]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .weightInit(WeightInit.XAVIER).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn) @@ -171,6 +172,7 @@ public void testGradientCNNL1L2MLN() { double l1 = l1vals[k]; MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .l2(l2).l1(l1).l2Bias(biasL2[k]).l1Bias(biasL1[k]) .optimizationAlgo( OptimizationAlgorithm.CONJUGATE_GRADIENT) @@ -258,6 +260,7 @@ public void testCnnWithSpaceToDepth() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).hasBias(false) @@ -321,6 +324,7 @@ public void testCnnWithSpaceToBatch() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth) @@ -387,6 +391,7 @@ public void testCnnWithUpsampling() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(new ConvolutionLayer.Builder(kernel, @@ -456,7 +461,7 @@ public void testCnnWithSubsampling() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, new ConvolutionLayer.Builder(kernel, @@ -528,6 +533,7 @@ public void testCnnWithSubsamplingV2() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, new ConvolutionLayer.Builder(kernel, @@ -592,6 +598,7 @@ public void testCnnMultiLayer() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) + .dataType(DataType.DOUBLE) .activation(afn) .list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) @@ -659,6 +666,7 @@ public void testCnnSamePaddingMode() { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(Same).list() .layer(0, new ConvolutionLayer.Builder().name("layer 0").kernelSize(k, k) @@ -733,6 +741,7 @@ public void testCnnSamePaddingModeStrided() { .stride(stride, stride).padding(0, 0).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(Same).list() .layer(0, convFirst ? convLayer : poolLayer) @@ -796,6 +805,7 @@ public void testCnnZeroPaddingLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).list() .layer(0, new ConvolutionLayer.Builder(kernel, stride, padding) .cudnnAllowFallback(false) @@ -874,6 +884,7 @@ public void testDeconvolution2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(act) .list() @@ -951,6 +962,7 @@ public void testDepthwiseConv2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) .convolutionMode(cm) @@ -1024,6 +1036,7 @@ public void testSeparableConv2D() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) .convolutionMode(cm) @@ -1099,6 +1112,7 @@ public void testCnnDilated() { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(cm).list() .layer(new ConvolutionLayer.Builder().name("layer 0") @@ -1176,6 +1190,7 @@ public void testCropping2DLayer() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .weightInit(new NormalDistribution(0, 1)).list() diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java index 3ef0b25ed4ea..79e147cf09e1 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java @@ -109,6 +109,7 @@ public void testConvolutional() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) .dist(new UniformDistribution(-1, 1)) .updater(new NoOp()).seed(12345L).list() @@ -202,6 +203,7 @@ public void testConvolutionalNoBias() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .dist(new UniformDistribution(-1, 1)) .updater(new NoOp()).seed(12345L) .list() @@ -265,6 +267,7 @@ public void testBatchNormCnn() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) @@ -323,6 +326,7 @@ public void testLRN() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new ConvolutionLayer.Builder().nOut(6).kernelSize(2, 2).stride(1, 1) @@ -380,6 +384,7 @@ public void testLSTM() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new LSTM.Builder().nIn(input.size(1)).nOut(lstmLayerSize) @@ -437,6 +442,7 @@ public void testLSTM2() throws Exception { } MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .updater(new NoOp()).seed(12345L) .dist(new NormalDistribution(0, 2)).list() .layer(0, new LSTM.Builder().nIn(input.size(1)).nOut(lstmLayerSize) @@ -513,6 +519,7 @@ public void testCnnDilated() throws Exception { } NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(cm).list() .layer(new ConvolutionLayer.Builder().name("layer 0") @@ -589,7 +596,7 @@ public void testDropout() { NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() .seed(12345) - + .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .convolutionMode(ConvolutionMode.Same) .dropOut(dropout) @@ -612,15 +619,6 @@ public void testDropout() { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - for (Layer l : mln.getLayers()) { - Dropout d = (Dropout) l.conf().getLayer().getIDropout(); - assertNotNull(d); - CudnnDropoutHelper h = (CudnnDropoutHelper) d.getHelper(); - assertNotNull(h); - } - - String msg = (cnn ? "CNN" : "Dense") + ": " + dropout.getClass().getSimpleName(); - INDArray f; if (cnn) { f = Nd4j.rand(new int[]{minibatch, 3, 8, 8}).muli(10).subi(5); @@ -629,6 +627,17 @@ public void testDropout() { } INDArray l = TestUtils.randomOneHot(minibatch, 10); + mln.output(f, true); + + for (Layer layer : mln.getLayers()) { + Dropout d = (Dropout) layer.conf().getLayer().getIDropout(); + assertNotNull(d); + CudnnDropoutHelper h = (CudnnDropoutHelper) d.getHelper(); + assertNotNull(h); + } + + String msg = (cnn ? "CNN" : "Dense") + ": " + dropout.getClass().getSimpleName(); + //Consumer function to enforce CuDNN RNG repeatability - otherwise will fail due to randomness (inconsistent // dropout mask between forward passes) Consumer c = new Consumer() { @@ -661,6 +670,7 @@ public void testDenseBatchNorm(){ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new NoOp()) diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java index 7b2846ec351a..3edf564b5738 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java @@ -19,6 +19,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.layers.dropout.CudnnDropoutHelper; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.factory.Nd4j; @@ -38,7 +39,7 @@ public void testCudnnDropoutSimple() { double pRetain = 0.25; double valueIfKept = 1.0 / pRetain; - CudnnDropoutHelper d = new CudnnDropoutHelper(); + CudnnDropoutHelper d = new CudnnDropoutHelper(DataType.DOUBLE); INDArray out = Nd4j.createUninitialized(shape); d.applyDropout(in, out, pRetain); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java index a4ef6926219a..08b57aa65a6f 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -142,6 +143,7 @@ public void validateImplMultiLayer() throws Exception { } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) .inferenceWorkspaceMode(WorkspaceMode.NONE).trainingWorkspaceMode(WorkspaceMode.NONE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() @@ -201,7 +203,7 @@ public void validateImplMultiLayer() throws Exception { mln1.computeGradientAndScore(); mln2.computeGradientAndScore(); - assertEquals(mln1.score(), mln2.score(), 1e-8); + assertEquals(mln1.score(), mln2.score(), 1e-5); assertEquals(mln1.getFlattenedGradients(), mln2.getFlattenedGradients()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/PermutePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/PermutePreprocessor.java index 5c08f5cc8017..6f2250b13c41 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/PermutePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/PermutePreprocessor.java @@ -44,7 +44,7 @@ public class PermutePreprocessor extends BaseInputPreProcessor { private int[] permutationIndices; private boolean hasLeadingDimension = false; - public PermutePreprocessor(@JsonProperty("permutationIndices") int[] permutationIndices) { + public PermutePreprocessor(@JsonProperty("permutationIndices") int... permutationIndices) { this.permutationIndices = permutationIndices; } @@ -91,7 +91,7 @@ public InputType getOutputType(InputType inputType) throws InvalidInputTypeExcep } else if (inputType instanceof InputType.InputTypeRecurrent) { InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType; return InputType.recurrent(it.getTimeSeriesLength(), it.getSize()); - } else if (inputType instanceof InputType.InputTypeFeedForward) { + } else if (inputType instanceof InputType.InputTypeFeedForward || inputType instanceof InputType.InputTypeConvolutional3D) { return inputType; } else { throw new InvalidInputTypeException("Unsupported Input type " + inputType); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index fd2596a2d0ad..3a0ca1410d13 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -762,7 +762,7 @@ private static INDArray[] getPredictions(Hdf5Archive archive, boolean tensorFlow } private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) { - INDArray diff = a.sub(b); + INDArray diff = a.sub(b.castTo(a.dataType())); double min = diff.minNumber().doubleValue(); double max = diff.maxNumber().doubleValue(); log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max); @@ -772,7 +772,7 @@ private static void compareINDArrays(String label, INDArray a, INDArray b, doubl // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { - assertTrue(a.equalsWithEps(b, eps)); + assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps)); } } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java index e63f59346875..3509279f235a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java @@ -407,7 +407,7 @@ public Thread newThread(Runnable r) { Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE"); int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); - INDArray basePoint = items.getRow(randomPoint); + INDArray basePoint = items.getRow(randomPoint, true); INDArray distancesArr = Nd4j.create(items.rows(), 1); ret.point = basePoint; ret.index = randomPoint; @@ -428,10 +428,10 @@ public Thread newThread(Runnable r) { continue; if (distancesArr.getDouble(i) < medianDistance) { - leftPoints.add(items.getRow(i)); + leftPoints.add(items.getRow(i, true)); leftIndices.add(i); } else { - rightPoints.add(items.getRow(i)); + rightPoints.add(items.getRow(i, true)); rightIndices.add(i); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java index bdb3634afa42..5be56964cf6f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java @@ -534,7 +534,7 @@ public INDArray vector(String word) { if (idx < 0) return null; } - return syn0.getRow(idx); + return syn0.getRow(idx, true); } @Override diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index 14ddc9641bd5..3fb328484866 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -297,7 +297,7 @@ public void fit() { val randArray = Nd4j.rand(new int[] {1, configuration.getLayersSize()}, rng).subi(0.5) .divi(configuration.getLayersSize()); - lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray); + lookupTable.getWeights().getRow(realElement.getIndex(), true).assign(randArray); realElement.setInit(true); } } @@ -313,7 +313,7 @@ public void fit() { INDArray randArray = Nd4j.rand(new int[] {1, configuration.getLayersSize()}, rng).subi(0.5) .divi(configuration.getLayersSize()); - lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray); + lookupTable.getWeights().getRow(realElement.getIndex(), true).assign(randArray); realElement.setInit(true); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index 29676e93cd2b..ec4bf99fb993 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -250,7 +250,7 @@ public void testMinibatchPadding() throws Exception { INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros); INDArray expM = Nd4j.vstack(expM0, expM1, zeros); INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}}); - INDArray expLM = Nd4j.create(Nd4j.defaultFloatingPointType(), 4, 1); + INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 96bccc72d225..09d59068027d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -186,7 +186,7 @@ public void testParagraphVectorsModelling1() throws Exception { File fullFile = File.createTempFile("paravec", "tests"); fullFile.deleteOnExit(); - INDArray originalSyn1_17 = ((InMemoryLookupTable) vec.getLookupTable()).getSyn1().getRow(17).dup(); + INDArray originalSyn1_17 = ((InMemoryLookupTable) vec.getLookupTable()).getSyn1().getRow(17, true).dup(); WordVectorSerializer.writeParagraphVectors(vec, fullFile); @@ -322,7 +322,7 @@ public void testParagraphVectorsModelling1() throws Exception { ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(fullFile); restoredVectors.setTokenizerFactory(t); - INDArray restoredSyn1_17 = ((InMemoryLookupTable) restoredVectors.getLookupTable()).getSyn1().getRow(17).dup(); + INDArray restoredSyn1_17 = ((InMemoryLookupTable) restoredVectors.getLookupTable()).getSyn1().getRow(17, true).dup(); assertEquals(originalSyn1_17, restoredSyn1_17); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index 4f9144138d6c..70d106a593bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -200,6 +200,18 @@ public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, doub + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } + DataType netDataType = mln.getLayerWiseConfigurations().getDataType(); + if (netDataType != DataType.DOUBLE) { + throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); + } + + if(netDataType != mln.params().dataType()){ + throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" + + "is: " + mln.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + } + + //Check network configuration: int layerCount = 0; for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) { @@ -469,6 +481,17 @@ public static boolean checkGradients(ComputationGraph graph, double epsilon, dou + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } + DataType netDataType = graph.getConfiguration().getDataType(); + if (netDataType != DataType.DOUBLE) { + throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); + } + + if(netDataType != graph.params().dataType()){ + throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" + + "is: " + graph.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + } + //Check configuration int layerCount = 0; for (String vertexName : graph.getConfiguration().getVertices().keySet()) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index b3fad2ef1181..4955677d0ed4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -37,6 +37,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.slf4j.Logger; @@ -81,7 +82,11 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { @Setter protected CacheMode cacheMode; - protected boolean validateOutputLayerConfig = true; //Default for 10.0.-beta3 and earlier nets + @Getter + @Setter + protected DataType dataType = DataType.FLOAT; //Default to float for 1.0.0-beta3 and earlier nets + + protected boolean validateOutputLayerConfig = true; //Default for 1.0.0-beta3 and earlier nets /** * List of inputs to the network, by name @@ -315,6 +320,7 @@ public ComputationGraphConfiguration clone() { conf.cacheMode = this.cacheMode; conf.defaultConfiguration.cacheMode = this.cacheMode; conf.validateOutputLayerConfig = this.validateOutputLayerConfig; + conf.dataType = this.dataType; return conf; } @@ -1120,6 +1126,7 @@ private ComputationGraphConfiguration buildConfig(){ conf.inferenceWorkspaceMode = globalConfiguration.inferenceWorkspaceMode; conf.cacheMode = globalConfiguration.cacheMode; conf.validateOutputLayerConfig = validateOutputConfig; + conf.dataType = globalConfiguration.dataType; conf.defaultConfiguration = globalConfiguration.build(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index ebc8c4b7d8d6..de3373323ec3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -32,6 +32,7 @@ import org.deeplearning4j.util.OutputLayerUtil; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; @@ -76,6 +77,10 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { @Setter protected CacheMode cacheMode; + @Getter + @Setter + protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of beta3 and earlier nets + //Counter for the number of parameter updates so far // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted // for Spark and model serialization @@ -369,6 +374,7 @@ public MultiLayerConfiguration clone() { clone.trainingWorkspaceMode = this.trainingWorkspaceMode; clone.cacheMode = this.cacheMode; clone.validateOutputLayerConfig = this.validateOutputLayerConfig; + clone.dataType = this.dataType; return clone; @@ -454,6 +460,7 @@ public static class Builder { protected CacheMode cacheMode = CacheMode.NONE; protected boolean validateOutputConfig = true; protected boolean validateTbpttConfig = true; + protected DataType dataType; /** * Specify the processors. @@ -595,6 +602,15 @@ public Builder validateTbpttConfig(boolean validate){ return this; } + /** + * Set the DataType for the network parameters and activations for all layers in the network. Default: Float + * @param dataType Datatype to use for parameters and activations + */ + public Builder dataType(@NonNull DataType dataType){ + this.dataType = dataType; + return this; + } + public MultiLayerConfiguration build() { //Validate BackpropType setting @@ -680,6 +696,7 @@ public MultiLayerConfiguration build() { conf.trainingWorkspaceMode = trainingWorkspaceMode; conf.inferenceWorkspaceMode = inferenceWorkspaceMode; conf.cacheMode = cacheMode; + conf.dataType = dataType; Nd4j.getRandom().setSeed(conf.getConf(0).getSeed()); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index a9b485edab4a..d8e77a55a4c9 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -51,6 +51,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.Sgd; @@ -98,6 +99,8 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { // this field defines preOutput cache protected CacheMode cacheMode; + protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of legacy format nets + //Counter for the number of parameter updates so far for this layer. //Note that this is only used for pretrain layers (AE, VAE) - MultiLayerConfiguration and ComputationGraphConfiguration //contain counters for standard backprop training. @@ -261,6 +264,7 @@ public MultiLayerConfiguration build() { .tBPTTBackwardLength(tbpttBackLength).setInputType(this.inputType) .trainingWorkspaceMode(wsmTrain).cacheMode(globalConfig.cacheMode) .inferenceWorkspaceMode(wsmTest).confs(list).validateOutputLayerConfig(validateOutputConfig) + .dataType(globalConfig.dataType) .build(); } @@ -496,6 +500,7 @@ public static class Builder implements Cloneable { protected boolean setTWM = false; protected boolean setIWM = false; protected CacheMode cacheMode = CacheMode.NONE; + protected DataType dataType = DataType.FLOAT; protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; protected ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; @@ -1152,6 +1157,18 @@ public Builder constrainWeights(LayerConstraint... constraints) { return this; } + + /** + * Set the DataType for the network parameters and activations. Must be a floating point type: {@link DataType#DOUBLE}, + * {@link DataType#FLOAT} or {@link DataType#HALF}.
+ */ + public Builder dataType(@NonNull DataType dataType){ + Preconditions.checkState(dataType == DataType.DOUBLE || dataType == DataType.FLOAT || dataType == DataType.HALF, + "Data type must be a floating point type: one of DOUBLE, FLOAT, or HALF. Got datatype: %s", dataType); + this.dataType = dataType; + return this; + } + /** * Return a configuration based on this builder * @@ -1168,6 +1185,7 @@ public NeuralNetConfiguration build() { conf.stepFunction = stepFunction; conf.miniBatch = miniBatch; conf.cacheMode = this.cacheMode; + conf.dataType = this.dataType; configureLayer(layer); if (layer instanceof FrozenLayer) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java index c94fcc8ebc03..c48048fa75fa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java @@ -133,7 +133,7 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite } lastPValue = pValue; - mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.shape(), output.ordering()); + mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()); Nd4j.getExecutioner().exec(new BernoulliDistribution(mask, pValue)); //a * (x * d + alphaPrime * (1-d)) + b diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index fd3f6e3c8694..899b4382a482 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; @@ -72,6 +73,7 @@ public class Dropout implements IDropout { private ISchedule pSchedule; private transient INDArray mask; private transient DropoutHelper helper; + private boolean initializedHelper = false; /** * @param activationRetainProbability Probability of retaining an activation - see {@link Dropout} javadoc @@ -97,18 +99,17 @@ public Dropout(ISchedule activationRetainProbabilitySchedule){ protected Dropout(@JsonProperty("p") double activationRetainProbability, @JsonProperty("pSchedule") ISchedule activationRetainProbabilitySchedule) { this.p = activationRetainProbability; this.pSchedule = activationRetainProbabilitySchedule; - initializeHelper(); } /** * Initialize the CuDNN dropout helper, if possible */ - protected void initializeHelper(){ + protected void initializeHelper(DataType dataType){ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); if("CUDA".equalsIgnoreCase(backend)) { try { helper = Class.forName("org.deeplearning4j.nn.layers.dropout.CudnnDropoutHelper") - .asSubclass(DropoutHelper.class).newInstance(); + .asSubclass(DropoutHelper.class).getConstructor(DataType.class).newInstance(dataType); log.debug("CudnnDropoutHelper successfully initialized"); if (!helper.checkSupported()) { helper = null; @@ -121,6 +122,7 @@ protected void initializeHelper(){ // benefit from them cudnn, they will get a warning from those } } + initializedHelper = true; } @@ -135,6 +137,10 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite currP = p; } + if(!initializedHelper){ + initializeHelper(output.dataType()); + } + if(helper != null){ helper.applyDropout(inputActivations, output, p); return output; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java index 09331f161b9a..4f364ee362fd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java @@ -85,7 +85,7 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite double stdev = Math.sqrt(r / (1.0 - r)); - noise = workspaceMgr.createUninitialized(ArrayType.INPUT, inputActivations.shape(), inputActivations.ordering()); + noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering()); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev)); return Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, noise, output)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java index a45bc4221549..89b5edf6cdd3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java @@ -66,7 +66,7 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite currS = stddev; } - INDArray noise = Nd4j.createUninitialized(inputActivations.shape(), inputActivations.ordering()); + INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering()); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS)); Nd4j.getExecutioner().exec(new OldAddOp(inputActivations, noise, output)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java index a2985c911f8c..aee73dbfa872 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java @@ -97,7 +97,7 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite val minibatch = inputActivations.size(0); val dim1 = inputActivations.size(1); - mask = workspaceMgr.createUninitialized(ArrayType.INPUT, minibatch, dim1).assign(1.0); + mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), minibatch, dim1).assign(1.0); Nd4j.getExecutioner().exec(new DropOutInverted(mask, currP)); Broadcast.mul(inputActivations, mask, output, 0, 1); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java index 3c5c2975b33a..6033f0030b00 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -96,7 +97,7 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex.Op op; switch (this.op) { case Add: @@ -117,7 +118,7 @@ public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGra default: throw new RuntimeException(); } - return new org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex(graph, name, idx, op); + return new org.deeplearning4j.nn.graph.vertex.impl.ElementWiseVertex(graph, name, idx, op, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/FrozenVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/FrozenVertex.java index 72755d64e3bc..5059395bb8c6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/FrozenVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/FrozenVertex.java @@ -16,18 +16,14 @@ package org.deeplearning4j.nn.conf.graph; -import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; -import org.deeplearning4j.nn.api.TrainingConfig; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.shade.jackson.annotation.JsonProperty; /** @@ -70,8 +66,8 @@ public int maxVertexInputs() { } @Override - public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { - org.deeplearning4j.nn.graph.vertex.GraphVertex u = underlying.instantiate(graph, name, idx, paramsView, initializeParams); + public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + org.deeplearning4j.nn.graph.vertex.GraphVertex u = underlying.instantiate(graph, name, idx, paramsView, initializeParams, networkDatatype); return new org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex(u); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java index 9f457a674d94..b7eb50dcc41f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java @@ -16,16 +16,13 @@ package org.deeplearning4j.nn.conf.graph; -import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; -import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; -import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializerHelper; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.Serializable; @@ -70,10 +67,11 @@ public abstract class GraphVertex implements Cloneable, Serializable { * @param idx The index of the GraphVertex * @param paramsView A view of the full parameters array * @param initializeParams If true: initialize the parameters. If false: make no change to the values in the paramsView array + * @param networkDatatype * @return The implementation GraphVertex object (i.e., implementation, no the configuration) */ public abstract org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, - int idx, INDArray paramsView, boolean initializeParams); + int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype); /** * Determine the type of output for this GraphVertex, given the specified inputs. Given that a GraphVertex may do arbitrary diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2NormalizeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2NormalizeVertex.java index 22c5d5aaa0f4..f299ca3b6a61 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2NormalizeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2NormalizeVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -78,9 +79,9 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { - return new org.deeplearning4j.nn.graph.vertex.impl.L2NormalizeVertex(graph, name, idx, dimension, eps); + return new org.deeplearning4j.nn.graph.vertex.impl.L2NormalizeVertex(graph, name, idx, dimension, eps, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2Vertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2Vertex.java index 41a98d448e2d..53fe591bb744 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2Vertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/L2Vertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -86,8 +87,8 @@ public int hashCode() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.L2Vertex(graph, name, idx, null, null, eps); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.L2Vertex(graph, name, idx, null, null, eps, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java index 2d504fca4b0e..4e15e36172cb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -98,19 +99,19 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { //Now, we need to work out if this vertex is an output vertex or not... boolean isOutput = graph.getConfiguration().getNetworkOutputs().contains(name); org.deeplearning4j.nn.api.Layer layer = - layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams); + layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams, networkDatatype); if(layer == null) { throw new IllegalStateException("Encountered null layer during initialization for layer:" + layerConf.getLayer().getClass().getSimpleName() + " initialization returned null layer?"); } - return new org.deeplearning4j.nn.graph.vertex.impl.LayerVertex(graph, name, idx, layer, preProcessor, isOutput); + return new org.deeplearning4j.nn.graph.vertex.impl.LayerVertex(graph, name, idx, layer, preProcessor, isOutput, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 149931c86f3d..b8c79192f5f4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** A MergeVertex is used to combine the activations of two or more layers/GraphVertex by means of concatenation/merging.
@@ -74,8 +75,8 @@ public String toString() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java index 81f90077d21c..6e2213b4e258 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -66,8 +67,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.PoolHelperVertex(graph, name, idx); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.PoolHelperVertex(graph, name, idx, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PreprocessorVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PreprocessorVertex.java index 68caa13d4930..90f905c60cbd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PreprocessorVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PreprocessorVertex.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -80,8 +81,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.PreprocessorVertex(graph, name, idx, preProcessor); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.PreprocessorVertex(graph, name, idx, preProcessor, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ReshapeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ReshapeVertex.java index 089e048ea5ce..76eeef8af85f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ReshapeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ReshapeVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -99,8 +100,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.ReshapeVertex(graph, name, idx, reshapeOrder, newShape, maskShape); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.ReshapeVertex(graph, name, idx, reshapeOrder, newShape, maskShape, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ScaleVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ScaleVertex.java index 2ba968b5946c..1859586c69ff 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ScaleVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ScaleVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -78,9 +79,9 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { - return new org.deeplearning4j.nn.graph.vertex.impl.ScaleVertex(graph, name, idx, scaleFactor); + return new org.deeplearning4j.nn.graph.vertex.impl.ScaleVertex(graph, name, idx, scaleFactor, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ShiftVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ShiftVertex.java index fe9114fda6ee..b50f16d5e4d0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ShiftVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ShiftVertex.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -79,9 +80,9 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { - return new org.deeplearning4j.nn.graph.vertex.impl.ShiftVertex(graph, name, idx, shiftFactor); + return new org.deeplearning4j.nn.graph.vertex.impl.ShiftVertex(graph, name, idx, shiftFactor, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java index 8f1888f2b233..d5e91493881e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java @@ -17,13 +17,13 @@ package org.deeplearning4j.nn.conf.graph; -import lombok.val; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -70,8 +70,8 @@ public int hashCode() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.StackVertex(graph, name, idx); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.StackVertex(graph, name, idx, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java index 0e7a75c1828d..cb3692af2b86 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -88,8 +89,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.SubsetVertex(graph, name, idx, from, to); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.SubsetVertex(graph, name, idx, from, to, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java index 2511db936352..a5d6c72f438f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -53,8 +54,8 @@ public UnstackVertex(@JsonProperty("from") int from, @JsonProperty("stackSize") @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.UnstackVertex(graph, name, idx, null, null, from, stackSize); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.UnstackVertex(graph, name, idx, null, null, from, stackSize, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/DuplicateToTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/DuplicateToTimeSeriesVertex.java index 4db77d0f6a1a..4560fd1f92b0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/DuplicateToTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/DuplicateToTimeSeriesVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -90,8 +91,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.rnn.DuplicateToTimeSeriesVertex(graph, name, idx, inputName); + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.rnn.DuplicateToTimeSeriesVertex(graph, name, idx, inputName, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/LastTimeStepVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/LastTimeStepVertex.java index 2ddd1e4e1a52..b95dccc9b7ae 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/LastTimeStepVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/LastTimeStepVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -88,8 +89,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.impl.rnn.LastTimeStepVertex instantiate(ComputationGraph graph, - String name, int idx, INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.rnn.LastTimeStepVertex(graph, name, idx, maskArrayInputName); + String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.rnn.LastTimeStepVertex(graph, name, idx, maskArrayInputName, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/ReverseTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/ReverseTimeSeriesVertex.java index 4a88be9fd7ca..eb4fccb19be8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/ReverseTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/rnn/ReverseTimeSeriesVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -89,8 +90,9 @@ public int maxVertexInputs() { return 1; } - public org.deeplearning4j.nn.graph.vertex.impl.rnn.ReverseTimeSeriesVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { - return new org.deeplearning4j.nn.graph.vertex.impl.rnn.ReverseTimeSeriesVertex(graph, name, idx, maskArrayInputName); + public org.deeplearning4j.nn.graph.vertex.impl.rnn.ReverseTimeSeriesVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, + boolean initializeParams, DataType networkDatatype) { + return new org.deeplearning4j.nn.graph.vertex.impl.rnn.ReverseTimeSeriesVertex(graph, name, idx, maskArrayInputName, networkDatatype); } public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java index 5039e9dc0a02..c4aaadc9836a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -72,8 +73,8 @@ public ActivationLayer clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { - org.deeplearning4j.nn.layers.ActivationLayer ret = new org.deeplearning4j.nn.layers.ActivationLayer(conf); + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.ActivationLayer ret = new org.deeplearning4j.nn.layers.ActivationLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java index 3573ba9d7a77..69788c5a3748 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -52,9 +53,9 @@ private AutoEncoder(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder ret = - new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(conf); + new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index 0ccb58e0b425..53c00acac627 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.NoOp; @@ -85,11 +86,11 @@ public BatchNormalization clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNOutSet("BatchNormalization", getLayerName(), layerIndex, getNOut()); org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = - new org.deeplearning4j.nn.layers.normalization.BatchNormalization(conf); + new org.deeplearning4j.nn.layers.normalization.BatchNormalization(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java index f3c9173cb837..34213038a918 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.params.CenterLossParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.NoOp; @@ -66,10 +67,10 @@ protected CenterLossOutputLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("CenterLossOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); - Layer ret = new org.deeplearning4j.nn.layers.training.CenterLossOutputLayer(conf); + Layer ret = new org.deeplearning4j.nn.layers.training.CenterLossOutputLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index c28f499151fd..b73265763ac7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -27,9 +27,9 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.util.Collection; import java.util.Map; @@ -72,9 +72,9 @@ private Cnn3DLossLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret = - new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(conf); + new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 698275ed3ee3..7b25ff797e97 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -69,9 +70,9 @@ private CnnLossLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret = - new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(conf); + new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index dabba7d1c0d8..b65f94d00664 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -27,13 +27,12 @@ import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; -import static org.deeplearning4j.nn.conf.layers.InputTypeUtil.getOutputTypeCnnLayers; - /** * 1D (temporal) convolutional layer. This layer accepts RNN InputTypes instead of CNN InputTypes * @@ -60,12 +59,12 @@ private Convolution1DLayer(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("Convolution1DLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.convolution.Convolution1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.Convolution1DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.Convolution1DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index 00adacaeb370..61475bf98a33 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -29,7 +29,7 @@ import org.deeplearning4j.util.Convolution3DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; -import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -95,10 +95,10 @@ public Convolution3D clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("Convolution3D", getLayerName(), layerIndex, getNIn(), getNOut()); - Convolution3DLayer ret = new Convolution3DLayer(conf); + Convolution3DLayer ret = new Convolution3DLayer(conf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 76dffbf51836..3d2e35d248f7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -161,11 +162,11 @@ public ConvolutionLayer clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.convolution.ConvolutionLayer ret = - new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(conf); + new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 3f6953025206..03b6ec405bb4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -84,11 +85,11 @@ public Deconvolution2D clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("Deconvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer ret = - new org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index f4e4f56c4566..67cac076d11b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -52,11 +53,11 @@ private DenseLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("DenseLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = - new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf); + new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index ebe4b7219d71..03fec1191d4f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -26,6 +26,8 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; @@ -49,6 +51,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { protected DepthwiseConvolution2D(Builder builder) { super(builder); + Preconditions.checkState(builder.depthMultiplier > 0, "Depth multiplier must be > 0, got %s", builder.depthMultiplier); this.depthMultiplier = builder.depthMultiplier; this.nOut = this.nIn * this.depthMultiplier; @@ -65,10 +68,10 @@ public DepthwiseConvolution2D clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("DepthwiseConvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); - DepthwiseConvolution2DLayer ret = new DepthwiseConvolution2DLayer(conf); + DepthwiseConvolution2DLayer ret = new DepthwiseConvolution2DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -95,6 +98,14 @@ public InputType getOutputType(int layerIndex, InputType inputType) { nOut, layerIndex, getLayerName(), DepthwiseConvolution2DLayer.class); } + @Override + public void setNIn(InputType inputType, boolean override) { + super.setNIn(inputType, override); + + if(nOut == 0 || override){ + nOut = this.nIn * this.depthMultiplier; + } + } @Getter @Setter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java index 6c44236ae238..94cad0a9804f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -51,6 +52,14 @@ private DropoutLayer(Builder builder) { super(builder); } + public DropoutLayer(double activationRetainProb){ + this(new Builder().dropOut(activationRetainProb)); + } + + public DropoutLayer(IDropout dropout){ + this(new Builder().dropOut(dropout)); + } + @Override public DropoutLayer clone() { return (DropoutLayer) super.clone(); @@ -58,9 +67,9 @@ public DropoutLayer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - org.deeplearning4j.nn.layers.DropoutLayer ret = new org.deeplearning4j.nn.layers.DropoutLayer(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.DropoutLayer ret = new org.deeplearning4j.nn.layers.DropoutLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 3226a4d34bfa..0227fee23e42 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -63,9 +64,9 @@ private EmbeddingLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret = - new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf); + new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 2b20a86e01c2..3e5766af208c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -19,6 +19,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; @@ -29,6 +30,7 @@ import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -66,9 +68,9 @@ private EmbeddingSequenceLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer ret = - new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer(conf); + new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -80,9 +82,9 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer ret = - new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(conf); + new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java index ecbdb739df17..d7aa869a1ac7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; @@ -81,9 +82,9 @@ protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Bui @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM ret = - new org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM(conf); + new org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java index 2f74ce4dd6cb..60fe918a8543 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java @@ -28,6 +28,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; @@ -79,10 +80,10 @@ protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Bui @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("GravesLSTM", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = - new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(conf); + new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java index 10bb0c7a7f6a..684f0df712ee 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java @@ -28,6 +28,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; @@ -75,9 +76,9 @@ protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Bui @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("LSTM", getLayerName(), layerIndex, getNIn(), getNOut()); - org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf); + org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index a4b210ff0484..0dcd121d7838 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -16,7 +16,6 @@ package org.deeplearning4j.nn.conf.layers; -import jdk.nashorn.internal.objects.annotations.Property; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; @@ -32,6 +31,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializerHelper; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; @@ -149,8 +149,8 @@ public Layer clone() { } public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType); /** * @return The parameter initializer for this model diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index 6e5a9d7f0ce8..dfc2df9c857d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -67,10 +68,10 @@ public LocalResponseNormalization clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization ret = - new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf); + new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index d2ccb77fbb7b..adede664adc7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -30,6 +30,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -128,6 +129,10 @@ public void setNIn(InputType inputType, boolean override) { InputType.InputTypeRecurrent c = (InputType.InputTypeRecurrent) inputType; this.nIn = c.getSize(); } + if(featureDim <= 0 || override){ + InputType.InputTypeRecurrent c = (InputType.InputTypeRecurrent) inputType; + this.featureDim = kernel * (int) c.getSize(); + } } @Override @@ -137,6 +142,7 @@ public InputPreProcessor getPreProcessorForInputType(InputType inputType) { @Override public void defineParameters(SDLayerParams params) { + Preconditions.checkState(featureDim > 0, "Cannot initialize layer: Feature dimension is set to %s", featureDim); params.clear(); val weightsShape = new long[] {outputSize, featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); @@ -164,12 +170,8 @@ public void initializeParameters(Map params) { @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { - SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); // (outH, featureDim, nOut) - // System.out.println(Arrays.toString(w.getShape())); - long[] inputShape = layerInput.getShape(); - long miniBatch = inputShape[0]; int outH = outputSize; int sH = stride; int kH = kernel; @@ -180,14 +182,11 @@ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, MapemptyMap(), "out"); -// System.out.println(Arrays.toString(concatOutput.getShape())); SDVariable mmulResult = sameDiff.mmul(concatOutput, w); // (outH, miniBatch, nOut) -// System.out.println(Arrays.toString(mmulResult.getShape())); SDVariable result = sameDiff.permute(mmulResult, 1, 2, 0); // (miniBatch, nOut, outH) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 031702dbcc7b..f96ddf1eea73 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -128,6 +128,7 @@ public void setNIn(InputType inputType, boolean override) { if (nIn <= 0 || override) { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; this.nIn = c.getChannels(); + this.featureDim = kernel[0] * kernel[1] * (int) nIn; } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java index f9ad012252d0..e26dbd83fea1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -59,8 +60,8 @@ protected LossLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { - org.deeplearning4j.nn.layers.LossLayer ret = new org.deeplearning4j.nn.layers.LossLayer(conf); + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.LossLayer ret = new org.deeplearning4j.nn.layers.LossLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java index 10563387e468..75d86460598d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -52,10 +53,10 @@ protected OutputLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("OutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); - org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(conf); + org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index 2fb1d07f2f39..db9a19eccd68 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -24,9 +24,10 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.PReLUParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -62,8 +63,8 @@ private PReLULayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { - org.deeplearning4j.nn.layers.feedforward.PReLU ret = new org.deeplearning4j.nn.layers.feedforward.PReLU(conf); + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.feedforward.PReLU ret = new org.deeplearning4j.nn.layers.feedforward.PReLU(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index 2124a8910216..209c61bcab6c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -62,9 +63,9 @@ private RnnLossLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret = - new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(conf); + new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index c1a960fc23e5..078673f5d717 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -57,11 +58,11 @@ private RnnOutputLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("RnnOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret = - new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(conf); + new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index a7ac62962109..181cc53111e0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -27,6 +27,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; @@ -125,11 +126,11 @@ public SeparableConvolution2D clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("SeparableConvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer ret = - new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index 3783943e8149..5d946f2c7c5f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -79,10 +80,10 @@ public SpaceToBatchLayer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret = - new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(conf); + new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -196,7 +197,7 @@ public Builder(int[] blocks, int[][] padding) { * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * dimensions */ - public T blocks(int[] blocks) { + public T blocks(int... blocks) { this.setBlocks(blocks); return (T) this; } @@ -218,6 +219,8 @@ public T name(String layerName) { @Override @SuppressWarnings("unchecked") public SpaceToBatchLayer build() { + if(padding == null) + setPadding(new int[][] {{0, 0}, {0, 0}}); return new SpaceToBatchLayer(this); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index f96fc9afa08e..aeca265f83cb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -76,10 +77,10 @@ public SpaceToDepthLayer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret = - new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(conf); + new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 7a380709368a..de491290fa88 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -60,10 +61,10 @@ private Subsampling1DLayer(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 0ec770fd31e2..877e216dab4c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -25,12 +25,12 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; -import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution3DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -113,10 +113,10 @@ public Subsampling3DLayer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection iterationListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection iterationListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer(conf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -443,6 +443,11 @@ public T poolingType(PoolingType poolingType) { return (T) this; } + public T poolingType(org.deeplearning4j.nn.conf.layers.PoolingType poolingType){ + this.setPoolingType(poolingType); + return (T) this; + } + public T dilation(int dDepth, int dHeight, int dWidth) { this.setDilation(new int[] {dDepth, dHeight, dWidth}); return (T) this; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index 003b9c5d69cd..9eff7a91a528 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -117,10 +118,10 @@ public SubsamplingLayer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(conf); + new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -339,7 +340,8 @@ public void setPadding(int... padding) { this.padding = ValidationUtils.validate2NonNegative(padding,false, "padding"); } - public void setDilation(int... dilation) { + + public void setDilation(int[] dilation) { this.dilation = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); } } @@ -454,6 +456,11 @@ public T poolingType(PoolingType poolingType) { return (T) this; } + public T poolingType(org.deeplearning4j.nn.conf.layers.PoolingType poolingType){ + this.setPoolingType(poolingType); + return (T) this; + } + public T pnorm(int pnorm) { this.setPnorm(pnorm); return (T) this; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java index f5e13aff18fe..1e1fcfab17c9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java @@ -29,6 +29,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -64,10 +65,10 @@ protected Upsampling1D(UpsamplingBuilder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D(conf); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index c6e056854e10..12bdfc53b0e6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyIntArrayDeserializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; -import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; @@ -73,10 +73,10 @@ public Upsampling2D clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D(conf); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index 78a486e901cf..c38ab0a534a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -22,10 +22,10 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -58,10 +58,10 @@ public Upsampling3D clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection iterationListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection iterationListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(conf); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(conf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index 75ad0b657059..e888a2904ac9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; -import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -64,10 +64,10 @@ public ZeroPadding1DLayer(int[] padding) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java index 18645332732d..8dfd594a6ace 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java @@ -24,9 +24,9 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; -import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -53,10 +53,10 @@ private ZeroPadding3DLayer(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection iterationListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection iterationListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer(conf); + new org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer(conf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index 18395a7886bb..48463b76b147 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -66,10 +67,10 @@ private ZeroPaddingLayer(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(conf); + new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java index 73387ca5f88a..6ae0bae8f4a4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java @@ -28,9 +28,9 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import java.util.Arrays; import java.util.Collection; import java.util.Map; @@ -75,9 +75,9 @@ protected Cropping1D(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - Cropping1DLayer ret = new Cropping1DLayer(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + Cropping1DLayer ret = new Cropping1DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index aa547028c781..497bb9a063e1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -29,9 +29,9 @@ import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import java.util.Arrays; import java.util.Collection; import java.util.Map; @@ -81,9 +81,9 @@ protected Cropping2D(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - Cropping2DLayer ret = new Cropping2DLayer(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + Cropping2DLayer ret = new Cropping2DLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java index 5090dda66440..e06b4b4e747b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java @@ -17,7 +17,6 @@ package org.deeplearning4j.nn.conf.layers.convolutional; import lombok.*; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -25,16 +24,13 @@ import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.layers.convolution.Cropping2DLayer; import org.deeplearning4j.nn.layers.convolution.Cropping3DLayer; -import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.api.TrainingListener; -import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import java.util.Arrays; import java.util.Collection; import java.util.Map; @@ -88,9 +84,9 @@ protected Cropping3D(Builder builder) { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection iterationListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - Cropping3DLayer ret = new Cropping3DLayer(conf); + Collection iterationListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + Cropping3DLayer ret = new Cropping3DLayer(conf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java index dbf66601a043..ef88dc8b7ac2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.ElementWiseParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -60,13 +61,13 @@ public ElementWiseMultiplicationLayer clone() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { if (this.nIn != this.nOut) { throw new IllegalStateException("Element wise layer must have the same input and output size. Got nIn=" + nIn + ", nOut=" + nOut); } org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer ret = - new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(conf); + new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index dc7b5113858f..7b2317a8f5f3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.conf.serde.FrozenLayerDeserializer; import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; @@ -77,12 +78,12 @@ public Layer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { //Need to be able to instantiate a layer, from a config - for JSON -> net type situations org.deeplearning4j.nn.api.Layer underlying = layer.instantiate(getInnerConf(conf), trainingListeners, - layerIndex, layerParamsView, initializeParams); + layerIndex, layerParamsView, initializeParams, networkDataType); NeuralNetConfiguration nncUnderlying = underlying.conf(); if (nncUnderlying.variables() != null) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java index 2e8611841220..468c310329b7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java @@ -17,20 +17,14 @@ package org.deeplearning4j.nn.conf.layers.misc; import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.Getter; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; -import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.nn.params.FrozenLayerWithBackpropParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; @@ -67,12 +61,12 @@ public Layer clone() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { //Need to be able to instantiate a layer, from a config - for JSON -> net type situations org.deeplearning4j.nn.api.Layer underlying = getUnderlying().instantiate(getInnerConf(conf), trainingListeners, - layerIndex, layerParamsView, initializeParams); + layerIndex, layerParamsView, initializeParams, networkDataType); NeuralNetConfiguration nncUnderlying = underlying.conf(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index 2bf4335860bf..424130f17dd4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -63,9 +64,9 @@ public ParamInitializer initializer() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - org.deeplearning4j.nn.layers.RepeatVector ret = new org.deeplearning4j.nn.layers.RepeatVector(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.RepeatVector ret = new org.deeplearning4j.nn.layers.RepeatVector(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index 3be638cfb3f0..804fc3d78c0b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.lossfunctions.ILossFunction; @@ -95,9 +96,9 @@ private Yolo2OutputLayer(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret = - new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(conf); + new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 70fb7b6ecb3b..4d32f22e5cdc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -18,7 +18,6 @@ import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -31,6 +30,7 @@ import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.params.BidirectionalParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; @@ -126,19 +126,19 @@ public long getNIn() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration c1 = conf.clone(); NeuralNetConfiguration c2 = conf.clone(); c1.setLayer(fwd); c2.setLayer(bwd); long n = layerParamsView.length() / 2; - INDArray fp = layerParamsView.get(point(0), interval(0, n)); - INDArray bp = layerParamsView.get(point(0), interval(n, 2 * n)); - org.deeplearning4j.nn.api.Layer f = fwd.instantiate(c1, trainingListeners, layerIndex, fp, initializeParams); + INDArray fp = layerParamsView.get(interval(0,0,true), interval(0, n)); + INDArray bp = layerParamsView.get(interval(0,0,true), interval(n, 2 * n)); + org.deeplearning4j.nn.api.Layer f = fwd.instantiate(c1, trainingListeners, layerIndex, fp, initializeParams, networkDataType); - org.deeplearning4j.nn.api.Layer b = bwd.instantiate(c2, trainingListeners, layerIndex, bp, initializeParams); + org.deeplearning4j.nn.api.Layer b = bwd.instantiate(c2, trainingListeners, layerIndex, bp, initializeParams, networkDataType); BidirectionalLayer ret = new BidirectionalLayer(conf, f, b, layerParamsView); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java index d94baa92c332..52c048472f27 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -51,12 +52,12 @@ public Layer getUnderlying() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf2 = conf.clone(); conf2.setLayer(((LastTimeStep) conf2.getLayer()).getUnderlying()); return new LastTimeStepLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, - initializeParams)); + initializeParams, networkDataType)); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java index d37bf2074b9e..7bc91c17eb00 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -60,11 +61,11 @@ private SimpleRnn() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("SimpleRnn", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.recurrent.SimpleRnn ret = - new org.deeplearning4j.nn.layers.recurrent.SimpleRnn(conf); + new org.deeplearning4j.nn.layers.recurrent.SimpleRnn(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java index 8dff70b763f7..74d5f450eb58 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.NetworkUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.L1Regularization; @@ -136,8 +137,8 @@ public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig @Override public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType); //================================================================================================================== diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java index 20e90522c87d..d9655a58f0ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -96,10 +97,10 @@ public void validateInput(INDArray input){/* no-op */} @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.samediff.SameDiffLayer ret = - new org.deeplearning4j.nn.layers.samediff.SameDiffLayer(conf); + new org.deeplearning4j.nn.layers.samediff.SameDiffLayer(conf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java index 527043b63fd0..d6c4892f37fd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java @@ -20,6 +20,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; @@ -89,10 +90,10 @@ public boolean labelsRequired() { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer ret = - new org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer(conf); + new org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer(conf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java index f11db47fd61a..69187f755af0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; @@ -124,9 +125,9 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType networkDatatype) { this.name = name; - return new SameDiffGraphVertex(this, graph, name, idx, paramsView, initializeParams); + return new SameDiffGraphVertex(this, graph, name, idx, paramsView, initializeParams, networkDatatype); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java index a27f2e58d676..1eaa661dfc97 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java @@ -21,11 +21,11 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -43,9 +43,9 @@ public class MaskLayer extends NoParamLayer { @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - org.deeplearning4j.nn.layers.util.MaskLayer ret = new org.deeplearning4j.nn.layers.util.MaskLayer(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + org.deeplearning4j.nn.layers.util.MaskLayer ret = new org.deeplearning4j.nn.layers.util.MaskLayer(conf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java index 5c24e8ef8a35..fd9c64ba6643 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -60,14 +61,14 @@ public MaskZeroLayer(@JsonProperty("underlying") Layer underlying, @JsonProperty @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf2 = conf.clone(); conf2.setLayer(((BaseWrapperLayer) conf2.getLayer()).getUnderlying()); org.deeplearning4j.nn.api.Layer underlyingLayer = - underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams); + underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); return new org.deeplearning4j.nn.layers.recurrent.MaskZeroLayer(underlyingLayer, maskingValue); } @@ -120,6 +121,16 @@ public Builder setMaskValue(double maskValue) { return this; } + public Builder underlying(Layer underlying){ + setUnderlying(underlying); + return this; + } + + public Builder maskValue(double maskValue){ + setMaskValue(maskValue); + return this; + } + @Override @SuppressWarnings("unchecked") public MaskZeroLayer build() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java index cf95e08bd801..509c3bfb1058 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java @@ -99,7 +99,7 @@ public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistribution } private INDArray calcLogProbArray(INDArray x, INDArray preOutDistributionParams) { - x = x.castTo(Nd4j.defaultFloatingPointType()); + x = x.castTo(preOutDistributionParams.dataType()); INDArray output = preOutDistributionParams.dup(); activationFn.getActivation(output, false); @@ -118,7 +118,7 @@ private INDArray calcLogProbArray(INDArray x, INDArray preOutDistributionParams) public INDArray gradient(INDArray x, INDArray preOutDistributionParams) { INDArray output = preOutDistributionParams.dup(); activationFn.getActivation(output, true); - x = x.castTo(Nd4j.defaultFloatingPointType()); + x = x.castTo(preOutDistributionParams.dataType()); INDArray diff = x.sub(output); INDArray outOneMinusOut = output.rsub(1.0).muli(output); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java index ec40e377c863..8b89371211c0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java @@ -198,7 +198,7 @@ public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistribution public INDArray gradient(INDArray x, INDArray preOutDistributionParams) { int inputSoFar = 0; int paramsSoFar = 0; - INDArray gradient = Nd4j.createUninitialized(preOutDistributionParams.shape()); + INDArray gradient = preOutDistributionParams.ulike(); for (int i = 0; i < distributionSizes.length; i++) { int thisInputSize = distributionSizes[i]; int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisInputSize); @@ -233,7 +233,7 @@ public INDArray generateAtMean(INDArray preOutDistributionParams) { private INDArray randomSample(INDArray preOutDistributionParams, boolean isMean) { int inputSoFar = 0; int paramsSoFar = 0; - INDArray out = Nd4j.createUninitialized(new long[] {preOutDistributionParams.size(0), totalSize}); + INDArray out = Nd4j.createUninitialized(preOutDistributionParams.dataType(), new long[] {preOutDistributionParams.size(0), totalSize}); for (int i = 0; i < distributionSizes.length; i++) { int thisDataSize = distributionSizes[i]; int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisDataSize); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.java index b4c8c9ca1884..9f41a2ac5538 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.java @@ -140,7 +140,7 @@ public INDArray gradient(INDArray x, INDArray preOutDistributionParams) { INDArray dLdsigma = sigma.rdiv(-1).addi(xSubMeanSq.divi(sigma3)); INDArray dLdlogSigma2 = sigma.divi(2).muli(dLdsigma); - INDArray dLdx = Nd4j.createUninitialized(output.shape()); + INDArray dLdx = Nd4j.createUninitialized(preOutDistributionParams.dataType(), output.shape()); dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, dLdmu); dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, dLdlogSigma2); dLdx.negi(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java index 15560126b3c8..9e545defc0fd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -75,11 +76,11 @@ private VariationalAutoencoder(Builder builder) { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = - new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf); + new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java index 7b46c6dd73ea..f76cb0dad0f7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.lossfunctions.ILossFunction; @@ -105,11 +106,11 @@ public ILossFunction getLossFn() { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("OCNNOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer ret = - new org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer(conf); + new org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java index c09c41b965f5..f41d3368dd61 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.workspace.ArrayType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -45,7 +46,7 @@ public ComposableInputPreProcessor(@JsonProperty("inputPreProcessors") InputPreP public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { for (InputPreProcessor preProcessor : inputPreProcessors) input = preProcessor.preProcess(input, miniBatchSize, workspaceMgr); - return input; + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input); } @Override @@ -55,7 +56,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w for (int i = inputPreProcessors.length - 1; i >= 0; i--) { output = inputPreProcessors[i].backprop(output, miniBatchSize, workspaceMgr); } - return output; + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java index ec557131ba97..76c8c220c1a3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java @@ -92,7 +92,7 @@ public INDArray getParameter(Layer layer, String paramKey, int iteration, int ep if (train && init.isWeightParam(layer.conf().getLayer(), paramKey) || (applyToBiases && init.isBiasParam(layer.conf().getLayer(), paramKey))) { - INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.shape(), param.ordering()); + INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); Nd4j.getExecutioner().exec(new DropOut(param, out, p)); return out; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java index 839593db0c56..fea4462477cb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java @@ -82,8 +82,8 @@ public INDArray getParameter(Layer layer, String paramKey, int iteration, int ep (applyToBias && init.isBiasParam(layer.conf().getLayer(), paramKey))) { org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution); - INDArray noise = dist.sample(param.shape()); - INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.shape(), param.ordering()); + INDArray noise = dist.sample(param.ulike()); + INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); if (additive) { Nd4j.getExecutioner().exec(new OldAddOp(param, noise,out)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 94ceabadee58..c5d59715a567 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -28,13 +28,13 @@ import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.*; +import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.util.ComputationGraphUtil; @@ -65,6 +65,7 @@ import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -450,6 +451,19 @@ public void init(INDArray parameters, boolean cloneParametersArray) { if (initCalled) return; + DataType netDtype = getConfiguration().getDataType(); + if(parameters != null && parameters.dataType() != netDtype){ + if(cloneParametersArray){ + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + parameters = parameters.castTo(netDtype); + } + } else { + throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype + + " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument" + + " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); + } + } + if (configuration.getTrainingWorkspaceMode() == null) configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE); @@ -482,7 +496,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { //Create network input vertices: int vertexNumber = 0; for (String name : networkInputNames) { - GraphVertex gv = new InputVertex(this, name, vertexNumber, null); //Output vertices: set later + GraphVertex gv = new InputVertex(this, name, vertexNumber, null, netDtype); //Output vertices: set later allNamesReverse.put(name, vertexNumber); vertices[vertexNumber++] = gv; } @@ -516,7 +530,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { initializeParams = false; } else if(numParams > 0){ - flattenedParams = Nd4j.create(1, numParams); + flattenedParams = Nd4j.create(netDtype, 1, numParams); initializeParams = true; } else { flattenedParams = null; @@ -536,7 +550,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { for (int vertexIdx : topologicalOrder) { long nParamsThisVertex = numParamsForVertex[vertexIdx]; if (nParamsThisVertex != 0) { - paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.point(0), + paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex)); } i++; @@ -554,7 +568,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name); GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber], - initializeParams); + initializeParams, netDtype); if(gv == null){ throw new IllegalStateException("Encountered null layer/vertex during initialization for layer \"" + name + @@ -766,7 +780,7 @@ public void initGradientsView() { for (int vertexIdx : topologicalOrder) { long nParamsThisVertex = numParamsForVertex[vertexIdx]; if (nParamsThisVertex != 0) { - INDArray gradientView = flattenedGradients.get(NDArrayIndex.point(0), + INDArray gradientView = flattenedGradients.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex)); vertices[vertexIdx].setBackpropGradientsViewArray(gradientView); } @@ -3210,7 +3224,7 @@ public void setParams(INDArray params) { long range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) - INDArray get = params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, range + idx)); + INDArray get = params.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(idx, range + idx)); layer.setParams(get); idx += range; } @@ -3237,7 +3251,7 @@ public void setBackpropGradientsViewArray(INDArray gradient) { long range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) - layer.setBackpropGradientsViewArray(gradient.get(NDArrayIndex.point(0), + layer.setBackpropGradientsViewArray(gradient.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramsSoFar, paramsSoFar + range))); paramsSoFar += range; } @@ -4481,6 +4495,37 @@ public static ComputationGraph load(File f, boolean loadUpdater) throws IOExcept return ModelSerializer.restoreComputationGraph(f, loadUpdater); } + /** + * Return a copy of the network with the parameters and activations set to use the specified (floating point) data type. + * If the existing datatype is the same as the requested dataype, the original network will be returned unchanged. + * Only floating point datatypes (DOUBLE, FLOAT, HALF) may be used. + * + * @param dataType Datatype to convert the network to + * @return The network, set to use the specified datatype for the parameters and activations + */ + public ComputationGraph convertDataType(@NonNull DataType dataType){ + Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType); + if(dataType == params().dataType()){ + return this; + } + + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + INDArray newParams = params().castTo(dataType); + String jsonConfig = getConfiguration().toJson(); + ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig); + newConf.setDataType(dataType); + ComputationGraph newNet = new ComputationGraph(newConf); + newNet.init(newParams, false); + + Updater u = getUpdater(false); + if(u != null && u.getStateViewArray() != null){ + INDArray oldUpdaterState = u.getStateViewArray(); + newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState); + } + return newNet; + } + } + /** * Set the learning rate for all layers in the network to the specified value. Note that if any learning rate * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java index 232c106f9f3f..555c64f94d03 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -61,13 +62,16 @@ public abstract class BaseGraphVertex implements GraphVertex { @Setter @Getter protected boolean outputVertex; + protected DataType dataType; + protected BaseGraphVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices) { + VertexIndices[] outputVertices, DataType dataType) { this.graph = graph; this.vertexName = name; this.vertexIndex = vertexIndex; this.inputVertices = inputVertices; this.outputVertices = outputVertices; + this.dataType = dataType; this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)]; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java index 0c56d8369633..d9fe6aceeee0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java @@ -51,13 +51,13 @@ public enum Op { private Op op; private int nInForwardPass; - public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op) { - this(graph, name, vertexIndex, null, null, op); + public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op, DataType dataType) { + this(graph, name, vertexIndex, null, null, op, dataType); } public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, Op op) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, Op op, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.op = op; } @@ -82,29 +82,32 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { switch (op) { case Add: - INDArray sum = workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]); + INDArray sum = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape()); + sum.assign(inputs[0]); for (int i = 1; i < inputs.length; i++) { - sum.addi(inputs[i]); + sum.addi(inputs[i].castTo(dataType)); } return sum; case Average: - INDArray average = workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]); + INDArray average = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape()); + average.assign(inputs[0]); for (int i = 1; i < inputs.length; i++) { - average.addi(inputs[i]); + average.addi(inputs[i].castTo(dataType)); } return average.divi(inputs.length); case Subtract: if (inputs.length != 2) throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs"); - return Nd4j.getExecutioner().exec(new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].shape()))); + return Nd4j.getExecutioner().exec(new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape()))); case Product: - INDArray product = workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]); + INDArray product = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, dataType, inputs[0].shape()); + product.assign(inputs[0]); for (int i = 1; i < inputs.length; i++) { - product.muli(inputs[i]); + product.muli(inputs[i].castTo(dataType)); } return product; case Max: - INDArray max = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].shape(), inputs[0].ordering()); + INDArray max = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].dataType(), inputs[0].shape(), inputs[0].ordering()); CustomOp op = DynamicCustomOp.builder("mergemax") .addInputs(inputs) .addOutputs(max) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java index 0c08437e20c7..07216b14a13e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -32,8 +33,8 @@ public class InputVertex extends BaseGraphVertex { - public InputVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, null, outputVertices); + public InputVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, null, outputVertices, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java index 89d2ffa331c5..4ec17dc9534b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp; @@ -47,13 +48,13 @@ public class L2NormalizeVertex extends BaseGraphVertex { private int[] dimension; private double eps; - public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, int[] dimension, double eps) { - this(graph, name, vertexIndex, null, null, dimension, eps); + public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, int[] dimension, double eps, DataType dataType) { + this(graph, name, vertexIndex, null, null, dimension, eps, dataType); } public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, int[] dimension, double eps) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, int[] dimension, double eps, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.dimension = dimension; this.eps = eps; } @@ -123,7 +124,7 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo Nd4j.getExecutioner().exec(new BroadcastMulOp(xDivNorm3, dx, xDivNorm3, 0)); //1/|x|_2 * dLda - above - dLdx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.shape(), epsilon.ordering()); + dLdx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering()); Nd4j.getExecutioner().exec(new BroadcastDivOp(epsilon, norm, dLdx, 0)); dLdx.subi(xDivNorm3); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java index 384866138707..b9931eb77429 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; @@ -43,13 +44,13 @@ public class L2Vertex extends BaseGraphVertex { private double eps; - public L2Vertex(ComputationGraph graph, String name, int vertexIndex, double eps) { - this(graph, name, vertexIndex, null, null, eps); + public L2Vertex(ComputationGraph graph, String name, int vertexIndex, double eps, DataType dataType) { + this(graph, name, vertexIndex, null, null, eps, dataType); } public L2Vertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, double eps) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, double eps, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.eps = eps; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index b9bc13f38254..886c719f94aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -55,14 +56,14 @@ public class LayerVertex extends BaseGraphVertex { * Create a network input vertex: */ public LayerVertex(ComputationGraph graph, String name, int vertexIndex, Layer layer, - InputPreProcessor layerPreProcessor, boolean outputVertex) { - this(graph, name, vertexIndex, null, null, layer, layerPreProcessor, outputVertex); + InputPreProcessor layerPreProcessor, boolean outputVertex, DataType dataType) { + this(graph, name, vertexIndex, null, null, layer, layerPreProcessor, outputVertex, dataType); } public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Layer layer, InputPreProcessor layerPreProcessor, - boolean outputVertex) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + boolean outputVertex, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.graph = graph; this.vertexName = name; this.vertexIndex = vertexIndex; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index 8abd0fdb4522..73cd7db4d4e0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -49,13 +49,13 @@ public class MergeVertex extends BaseGraphVertex { private long[][] forwardPassShapes; private int fwdPassRank; - public MergeVertex(ComputationGraph graph, String name, int vertexIndex) { - this(graph, name, vertexIndex, null, null); + public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { + this(graph, name, vertexIndex, null, null, dataType); } public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } @Override @@ -85,12 +85,17 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, inputs[0]); } - forwardPassShapes = new long[inputs.length][0]; - val nExamples = inputs[0].size(0); + INDArray[] in = new INDArray[inputs.length]; + for( int i=0; i doBackward(boolean tbptt, LayerWorkspaceMgr wo //Split the epsilons in the opposite way that the activations were merged INDArray[] out = new INDArray[forwardPassShapes.length]; for (int i = 0; i < out.length; i++) - out[i] = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, forwardPassShapes[i]); + out[i] = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), forwardPassShapes[i]); int cumulative = 0; switch (fwdPassRank) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java index e5edcdf3e324..d13368969e0c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.factory.Nd4j; @@ -39,13 +40,13 @@ */ public class PoolHelperVertex extends BaseGraphVertex { - public PoolHelperVertex(ComputationGraph graph, String name, int vertexIndex) { - this(graph, name, vertexIndex, null, null); + public PoolHelperVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { + this(graph, name, vertexIndex, null, null, dataType); } public PoolHelperVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } @Override @@ -76,7 +77,11 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo if (!canDoBackward()) throw new IllegalStateException("Cannot do backward pass: errors not set"); - return new Pair<>(null, new INDArray[] {workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon)}); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.size(0), epsilon.size(1), 1+epsilon.size(2), 1+epsilon.size(2)); + out.get(NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.interval(1, inputs[0].size(2)), NDArrayIndex.interval(1, inputs[0].size(3))) + .assign(epsilon); + + return new Pair<>(null, new INDArray[] {out}); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java index 4f956324f286..906a3a35226d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -35,13 +36,13 @@ public class PreprocessorVertex extends BaseGraphVertex { private InputPreProcessor preProcessor; - public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex, InputPreProcessor preProcessor) { - this(graph, name, vertexIndex, null, null, preProcessor); + public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex, InputPreProcessor preProcessor, DataType dataType) { + this(graph, name, vertexIndex, null, null, preProcessor, dataType); } public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, InputPreProcessor preProcessor) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, InputPreProcessor preProcessor, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.preProcessor = preProcessor; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java index 5984733bb78e..828d190181f9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -44,13 +45,13 @@ public class ReshapeVertex extends BaseGraphVertex { private int[] maskShape; - public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, char order, int[] newShape, int[] maskShape) { - this(graph, name, vertexIndex, null, null, order, newShape, maskShape); + public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, char order, int[] newShape, int[] maskShape, DataType dataType) { + this(graph, name, vertexIndex, null, null, order, newShape, maskShape, dataType); } public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, char order, int[] newShape, int[] maskShape) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, char order, int[] newShape, int[] maskShape, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.order = order; this.newShape = newShape; this.maskShape = maskShape; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java index 1098172e94aa..0bea6c3ec0fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -39,13 +40,13 @@ public class ScaleVertex extends BaseGraphVertex { private double scaleFactor; - public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, double scaleFactor) { - this(graph, name, vertexIndex, null, null, scaleFactor); + public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, double scaleFactor, DataType dataType) { + this(graph, name, vertexIndex, null, null, scaleFactor, dataType); } public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, double scaleFactor) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, double scaleFactor, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.scaleFactor = scaleFactor; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java index b21cb3a12f9b..b4f1718948fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -48,13 +49,13 @@ public class ShiftVertex extends BaseGraphVertex { private double shiftFactor; - public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, double shiftFactor) { - this(graph, name, vertexIndex, null, null, shiftFactor); + public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, double shiftFactor, DataType dataType) { + this(graph, name, vertexIndex, null, null, shiftFactor, dataType); } public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, double shiftFactor) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, double shiftFactor, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.shiftFactor = shiftFactor; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java index abf023e3823e..37e091e44aef 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -46,13 +47,13 @@ public class StackVertex extends BaseGraphVertex { private long[][] lastInputShapes; - public StackVertex(ComputationGraph graph, String name, int vertexIndex) { - this(graph, name, vertexIndex, null, null); + public StackVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { + this(graph, name, vertexIndex, null, null, dataType); } public StackVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } @Override @@ -100,7 +101,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { } outShape[2] = maxLength; - INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, outShape); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape); long numExamples = inputs[0].size(0); lastInputShapes = new long[inputs.length][0]; for (int i = 0; i < inputs.length; i++) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java index 45e765046a1d..0e148185d4fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -44,13 +45,13 @@ public class SubsetVertex extends BaseGraphVertex { private int to; //inclusive private long[] forwardShape; - public SubsetVertex(ComputationGraph graph, String name, int vertexIndex, int from, int to) { - this(graph, name, vertexIndex, null, null, from, to); + public SubsetVertex(ComputationGraph graph, String name, int vertexIndex, int from, int to, DataType dataType) { + this(graph, name, vertexIndex, null, null, from, to, dataType); } public SubsetVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, int from, int to) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, int from, int to, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.from = from; this.to = to; } @@ -96,7 +97,7 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo if (!canDoBackward()) throw new IllegalStateException("Cannot do backward pass: error not set"); - INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, forwardShape); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), forwardShape); switch (forwardShape.length) { case 2: out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true)}, epsilon); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java index 2883a4b43c32..4ca04c418c16 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -47,13 +48,13 @@ public class UnstackVertex extends BaseGraphVertex { private long forwardShape[]; private int step; - public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize) { - this(graph, name, vertexIndex, null, null, from, stackSize); + public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) { + this(graph, name, vertexIndex, null, null, from, stackSize, dataType); } public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, int from, int stackSize) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, int from, int stackSize, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.from = from; this.stackSize = stackSize; } @@ -106,7 +107,7 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo if (!canDoBackward()) throw new IllegalStateException("Cannot do backward pass: error not set"); - INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, forwardShape); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape); int start = from * step; int end = (from + 1) * step; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java index 8be902fee2c0..293ce4af2a4f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -45,13 +46,13 @@ public class DuplicateToTimeSeriesVertex extends BaseGraphVertex { private String inputName; private int inputVertexIndex; - public DuplicateToTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputVertexName) { - this(graph, name, vertexIndex, null, null, inputVertexName); + public DuplicateToTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputVertexName, DataType dataType) { + this(graph, name, vertexIndex, null, null, inputVertexName, dataType); } public DuplicateToTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, - VertexIndices[] inputVertices, VertexIndices[] outputVertices, String inputName) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] inputVertices, VertexIndices[] outputVertices, String inputName, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.inputName = inputName; this.inputVertexIndex = graph.getConfiguration().getNetworkInputs().indexOf(inputName); if (inputVertexIndex == -1) @@ -81,7 +82,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { val tsLength = graph.getInput(inputVertexIndex).size(2); val outShape = new long[] {inputs[0].size(0), inputs[0].size(1), tsLength}; - INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, outShape, 'f'); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape, 'f'); for (int i = 0; i < tsLength; i++) { out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)}, inputs[0]); } @@ -91,7 +92,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { @Override public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { //Because we duplicated for each time step: simply need to sum along time for errors/epsilons - INDArray ret = epsilon.sum(workspaceMgr.create(ArrayType.ACTIVATION_GRAD, epsilon.size(0), epsilon.size(1)), 2); + INDArray ret = epsilon.sum(workspaceMgr.create(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.size(0), epsilon.size(1)), 2); return new Pair<>(null, new INDArray[] {ret}); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java index 369b05144179..a1dd47469525 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -49,14 +50,14 @@ public class LastTimeStepVertex extends BaseGraphVertex { /** Indexes of the time steps that were extracted, for each example */ private int[] fwdPassTimeSteps; - public LastTimeStepVertex(ComputationGraph graph, String name, int vertexIndex, String inputName) { - this(graph, name, vertexIndex, null, null, inputName); + public LastTimeStepVertex(ComputationGraph graph, String name, int vertexIndex, String inputName, DataType dataType) { + this(graph, name, vertexIndex, null, null, inputName, dataType); } public LastTimeStepVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, String inputName) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, String inputName, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.inputName = inputName; this.inputIdx = graph.getConfiguration().getNetworkInputs().indexOf(inputName); if (inputIdx == -1) @@ -94,7 +95,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { fwdPassTimeSteps = null; //Null -> last time step for all examples } else { val outShape = new long[] {inputs[0].size(0), inputs[0].size(1)}; - out = workspaceMgr.create(ArrayType.ACTIVATIONS, outShape); + out = workspaceMgr.create(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape); //Want the index of the last non-zero entry in the mask array. //Check a little here by using mulRowVector([0,1,2,3,...]) and argmax @@ -122,7 +123,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { //Allocate the appropriate sized array: - INDArray epsilonsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, fwdPassShape, 'f'); + INDArray epsilonsOut = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), fwdPassShape, 'f'); if (fwdPassTimeSteps == null) { //Last time step for all examples diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java index 60bff8fdc11e..d4f5a4bb81a2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -45,8 +46,8 @@ public class ReverseTimeSeriesVertex extends BaseGraphVertex { private final String inputName; private final int inputIdx; - public ReverseTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputName) { - super(graph, name, vertexIndex, null, null); + public ReverseTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputName, DataType dataType) { + super(graph, name, vertexIndex, null, null, dataType); this.inputName = inputName; @@ -136,7 +137,7 @@ private static INDArray revertTimeSeries(INDArray input, INDArray mask, LayerWor val m = input.size(2); // Create empty output - INDArray out = workspaceMgr.create(type, input.shape(), 'f'); + INDArray out = workspaceMgr.create(type, input.dataType(), input.shape(), 'f'); // Iterate over all samples for (int s = 0; s < n; s++) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index aababef74007..fdf12d2f8442 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -54,18 +55,15 @@ public abstract class AbstractLayer paramTable(boolean backpropParamsOnly) { } protected void applyMask(INDArray to) { - to.muliColumnVector(maskArray); + to.muliColumnVector(maskArray.castTo(to.dataType())); } @Override @@ -291,7 +289,7 @@ protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr works if(inputModificationAllowed){ result = input; } else { - result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.shape(), input.ordering()); + result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.dataType(), input.shape(), input.ordering()); } input = layerConf().getIDropout().applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java index 4e5e73a35df1..2d9e350b3d2d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -37,12 +38,8 @@ */ public class ActivationLayer extends AbstractLayer { - public ActivationLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public ActivationLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public ActivationLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 6c6939828966..405889f0e8d5 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.api.ConvexOptimizer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; @@ -56,13 +57,8 @@ public abstract class BaseLayer weightNoiseParams = new HashMap<>(); - public BaseLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public BaseLayer(NeuralNetConfiguration conf, INDArray input) { - this(conf); - this.input = input; + public BaseLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } public LayerConfT layerConf() { @@ -92,7 +88,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); - INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{W.size(0), delta.size(0)}, 'f'); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.dataType(), new long[]{W.size(0), delta.size(0)}, 'f'); if(hasLayerNorm()) { INDArray g = getParam(DefaultParamInitializer.GAIN_KEY); @@ -300,6 +296,8 @@ protected Pair preOutputWithPreNorm(boolean training, boolea INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); INDArray g = (hasLayerNorm() ? getParam(DefaultParamInitializer.GAIN_KEY) : null); + INDArray input = this.input.castTo(dataType); + //Input validation: if (input.rank() != 2 || input.columns() != W.rows()) { if (input.rank() != 2) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java index 27a390aa7553..7647361f8269 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.Solver; import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -59,12 +60,8 @@ public abstract class BaseOutputLayer backpropGradient(INDArray epsilon, LayerWorkspac INDArray delta = pair.getSecond(); INDArray w = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); - INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{w.size(0), delta.size(0)}, 'f'); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.dataType(), new long[]{w.size(0), delta.size(0)}, 'f'); epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose(); //Normally we would clear weightNoiseParams here - but we want to reuse them for forward + backward + score @@ -336,9 +333,9 @@ protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) protected void applyMask(INDArray to) { //For output layers: can be either per-example masking, or per- if (maskArray.isColumnVectorOrScalar()) { - to.muliColumnVector(maskArray); + to.muliColumnVector(maskArray.castTo(to.dataType())); } else if (Arrays.equals(to.shape(), maskArray.shape())) { - to.muli(maskArray); + to.muli(maskArray.castTo(to.dataType())); } else { throw new IllegalStateException("Invalid mask array: per-example masking should be a column vector, " + "per output masking arrays should be the same shape as the output/labels arrays. Mask shape: " diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BasePretrainNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BasePretrainNetwork.java index 91edb2ab4c04..e36e307cb88a 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BasePretrainNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BasePretrainNetwork.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.regularization.Regularization; @@ -45,12 +46,8 @@ public abstract class BasePretrainNetwork { - public BasePretrainNetwork(NeuralNetConfiguration conf) { - super(conf); - } - - public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public BasePretrainNetwork(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -62,8 +59,8 @@ public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) { * @return the binomial sampled corrupted input */ public INDArray getCorruptedInput(INDArray x, double corruptionLevel) { - INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1 - corruptionLevel).sample(x.shape()); - corrupted.muli(x.castTo(Nd4j.defaultFloatingPointType())); + INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1 - corruptionLevel).sample(x.ulike()); + corrupted.muli(x.castTo(corrupted.dataType())); return corrupted; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java index bdbf74117759..6af4025543d0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; @@ -30,12 +31,8 @@ */ public class DropoutLayer extends BaseLayer { - public DropoutLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public DropoutLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public DropoutLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -79,7 +76,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { if(inputModificationAllowed){ result = input; } else { - result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.shape(), input.ordering()); + result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.dataType(), input.shape(), input.ordering()); } ret = layerConf().getIDropout().applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java index 776aa7133a58..94fa17b598f4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.optimize.Solver; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -55,12 +56,8 @@ public class LossLayer extends BaseLayer { - public OutputLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public OutputLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public OutputLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) { - return workspaceMgr.castTo(arrayType, Nd4j.defaultFloatingPointType(), labels, false); + return labels; } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java index af3ab6291159..d802aad33018 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -45,12 +46,8 @@ */ public class RepeatVector extends AbstractLayer { - public RepeatVector(NeuralNetConfiguration conf) { - super(conf); - } - - public RepeatVector(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public RepeatVector(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java index cdc4229b71bb..6c70c7253c07 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.util.ConvolutionUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -62,8 +63,8 @@ public class Cnn3DLossLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspac if (labels == null) throw new IllegalStateException("Labels are not set (null)"); + Preconditions.checkState(input.equalShapes(labels), "Input and label arrays do not have same shape: %ndShape vs. %ndShape",input, labels); + INDArray input2d = ConvolutionUtils.reshape4dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray labels2d = ConvolutionUtils.reshape4dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray maskReshaped = ConvolutionUtils.reshapeMaskIfRequired(maskArray, input, workspaceMgr, ArrayType.FF_WORKING_MEM); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java index a84b987d4cb3..1ffd19062e38 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Broadcast; import org.nd4j.linalg.primitives.Pair; @@ -47,13 +48,10 @@ * @author dave@skymind.io */ public class Convolution1DLayer extends ConvolutionLayer { - public Convolution1DLayer(NeuralNetConfiguration conf) { - super(conf); + public Convolution1DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } - public Convolution1DLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); - } @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java index c7f6a6fe8240..f540d738b205 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.Convolution3DUtils; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -42,12 +43,8 @@ */ public class Convolution3DLayer extends ConvolutionLayer { - public Convolution3DLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public Convolution3DLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public Convolution3DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -98,7 +95,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray weightGradView = gradientViews.get(Convolution3DParamInitializer.WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), miniBatch * outEpsChannels * inD * inH * inW); if (isNCDHW) outEpsilon = outEpsilon.reshape('c', miniBatch, outEpsChannels, inD, inH, inW); @@ -242,8 +239,7 @@ protected Pair preOutput(boolean training, boolean forBackpr int outH = outSize[1]; int outW = outSize[2]; - INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, - miniBatch*outWeightChannels*outD*outH*outW); + INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(),miniBatch*outWeightChannels*outD*outH*outW); if (isNCDHW) output = output.reshape('c', miniBatch, outWeightChannels, outD, outH, outW); else diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index d696642115fe..1b1ed0910942 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -61,23 +62,18 @@ public class ConvolutionLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspac INDArray weights = getParamWithNoise(ConvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr); INDArray bias = getParamWithNoise(ConvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); + INDArray input = this.input.castTo(dataType); //No op if correct type + // FIXME: int cast int miniBatch = (int) input.size(0); int inH = (int) input.size(2); @@ -208,7 +206,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac //to get old order from required order: permute(0,3,4,5,1,2) INDArray im2col2d = p.getSecond(); //Re-use im2col2d array from forward pass if available; recalculate if not if (im2col2d == null) { - INDArray col = Nd4j.createUninitialized(new int[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); + INDArray col = Nd4j.createUninitialized(dataType, new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); INDArray col2 = col.permute(0, 3, 4, 5, 1, 2); Convolution.im2col(input, kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], convolutionMode == ConvolutionMode.Same, col2); @@ -234,7 +232,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac //Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW] //currently have [kH,kW,inDepth,outW,outH,miniBatch] -> permute first eps6d = eps6d.permute(5, 2, 1, 0, 4, 3); - INDArray epsNextOrig = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[] {inDepth, miniBatch, inH, inW}, 'c'); + INDArray epsNextOrig = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, eps6d.dataType(), new long[] {inDepth, miniBatch, inH, inW}, 'c'); //Note: we are execute col2im in a way that the output array should be used in a stride 1 muli in the layer below... (same strides as zs/activations) INDArray epsNext = epsNextOrig.permute(1, 0, 2, 3); @@ -308,6 +306,8 @@ protected Pair preOutput(boolean training, boolean forBackpr validateInputRank(); + INDArray input = this.input.castTo(dataType); + // FIXME: int cast int miniBatch = (int) input.size(0); int outDepth = (int) weights.size(0); @@ -387,7 +387,7 @@ protected Pair preOutput(boolean training, boolean forBackpr //To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that //to get old order from required order: permute(0,3,4,5,1,2) //Post reshaping: rows are such that minibatch varies slowest, outW fastest as we step through the rows post-reshape - INDArray col = Nd4j.createUninitialized(new int[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); + INDArray col = Nd4j.createUninitialized(weights.dataType(), new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); INDArray col2 = col.permute(0, 3, 4, 5, 1, 2); INDArray im2ColIn = input.castTo(col2.dataType()); //No op if already (for example) float Convolution.im2col(im2ColIn, kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], @@ -402,7 +402,7 @@ protected Pair preOutput(boolean training, boolean forBackpr INDArray reshapedW = permutedW.reshape('f', kW * kH * inDepth, outDepth); //Do the MMUL; c and f orders in, f order out. output shape: [miniBatch*outH*outW,depthOut] - INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{im2col2d.size(0), reshapedW.size(1)}, 'f'); + INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[]{im2col2d.size(0), reshapedW.size(1)}, 'f'); im2col2d.mmuli(reshapedW, z); //Add biases, before reshaping. Note that biases are [1,depthOut] and currently z is [miniBatch*outH*outW,depthOut] -> addiRowVector diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java index ef2887962db4..a1fc4d6d8221 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java @@ -25,6 +25,8 @@ import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -42,8 +44,8 @@ public class Cropping1DLayer extends AbstractLayer { private int[] cropping; //[padTop, padBottom] - public Cropping1DLayer(NeuralNetConfiguration conf) { - super(conf); + public Cropping1DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); this.cropping = ((org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D) conf.getLayer()).getCropping(); } @@ -65,8 +67,8 @@ public Type type() { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { val inShape = input.shape(); - INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inShape, 'c'); - INDArray epsNextSubset = inputSubset(epsNext, ArrayType.ACTIVATION_GRAD, workspaceMgr); + INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, dataType, inShape, 'c'); + INDArray epsNextSubset = epsNext.get(all(), all(), interval(cropping[0], epsNext.size(2)-cropping[1])); epsNextSubset.assign(epsilon); return new Pair<>((Gradient) new DefaultGradient(), epsNext); } @@ -80,7 +82,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { @Override public Layer clone() { - return new Cropping2DLayer(conf.clone()); + return new Cropping2DLayer(conf.clone(), dataType); } @Override @@ -89,6 +91,12 @@ public double calcRegularizationScore(boolean backpropParamsOnly){ } private INDArray inputSubset(INDArray from, ArrayType arrayType, LayerWorkspaceMgr workspaceMgr){ - return workspaceMgr.leverageTo(arrayType, from.get(all(), all(), interval(cropping[0], from.size(2)-cropping[1]))); + try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(arrayType)){ + if(from.dataType() == dataType){ + return from.get(all(), all(), interval(cropping[0], from.size(2)-cropping[1])).dup(from.ordering()); + } else { + return from.get(all(), all(), interval(cropping[0], from.size(2)-cropping[1])).castTo(dataType); + } + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index 30bf8f6e6eff..da2cf1629c13 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -40,8 +41,8 @@ public class Cropping2DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { val inShape = input.shape(); - INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inShape, 'c'); + INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), inShape, 'c'); INDArray epsNextSubset = inputSubset(epsNext); epsNextSubset.assign(epsilon); return new Pair<>((Gradient) new DefaultGradient(), epsNext); @@ -81,7 +82,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { @Override public Layer clone() { - return new Cropping2DLayer(conf.clone()); + return new Cropping2DLayer(conf.clone(), dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java index 612e11da0076..2a4040707152 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -41,8 +42,8 @@ public class Cropping3DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { val inShape = input.shape(); - INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inShape, 'c'); + INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), inShape, 'c'); INDArray epsNextSubset = inputSubset(epsNext); epsNextSubset.assign(epsilon); return new Pair<>((Gradient) new DefaultGradient(), epsNext); @@ -82,7 +83,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { @Override public Layer clone() { - return new Cropping3DLayer(conf.clone()); + return new Cropping3DLayer(conf.clone(), dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java index 87cb3a3af3a3..8a6c187c247f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution2DLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -54,12 +55,8 @@ */ public class Deconvolution2DLayer extends ConvolutionLayer { - public Deconvolution2DLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public Deconvolution2DLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public Deconvolution2DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -80,15 +77,14 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr); - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + long miniBatch = input.size(0); + long inH = input.size(2); + long inW = input.size(3); - int inDepth = (int) weights.size(0); + long inDepth = weights.size(0); - int kH = (int) weights.size(2); - int kW = (int) weights.size(3); + long kH = weights.size(2); + long kW = weights.size(3); int[] dilation = layerConf().getDilation(); int[] kernel = layerConf().getKernelSize(); @@ -97,7 +93,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int[] outSize; if (convolutionMode == ConvolutionMode.Same) { outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, null, convolutionMode, dilation); - pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {inH, inW}, kernel, strides, dilation); + pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int)inH, (int)inW}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation); @@ -106,12 +102,12 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray biasGradView = gradientViews.get(DeconvolutionParamInitializer.BIAS_KEY); INDArray weightGradView = gradientViews.get(DeconvolutionParamInitializer.WEIGHT_KEY); - INDArray outEps = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inDepth, inH, inW}, 'c'); + INDArray outEps = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { - kH, kW, strides[0], strides[1], + (int)kH, (int)kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], sameMode }; @@ -215,7 +211,7 @@ protected Pair preOutput(boolean training , boolean forBackp val miniBatch = input.size(0); - INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, new long[]{miniBatch, outDepth, outH, outW}, 'c'); + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java index f055b8aa3066..d183064f836b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -48,12 +49,8 @@ */ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { - public DepthwiseConvolution2DLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public DepthwiseConvolution2DLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public DepthwiseConvolution2DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -102,7 +99,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray weightGradView = gradientViews.get(DepthwiseConvolutionParamInitializer.WEIGHT_KEY); INDArray outEpsilon = workspaceMgr.create( - ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inDepth, inH, inW}, 'c'); + ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; @@ -209,7 +206,7 @@ protected Pair preOutput(boolean training, boolean forBackpr val miniBatch = input.size(0); INDArray output = workspaceMgr.create( - ArrayType.ACTIVATIONS, new long[]{miniBatch, outDepth, outH, outW}, 'c'); + ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index ea15ef1babbe..4511eb4f7e9d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.SeparableConvolutionParamInitializer; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -61,12 +62,8 @@ */ public class SeparableConvolution2DLayer extends ConvolutionLayer { - public SeparableConvolution2DLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public SeparableConvolution2DLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public SeparableConvolution2DLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -117,7 +114,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray depthWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY); INDArray pointWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inDepth, inH, inW}, 'c'); + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; @@ -232,7 +229,7 @@ protected Pair preOutput(boolean training , boolean forBackp int outW = outSize[1]; val miniBatch = input.size(0); - INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, new long[]{miniBatch, outDepth, outH, outW}, 'c'); + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index 0231f615444d..c0770161da20 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -55,12 +56,8 @@ @Slf4j public class SpaceToBatch extends AbstractLayer { - public SpaceToBatch(NeuralNetConfiguration conf) { - super(conf); - } - - public SpaceToBatch(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public SpaceToBatch(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } private int[] getBlocks() { @@ -92,13 +89,14 @@ public Type type() { public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inDepth = (int) input.size(1); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type) + + long miniBatch = input.size(0); + long inDepth = input.size(1); + long inH = input.size(2); + long inW = input.size(3); - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inDepth, inH, inW}, 'c'); + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); Gradient gradient = new DefaultGradient(); @@ -128,23 +126,22 @@ protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspa return preOutput; } - // FIXME: int cast - int inMiniBatch = (int) input.size(0); - int depth = (int) input.size(1); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + long inMiniBatch = input.size(0); + long depth = input.size(1); + long inH = input.size(2); + long inW = input.size(3); int[] blocks = getBlocks(); int[][] padding = getPadding(); - int paddedH = inH + padding[0][0] + padding[0][1]; - int paddedW = inW + padding[1][0] + padding[1][1]; + long paddedH = inH + padding[0][0] + padding[0][1]; + long paddedW = inW + padding[1][0] + padding[1][1]; - int outH = paddedH / blocks[0]; - int outW = paddedW / blocks[1]; - int outMiniBatch = inMiniBatch * blocks[0] * blocks[1]; + long outH = paddedH / blocks[0]; + long outW = paddedW / blocks[1]; + long outMiniBatch = inMiniBatch * blocks[0] * blocks[1]; - INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, new int[]{outMiniBatch, depth, outH, outW}, 'c'); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{outMiniBatch, depth, outH, outW}, 'c'); CustomOp op = DynamicCustomOp.builder("space_to_batch") .addInputs(input, getBlocksArray(), getPaddingArray()) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java index b54cd5019516..5516738fc878 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -54,12 +55,8 @@ @Slf4j public class SpaceToDepth extends AbstractLayer { - public SpaceToDepth(NeuralNetConfiguration conf) { - super(conf); - } - - public SpaceToDepth(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public SpaceToDepth(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } private int getBlockSize() { @@ -84,7 +81,9 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int inH = (int) input.size(2); int inW = (int) input.size(3); - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, new int[]{1, miniBatch * inDepth * inH * inW}, 'c'); + INDArray input = this.input.castTo(dataType); //No-op if already correct type + + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{1, miniBatch * inDepth * inH * inW}, 'c'); INDArray reshapedEpsilon; if (isNHWC() == 1) { @@ -135,7 +134,7 @@ protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspa int outW = inW / blockSize; int outDepth = depth * blockSize * blockSize; - INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, new int[]{1, miniBatch * outDepth * outH * outW}, 'c'); + INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{1, miniBatch * outDepth * outH * outW}, 'c'); INDArray reshapedOut; if (isNHWC() == 1) { reshapedOut = out.reshape('c', miniBatch, outH, outW, outDepth); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java index 08d72673ed30..f2ed6d6f1ed2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -39,8 +40,8 @@ public class ZeroPadding1DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspac + " array as epsilon for Subsampling1DLayer backprop with shape " + Arrays.toString(epsilon.shape()) + ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId()); - if(maskArray != null){ INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst(); Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1), @@ -73,7 +69,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac // add singleton fourth dimension to input and next layer's epsilon INDArray origInput = input; - input = input.reshape(input.size(0), input.size(1), input.size(2), 1); + input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1); epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1); // call 2D SubsamplingLayer's backpropGradient method @@ -96,7 +92,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { // add singleton fourth dimension to input INDArray origInput = input; - input = input.reshape(input.size(0), input.size(1), input.size(2), 1); + input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1); // call 2D SubsamplingLayer's activate method INDArray acts = super.activate(training, workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java index 8c673a4f22b7..b02baa4c3f74 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.Convolution3DUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -47,16 +48,12 @@ public class Subsampling3DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspac pad = layerConf().getPadding(); } - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, - isNCDHW ? new int[]{miniBatch, inChannels, inD, inH, inW} : new int[]{miniBatch, inD, inH, inW, inChannels}, 'c'); + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), + isNCDHW ? new long[]{miniBatch, inChannels, inD, inH, inW} : new long[]{miniBatch, inD, inH, inW, inChannels}, 'c'); int[] intArgs = new int[]{ @@ -179,8 +176,8 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { String opName = layerConf().getPoolingType() == PoolingType.MAX ? "maxpool3dnew" : "avgpool3dnew"; - INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, - isNCDHW ? new int[]{miniBatch, inChannels, outD, outH, outW} : new int[]{miniBatch, outD, outH, outW, inChannels}, 'c'); + INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), + isNCDHW ? new long[]{miniBatch, inChannels, outD, outH, outW} : new long[]{miniBatch, outD, outH, outW, inChannels}, 'c'); int[] intArgs = new int[]{ kernel[0], kernel[1], kernel[2], diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java index 85c521bc1a65..e0c6190d7131 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.mkldnn.MKLDNNSubsamplingHelper; import org.deeplearning4j.util.ConvolutionUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; @@ -62,24 +63,19 @@ public class SubsamplingLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + INDArray input = this.input.castTo(dataType); + // FIXME: int cast int miniBatch = (int) input.size(0); int inDepth = (int) input.size(1); @@ -139,7 +137,6 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int outH = outSize[0]; int outW = outSize[1]; - if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) { Pair ret = null; try{ @@ -206,12 +203,12 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray epsilon1d; if (cOrderStrides) { //"Dense/Output layer above strides... i.e., standard c-order strides - col6d = Nd4j.create(new int[] {miniBatch, inDepth, outH, outW, kernel[0], kernel[1]}, 'c'); + col6d = Nd4j.create(dataType, new long[] {miniBatch, inDepth, outH, outW, kernel[0], kernel[1]}, 'c'); col6dPermuted = col6d.permute(0, 1, 4, 5, 2, 3); epsilon1d = epsilon.reshape('c', ArrayUtil.prod(epsilon.length()), 1); //zero copy reshape } else { //"CNN layer above" strides... - col6d = Nd4j.create(new int[] {inDepth, miniBatch, outH, outW, kernel[0], kernel[1]}, 'c'); + col6d = Nd4j.create(dataType, new long[] {inDepth, miniBatch, outH, outW, kernel[0], kernel[1]}, 'c'); col6dPermuted = col6d.permute(1, 0, 4, 5, 2, 3); INDArray epsilonTemp = epsilon.permute(1, 0, 2, 3); @@ -276,7 +273,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac // c-order [channels*H*W, H*W, W, 1] strides //To achieve this: [channels, miniBatch, H, W] in c order, then permute to [miniBatch, channels, H, W] //This gives us proper strides of 1 on the muli... - INDArray tempEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, new int[] {inDepth, miniBatch, inH, inW}, 'c'); + INDArray tempEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, dataType, new long[] {inDepth, miniBatch, inH, inW}, 'c'); INDArray outEpsilon = tempEpsilon.permute(1, 0, 2, 3); Convolution.col2im(col6dPermuted, outEpsilon, strides[0], strides[1], pad[0], pad[1], inputHeight, inputWidth, dilation[0], dilation[1]); @@ -314,6 +311,8 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + layerId()); } + INDArray input = this.input.castTo(dataType); + // FIXME: int cast int miniBatch = (int) input.size(0); int inDepth = (int) input.size(1); @@ -360,7 +359,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { //Similar to convolution layer forward pass: do im2col, but permute so that pooling can be done with efficient strides... //Current im2col implementation expects input with shape [miniBatch,channels,kH,kW,outH,outW] - INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[]{miniBatch, inDepth, outH, outW}, 'c'); + INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c'); LegacyPooling2D.Pooling2DType pt; double extra = 0.0; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java index 8345bb57f08b..9f84e28959f6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java @@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.layers.BaseUpsamplingLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -46,12 +48,8 @@ public class Upsampling1D extends Upsampling2D { - public Upsampling1D(NeuralNetConfiguration conf) { - super(conf); - } - - public Upsampling1D(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public Upsampling1D(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @@ -65,7 +63,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac epsilon = epsilon.repeat(3, size[0]); INDArray originalInput = input; - input = input.reshape(input.size(0), input.size(1), input.size(2), 1); + input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1); // FIXME: int cast int miniBatch = (int) input.size(0); @@ -74,7 +72,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int inW = (int) input.size(3); - INDArray outEpsilon = Nd4j.create(miniBatch * inDepth * inH * inW); + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), miniBatch * inDepth * inH * inW); INDArray reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inDepth, inH, inW); int[] intArgs = new int[] {1}; // 1 is for NCHW @@ -111,7 +109,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { // add singleton fourth dimension to input INDArray origInput = input; - input = input.reshape(input.size(0), input.size(1), input.size(2), 1); + input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1); // call 2D SubsamplingLayer's activate method INDArray acts = super.activate(training, workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java index d8366c1e52f7..9e2a2d8143ec 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -47,12 +48,8 @@ public class Upsampling2D extends AbstractLayer { - public Upsampling2D(NeuralNetConfiguration conf) { - super(conf); - } - - public Upsampling2D(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public Upsampling2D(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -71,7 +68,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int inH = (int) input.size(2); int inW = (int) input.size(3); - INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inDepth, inH, inW}, 'c'); + INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); Gradient gradient = new DefaultGradient(); @@ -119,7 +116,7 @@ protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspa int outH = inH * size[0]; int outW = inW * size[1]; - INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[]{miniBatch, inDepth, outH, outW}, 'c'); + INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c'); int[] intArgs = new int[] {size[0], size[1], 1}; // 1 is for NCHW diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java index 09273c7caa33..0c35d9fcd924 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -47,15 +48,10 @@ public class Upsampling3D extends AbstractLayer { - public Upsampling3D(NeuralNetConfiguration conf) { - super(conf); + public Upsampling3D(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } - public Upsampling3D(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); - } - - @Override public double calcRegularizationScore(boolean backpropParamsOnly){ return 0; @@ -82,7 +78,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac int[] intArgs = new int[] {1}; // 1 is channels first INDArray reshapedEpsilon = workspaceMgr.createUninitialized( - ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, inChannels, inD, inH, inW}, 'c'); + ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inChannels, inD, inH, inW}, 'c'); Gradient gradient = new DefaultGradient(); @@ -119,22 +115,21 @@ protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspa return preOutput; } - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inChannels = (int) input.size(1); - int inD = (int) input.size(2); - int inH = (int) input.size(3); - int inW = (int) input.size(4); + long miniBatch = (int) input.size(0); + long inChannels = (int) input.size(1); + long inD = (int) input.size(2); + long inH = (int) input.size(3); + long inW = (int) input.size(4); int[] size = getSize(); - int outD = inD * size[0]; - int outH = inH * size[1]; - int outW = inW * size[2]; + long outD = inD * size[0]; + long outH = inH * size[1]; + long outW = inW * size[2]; int[] intArgs = new int[] {size[0], size[1], size[2], 1}; // 1 is channels first INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, - new int[]{miniBatch, inChannels, outD, outH, outW}, 'c'); + input.dataType(), new long[]{miniBatch, inChannels, outD, outH, outW}, 'c'); CustomOp upsampling = DynamicCustomOp.builder("upsampling3d") diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java index a8b4873a1680..b35d946aad09 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationPReLU; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -45,12 +46,8 @@ public class PReLU extends BaseLayer { - public AutoEncoder(NeuralNetConfiguration conf) { - super(conf); - } - - public AutoEncoder(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public AutoEncoder(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -59,7 +56,7 @@ public Pair sampleVisibleGivenHidden(INDArray h) { public INDArray encode(INDArray v, boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray W = getParamWithNoise(PretrainParamInitializer.WEIGHT_KEY, training, workspaceMgr); INDArray hBias = getParamWithNoise(PretrainParamInitializer.BIAS_KEY, training, workspaceMgr); - INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, v.size(0), W.size(1)); + INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, W.dataType(), v.size(0), W.size(1)); INDArray preAct = v.castTo(W.dataType()).mmuli(W, ret).addiRowVector(hBias); ret = layerConf().getActivationFn().getActivation(preAct, training); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java index b95420631643..fa0b893b3748 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java @@ -18,6 +18,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -25,12 +26,8 @@ * @author Adam Gibson */ public class DenseLayer extends BaseLayer { - public DenseLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public DenseLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public DenseLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java index 8fe48b913983..f7a05ba5c60e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -43,12 +44,8 @@ */ public class ElementWiseMultiplicationLayer extends BaseLayer { - public ElementWiseMultiplicationLayer(NeuralNetConfiguration conf){ - super(conf); - } - - public ElementWiseMultiplicationLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public ElementWiseMultiplicationLayer(NeuralNetConfiguration conf, DataType dataType){ + super(conf, dataType); } @Override @@ -61,6 +58,8 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac applyMask(delta); } + INDArray input = this.input.castTo(dataType); + Gradient ret = new DefaultGradient(); INDArray weightGrad = gradientViews.get(ElementWiseParamInitializer.WEIGHT_KEY); @@ -105,9 +104,11 @@ public INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { + W.shapeInfoToString() + ") " + layerId()); } + INDArray input = this.input.castTo(dataType); + applyDropOutIfNecessary(training, workspaceMgr); - INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.shape(), 'c'); + INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c'); ret.assign(input.mulRowVector(W).addiRowVector(b)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayer.java index f230ac31627b..fb0cdef0cf06 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayer.java @@ -18,6 +18,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.custom.ScatterUpdate; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.exception.DL4JInvalidInputException; @@ -45,8 +46,8 @@ public class EmbeddingLayer extends BaseLayer { private static final int[] DIM_1 = new int[]{1}; - public EmbeddingLayer(NeuralNetConfiguration conf) { - super(conf); + public EmbeddingLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -57,7 +58,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //TODO handle activation function params if (maskArray != null) { - delta.muliColumnVector(maskArray); + delta.muliColumnVector(maskArray.castTo(dataType)); } INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); @@ -113,7 +114,7 @@ protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray weights = getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray bias = getParam(DefaultParamInitializer.BIAS_KEY); - INDArray destination = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.size(0), weights.size(1)); + INDArray destination = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), input.size(0), weights.size(1)); INDArray rows = Nd4j.pullRows(weights, destination, 1, indexes); if(hasBias()){ rows.addiRowVector(bias); @@ -128,7 +129,7 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray ret = layerConf().getActivationFn().getActivation(rows, training); if (maskArray != null) { - ret.muliColumnVector(maskArray); + ret.muliColumnVector(maskArray.castTo(dataType)); } return ret; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java index 216f7c755fbc..5da9bdece131 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.custom.ScatterUpdate; import org.nd4j.linalg.factory.Broadcast; @@ -51,8 +52,8 @@ public class EmbeddingSequenceLayer extends BaseLayer { private static final int[] WEIGHT_DIM = new int[]{1}; - public EmbeddingSequenceLayer(NeuralNetConfiguration conf) { - super(conf); + public EmbeddingSequenceLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } private int[] indexes; @@ -149,7 +150,7 @@ protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { val nOut = layerConf().getNOut(); INDArray destination = workspaceMgr.createUninitialized( - ArrayType.ACTIVATIONS, new long[]{minibatch * inputLength, nOut}, 'c'); + ArrayType.ACTIVATIONS, weights.dataType(), new long[]{minibatch * inputLength, nOut}, 'c'); INDArray rows = Nd4j.pullRows(weights, destination, 1, indexes); if (hasBias()) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index f879028c275a..1dcc556b6777 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -46,6 +46,10 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { private INDArray meanCache; private INDArray varCache; + public MKLDNNBatchNormHelper(DataType dataType){ + + } + @Override public boolean checkSupported(double eps, boolean fixedGammaBeta) { return !fixedGammaBeta && BaseMKLDNNHelper.mklDnnEnabled(); @@ -131,12 +135,12 @@ public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray ga } @Override - public INDArray getMeanCache() { + public INDArray getMeanCache(DataType dataType) { return meanCache; } @Override - public INDArray getVarCache() { + public INDArray getVarCache(DataType dataType) { return varCache; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index b0b529cb280d..2884f4cedddc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -48,6 +48,10 @@ public class MKLDNNConvHelper implements ConvolutionHelper { protected OpContext context; protected OpContext contextBwd; + public MKLDNNConvHelper(DataType dataType){ + + } + @Override public boolean checkSupported() { return BaseMKLDNNHelper.mklDnnEnabled(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java index cc828aa290a9..1000065f3613 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization; @@ -41,6 +42,10 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp protected OpContext context; + public MKLDNNLocalResponseNormalizationHelper(DataType dataType){ + + } + @Override public boolean checkSupported(double k, double n, double alpha, double beta) { return BaseMKLDNNHelper.mklDnnEnabled(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java index e3f155dcb272..ff4b52ac53ac 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.ConvolutionUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.OpContext; @@ -48,6 +49,10 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { protected OpContext context; + public MKLDNNSubsamplingHelper(DataType dataType){ + + } + @Override public boolean checkSupported() { return BaseMKLDNNHelper.mklDnnEnabled(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index d28276d903ac..7f76c2459ca3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp; @@ -66,8 +67,8 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspac val batchSize = epsilon.size(0); // number examples in batch org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf(); + INDArray input = this.input.castTo(dataType); //No-op if correct type + INDArray globalMean = params.get(BatchNormalizationParamInitializer.GLOBAL_MEAN); INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); @@ -121,8 +124,8 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac INDArray dGlobalLog10StdView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); if (layerConf.isLockGammaBeta()) { val tempShape = new long[] {1, shape[1]}; - dGammaView = Nd4j.createUninitialized(tempShape, 'c'); - dBetaView = Nd4j.createUninitialized(tempShape, 'c'); + dGammaView = Nd4j.createUninitialized(dataType, tempShape, 'c'); + dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c'); } else { gamma = getParam(BatchNormalizationParamInitializer.GAMMA); dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA); @@ -135,7 +138,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())){ //Note that cudnn does not support dense (2d) batch norm case as of v5.1 if (layerConf.isLockGammaBeta()) { - gamma = Nd4j.valueArrayOf(new long[] {1, shape[1]}, layerConf.getGamma()); + gamma = Nd4j.createUninitialized(dataType, 1, shape[1]).assign(layerConf.getGamma()); } INDArray in; @@ -196,8 +199,8 @@ However, because of distributed training (gradient sharing), we don't want to do These make zero difference for local training (other than perhaps when using FP16), but the latter is more numerically stable and is scaled better for distributed training */ - INDArray batchMean = helper.getMeanCache(); - INDArray batchVar = helper.getVarCache(); + INDArray batchMean = helper.getMeanCache(dataType); + INDArray batchVar = helper.getVarCache(dataType); Nd4j.getExecutioner().exec(new OldSubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean dGlobalMeanView.muli(1-layerConf().getDecay()); @@ -207,7 +210,7 @@ These make zero difference for local training (other than perhaps when using FP1 //First: we have log10(var[i]) from last iteration, hence can calculate var[i] and stdev[i] //Need to calculate log10{std[i]) - log10(std[i+1]) as the "update" //Note, var[i+1] = d*var[i] + (1-d)*batchVar - INDArray vari = Nd4j.valueArrayOf(globalLog10Std.shape(), 10.0); + INDArray vari = Nd4j.createUninitialized(dataType, globalLog10Std.shape()).assign(10.0); Transforms.pow(vari, globalLog10Std, false); //variance = (10^log10(s))^2 vari.muli(vari); @@ -230,11 +233,11 @@ These make zero difference for local training (other than perhaps when using FP1 INDArray batchVar; if (epsilon.rank() == 2) { if(xHat == null && helper != null){ - INDArray mean = helper.getMeanCache(); - std = Transforms.sqrt(helper.getVarCache().addi(layerConf().getEps())); - xMu = Nd4j.createUninitialized(input.shape(), input.ordering()); + INDArray mean = helper.getMeanCache(dataType); + std = Transforms.sqrt(helper.getVarCache(dataType).addi(layerConf().getEps())); + xMu = Nd4j.createUninitialized(dataType, input.shape(), input.ordering()); xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(input, mean, xMu, 1)); - xHat = Nd4j.createUninitialized(input.shape(), input.ordering()); + xHat = Nd4j.createUninitialized(dataType, input.shape(), input.ordering()); xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(xMu, std,xHat, 1)); } @@ -276,11 +279,11 @@ These make zero difference for local training (other than perhaps when using FP1 batchVar = input.var(false, 0); } else if (epsilon.rank() == 4) { if(xHat == null && helper != null){ - INDArray mean = helper.getMeanCache(); - std = Transforms.sqrt(helper.getVarCache().addi(layerConf().getEps())); - xMu = Nd4j.createUninitialized(input.shape(), input.ordering()); + INDArray mean = helper.getMeanCache(dataType); + std = Transforms.sqrt(helper.getVarCache(dataType).addi(layerConf().getEps())); + xMu = Nd4j.createUninitialized(dataType, input.shape(), input.ordering()); xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(input, mean, xMu, 1)); - xHat = Nd4j.createUninitialized(input.shape(), input.ordering()); + xHat = Nd4j.createUninitialized(dataType, input.shape(), input.ordering()); xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(xMu, std,xHat, 1)); } @@ -292,7 +295,7 @@ These make zero difference for local training (other than perhaps when using FP1 } else { //Standard case dxhat = Nd4j.getExecutioner().exec(new BroadcastMulOp(epsilon, gamma, - Nd4j.createUninitialized(epsilon.shape(), epsilon.ordering()), 1)); + Nd4j.createUninitialized(epsilon.dataType(), epsilon.shape(), epsilon.ordering()), 1)); } //dL/dVariance @@ -345,7 +348,7 @@ However, because of distributed training (gradient sharing), we don't want to do //First: we have log10(var[i]) from last iteration, hence can calculate var[i] and stdev[i] //Need to calculate log10{std[i]) - log10(std[i+1]) as the "update" //Note, var[i+1] = d*var[i] + (1-d)*batchVar - INDArray vari = Nd4j.valueArrayOf(globalLog10Std.shape(), 10.0); + INDArray vari = Nd4j.valueArrayOf(globalLog10Std.shape(), 10.0, globalMean.dataType()); Transforms.pow(vari, globalLog10Std, false); //variance = (10^log10(s))^2 vari.muli(vari); @@ -398,6 +401,8 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w throw new IllegalArgumentException("input.size(1) does not match expected input size of " + layerConf().getNIn() + " - got input array with shape " + Arrays.toString(x.shape())); } + x = x.castTo(dataType); //No-op if correct type + INDArray activations; // TODO add this directly in layer or get the layer prior... // batchnorm true but need to clarify if activation before or after @@ -413,8 +418,8 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w if (helper != null && input.rank() == 4) { //TODO: don't create these each iteration, when using cudnn val gammaBetaShape = new long[] {1, layerConf().getNOut()}; - gamma = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getGamma()); - beta = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getBeta()); + gamma = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getGamma(), dataType); + beta = Nd4j.valueArrayOf(gammaBetaShape, layerConf().getBeta(), dataType); } } else { gamma = getParam(BatchNormalizationParamInitializer.GAMMA); @@ -436,7 +441,7 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w if(globalVarView == null){ //May be null when useLogStd is true INDArray log10s = getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); - globalVarView = Transforms.pow(Nd4j.valueArrayOf(log10s.shape(), 10.0), log10s, false); + globalVarView = Transforms.pow(Nd4j.valueArrayOf(log10s.shape(), 10.0, dataType), log10s, false); globalVarView.muli(globalVarView); } @@ -495,7 +500,7 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w if(layerConf().isUseLogStd()){ //var = (10^(log10(s)))^2 INDArray log10s = getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); - var = Transforms.pow(Nd4j.valueArrayOf(log10s.shape(), 10.0), log10s); + var = Transforms.pow(Nd4j.valueArrayOf(log10s.shape(), 10.0, dataType), log10s); var.muli(var); } else { var = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR); @@ -527,9 +532,9 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w } else if (x.rank() == 4) { if (!Shape.strideDescendingCAscendingF(x)) x = x.dup(); //TODO: temp Workaround for broadcast bug. To be removed when fixed - xMu = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering()); + xMu = workspaceMgr.createUninitialized(ArrayType.INPUT, x.dataType(), x.shape(), x.ordering()); xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(x, mean,xMu, 1)); - xHat = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering()); + xHat = workspaceMgr.createUninitialized(ArrayType.INPUT, x.dataType(), x.shape(), x.ordering()); xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(xMu, std,xHat, 1)); if (layerConf.isLockGammaBeta()) { @@ -545,7 +550,7 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w } } else { //Standard case: gamma and beta are learned per parameter - activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.shape(), x.ordering()); + activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), x.shape(), x.ordering()); activations = Nd4j.getExecutioner().exec(new BroadcastMulOp(xHat, gamma, activations, 1)); activations = Nd4j.getExecutioner().exec(new BroadcastAddOp(activations, beta, activations, 1)); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationHelper.java index e8cf5360df99..8190c32845ea 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationHelper.java @@ -18,6 +18,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.LayerHelper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -36,7 +37,7 @@ Pair backpropGradient(INDArray input, INDArray epsilon, int[ INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); - INDArray getMeanCache(); + INDArray getMeanCache(DataType dataType); - INDArray getVarCache(); + INDArray getVarCache(DataType dataType); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index 119ba330e4ed..fdf594aa1e1c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLocalResponseNormalizationHelper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.factory.Nd4j; @@ -73,18 +74,13 @@ public class LocalResponseNormalization protected LocalResponseNormalizationHelper helper = null; protected int helperCountFail = 0; - public LocalResponseNormalization(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); - initializeHelper(); - } - @Override public Layer clone() { - return new LocalResponseNormalization(conf.clone()); + return new LocalResponseNormalization(conf.clone(), dataType); } - public LocalResponseNormalization(NeuralNetConfiguration conf) { - super(conf); + public LocalResponseNormalization(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); initializeHelper(); } @@ -93,7 +89,7 @@ void initializeHelper() { if("CUDA".equalsIgnoreCase(backend)) { try { helper = Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnLocalResponseNormalizationHelper") - .asSubclass(LocalResponseNormalizationHelper.class).newInstance(); + .asSubclass(LocalResponseNormalizationHelper.class).getConstructor(DataType.class).newInstance(dataType); log.debug("CudnnLocalResponseNormalizationHelper successfully initialized"); } catch (Throwable t) { if (!(t instanceof ClassNotFoundException)) { @@ -187,7 +183,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac } // gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops - INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.shape(), epsilon.ordering()); + INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering()); Nd4j.getExecutioner().exec(new OldMulOp(epsilon, scale, nextEpsilon)); nextEpsilon.subi(sumPart.muli(input).divi(unitScale).muli(2 * alpha * beta)); return new Pair<>(retGradient, nextEpsilon); @@ -250,7 +246,7 @@ private Triple activateHelper(boolean training, Laye INDArray unitScale = null; INDArray scale = null; - INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.shape(), input.ordering()); + INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), input.ordering()); if(forBackprop) { // unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 ) unitScale = sumPart.mul(alpha).addi(k); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index 6d3b65d37566..2610603bc2b0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -94,8 +94,8 @@ public class Yolo2OutputLayer extends AbstractLayer dLc/diou = 2*1^(obj)*(iou-predicted) + 2 * lambdaNoObj * 1^(noobj) * (iou-predicted) = 2*(iou-predicted) * (1^(obj) + lambdaNoObj * 1^(noobj)) INDArray twoIOUSubPredicted = iou.subi(predictedConfidence).muli(2.0); //Shape: [mb, b, h, w]. Note that when an object is present, IOU and confidence are the same. In-place to avoid copy op (iou no longer needed) - INDArray dLc_dIOU = twoIOUSubPredicted.muli(mask1_ij_noobj.castTo(Nd4j.defaultFloatingPointType()).muli(lambdaNoObj).addi(mask1_ij_obj)); + INDArray dLc_dIOU = twoIOUSubPredicted.muli(mask1_ij_noobj.castTo(input.dataType()).muli(lambdaNoObj).addi(mask1_ij_obj)); - INDArray dLc_dxy = Nd4j.createUninitialized(iouRet.dIOU_dxy.shape(), iouRet.dIOU_dxy.ordering()); + INDArray dLc_dxy = Nd4j.createUninitialized(iouRet.dIOU_dxy.dataType(), iouRet.dIOU_dxy.shape(), iouRet.dIOU_dxy.ordering()); Broadcast.mul(iouRet.dIOU_dxy, dLc_dIOU, dLc_dxy, 0, 1, 3, 4); //[mb, b, h, w] x [mb, b, 2, h, w] - INDArray dLc_dwh = Nd4j.createUninitialized(iouRet.dIOU_dwh.shape(), iouRet.dIOU_dwh.ordering()); + INDArray dLc_dwh = Nd4j.createUninitialized(iouRet.dIOU_dwh.dataType(), iouRet.dIOU_dwh.shape(), iouRet.dIOU_dwh.ordering()); Broadcast.mul(iouRet.dIOU_dwh, dLc_dIOU, dLc_dwh, 0, 1, 3, 4); //[mb, b, h, w] x [mb, b, 2, h, w] @@ -436,16 +438,16 @@ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labe int gridW = (int) labelTL.size(3); //Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image, // from (0 to 1 in grid cell) format) - INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, Nd4j.dataType()); - INDArray linspaceY = Nd4j.linspace(0, gridH-1, gridH, Nd4j.dataType()); - INDArray grid = Nd4j.createUninitialized(new int[]{2, gridH, gridW}, 'c'); + INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType()); + INDArray linspaceY = Nd4j.linspace(0, gridH-1, gridH, predictedWH.dataType()); + INDArray grid = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{2, gridH, gridW}, 'c'); INDArray gridX = grid.get(point(0), all(), all()); INDArray gridY = grid.get(point(1), all(), all()); Broadcast.copy(gridX, linspaceX, gridX, 1); Broadcast.copy(gridY, linspaceY, gridY, 0); //Calculate X/Y position overall (in grid box units) from "position in current grid box" format - INDArray predictedXY = Nd4j.createUninitialized(predictedXYinGridBox.shape(), predictedXYinGridBox.ordering()); + INDArray predictedXY = predictedXYinGridBox.ulike();; Broadcast.add(predictedXYinGridBox, grid, predictedXY, 2,3,4); // [2, H, W] to [mb, b, 2, H, W] @@ -453,9 +455,9 @@ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labe INDArray predictedTL_XY = halfWH.rsub(predictedXY); //xy - 0.5 * wh INDArray predictedBR_XY = halfWH.add(predictedXY); //xy + 0.5 * wh - INDArray maxTL = Nd4j.createUninitialized(predictedTL_XY.shape(), predictedTL_XY.ordering()); //Shape: [mb, b, 2, H, W] + INDArray maxTL = predictedTL_XY.ulike(); //Shape: [mb, b, 2, H, W] Broadcast.max(predictedTL_XY, labelTL, maxTL, 0, 2, 3, 4); - INDArray minBR = Nd4j.createUninitialized(predictedBR_XY.shape(), predictedBR_XY.ordering()); + INDArray minBR = predictedBR_XY.ulike(); Broadcast.min(predictedBR_XY, labelBR, minBR, 0, 2, 3, 4); INDArray diff = minBR.sub(maxTL); @@ -479,7 +481,7 @@ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labe Broadcast.mul(intMask, objectPresentMaskBool, intMask, 0, 2, 3); //Mask the intersection area: should be 0 if no intersection - intMask = intMask.castTo(Nd4j.defaultFloatingPointType()); + intMask = intMask.castTo(predictedWH.dataType()); intersectionArea.muli(intMask); @@ -501,11 +503,11 @@ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labe //Finally, calculate derivatives: INDArray maskMaxTL = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering()); //1 if predicted Top/Left is max, 0 otherwise Broadcast.gt(predictedTL_XY, labelTL, maskMaxTL, 0, 2, 3, 4); // z = x > y - maskMaxTL = maskMaxTL.castTo(Nd4j.defaultFloatingPointType()); + maskMaxTL = maskMaxTL.castTo(predictedWH.dataType()); INDArray maskMinBR = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering()); //1 if predicted Top/Left is max, 0 otherwise Broadcast.lt(predictedBR_XY, labelBR, maskMinBR, 0, 2, 3, 4); // z = x < y - maskMinBR = maskMinBR.castTo(Nd4j.defaultFloatingPointType()); + maskMinBR = maskMinBR.castTo(predictedWH.dataType()); //dI/dx = lambda * (1^(min(x1+w1/2) - 1^(max(x1-w1/2)) //dI/dy = omega * (1^(min(y1+h1/2) - 1^(max(y1-h1/2)) @@ -526,18 +528,18 @@ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labe INDArray uPlusIDivU2 = uPlusI.div(u2); //Shape: [mb, b, h, w] BooleanIndexing.replaceWhere(uPlusIDivU2, 0.0, Conditions.isNan()); //Handle 0/0 - INDArray dIOU_dxy = Nd4j.createUninitialized(new int[]{mb, b, 2, h, w}, 'c'); + INDArray dIOU_dxy = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{mb, b, 2, h, w}, 'c'); Broadcast.mul(dI_dxy, uPlusIDivU2, dIOU_dxy, 0, 1, 3, 4); //[mb, b, h, w] x [mb, b, 2, h, w] - INDArray predictedHW = Nd4j.createUninitialized(new int[]{mb, b, 2, h, w}, predictedWH.ordering()); + INDArray predictedHW = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{mb, b, 2, h, w}, predictedWH.ordering()); //Next 2 lines: permuting the order... WH to HW along dimension 2 predictedHW.get(all(), all(), point(0), all(), all()).assign(predictedWH.get(all(), all(), point(1), all(), all())); predictedHW.get(all(), all(), point(1), all(), all()).assign(predictedWH.get(all(), all(), point(0), all(), all())); - INDArray Ihw = Nd4j.createUninitialized(predictedHW.shape(), predictedHW.ordering()); + INDArray Ihw = predictedHW.ulike();; Broadcast.mul(predictedHW, intersectionArea, Ihw, 0, 1, 3, 4 ); //Predicted_wh: [mb, b, 2, h, w]; intersection: [mb, b, h, w] - INDArray dIOU_dwh = Nd4j.createUninitialized(new int[]{mb, b, 2, h, w}, 'c'); + INDArray dIOU_dwh = Nd4j.createUninitialized(predictedHW.dataType(), new long[]{mb, b, 2, h, w}, 'c'); Broadcast.mul(dI_dwh, uPlusI, dIOU_dwh, 0, 1, 3, 4); dIOU_dwh.subi(Ihw); Broadcast.div(dIOU_dwh, u2, dIOU_dwh, 0, 1, 3, 4); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/YoloUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/YoloUtils.java index 4a787287c8c2..6bff5f207021 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/YoloUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/YoloUtils.java @@ -51,7 +51,7 @@ public static INDArray activate(INDArray boundingBoxPriors, INDArray input, Laye int b = (int) boundingBoxPriors.size(0); int c = (int) (input.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5 - INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.shape(), 'c'); + INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c'); INDArray output5 = output.reshape('c', mb, b, 5+c, h, w); INDArray output4 = output; //output.get(all(), interval(0,5*b), all(), all()); INDArray input4 = input.dup('c'); //input.get(all(), interval(0,5*b), all(), all()).dup('c'); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java index a0fc62a6af18..5dda9e697cea 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationReLU; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Broadcast; @@ -63,20 +64,13 @@ public class OCNNOutputLayer extends BaseOutputLayer backpropGradient(INDArray epsilon, LayerWorkspac long inputShape = (( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) this.getConf().getLayer()).getNIn(); INDArray delta = pair.getSecond(); //4 x 150 - INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{inputShape, delta.length()}, 'f'); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{inputShape, delta.length()}, 'f'); epsilonNext = epsilonNext.assign(delta.broadcast(epsilonNext.shape())).transpose(); //Normally we would clear weightNoiseParams here - but we want to reuse them for forward + backward + score @@ -183,7 +177,7 @@ else if(conf.getLastEpochSinceRUpdated() != epochCount) { //dG -> sigmoid derivative INDArray firstVertDerivV = layerConf().getActivationFn() - .backprop(xTimesV.dup(),Nd4j.ones(xTimesV.shape())) + .backprop(xTimesV.dup(),Nd4j.ones(input.dataType(), xTimesV.shape())) .getFirst().muliRowVector(getParam(W_KEY).neg()); firstVertDerivV = firstVertDerivV.muliColumnVector(delta) .reshape('f',input.size(0),1,layerConf().getHiddenSize()); @@ -195,7 +189,7 @@ else if(conf.getLastEpochSinceRUpdated() != epochCount) { shape[i] = Math.max(firstVertDerivV.size(i),secondTermDerivV.size(i)); } - INDArray firstDerivVBroadcast = Nd4j.createUninitialized(shape); + INDArray firstDerivVBroadcast = Nd4j.createUninitialized(input.dataType(), shape); INDArray mulResult = firstVertDerivV.broadcast(firstDerivVBroadcast); int[] bcDims = {0,1}; @@ -257,10 +251,10 @@ private INDArray doOutput(boolean training,LayerWorkspaceMgr workspaceMgr) { INDArray v = getParamWithNoise(V_KEY,training,workspaceMgr); applyDropOutIfNecessary(training, workspaceMgr); - INDArray first = Nd4j.createUninitialized(input.size(0), v.size(1)); + INDArray first = Nd4j.createUninitialized(input.dataType(), input.size(0), v.size(1)); input.mmuli(v, first); INDArray act2d = layerConf().getActivationFn().getActivation(first, training); - INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,input.size(0)); + INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.size(0)); act2d.mmuli(w.reshape(w.length()), output); this.labels = output; return output; @@ -320,7 +314,7 @@ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivati @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { INDArray preAct = preOutput.rsub(getParam(R_KEY).getDouble(0)); - INDArray target = relu.backprop(preAct,Nd4j.ones(preAct.shape())).getFirst(); + INDArray target = relu.backprop(preAct,Nd4j.ones(preOutput.dataType(), preAct.shape())).getFirst(); return target; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java index 3b5d1e3004ba..e9e0746e01e4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -105,6 +106,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi Map params = Collections.synchronizedMap(new LinkedHashMap()); val nIn = ocnnOutputLayer.getNIn(); int hiddenLayer = ocnnOutputLayer.getHiddenSize(); + Preconditions.checkState(hiddenLayer > 0, "OCNNOutputLayer hidden layer state: must be non-zero."); val firstLayerWeightLength = hiddenLayer; val secondLayerLength = nIn * hiddenLayer; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java index e8282d642134..413fa0983f78 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.util.MaskedReductionUtil; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; @@ -76,8 +77,8 @@ public class GlobalPoolingLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspac epsilon = epsilon.reshape(epsilon.ordering(), origShape[0], origShape[1]); } + INDArray input = this.input.castTo(dataType); //No-op if already correct dtype + Gradient retGradient = new DefaultGradient(); //Empty: no params int[] poolDim = null; @@ -267,7 +270,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonTimeSeries(poolingType, input, maskArray, epsilon, pNorm); } else if (input.rank() == 4) { - epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonCnn(poolingType, input, maskArray, epsilon, pNorm); + epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonCnn(poolingType, input, maskArray, epsilon, pNorm, dataType); } else { throw new UnsupportedOperationException(layerId()); } @@ -301,13 +304,13 @@ private INDArray epsilonHelperFullArray(INDArray inputArray, INDArray epsilon, i for (int d : poolDim) { n *= inputArray.size(d); } - INDArray ret = Nd4j.create(inputArray.shape()); + INDArray ret = inputArray.ulike(); Nd4j.getExecutioner().exec(new BroadcastCopyOp(ret, epsilon, ret, broadcastDims)); ret.divi(n); return ret; case SUM: - INDArray retSum = Nd4j.create(inputArray.shape()); + INDArray retSum = inputArray.ulike(); Nd4j.getExecutioner().exec(new BroadcastCopyOp(retSum, epsilon, retSum, broadcastDims)); return retSum; case PNORM: diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java index bef14b7255d4..1590adcc7e85 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java @@ -19,6 +19,7 @@ import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.HashMap; @@ -40,12 +41,8 @@ public abstract class BaseRecurrentLayer tBpttStateMap = new ConcurrentHashMap<>(); - public BaseRecurrentLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public BaseRecurrentLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java index 260c7827856e..dabd417ec0b2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java @@ -285,8 +285,8 @@ public void setParams(INDArray params) { public void setParamsViewArray(INDArray params) { this.paramsView = params; val n = params.length(); - fwd.setParamsViewArray(params.get(point(0), interval(0, n))); - bwd.setParamsViewArray(params.get(point(0), interval(n, 2*n))); + fwd.setParamsViewArray(params.get(interval(0, 0, true), interval(0, n))); + bwd.setParamsViewArray(params.get(interval(0, 0, true), interval(n, 2*n))); } @Override @@ -302,8 +302,8 @@ public void setBackpropGradientsViewArray(INDArray gradients) { this.gradientView = gradients; val n = gradients.length() / 2; - INDArray g1 = gradients.get(point(0), interval(0,n)); - INDArray g2 = gradients.get(point(0), interval(n, 2*n)); + INDArray g1 = gradients.get(interval(0, 0, true), interval(0,n)); + INDArray g2 = gradients.get(interval(0, 0, true), interval(n, 2*n)); fwd.setBackpropGradientsViewArray(g1); bwd.setBackpropGradientsViewArray(g2); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index 164810f3e361..78e15e167c1d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -57,12 +58,8 @@ public class GravesBidirectionalLSTM protected FwdPassReturn cachedPassForward; protected FwdPassReturn cachedPassBackward; - public GravesBidirectionalLSTM(NeuralNetConfiguration conf) { - super(conf); - } - - public GravesBidirectionalLSTM(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public GravesBidirectionalLSTM(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index d1d9ced3f956..a2f38b32432d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -49,12 +50,8 @@ public class GravesLSTM extends BaseRecurrentLayer 1 || forBackprop) { wFFTranspose = Shape.toMmulCompatible(wFFTranspose); @@ -166,19 +164,19 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe if (training && cacheMode != CacheMode.NONE && workspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && workspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) { try (MemoryWorkspace wsB = workspaceMgr.notifyScopeBorrowed(ArrayType.FF_CACHE)) { - outputActivations = Nd4j.create(new int[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together + outputActivations = Nd4j.create(inputWeights.dataType(), new long[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together toReturn.fwdPassOutput = outputActivations; } } else { - outputActivations = workspaceMgr.create(ArrayType.ACTIVATIONS, new int[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together + outputActivations = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together toReturn.fwdPassOutput = outputActivations; } } else { - outputActivations = workspaceMgr.create(ArrayType.ACTIVATIONS, new int[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together + outputActivations = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatchSize, hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together toReturn.fwdPassOutput = outputActivations; } - Level1 l1BLAS = Nd4j.getBlasWrapper().level1(); + //Level1 l1BLAS = Nd4j.getBlasWrapper().level1(); //Input validation: check input data matches nIn if (input.size(1) != inputWeights.size(0)) { @@ -197,7 +195,7 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe //initialize prevOutputActivations to zeroes if (prevOutputActivations == null) { - prevOutputActivations = Nd4j.zeros(new int[] {miniBatchSize, hiddenLayerSize}); + prevOutputActivations = Nd4j.zeros(input.dataType(), new long[] {miniBatchSize, hiddenLayerSize}); } if (helper != null) { @@ -232,7 +230,7 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe ifogActivations.addiRowVector(biases); INDArray inputActivations = - ifogActivations.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize)); + ifogActivations.get(all(), interval(0, hiddenLayerSize)); if (forBackprop) { if(shouldCache(training, cacheMode, workspaceMgr)){ cacheEnter(training, cacheMode, workspaceMgr); @@ -253,11 +251,11 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe } } - INDArray forgetGateActivations = ifogActivations.get(NDArrayIndex.all(), - NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize)); + INDArray forgetGateActivations = ifogActivations.get(all(), + interval(hiddenLayerSize, 2 * hiddenLayerSize)); if (hasPeepholeConnections) { INDArray pmcellWFF = prevMemCellState.dup('f').muliRowVector(wFFTranspose); - l1BLAS.axpy(pmcellWFF.length(), 1.0, pmcellWFF, forgetGateActivations); //y = a*x + y i.e., forgetGateActivations.addi(pmcellWFF) + forgetGateActivations.addi(pmcellWFF); } //Above line: treats matrix as a vector. Can only do this because we're sure both pwcelWFF and forgetGateACtivations are f order, offset 0 and have same strides if (forBackprop && !sigmoidGates) { @@ -282,11 +280,11 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe } - INDArray inputModGateActivations = ifogActivations.get(NDArrayIndex.all(), - NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray inputModGateActivations = ifogActivations.get(all(), + interval(3 * hiddenLayerSize, 4 * hiddenLayerSize)); if (hasPeepholeConnections) { INDArray pmcellWGG = prevMemCellState.dup('f').muliRowVector(wGGTranspose); - l1BLAS.axpy(pmcellWGG.length(), 1.0, pmcellWGG, inputModGateActivations); //inputModGateActivations.addi(pmcellWGG) + inputModGateActivations.addi(pmcellWGG); } if (forBackprop && !sigmoidGates) { cacheEnter(training, cacheMode, workspaceMgr); @@ -317,13 +315,13 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe currentMemoryCellState = workspaceMgr.leverageTo(ArrayType.FF_WORKING_MEM, forgetGateActivations.muli(prevMemCellState)); //TODO optimize without the copy inputModMulInput = inputModGateActivations.muli(inputActivations); } - l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState); //currentMemoryCellState.addi(inputModMulInput) + currentMemoryCellState.addi(inputModMulInput); - INDArray outputGateActivations = ifogActivations.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize)); + INDArray outputGateActivations = ifogActivations.get(all(), + interval(2 * hiddenLayerSize, 3 * hiddenLayerSize)); if (hasPeepholeConnections) { INDArray pmcellWOO = currentMemoryCellState.dup('f').muliRowVector(wOOTranspose); - l1BLAS.axpy(pmcellWOO.length(), 1.0, pmcellWOO, outputGateActivations); //outputGateActivations.addi(pmcellWOO) + outputGateActivations.addi(pmcellWOO); } if (forBackprop && !sigmoidGates) { cacheEnter(training, cacheMode, workspaceMgr); @@ -365,7 +363,7 @@ static public FwdPassReturn activateHelper(final BaseLayer layer, final NeuralNe // incorrectly using activations from masked time steps (i.e., want 0 initialization in both directions) //We *also* need to apply this to the memory cells, as they are carried forward //Mask array has shape [minibatch, timeSeriesLength] -> get column - INDArray timeStepMaskColumn = maskArray.getColumn(time); + INDArray timeStepMaskColumn = maskArray.getColumn(time, true); currHiddenUnitActivations.muliColumnVector(timeStepMaskColumn); currentMemoryCellState.muliColumnVector(timeStepMaskColumn); } @@ -427,7 +425,7 @@ private static void cacheExit(boolean training, CacheMode cacheMode, LayerWorksp } static public Pair backpropGradientHelper(final NeuralNetConfiguration conf, - final IActivation gateActivationFn, final INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] + final IActivation gateActivationFn, INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg] final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, @@ -437,6 +435,7 @@ static public Pair backpropGradientHelper(final NeuralNetCon final LSTMHelper helper, final LayerWorkspaceMgr workspaceMgr) { + input = input.castTo(inputWeights.dataType()); //No-op if //Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength] val hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L @@ -449,28 +448,28 @@ static public Pair backpropGradientHelper(final NeuralNetCon INDArray wOOTranspose = null; INDArray wGGTranspose = null; if (hasPeepholeConnections) { - wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose(); - wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose(); - wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose(); + wFFTranspose = recurrentWeights.get(all(), point(4 * hiddenLayerSize)).reshape(1, recurrentWeights.size(0)); + wOOTranspose = recurrentWeights.get(all(), point(4 * hiddenLayerSize + 1)).reshape(1, recurrentWeights.size(0)); + wGGTranspose = recurrentWeights.get(all(), point(4 * hiddenLayerSize + 2)).reshape(1, recurrentWeights.size(0)); } - INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize)); + INDArray wIFOG = recurrentWeights.get(all(), interval(0, 4 * hiddenLayerSize)); //F order here so that content for time steps are together - INDArray epsilonNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, new long[] {miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T] + INDArray epsilonNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T] INDArray nablaCellStateNext = null; - INDArray deltaifogNext = Nd4j.create(new long[] {miniBatchSize, 4 * hiddenLayerSize}, 'f'); - INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize)); - INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize)); - INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize)); - INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray deltaifogNext = Nd4j.create(inputWeights.dataType(), new long[] {miniBatchSize, 4 * hiddenLayerSize}, 'f'); + INDArray deltaiNext = deltaifogNext.get(all(), interval(0, hiddenLayerSize)); + INDArray deltafNext = deltaifogNext.get(all(), + interval(hiddenLayerSize, 2 * hiddenLayerSize)); + INDArray deltaoNext = deltaifogNext.get(all(), + interval(2 * hiddenLayerSize, 3 * hiddenLayerSize)); + INDArray deltagNext = deltaifogNext.get(all(), + interval(3 * hiddenLayerSize, 4 * hiddenLayerSize)); - Level1 l1BLAS = Nd4j.getBlasWrapper().level1(); +// Level1 l1BLAS = Nd4j.getBlasWrapper().level1(); long endIdx = 0; if (truncatedBPTT) { @@ -487,14 +486,14 @@ static public Pair backpropGradientHelper(final NeuralNetCon bGradientsOut.assign(0); INDArray rwGradientsIFOG = - rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize)); + rwGradientsOut.get(all(), interval(0, 4 * hiddenLayerSize)); INDArray rwGradientsFF = null; INDArray rwGradientsOO = null; INDArray rwGradientsGG = null; if (hasPeepholeConnections) { - rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize)); - rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1)); - rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2)); + rwGradientsFF = rwGradientsOut.get(all(), NDArrayIndex.point(4 * hiddenLayerSize)).reshape(1, recurrentWeights.size(0)); + rwGradientsOO = rwGradientsOut.get(all(), NDArrayIndex.point(4 * hiddenLayerSize + 1)).reshape(1, recurrentWeights.size(0)); + rwGradientsGG = rwGradientsOut.get(all(), NDArrayIndex.point(4 * hiddenLayerSize + 2)).reshape(1, recurrentWeights.size(0)); } if (helper != null) { @@ -528,10 +527,9 @@ static public Pair backpropGradientHelper(final NeuralNetCon INDArray nablaCellState; if (iTimeIndex != timeSeriesLength - 1 && hasPeepholeConnections) { nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose); - l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose), - nablaCellState); + nablaCellState.addi(deltagNext.dup('f').muliRowVector(wGGTranspose)); } else { - nablaCellState = Nd4j.create(new long[]{miniBatchSize, hiddenLayerSize}, 'f'); + nablaCellState = Nd4j.create(inputWeights.dataType(), new long[]{miniBatchSize, hiddenLayerSize}, 'f'); } INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[(int) (time - inext)]); @@ -567,15 +565,14 @@ static public Pair backpropGradientHelper(final NeuralNetCon //Memory cell error: INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst(); //TODO activation functions with params - l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState); + nablaCellState.addi(temp); if (hasPeepholeConnections) { INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose); - l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState); //nablaCellState.addi(deltao.mulRowVector(wOOTranspose)); + nablaCellState.addi(deltaMulRowWOO); } if (iTimeIndex != timeSeriesLength - 1) { INDArray nextForgetGateAs = fwdPass.fa[time + inext]; - val length = nablaCellState.length(); - l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState); //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext)) + nablaCellState.addi(nextForgetGateAs.muli(nablaCellStateNext)); } @@ -610,7 +607,7 @@ static public Pair backpropGradientHelper(final NeuralNetCon deltag.muli(ai); deltag.muli(nablaCellState); } else { - INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f'))); + INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f'))); deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst()); //TODO activation functions with params; optimize (no assign) } @@ -619,7 +616,7 @@ static public Pair backpropGradientHelper(final NeuralNetCon //Network input delta: INDArray zi = fwdPass.iz[time]; INDArray deltai = deltaiNext; - temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f'))); + temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f'))); deltai.assign(afn.backprop(zi, temp).getFirst()); //TODO activation functions with params; also: optimize this (no assign) //Shape: [m,n^L] @@ -629,7 +626,7 @@ static public Pair backpropGradientHelper(final NeuralNetCon if (maskArray != null) { //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step // to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step) - timeStepMaskColumn = maskArray.getColumn(time); + timeStepMaskColumn = maskArray.getColumn(time, true); deltaifogNext.muliColumnVector(timeStepMaskColumn); //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients } @@ -642,12 +639,12 @@ static public Pair backpropGradientHelper(final NeuralNetCon Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0); } else { INDArray iwGradients_i = - iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize)); + iwGradientsOut.get(all(), interval(0, hiddenLayerSize)); Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0); - INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); - INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray iwGradients_og = iwGradientsOut.get(all(), + interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray deltaog = deltaifogNext.get(all(), + interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0); } @@ -661,28 +658,26 @@ static public Pair backpropGradientHelper(final NeuralNetCon //Shape: [1,n^L]. sum(0) is sum over examples in mini-batch. //Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create() if (hasPeepholeConnections) { - INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0); //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j) - l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF); //rwGradients[4].addi(dLdwFF); //dL/dw_{FF} - INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0); - l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG); //rwGradients[6].addi(dLdwGG); + INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(true, 0); //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j) + rwGradientsFF.addi(dLdwFF); + INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(true, 0); + rwGradientsGG.addi(dLdwGG); } } if (hasPeepholeConnections) { - INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0); //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch. - l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO); //rwGradients[5].addi(dLdwOO); //dL/dw_{OOxy} + INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(true, 0); //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch. + rwGradientsOO.addi(dLdwOO); } if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0 //Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT - l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut); + bGradientsOut.addi(deltaifogNext.sum(true, 0)); } else { - l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut.get(point(0), interval(0, hiddenLayerSize))); //bGradients_i += deltai.sum(0) - INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0); - INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); - l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad); + bGradientsOut.get(interval(0,0,true), interval(0, hiddenLayerSize)).addi(deltai.sum(true, 0)); + INDArray ogBiasToAdd = deltaifogNext.get(all(), interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(true, 0); + INDArray ogBiasGrad = bGradientsOut.get(interval(0,0,true), interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); + ogBiasGrad.addi(ogBiasToAdd); } //Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network @@ -693,12 +688,10 @@ static public Pair backpropGradientHelper(final NeuralNetCon Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0); } else { //No contribution from forget gate at t=0 - INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize)); + INDArray wi = inputWeights.get(all(), interval(0, hiddenLayerSize)); Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0); - INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); - INDArray wog = inputWeights.get(NDArrayIndex.all(), - NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray deltaog = deltaifogNext.get(all(), interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); + INDArray wog = inputWeights.get(all(), interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)); Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0); //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose)); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java index fb257bf65ae7..8e9f7c8f11f9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java @@ -59,7 +59,7 @@ public Type type() { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - INDArray newEps = Nd4j.create(origOutputShape, 'f'); + INDArray newEps = Nd4j.create(epsilon.dataType(), origOutputShape, 'f'); if(lastTimeStepIdxs == null){ //no mask case newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java index aea696416883..4e01ea084c71 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java @@ -22,7 +22,9 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; +import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import lombok.NonNull; @@ -74,8 +76,8 @@ private void setMaskFromInput(INDArray input) { throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], " + "got shape "+Arrays.toString(input.shape()) + " instead"); } - INDArray mask = input.eq(maskingValue).sum(1).neq(input.shape()[1]); - underlying.setMaskArray(mask); + INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]); + underlying.setMaskArray(mask.detach()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index 5b2fcaeefac0..e032a2bf11a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -50,8 +51,8 @@ public class RnnLossLayer extends BaseLayer implements IOutputLayer { @Setter @Getter protected INDArray labels; - public RnnLossLayer(NeuralNetConfiguration conf) { - super(conf); + public RnnLossLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -174,6 +175,8 @@ public boolean isPretrainLayer() { @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + if(maskArray == null) + return null; this.maskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(maskArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); //TODO this.maskState = currentMaskState; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 0e2768566f68..e6cac62fe3e8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; @@ -43,12 +44,8 @@ */ public class RnnOutputLayer extends BaseOutputLayer { - public RnnOutputLayer(NeuralNetConfiguration conf) { - super(conf); - } - - public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) { - super(conf, input); + public RnnOutputLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -116,7 +113,7 @@ protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) { if (labels.rank() == 3) return TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, arrayType); - return workspaceMgr.castTo(arrayType, Nd4j.defaultFloatingPointType(), labels, false); + return labels; } @Override @@ -134,10 +131,10 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { if (maskArray != null) { if(!maskArray.isColumnVectorOrScalar() || Arrays.equals(maskArray.shape(), act2d.shape())){ //Per output masking - act2d.muli(maskArray); + act2d.muli(maskArray.castTo(act2d.dataType())); } else { //Per time step masking - act2d.muliColumnVector(maskArray); + act2d.muliColumnVector(maskArray.castTo(act2d.dataType())); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index e91aa9461475..9c05ec1d9dd6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; @@ -49,8 +50,8 @@ public class SimpleRnn extends BaseRecurrentLayer { public static final String STATE_KEY_PREV_ACTIVATION = "prevAct"; - public SimpleRnn(NeuralNetConfiguration conf) { - super(conf); + public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -90,6 +91,8 @@ public Pair tbpttBackpropGradient(INDArray epsilon, int tbpt val nOut = layerConf().getNOut(); + INDArray input = this.input.castTo(dataType); //No-op if correct type + //First: Do forward pass to get gate activations and Zs Quad p = activateHelper(null, true, true, workspaceMgr); @@ -97,15 +100,15 @@ public Pair tbpttBackpropGradient(INDArray epsilon, int tbpt INDArray rw = getParamWithNoise(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY, true, workspaceMgr); INDArray b = getParamWithNoise(SimpleRnnParamInitializer.BIAS_KEY, true, workspaceMgr); INDArray g = (hasLayerNorm() ? getParamWithNoise(SimpleRnnParamInitializer.GAIN_KEY, true, workspaceMgr) : null); - INDArray gx = (g != null ? g.get(point(0), interval(0, nOut)) : null); - INDArray gr = (g != null ? g.get(point(0), interval(nOut, nOut * 2)) : null); + INDArray gx = (g != null ? g.get(interval(0, 0, true), interval(0, nOut)) : null); + INDArray gr = (g != null ? g.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null); INDArray wg = gradientViews.get(SimpleRnnParamInitializer.WEIGHT_KEY); INDArray rwg = gradientViews.get(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY); INDArray bg = gradientViews.get(SimpleRnnParamInitializer.BIAS_KEY); INDArray gg = (hasLayerNorm() ? gradientViews.get(SimpleRnnParamInitializer.GAIN_KEY) : null); - INDArray gxg = (gg != null ? gg.get(point(0), interval(0, nOut)) : null); - INDArray grg = (gg != null ? gg.get(point(0), interval(nOut, nOut * 2)) : null); + INDArray gxg = (gg != null ? gg.get(interval(0, 0, true), interval(0, nOut)) : null); + INDArray grg = (gg != null ? gg.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null); gradientsFlattened.assign(0); @@ -113,7 +116,7 @@ public Pair tbpttBackpropGradient(INDArray epsilon, int tbpt val tsLength = input.size(2); - INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.shape(), 'f'); + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'f'); INDArray dldzNext = null; long end; @@ -146,7 +149,7 @@ public Pair tbpttBackpropGradient(INDArray epsilon, int tbpt //Mask array: shape [minibatch, tsLength] //If mask array is present (for example, with bidirectional RNN) -> need to zero out these errors to // avoid using errors from a masked time step to calculate the parameter gradients - maskCol = maskArray.getColumn(i); + maskCol = maskArray.getColumn(i, true); dldzCurrent.muliColumnVector(maskCol); } @@ -216,6 +219,8 @@ private Quad activateHelper(INDArray prevS "3D input expected to RNN layer expected, got " + input.rank()); applyDropOutIfNecessary(training, workspaceMgr); + + INDArray input = this.input.castTo(dataType); //No-op if correct type val m = input.size(0); val tsLength = input.size(2); val nOut = layerConf().getNOut(); @@ -224,13 +229,13 @@ private Quad activateHelper(INDArray prevS INDArray rw = getParamWithNoise(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); INDArray b = getParamWithNoise(SimpleRnnParamInitializer.BIAS_KEY, training, workspaceMgr); INDArray g = (hasLayerNorm() ? getParamWithNoise(SimpleRnnParamInitializer.GAIN_KEY, training, workspaceMgr) : null); - INDArray gx = (g != null ? g.get(point(0), interval(0, nOut)) : null); - INDArray gr = (g != null ? g.get(point(0), interval(nOut, nOut * 2)) : null); + INDArray gx = (g != null ? g.get(interval(0, 0, true), interval(0, nOut)) : null); + INDArray gr = (g != null ? g.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null); - INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{m, nOut, tsLength}, 'f'); - INDArray outZ = (forBackprop ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, out.shape()) : null); - INDArray outPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, out.shape(), 'f') : null); - INDArray recPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, out.shape(), 'f') : null); + INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, w.dataType(), new long[]{m, nOut, tsLength}, 'f'); + INDArray outZ = (forBackprop ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape()) : null); + INDArray outPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null); + INDArray recPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null); if(input.ordering() != 'f' || Shape.strideDescendingCAscendingF(input)) input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f'); @@ -258,9 +263,9 @@ private Quad activateHelper(INDArray prevS if(i > 0 || prevStepOut != null){ if(hasLayerNorm()){ - INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.shape(), 'f');; + INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');; Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0); - INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.shape(), 'f'); + INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f'); Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, 1)); currOut.addi(recNorm); }else { @@ -280,9 +285,10 @@ private Quad activateHelper(INDArray prevS //Apply mask, if present: if(maskArray != null){ //Mask should be shape [minibatch, tsLength] - Nd4j.getExecutioner().exec(new BroadcastMulOp(out, maskArray, out, 0, 2)); + INDArray mask = maskArray.castTo(dataType); + Nd4j.getExecutioner().exec(new BroadcastMulOp(out, mask, out, 0, 2)); if(forBackprop){ - Nd4j.getExecutioner().exec(new BroadcastMulOp(outZ, maskArray, outZ, 0, 2)); + Nd4j.getExecutioner().exec(new BroadcastMulOp(outZ, mask, outZ, 0, 2)); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 75f721b96ed0..74a2cfb0dd31 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -32,6 +32,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; @@ -66,8 +67,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex { private int minibatchSize; public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, - INDArray paramsView, boolean initParams) { - super(graph, name, vertexIndex, null, null); + INDArray paramsView, boolean initParams, DataType dataType) { + super(graph, name, vertexIndex, null, null, dataType); this.config = config; SDVertexParams vp = config.getVertexParams(); paramTable = SameDiffParamInitializer.getInstance().subsetAndReshape(vp.getParameterKeys(), @@ -109,7 +110,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { if(maskArrays != null && maskArrays[i] != null) { sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName); }else{ - sameDiff.associateArrayWithVariable(createMask(inputs[i].shape()), maskName); + sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName); } } @@ -140,7 +141,7 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo if(maskArrays != null && maskArrays[i] != null) { sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName); }else{ - sameDiff.associateArrayWithVariable(createMask(inputs[i].shape()), maskName); + sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName); } } fn.updateVariable(outputVar.getVarName(), epsilon.dup()); @@ -200,7 +201,7 @@ protected void doInit(){ val inputShape = inputs[i++].shape().clone(); SDVariable inputVar = sameDiff.var(s, inputShape); inputVars.put(s, inputVar); - SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(inputShape)); + SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(dataType, inputShape)); maskVars.put(s, maskVar); } @@ -252,12 +253,15 @@ public INDArray getGradientsViewArray() { return gradients; } - static INDArray createMask(long[] shape){ + //Package private + static INDArray createMask(DataType dataType, long[] shape){ switch (shape.length){ case 2: // FF-Type input - return Nd4j.ones(shape[0], 1); + return Nd4j.ones(dataType,shape[0], 1); case 3: // RNN-Type input - return Nd4j.ones(shape[0], shape[2]); + return Nd4j.ones(dataType, shape[0], shape[2]); + case 4: //CNN input + return Nd4j.ones(dataType, shape[0], 1, 1, 1); default: Preconditions.throwEx("Can not create all-ones-mask for given input shape %s.", Arrays.toString(shape)); return null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 7e2febebe4a3..55f86e06602e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; @@ -53,8 +54,8 @@ public class SameDiffLayer extends AbstractLayer { protected Map gradTable; - public SameDiffLayer(NeuralNetConfiguration conf){ - super(conf); + public SameDiffLayer(NeuralNetConfiguration conf, DataType dataType){ + super(conf, dataType); } @@ -82,15 +83,13 @@ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { } try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { -// sameDiff.clearExecutionCache(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); bl.validateInput(input); - sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY)); if(maskArray != null){ sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY)); }else{ - sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(input.shape()), sameDiff.getVariable(MASK_KEY)); + sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY)); } for(String s : paramTable.keySet() ) { sameDiff.associateArrayWithVariable(paramTable.get(s), s); @@ -119,7 +118,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac if(maskArray != null){ sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY)); }else{ - sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(input.shape()), sameDiff.getVariable(MASK_KEY)); + sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY)); } fn.updateVariable(outputVar.getVarName(), epsilon.dup()); @@ -227,17 +226,16 @@ protected void doInit(){ Map p = paramTable(); val inputShape = input.shape().clone(); -// inputShape[0] = -1; //TODO THIS DOESN'T ENABLE VARIABLE SIZE MINIBATCHES - SDVariable inputVar = sameDiff.var(INPUT_KEY, inputShape); + SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape); Map paramShapes = layerConf().getLayerParams().getParamShapes(); Map params = new LinkedHashMap<>(); for (String s : paramShapes.keySet()) { val ps = paramShapes.get(s); - SDVariable v = sameDiff.var(s, ps); + SDVariable v = sameDiff.var(s, dataType, ps); params.put(s, v); } - SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(inputShape)); + SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(dataType, inputShape)); SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index a269b81013c5..30c4adcc9c8b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -30,6 +30,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; @@ -59,12 +60,10 @@ public class SameDiffOutputLayer extends AbstractLayer gradTable; - public SameDiffOutputLayer(NeuralNetConfiguration conf){ - super(conf); + public SameDiffOutputLayer(NeuralNetConfiguration conf, DataType dataType){ + super(conf, dataType); } - - @Override public Layer clone() { throw new UnsupportedOperationException(); @@ -248,11 +247,11 @@ protected void doInit(){ Map p = paramTable(); val inputShape = input.shape().clone(); - SDVariable inputVar = sameDiff.var(INPUT_KEY, inputShape); + SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape); SDVariable labelVar = null; if(layerConf().labelsRequired()){ long[] labelShape = labels == null ? new long[]{1} : labels.shape().clone(); - labelVar = sameDiff.var(LABELS_KEY, labelShape); + labelVar = sameDiff.var(LABELS_KEY, dataType, labelShape); } Map paramShapes = layerConf().getLayerParams().getParamShapes(); Map params = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java index 03eea68d2b9e..df697798147b 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.params.CenterLossParamInitializer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; @@ -45,12 +46,8 @@ public class CenterLossOutputLayer extends BaseOutputLayer backpropGradient(INDArray epsilon, LayerWorkspac // centers INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY); - INDArray centersForExamples = labels.mmul(centers); + INDArray l = labels.castTo(centers.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype + INDArray centersForExamples = l.mmul(centers); INDArray dLcdai = input.sub(centersForExamples); INDArray w = getParamWithNoise(CenterLossParamInitializer.WEIGHT_KEY, true, workspaceMgr); - INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{w.size(0), delta.size(0)}, 'f'); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, w.dataType(), new long[]{w.size(0), delta.size(0)}, 'f'); epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose(); double lambda = layerConf().getLambda(); epsilonNext.addi(dLcdai.muli(lambda)); // add center loss here @@ -201,10 +200,11 @@ private Pair getGradientsAndDelta(INDArray preOut, LayerWork double alpha = layerConf().getAlpha(); INDArray centers = params.get(CenterLossParamInitializer.CENTER_KEY); - INDArray centersForExamples = labels.mmul(centers); + INDArray l = labels.castTo(centers.dataType()); //Ensure correct dtype (same as params); no-op if already correct dtype + INDArray centersForExamples = l.mmul(centers); INDArray diff = centersForExamples.sub(input).muli(alpha); - INDArray numerator = labels.transpose().mmul(diff); - INDArray denominator = labels.sum(0).reshape(labels.size(1), 1).addi(1.0); + INDArray numerator = l.transpose().mmul(diff); + INDArray denominator = l.sum(0).reshape(l.size(1), 1).addi(1.0); INDArray deltaC; if (layerConf().getGradientCheck()) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java index 62510db74beb..5436bc2b1763 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Broadcast; import org.nd4j.linalg.primitives.Pair; @@ -38,8 +39,8 @@ public class MaskLayer extends AbstractLayer { private Gradient emptyGradient = new DefaultGradient(); - public MaskLayer(NeuralNetConfiguration conf) { - super(conf); + public MaskLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override @@ -86,7 +87,7 @@ private static INDArray applyMask(INDArray input, INDArray maskArray, LayerWorks Arrays.toString(input.shape()) + ", expected 2d mask array with shape [minibatch, sequenceLength]." + " Got mask with shape: "+ Arrays.toString(maskArray.shape())); } - INDArray fwd = workspaceMgr.createUninitialized(type, input.shape(), 'f'); + INDArray fwd = workspaceMgr.createUninitialized(type, input.dataType(), input.shape(), 'f'); Broadcast.mul(input, maskArray, fwd, 0, 2); return fwd; case 4: @@ -102,7 +103,7 @@ private static INDArray applyMask(INDArray input, INDArray maskArray, LayerWorks dimensions = Arrays.copyOfRange(dimensions, 0, count); } - INDArray fwd2 = workspaceMgr.createUninitialized(type, input.shape(), 'c'); + INDArray fwd2 = workspaceMgr.createUninitialized(type, input.dataType(), input.shape(), 'c'); Broadcast.mul(input, maskArray, fwd2, dimensions); return fwd2; default: diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java index d229d48f4005..9331dda557d5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java @@ -38,6 +38,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.blas.Level1; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -88,6 +89,7 @@ public class VariationalAutoencoder implements Layer { protected IActivation pzxActivationFn; protected int numSamples; protected CacheMode cacheMode = CacheMode.NONE; + protected DataType dataType; protected boolean zeroedPretrainParamGradients = false; @@ -98,8 +100,9 @@ public class VariationalAutoencoder implements Layer { @Getter @Setter protected int epochCount; - public VariationalAutoencoder(NeuralNetConfiguration conf) { + public VariationalAutoencoder(NeuralNetConfiguration conf, DataType dataType) { this.conf = conf; + this.dataType = dataType; this.encoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) @@ -213,7 +216,7 @@ public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { for (int l = 0; l < numSamples; l++) { //Default (and in most cases) numSamples == 1 double gemmCConstant = (l == 0 ? 0.0 : 1.0); //0 for first one (to get rid of previous buffer data), otherwise 1 (for adding) - INDArray e = Nd4j.randn(minibatch, size); + INDArray e = Nd4j.randn(dataType, minibatch, size); INDArray z = pzxSigma.mul(e).addi(meanZ); //z = mu + sigma * e, with e ~ N(0,1) @@ -666,6 +669,8 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac zeroedPretrainParamGradients = true; } + INDArray input = this.input.castTo(dataType); + Gradient gradient = new DefaultGradient(); VAEFwdHelper fwd = doForward(true, true, workspaceMgr); @@ -713,7 +718,7 @@ public Pair backpropGradient(INDArray epsilon, LayerWorkspac gradient.gradientForVariable().put(bKey, dLdB); if(i == 0) { - epsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{weights.size(0), currentDelta.size(0)}, 'f'); + epsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, currentDelta.dataType(), new long[]{weights.size(0), currentDelta.size(0)}, 'f'); weights.mmuli(currentDelta.transpose(), epsilon); epsilon = epsilon.transpose(); } else { @@ -767,7 +772,7 @@ private VAEFwdHelper doForward(boolean training, boolean forBackprop, LayerWorks INDArray mW = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, training, workspaceMgr); INDArray mB = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_B, training, workspaceMgr); - INDArray pzxMean = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{current.size(0), mW.size(1)}, 'f'); + INDArray pzxMean = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, mW.dataType(), new long[]{current.size(0), mW.size(1)}, 'f'); pzxMean = current.mmuli(mW, pzxMean).addiRowVector(mB); @@ -937,7 +942,7 @@ public void fit() { */ public INDArray reconstructionProbability(INDArray data, int numSamples) { INDArray reconstructionLogProb = reconstructionLogProbability(data, numSamples); - return Transforms.exp(reconstructionLogProb, false); + return Transforms.exp(reconstructionLogProb.castTo(DataType.DOUBLE), false); //Cast to double to reduce risk of numerical underflow } /** @@ -960,6 +965,8 @@ public INDArray reconstructionLogProbability(INDArray data, int numSamples) { + layerId()); } + data = data.castTo(dataType); + //Forward pass through the encoder and mean for P(Z|X) LayerWorkspaceMgr workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); //TODO add workspace support to this method setInput(data, workspaceMgr); @@ -997,7 +1004,7 @@ public INDArray reconstructionLogProbability(INDArray data, int numSamples) { INDArray sumReconstructionNegLogProbability = null; for (int i = 0; i < numSamples; i++) { - INDArray e = Nd4j.randn(minibatch, size); + INDArray e = Nd4j.randn(dataType, minibatch, size); INDArray z = e.muli(pzxSigma).addi(meanZ); //z = mu + sigma * e, with e ~ N(0,1) //Do forward pass through decoder diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 5bbe831f26e2..92b9e47f433d 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -46,7 +46,6 @@ import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; -import org.deeplearning4j.nn.updater.MultiLayerUpdater; import org.deeplearning4j.nn.updater.UpdaterCreator; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -63,6 +62,7 @@ import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -614,6 +614,19 @@ public void init(INDArray parameters, boolean cloneParametersArray) { if (initCalled) return; + DataType netDtype = getLayerWiseConfigurations().getDataType(); + if(parameters != null && parameters.dataType() != netDtype){ + if(cloneParametersArray){ + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + parameters = parameters.castTo(netDtype); + } + } else { + throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype + + " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument" + + " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); + } + } + if (layerMap == null) layerMap = new LinkedHashMap<>(); @@ -665,7 +678,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { initializeParams = false; } else if(paramLength > 0){ - flattenedParams = Nd4j.create(1, paramLength); + flattenedParams = Nd4j.create(netDtype, 1, paramLength); initializeParams = true; } else { //Edge case: 0 params in network @@ -683,7 +696,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { for (int i = 0; i < nLayers; i++) { INDArray paramsView; if (nParamsPerLayer[i] > 0) { - paramsView = flattenedParams.get(NDArrayIndex.point(0), + paramsView = flattenedParams.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramCountSoFar, paramCountSoFar + nParamsPerLayer[i])); } else { paramsView = null; @@ -691,7 +704,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) { paramCountSoFar += nParamsPerLayer[i]; NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - layers[i] = conf.getLayer().instantiate(conf, trainingListeners, i, paramsView, initializeParams); + layers[i] = conf.getLayer().instantiate(conf, trainingListeners, i, paramsView, initializeParams, netDtype); layerMap.put(conf.getLayer().getLayerName(), layers[i]); } initCalled = true; @@ -777,7 +790,7 @@ public void initGradientsView() { for (int i = 0; i < layers.length; i++) { if (nParamsPerLayer[i] == 0) continue; //This layer doesn't have any parameters... - INDArray thisLayerGradView = flattenedGradients.get(NDArrayIndex.point(0), + INDArray thisLayerGradView = flattenedGradients.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramsSoFar, paramsSoFar + nParamsPerLayer[i])); layers[i].setBackpropGradientsViewArray(thisLayerGradView); paramsSoFar += nParamsPerLayer[i]; @@ -1481,7 +1494,7 @@ public void setParams(INDArray params) { long range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling, etc) - INDArray get = params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, range + idx)); + INDArray get = params.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(idx, range + idx)); layer.setParams(get); idx += range; } @@ -1504,7 +1517,7 @@ public void setBackpropGradientsViewArray(INDArray gradients) { for (Layer layer : layers) { if (layer.numParams() == 0) continue; - layer.setBackpropGradientsViewArray(gradients.get(NDArrayIndex.point(0), + layer.setBackpropGradientsViewArray(gradients.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramsSoFar, paramsSoFar + layer.numParams()))); paramsSoFar += layer.numParams(); } @@ -1896,7 +1909,7 @@ protected Pair calcBackpropGradients(INDArray epsilon, boole if(currPair.getSecond() != null) { //Edge case: may be null for Embedding layer, for example - validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, numLayers - 1, + validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, false, "Backprop"); } @@ -3734,6 +3747,37 @@ public ComputationGraph toComputationGraph(){ return NetworkUtils.toComputationGraph(this); } + /** + * Return a copy of the network with the parameters and activations set to use the specified (floating point) data type. + * If the existing datatype is the same as the requested dataype, the original network will be returned unchanged. + * Only floating point datatypes (DOUBLE, FLOAT, HALF) may be used. + * + * @param dataType Datatype to convert the network to + * @return The network, set to use the specified datatype for the parameters and activations + */ + public MultiLayerNetwork convertDataType(@NonNull DataType dataType){ + Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType); + if(dataType == params().dataType()){ + return this; + } + + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + INDArray newParams = params().castTo(dataType); + String jsonConfig = getLayerWiseConfigurations().toJson(); + MultiLayerConfiguration newConf = MultiLayerConfiguration.fromJson(jsonConfig); + newConf.setDataType(dataType); + MultiLayerNetwork newNet = new MultiLayerNetwork(newConf); + newNet.init(newParams, false); + + Updater u = getUpdater(false); + if(u != null && u.getStateViewArray() != null){ + INDArray oldUpdaterState = u.getStateViewArray(); + newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState); + } + return newNet; + } + } + /** * Set the learning rate for all layers in the network to the specified value. Note that if any learning rate * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java index 6048f4ea5632..534f99a1de08 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java @@ -103,8 +103,8 @@ public Map init(NeuralNetConfiguration conf, INDArray paramVie long meanOffset = 0; if (!layer.isLockGammaBeta()) { //No gamma/beta parameters when gamma/beta are locked - INDArray gammaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); - INDArray betaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut)); + INDArray gammaView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray betaView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, 2 * nOut)); params.put(GAMMA, createGamma(conf, gammaView, initializeParams)); conf.addVariable(GAMMA); @@ -115,8 +115,8 @@ public Map init(NeuralNetConfiguration conf, INDArray paramVie } INDArray globalMeanView = - paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut)); - INDArray globalVarView = paramView.get(NDArrayIndex.point(0), + paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset, meanOffset + nOut)); + INDArray globalVarView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)); if (initializeParams) { @@ -151,20 +151,20 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co Map out = new LinkedHashMap<>(); long meanOffset = 0; if (!layer.isLockGammaBeta()) { - INDArray gammaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); - INDArray betaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut)); + INDArray gammaView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray betaView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, 2 * nOut)); out.put(GAMMA, gammaView); out.put(BETA, betaView); meanOffset = 2 * nOut; } out.put(GLOBAL_MEAN, - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut))); + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset, meanOffset + nOut))); if(layer.isUseLogStd()){ - out.put(GLOBAL_LOG_STD, gradientView.get(NDArrayIndex.point(0), + out.put(GLOBAL_LOG_STD, gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut))); } else { - out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.point(0), + out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut))); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java index 6b175e4c918c..c52010227cba 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java @@ -108,8 +108,8 @@ public boolean isBiasParam(Layer layer, String key) { @Override public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { val n = paramsView.length()/2; - INDArray forwardView = paramsView.get(point(0), interval(0, n)); - INDArray backwardView = paramsView.get(point(0), interval(n, 2*n)); + INDArray forwardView = paramsView.get(interval(0,0,true), interval(0, n)); + INDArray backwardView = paramsView.get(interval(0,0,true), interval(n, 2*n)); conf.clearVariables(); @@ -159,8 +159,8 @@ private List addPrefixes(List fwd, List bwd){ @Override public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { val n = gradientView.length()/2; - INDArray forwardView = gradientView.get(point(0), interval(0, n)); - INDArray backwardView = gradientView.get(point(0), interval(n, 2*n)); + INDArray forwardView = gradientView.get(interval(0,0,true), interval(0, n)); + INDArray backwardView = gradientView.get(interval(0,0,true), interval(n, 2*n)); Map origFwd = underlying.initializer().getGradientsFromFlattened(conf, forwardView); Map origBwd = underlying.initializer().getGradientsFromFlattened(conf, backwardView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java index 461dfb7a7e2f..eabb4f85dac1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java @@ -67,9 +67,9 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val bEndOffset = wEndOffset + nOut; val cEndOffset = bEndOffset + nIn * nOut; - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, wEndOffset)); - INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(wEndOffset, bEndOffset)); - INDArray centerLossView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bEndOffset, cEndOffset)) + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, wEndOffset)); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(wEndOffset, bEndOffset)); + INDArray centerLossView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bEndOffset, cEndOffset)) .reshape('c', nOut, nIn); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); @@ -94,10 +94,10 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val bEndOffset = wEndOffset + nOut; val cEndOffset = bEndOffset + nIn * nOut; // note: numClasses == nOut - INDArray weightGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, wEndOffset)) + INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, wEndOffset)) .reshape('f', nIn, nOut); - INDArray biasView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(wEndOffset, bEndOffset)); //Already a row vector - INDArray centerLossView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bEndOffset, cEndOffset)) + INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(wEndOffset, bEndOffset)); //Already a row vector + INDArray centerLossView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bEndOffset, cEndOffset)) .reshape('c', nOut, nIn); Map out = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java index a1a324844bcc..054927d3157d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java @@ -73,8 +73,8 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nOut = layerConf.getNOut(); if (layer.hasBias()) { - INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(conf))); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.addVariable(WEIGHT_KEY); @@ -99,9 +99,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co Map out = new LinkedHashMap<>(); if (layerConf.hasBias()) { - INDArray biasGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); INDArray weightGradientView = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(conf))) + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))) .reshape('c', nOut, nIn, kernel[0], kernel[1], kernel[2]); out.put(BIAS_KEY, biasGradientView); out.put(WEIGHT_KEY, weightGradientView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java index 38a0e2af64ff..a133fc8431a8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java @@ -112,8 +112,8 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi if(layer.hasBias()){ //Standard case - INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(conf))); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.addVariable(WEIGHT_KEY); @@ -140,9 +140,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co Map out = new LinkedHashMap<>(); if(layerConf.hasBias()){ //Standard case - INDArray biasGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); INDArray weightGradientView = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(conf))) + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))) .reshape('c', nOut, nIn, kernel[0], kernel[1]); out.put(BIAS_KEY, biasGradientView); out.put(WEIGHT_KEY, weightGradientView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java index 4fe2439e2328..67ae7d37c4a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java @@ -83,9 +83,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co Map out = new LinkedHashMap<>(); if(layerConf.hasBias()){ - INDArray biasGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); INDArray weightGradientView = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(conf))) + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))) .reshape('c', nIn, nOut, kernel[0], kernel[1]); out.put(BIAS_KEY, biasGradientView); out.put(WEIGHT_KEY, weightGradientView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java index a5449cf7777f..f309927174f6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java @@ -110,14 +110,14 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nWeightParams)); + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.addVariable(WEIGHT_KEY); long offset = nWeightParams; if(hasBias(layerConf)){ - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); conf.addVariable(BIAS_KEY); @@ -125,7 +125,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi } if(hasLayerNorm(layerConf)){ - INDArray gainView = paramsView.get(NDArrayIndex.point(0), + INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); params.put(GAIN_KEY, createGain(conf, gainView, initializeParams)); conf.addVariable(GAIN_KEY); @@ -142,7 +142,7 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; - INDArray weightGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nWeightParams)) + INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)) .reshape('f', nIn, nOut); Map out = new LinkedHashMap<>(); @@ -150,14 +150,14 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co long offset = nWeightParams; if(hasBias(layerConf)){ - INDArray biasView = gradientView.get(NDArrayIndex.point(0), + INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector out.put(BIAS_KEY, biasView); offset += nOut; } if(hasLayerNorm(layerConf)){ - INDArray gainView = gradientView.get(NDArrayIndex.point(0), + INDArray gainView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector out.put(GAIN_KEY, gainView); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java index e5dde2cc6a29..220f591b3d62 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java @@ -129,13 +129,13 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val biasParams = numBiasParams(layerConf); INDArray depthWiseWeightView = paramsView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)); + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)); params.put(WEIGHT_KEY, createDepthWiseWeightMatrix(conf, depthWiseWeightView, initializeParams)); conf.addVariable(WEIGHT_KEY); if(layer.hasBias()){ - INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, biasParams)); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, biasParams)); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); conf.addVariable(BIAS_KEY); } @@ -159,12 +159,12 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val biasParams = numBiasParams(layerConf); INDArray depthWiseWeightGradientView = gradientView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)) + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)) .reshape('c', kernel[0], kernel[1], nIn, depthMultiplier); out.put(WEIGHT_KEY, depthWiseWeightGradientView); if(layerConf.hasBias()){ - INDArray biasGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); out.put(BIAS_KEY, biasGradientView); } return out; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java index 11a8d3c9fea3..3158b36dcbe1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java @@ -72,8 +72,8 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nIn = layerConf.getNIn(); val nWeightParams = nIn ; - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nWeightParams)); - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nWeightParams, nWeightParams + nIn)); @@ -102,8 +102,8 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val nOut = layerConf.getNOut(); val nWeightParams = nIn ; - INDArray weightGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nWeightParams)); - INDArray biasView = gradientView.get(NDArrayIndex.point(0), + INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); + INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nWeightParams, nWeightParams + nOut)); //Already a row vector Map out = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java index 2c3022a1e1d7..fa379b93751a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java @@ -135,17 +135,17 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val rwROffset = iwROffset + nParamsInput; val bROffset = rwROffset + nParamsRecurrent; - INDArray iwF = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, rwFOffset)); - INDArray rwF = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(rwFOffset, bFOffset)); - INDArray bF = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bFOffset, iwROffset)); - INDArray iwR = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(iwROffset, rwROffset)); - INDArray rwR = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(rwROffset, bROffset)); - INDArray bR = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bROffset, bROffset + nBias)); + INDArray iwF = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, rwFOffset)); + INDArray rwF = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(rwFOffset, bFOffset)); + INDArray bF = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bFOffset, iwROffset)); + INDArray iwR = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(iwROffset, rwROffset)); + INDArray rwR = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(rwROffset, bROffset)); + INDArray bR = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bROffset, bROffset + nBias)); if (initializeParams) { - bF.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, + bF.put(new INDArrayIndex[]{NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.ones(1, nL).muli(forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG - bR.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, + bR.put(new INDArrayIndex[]{NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.ones(1, nL).muli(forgetGateInit)); } /*The above line initializes the forget gate biases to specified value. @@ -205,16 +205,16 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val rwROffset = iwROffset + nParamsInput; val bROffset = rwROffset + nParamsRecurrent; - INDArray iwFG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, rwFOffset)).reshape('f', nLast, + INDArray iwFG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, rwFOffset)).reshape('f', nLast, 4 * nL); - INDArray rwFG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(rwFOffset, bFOffset)).reshape('f', + INDArray rwFG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(rwFOffset, bFOffset)).reshape('f', nL, 4 * nL + 3); - INDArray bFG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bFOffset, iwROffset)); - INDArray iwRG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(iwROffset, rwROffset)) + INDArray bFG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bFOffset, iwROffset)); + INDArray iwRG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(iwROffset, rwROffset)) .reshape('f', nLast, 4 * nL); - INDArray rwRG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(rwROffset, bROffset)).reshape('f', + INDArray rwRG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(rwROffset, bROffset)).reshape('f', nL, 4 * nL + 3); - INDArray bRG = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(bROffset, bROffset + nBias)); + INDArray bRG = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(bROffset, bROffset + nBias)); Map out = new LinkedHashMap<>(); out.put(INPUT_WEIGHT_KEY_FORWARDS, iwFG); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java index d9d513b17433..46f6af7f63a1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java @@ -112,10 +112,10 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nParamsIn = nLast * (4 * nL); val nParamsRecurrent = nL * (4 * nL + 3); val nBias = 4 * nL; - INDArray inputWeightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nParamsIn)); - INDArray recurrentWeightView = paramsView.get(NDArrayIndex.point(0), + INDArray inputWeightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsIn)); + INDArray recurrentWeightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)); - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn + nParamsRecurrent, nParamsIn + nParamsRecurrent + nBias)); if (initializeParams) { @@ -135,7 +135,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); - biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, + biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.valueArrayOf(new long[]{1, nL}, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG} /*The above line initializes the forget gate biases to specified value. * See Sutskever PhD thesis, pg19: @@ -172,12 +172,12 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val nParamsIn = nLast * (4 * nL); val nParamsRecurrent = nL * (4 * nL + 3); val nBias = 4 * nL; - INDArray inputWeightGradView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nParamsIn)) + INDArray inputWeightGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsIn)) .reshape('f', nLast, 4 * nL); INDArray recurrentWeightGradView = gradientView - .get(NDArrayIndex.point(0), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)) + .get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)) .reshape('f', nL, 4 * nL + 3); - INDArray biasGradView = gradientView.get(NDArrayIndex.point(0), + INDArray biasGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn + nParamsRecurrent, nParamsIn + nParamsRecurrent + nBias)); //already a row vector Map out = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java index 1452e3c5a4ea..327596fb9a10 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java @@ -118,10 +118,10 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nParamsIn = nLast * (4 * nL); val nParamsRecurrent = nL * (4 * nL); val nBias = 4 * nL; - INDArray inputWeightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nParamsIn)); - INDArray recurrentWeightView = paramsView.get(NDArrayIndex.point(0), + INDArray inputWeightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsIn)); + INDArray recurrentWeightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)); - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn + nParamsRecurrent, nParamsIn + nParamsRecurrent + nBias)); if (initializeParams) { @@ -140,7 +140,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); - biasView.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(nL, 2 * nL)}, + biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, Nd4j.valueArrayOf(new long[]{1, nL}, forgetGateInit)); //Order: input, forget, output, input modulation, i.e., IFOG} /*The above line initializes the forget gate biases to specified value. * See Sutskever PhD thesis, pg19: @@ -176,12 +176,12 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val nParamsIn = nLast * (4 * nL); val nParamsRecurrent = nL * (4 * nL); val nBias = 4 * nL; - INDArray inputWeightGradView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nParamsIn)) + INDArray inputWeightGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsIn)) .reshape('f', nLast, 4 * nL); INDArray recurrentWeightGradView = gradientView - .get(NDArrayIndex.point(0), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)) + .get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn, nParamsIn + nParamsRecurrent)) .reshape('f', nL, 4 * nL); - INDArray biasGradView = gradientView.get(NDArrayIndex.point(0), + INDArray biasGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nParamsIn + nParamsRecurrent, nParamsIn + nParamsRecurrent + nBias)); //already a row vector Map out = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java index 6513a2b767d4..8927fbe0ac7c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java @@ -116,7 +116,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi throw new IllegalStateException( "Expected params view of length " + length + ", got length " + paramsView.length()); - INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length)); + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, length)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.addVariable(WEIGHT_KEY); @@ -128,7 +128,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { val length = numParams(conf); - INDArray weightGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length)) + INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, length)) .reshape('f', weightShape); Map out = new LinkedHashMap<>(); out.put(WEIGHT_KEY, weightGradientView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java index 63e0618c02a6..7391da93d431 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java @@ -57,7 +57,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; - INDArray visibleBiasView = paramsView.get(NDArrayIndex.point(0), + INDArray visibleBiasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nWeightParams + nOut, nWeightParams + nOut + nIn)); params.put(VISIBLE_BIAS_KEY, createVisibleBias(conf, visibleBiasView, initializeParams)); conf.addVariable(VISIBLE_BIAS_KEY); @@ -87,7 +87,7 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; - INDArray vBiasView = gradientView.get(NDArrayIndex.point(0), + INDArray vBiasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nWeightParams + nOut, nWeightParams + nOut + nIn)); out.put(VISIBLE_BIAS_KEY, vBiasView); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java index 7e2937e25fed..ca9c10c80bee 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java @@ -132,7 +132,7 @@ public Map subsetAndReshape(List params, Map 0 - parameter array shape: " + Arrays.toString(sh)); } - INDArray sub = view.get(point(0), interval(soFar, soFar + length)); + INDArray sub = view.get(interval(0,0,true), interval(soFar, soFar + length)); if(!Arrays.equals(sub.shape(), sh)){ char order = (sdl != null ? sdl.paramReshapeOrder(s) : sdv.paramReshapeOrder(s)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java index 80495d779b5d..796bf29d7251 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java @@ -146,9 +146,9 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi val biasParams = numBiasParams(layerConf); INDArray depthWiseWeightView = paramsView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)); + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)); INDArray pointWiseWeightView = paramsView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams + depthWiseParams, numParams(conf))); + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams + depthWiseParams, numParams(conf))); params.put(DEPTH_WISE_WEIGHT_KEY, createDepthWiseWeightMatrix(conf, depthWiseWeightView, initializeParams)); conf.addVariable(DEPTH_WISE_WEIGHT_KEY); @@ -156,7 +156,7 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi conf.addVariable(POINT_WISE_WEIGHT_KEY); if(layer.hasBias()){ - INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, biasParams)); + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, biasParams)); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); conf.addVariable(BIAS_KEY); } @@ -181,16 +181,16 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co val biasParams = numBiasParams(layerConf); INDArray depthWiseWeightGradientView = gradientView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)) + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams, biasParams + depthWiseParams)) .reshape('c', depthMultiplier, nIn, kernel[0], kernel[1]); INDArray pointWiseWeightGradientView = gradientView.get( - NDArrayIndex.point(0), NDArrayIndex.interval(biasParams + depthWiseParams, numParams(conf))) + NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(biasParams + depthWiseParams, numParams(conf))) .reshape('c', nOut, nIn * depthMultiplier, 1, 1); out.put(DEPTH_WISE_WEIGHT_KEY, depthWiseWeightGradientView); out.put(POINT_WISE_WEIGHT_KEY, pointWiseWeightGradientView); if(layerConf.hasBias()){ - INDArray biasGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)); + INDArray biasGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut)); out.put(BIAS_KEY, biasGradientView); } return out; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java index 8e2463bdfc61..9f0ab62d3ed1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java @@ -146,10 +146,10 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co private static Map getSubsets(INDArray in, long nIn, long nOut, boolean reshape, boolean hasLayerNorm){ long pos = nIn * nOut; - INDArray w = in.get(point(0), interval(0, pos)); - INDArray rw = in.get(point(0), interval(pos, pos + nOut * nOut)); + INDArray w = in.get(interval(0,0,true), interval(0, pos)); + INDArray rw = in.get(interval(0,0,true), interval(pos, pos + nOut * nOut)); pos += nOut * nOut; - INDArray b = in.get(point(0), interval(pos, pos + nOut)); + INDArray b = in.get(interval(0,0,true), interval(pos, pos + nOut)); if(reshape){ w = w.reshape('f', nIn, nOut); @@ -162,7 +162,7 @@ private static Map getSubsets(INDArray in, long nIn, long nOut, m.put(BIAS_KEY, b); if(hasLayerNorm){ pos += nOut; - INDArray g = in.get(point(0), interval(pos, pos + 2 * nOut)); + INDArray g = in.get(interval(0,0,true), interval(pos, pos + 2 * nOut)); m.put(GAIN_KEY, g); } return m; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java index f1562b7fdedc..f30ec84ae44f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java @@ -212,10 +212,10 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi encoderLayerNIn = encoderLayerSizes[i - 1]; } val weightParamCount = encoderLayerNIn * encoderLayerSizes[i]; - INDArray weightView = paramsView.get(NDArrayIndex.point(0), + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + weightParamCount)); soFar += weightParamCount; - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i])); soFar += encoderLayerSizes[i]; @@ -235,9 +235,9 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi //Last encoder layer -> p(z|x) val nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut; INDArray pzxWeightsMean = - paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); + paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); soFar += nWeightsPzx; - INDArray pzxBiasMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut)); + INDArray pzxBiasMean = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut)); soFar += nOut; INDArray pzxWeightsMeanReshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, @@ -252,9 +252,9 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi //Pretrain params INDArray pzxWeightsLogStdev2 = - paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); + paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); soFar += nWeightsPzx; - INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut)); + INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut)); soFar += nOut; INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, @@ -274,10 +274,10 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi decoderLayerNIn = decoderLayerSizes[i - 1]; } val weightParamCount = decoderLayerNIn * decoderLayerSizes[i]; - INDArray weightView = paramsView.get(NDArrayIndex.point(0), + INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + weightParamCount)); soFar += weightParamCount; - INDArray biasView = paramsView.get(NDArrayIndex.point(0), + INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i])); soFar += decoderLayerSizes[i]; @@ -298,9 +298,9 @@ public Map init(NeuralNetConfiguration conf, INDArray paramsVi int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; INDArray pxzWeightView = - paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount)); + paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + pxzWeightCount)); soFar += pxzWeightCount; - INDArray pxzBiasView = paramsView.get(NDArrayIndex.point(0), + INDArray pxzBiasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nDistributionParams)); INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], @@ -334,10 +334,10 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co encoderLayerNIn = encoderLayerSizes[i - 1]; } val weightParamCount = encoderLayerNIn * encoderLayerSizes[i]; - INDArray weightGradView = gradientView.get(NDArrayIndex.point(0), + INDArray weightGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + weightParamCount)); soFar += weightParamCount; - INDArray biasGradView = gradientView.get(NDArrayIndex.point(0), + INDArray biasGradView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i])); soFar += encoderLayerSizes[i]; @@ -351,9 +351,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co //Last encoder layer -> p(z|x) val nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut; INDArray pzxWeightsMean = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); soFar += nWeightsPzx; - INDArray pzxBiasMean = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut)); + INDArray pzxBiasMean = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut)); soFar += nOut; INDArray pzxWeightGradMeanReshaped = @@ -365,9 +365,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co //////////////////////////////////////////////////////// INDArray pzxWeightsLogStdev2 = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx)); soFar += nWeightsPzx; - INDArray pzxBiasLogStdev2 = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut)); + INDArray pzxBiasLogStdev2 = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut)); soFar += nOut; INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, @@ -384,10 +384,10 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co decoderLayerNIn = decoderLayerSizes[i - 1]; } long weightParamCount = decoderLayerNIn * decoderLayerSizes[i]; - INDArray weightView = gradientView.get(NDArrayIndex.point(0), + INDArray weightView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + weightParamCount)); soFar += weightParamCount; - INDArray biasView = gradientView.get(NDArrayIndex.point(0), + INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i])); soFar += decoderLayerSizes[i]; @@ -406,9 +406,9 @@ public Map getGradientsFromFlattened(NeuralNetConfiguration co int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; INDArray pxzWeightView = - gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount)); + gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + pxzWeightCount)); soFar += pxzWeightCount; - INDArray pxzBiasView = gradientView.get(NDArrayIndex.point(0), + INDArray pxzBiasView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nDistributionParams)); INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index 029700d112b2..673010af2518 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -35,6 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -75,6 +76,7 @@ public static class Builder { private InputType inputType; private Boolean validateOutputLayerConfig; + private DataType dataType; /** * Multilayer Network to tweak for transfer learning @@ -83,6 +85,7 @@ public static class Builder { public Builder(MultiLayerNetwork origModel) { this.origModel = origModel; this.origConf = origModel.getLayerWiseConfigurations().clone(); + this.dataType = origModel.getLayerWiseConfigurations().getDataType(); this.inputPreProcessors = origConf.getInputPreProcessors(); } @@ -321,8 +324,8 @@ public Builder addLayer(Layer layer) { val numParams = layer.initializer().numParams(layerConf); INDArray params; if (numParams > 0) { - params = Nd4j.create(1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true); + params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true, dataType); appendParams.add(someLayer.params()); appendConfs.add(someLayer.conf()); } else { @@ -469,8 +472,8 @@ private void nInReplaceBuild(int layerNum, int nIn, IWeightInit init) { layerImplF.setWeightInitFn(init); layerImplF.setNIn(nIn); long numParams = layerImpl.initializer().numParams(layerConf); - INDArray params = Nd4j.create(1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true); + INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); editedParams.set(layerNum, someLayer.params()); } @@ -487,8 +490,8 @@ private void nOutReplaceBuild(int layerNum, int nOut, IWeightInit scheme, IWeigh layerImplF.setWeightInitFn(scheme); layerImplF.setNOut(nOut); long numParams = layerImpl.initializer().numParams(layerConf); - INDArray params = Nd4j.create(1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true); + INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); editedParams.set(layerNum, someLayer.params()); if (layerNum + 1 < editedConfs.size()) { @@ -500,8 +503,8 @@ private void nOutReplaceBuild(int layerNum, int nOut, IWeightInit scheme, IWeigh layerImplF.setNIn(nOut); numParams = layerImpl.initializer().numParams(layerConf); if (numParams > 0) { - params = Nd4j.create(1, numParams); - someLayer = layerImpl.instantiate(layerConf, null, 0, params, true); + params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); + someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); editedParams.set(layerNum + 1, someLayer.params()); } } @@ -545,6 +548,7 @@ private MultiLayerConfiguration constructConf() { MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors) .setInputType(this.inputType).confs(allConfs) .validateOutputLayerConfig(validateOutputLayerConfig == null ? true : validateOutputLayerConfig) + .dataType(origConf.getDataType()) .build(); if (finetuneConfiguration != null) { finetuneConfiguration.applyToMultiLayerConfiguration(conf); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java index 65fb3d958375..28ce92a01507 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java @@ -299,7 +299,9 @@ private void initHelperMLN() { unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder() .inputPreProcessors(c.getInputPreProcessors()) .backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()) - .tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs).build()); + .tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs) + .dataType(origMLN.getLayerWiseConfigurations().getDataType()) + .build()); unFrozenSubsetMLN.init(); //copy over params for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index 89918d1c9f57..10195b59d6dc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -101,9 +101,9 @@ public BaseMultiLayerUpdater(T network, INDArray updaterState) { INDArray gradientViewSubset = null; INDArray paramsViewSubset = null; if (paramSizeThisVariable > 0) { - paramsViewSubset = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsViewSoFar, + paramsViewSubset = paramsView.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable)); - gradientViewSubset = gradientView.get(NDArrayIndex.point(0), NDArrayIndex + gradientViewSubset = gradientView.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex .interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable)); } @@ -163,14 +163,14 @@ public BaseMultiLayerUpdater(T network, INDArray updaterState) { int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart(); if (viewStateSize > 0) { - INDArray updaterViewSubset = updaterStateViewArray.get(NDArrayIndex.point(0), + INDArray updaterViewSubset = updaterStateViewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(updaterViewSoFar, updaterViewSoFar + viewStateSize)); ub.setUpdaterView(updaterViewSubset); ub.setUpdaterViewRequiresInitialization(updaterRequiresInit); } if (gradSize > 0) { - INDArray gradientViewSubset = gradientView.get(NDArrayIndex.point(0), + INDArray gradientViewSubset = gradientView.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + gradSize)); ub.setGradientView(gradientViewSubset); } @@ -359,11 +359,12 @@ protected List getMinibatchDivisionSubsets(INDArray from){ Map paramTable = t.paramTable(false); for(String s : layerParams) { if(t.updaterDivideByMinibatch(s)){ - currentEnd += paramTable.get(s).length(); + long l = paramTable.get(s).length(); + currentEnd += l; } else { //This param/gradient subset should be excluded if(currentEnd > currentStart){ - INDArray subset = from.get(NDArrayIndex.point(0), NDArrayIndex.interval(currentStart, currentEnd)); + INDArray subset = from.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(currentStart, currentEnd)); out.add(subset); } currentStart = paramsSoFar + paramTable.get(s).length(); @@ -375,7 +376,7 @@ protected List getMinibatchDivisionSubsets(INDArray from){ if(currentEnd > currentStart && currentStart < from.length()){ //Process last part of the gradient view array - INDArray subset = from.get(NDArrayIndex.point(0), NDArrayIndex.interval(currentStart, currentEnd)); + INDArray subset = from.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(currentStart, currentEnd)); out.add(subset); } return out; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 4b68c78c448b..f1849aa9694b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -491,7 +491,7 @@ public static INDArray reshapeCnn3dMask(@NonNull Convolution3D.DataFormat format int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4; lShape[channelIdx] = mask.size(channelIdx); //Keep existing channel size - INDArray bMask = workspaceMgr.createUninitialized(type, lShape, 'c'); + INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c'); int[] bcDims = broadcastDims.toIntArray(); Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, bcDims)); return reshape5dTo2d(format, bMask, workspaceMgr, type); @@ -548,7 +548,7 @@ public static INDArray adapt2dMask(INDArray mask, INDArray output, LayerWorkspac //Use workaround for: https://github.com/deeplearning4j/nd4j/issues/2066 val s = output.shape(); - INDArray bMask = workspaceMgr.create(type, new long[]{s[0], 1, s[2], s[3]}, 'c'); + INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1, s[2], s[3]}, 'c'); Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 1)); INDArray bMaskPermute = bMask.permute(0, 2, 3, 1).dup('c'); //Not sure if dup is strictly necessary... diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/MaskedReductionUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/MaskedReductionUtil.java index f3459cbb070e..3e48b6dbc52d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/MaskedReductionUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/MaskedReductionUtil.java @@ -48,7 +48,7 @@ public class MaskedReductionUtil { private MaskedReductionUtil(){ } public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray toReduce, INDArray mask, - int pnorm) { + int pnorm, DataType dataType) { if (toReduce.rank() != 3) { throw new IllegalArgumentException("Expect rank 3 array: got " + toReduce.rank()); } @@ -56,6 +56,8 @@ public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank()); } + mask = mask.castTo(dataType); + //Sum pooling: easy. Multiply by mask, then sum as normal //Average pooling: as above, but do a broadcast element-wise divi by mask.sum(1) //Max pooling: set to -inf if mask is 0, then do max as normal @@ -65,20 +67,20 @@ public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray //TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask; if(mask.dataType() == DataType.BOOL){ - negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType()); + negInfMask = Transforms.not(mask).castTo(dataType); } else { negInfMask = mask.rsub(1.0); } BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); - INDArray withInf = Nd4j.createUninitialized(toReduce.shape()); + INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastAddOp(toReduce, negInfMask, withInf, 0, 2)); //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op return withInf.max(2); case AVG: case SUM: - INDArray masked = Nd4j.createUninitialized(toReduce.shape()); + INDArray masked = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(toReduce, mask, masked, 0, 2)); INDArray summed = masked.sum(2); if (poolingType == PoolingType.SUM) { @@ -90,7 +92,7 @@ public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray return summed; case PNORM: //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 - INDArray masked2 = Nd4j.createUninitialized(toReduce.shape()); + INDArray masked2 = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(toReduce, mask, masked2, 0, 2)); INDArray abs = Transforms.abs(masked2, true); @@ -187,13 +189,15 @@ public static INDArray maskedPoolingEpsilonTimeSeries(PoolingType poolingType, I } - public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArray toReduce, INDArray mask, int pnorm) { + public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArray toReduce, INDArray mask, int pnorm, DataType dataType) { if(mask.rank() != 4){ //TODO BETTER ERROR MESSAGE EXPLAINING FORMAT //TODO ALSO HANDLE LEGACY FORMAT WITH WARNING WHERE POSSIBLE throw new IllegalStateException("Expected rank 4 mask array: Got array with shape " + Arrays.toString(mask.shape())); } + mask = mask.castTo(dataType); //no-op if already correct dtype + // [minibatch, channels, h, w] data with a mask array of shape [minibatch, 1, X, Y] // where X=(1 or inH) and Y=(1 or inW) @@ -214,20 +218,20 @@ public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArra //TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask; if(mask.dataType() == DataType.BOOL){ - negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType()); + negInfMask = Transforms.not(mask).castTo(dataType); } else { negInfMask = mask.rsub(1.0); } BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); - INDArray withInf = Nd4j.createUninitialized(toReduce.shape()); + INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastAddOp(toReduce, negInfMask, withInf, dimensions)); //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op return withInf.max(2, 3); case AVG: case SUM: - INDArray masked = Nd4j.createUninitialized(toReduce.shape()); + INDArray masked = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(toReduce, mask, masked, dimensions)); INDArray summed = masked.sum(2, 3); @@ -240,7 +244,7 @@ public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArra case PNORM: //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 - INDArray masked2 = Nd4j.createUninitialized(toReduce.shape()); + INDArray masked2 = Nd4j.createUninitialized(dataType, toReduce.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(toReduce, mask, masked2, dimensions)); INDArray abs = Transforms.abs(masked2, true); @@ -255,7 +259,7 @@ public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArra public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray input, INDArray mask, - INDArray epsilon2d, int pnorm) { + INDArray epsilon2d, int pnorm, DataType dataType) { // [minibatch, channels, h=1, w=X] or [minibatch, channels, h=X, w=1] data // with a mask array of shape [minibatch, X] @@ -263,6 +267,8 @@ public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray //If masking along height: broadcast dimensions are [0,2] //If masking along width: broadcast dimensions are [0,3] + mask = mask.castTo(dataType); //No-op if correct type + //General case: must be equal or 1 on each dimension int[] dimensions = new int[4]; int count = 0; @@ -280,13 +286,13 @@ public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray //TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask; if(mask.dataType() == DataType.BOOL){ - negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType()); + negInfMask = Transforms.not(mask).castTo(dataType); } else { negInfMask = mask.rsub(1.0); } BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); - INDArray withInf = Nd4j.createUninitialized(input.shape()); + INDArray withInf = Nd4j.createUninitialized(dataType, input.shape()); Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, dimensions)); //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op @@ -299,7 +305,7 @@ public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray //if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut //With masking: N differs for different time series - INDArray out = Nd4j.createUninitialized(input.shape(), 'f'); + INDArray out = Nd4j.createUninitialized(dataType, input.shape(), 'f'); //Broadcast copy op, then divide and mask to 0 as appropriate Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1)); @@ -317,7 +323,7 @@ public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray case PNORM: //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 - INDArray masked2 = Nd4j.createUninitialized(input.shape()); + INDArray masked2 = Nd4j.createUninitialized(dataType, input.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, dimensions)); INDArray abs = Transforms.abs(masked2, true); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index 71240b769675..f37ac245c371 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -312,6 +312,10 @@ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boo throw e; } } + + //Handle legacy config - no network DataType in config, in beta3 or earlier + if(params != null) + confFromJson.setDataType(params.dataType()); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(params, false); @@ -640,6 +644,11 @@ public static ComputationGraph restoreComputationGraph(@NonNull File file, boole throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " + "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead"); } + + //Handle legacy config - no network DataType in config, in beta3 or earlier + if(params != null) + confFromJson.setDataType(params.dataType()); + ComputationGraph cg = new ComputationGraph(confFromJson); cg.init(params, false); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java index 5ad69a993a75..7c24a264b1a5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java @@ -60,6 +60,7 @@ public static ComputationGraph toComputationGraph(MultiLayerNetwork net) { // for the updater state... ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + .dataType(net.getLayerWiseConfigurations().getDataType()) .graphBuilder(); MultiLayerConfiguration origConf = net.getLayerWiseConfigurations().clone(); @@ -452,7 +453,7 @@ protected static INDArray rebuildUpdaterStateArray(INDArray origUpdaterState, Li long soFar = 0; for( int sub=0; sub()); stateViewsPerParam.get(paramName).add(currSplit); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index 0939b5143168..f356fab71081 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -257,7 +257,7 @@ public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspac INDArray inReshape = in.reshape('f', in.size(0)*in.size(1), in.size(2)); - INDArray outReshape = workspaceMgr.create(arrayType, new long[]{inReshape.size(0), idxs.length}, 'f'); + INDArray outReshape = workspaceMgr.create(arrayType, in.dataType(), new long[]{inReshape.size(0), idxs.length}, 'f'); Nd4j.pullRows(inReshape, outReshape, 0, idxs); return workspaceMgr.leverageTo(arrayType, outReshape.reshape('f', in.size(0), in.size(1), in.size(2))); @@ -326,7 +326,7 @@ public static INDArray reverseTimeSeriesMask(INDArray mask, LayerWorkspaceMgr wo idxs[j++] = i; } - INDArray ret = workspaceMgr.createUninitialized(arrayType, new long[]{mask.size(0), idxs.length}, 'f'); + INDArray ret = workspaceMgr.createUninitialized(arrayType, mask.dataType(), new long[]{mask.size(0), idxs.length}, 'f'); return Nd4j.pullRows(mask, ret, 0, idxs); diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index d4c095be97d9..f945c4c9214d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -65,7 +65,7 @@ public void activate() { for (int i = 12; i < 16; i++) { params.putScalar(i, 1.0); } - Layer lstm = underlying.instantiate(conf, Collections.emptyList(), 0, params, false); + Layer lstm = underlying.instantiate(conf, Collections.emptyList(), 0, params, false, params.dataType()); double maskingValue = 0.0; MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index c1144ad2d883..609df6651ba1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -97,8 +97,8 @@ public JavaSparkContext getContext() { protected JavaRDD getBasicSparkDataSet(int nRows, INDArray input, INDArray labels) { List list = new ArrayList<>(); for (int i = 0; i < nRows; i++) { - INDArray inRow = input.getRow(i).dup(); - INDArray outRow = labels.getRow(i).dup(); + INDArray inRow = input.getRow(i, true).dup(); + INDArray outRow = labels.getRow(i, true).dup(); DataSet ds = new DataSet(inRow, outRow); list.add(ds); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java index 8a04c8d74dae..a4294161b4ee 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -49,9 +50,9 @@ public CustomLayer(@JsonProperty("someCustomParameter") double someCustomParamet @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams) { - CustomLayerImpl ret = new CustomLayerImpl(conf); + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java index f37bd02a7e53..680edd1fab31 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java @@ -18,6 +18,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; +import org.nd4j.linalg.api.buffer.DataType; /** * @@ -26,8 +27,8 @@ * Created by Alex on 26/08/2016. */ public class CustomLayerImpl extends BaseLayer { - public CustomLayerImpl(NeuralNetConfiguration conf) { - super(conf); + public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); } @Override diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index d27c22854143..1cc4028062aa 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -163,7 +163,7 @@ public void testDistributedScoring() { List> dataWithKeys = new ArrayList<>(); for (int i = 0; i < nRows; i++) { - DataSet ds = new DataSet(features.getRow(i).dup(), labels.getRow(i).dup()); + DataSet ds = new DataSet(features.getRow(i,true).dup(), labels.getRow(i,true).dup()); dataWithKeys.add(new Tuple2<>(String.valueOf(i), ds)); } JavaPairRDD dataWithKeysRdd = sc.parallelizePairs(dataWithKeys); @@ -188,7 +188,7 @@ public void testDistributedScoring() { List dataNoKeys = new ArrayList<>(); for (int i = 0; i < nRows; i++) { - dataNoKeys.add(new DataSet(features.getRow(i).dup(), labels.getRow(i).dup())); + dataNoKeys.add(new DataSet(features.getRow(i,true).dup(), labels.getRow(i,true).dup())); } JavaRDD dataNoKeysRdd = sc.parallelize(dataNoKeys); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index a701f582d055..93506657a583 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -297,7 +297,7 @@ public void testDistributedScoring() { List> dataWithKeys = new ArrayList<>(); for (int i = 0; i < nRows; i++) { - DataSet ds = new DataSet(features.getRow(i).dup(), labels.getRow(i).dup()); + DataSet ds = new DataSet(features.getRow(i,true).dup(), labels.getRow(i,true).dup()); dataWithKeys.add(new Tuple2<>(String.valueOf(i), ds)); } JavaPairRDD dataWithKeysRdd = sc.parallelizePairs(dataWithKeys); @@ -322,7 +322,7 @@ public void testDistributedScoring() { List dataNoKeys = new ArrayList<>(); for (int i = 0; i < nRows; i++) { - dataNoKeys.add(new DataSet(features.getRow(i).dup(), labels.getRow(i).dup())); + dataNoKeys.add(new DataSet(features.getRow(i,true).dup(), labels.getRow(i,true).dup())); } JavaRDD dataNoKeysRdd = sc.parallelize(dataNoKeys); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index acd2698f22e0..82a770d6c583 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -999,60 +999,6 @@ public List propertiesToResolveForFunction(DifferentialFunction function } - /** - * Returns true if the given function has ndarray properties to resolve. - * - * @param function the function to check - * @return true if the function has yet to be resolved properties - */ - public boolean hasPropertiesToResolve(DifferentialFunction function) { - return propertiesToResolve.containsKey(function.getOwnName()); - } - - - /** - * Get the property for a given function - * - * @param functionInstance the function to get the - * property for - * @param propertyName the name of the property to get - * @param the inferred return type - * @return the property for the given function - */ - public T getPropertyForFunction(DifferentialFunction functionInstance, String propertyName) { - if (!propertiesForFunction.containsKey(functionInstance.getOwnName())) { - return null; - } else { - val map = propertiesForFunction.get(functionInstance.getOwnName()); - return (T) map.get(propertyName); - - } - } - - /** - * Add a property for the given function - * - * @param functionFor the function add a property for - * @param propertyName the property name - * @param property the property value - */ - public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, INDArray property) { - addPropertyForFunction(functionFor, propertyName, (Object) property); - } - - - /** - * Add a property for the given function - * - * @param functionFor the function to add the property for - * @param propertyName the name of the property to add the value for - * @param property the property value to add - */ - public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, long property) { - addPropertyForFunction(functionFor, propertyName, (Object) property); - } - - private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) { if (!propertiesForFunction.containsKey(functionFor.getOwnName())) { Map fields = new LinkedHashMap<>(); @@ -5029,4 +4975,6 @@ public Map calculateOutputDataTypes( Map out = session.output(allVars, phValues); return out; } + + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 2886c264472c..098c971d946c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -10,6 +10,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.indexing.conditions.Condition; +import static org.nd4j.autodiff.samediff.ops.SDValidation.*; + /** * Core op creator methods available via SameDiff class directly * @@ -77,6 +79,7 @@ public SDVariable argmax(SDVariable in, int... dimensions) { * of rank (input rank) if keepdims = true */ public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { + validateNumerical("argmax", in); SDVariable ret = f().argmax(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } @@ -143,6 +146,7 @@ public SDVariable argmin(String name, SDVariable in, int... dimensions) { * of rank (input rank) if keepdims = true */ public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { + validateNumerical("argmin", in); SDVariable ret = f().argmin(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } @@ -242,6 +246,8 @@ public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB, */ public SDVariable[] batchMmul(String[] names, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) { + validateSameType("batchMmul", true, matricesA); + validateSameType("batchMmul", true, matricesB); SDVariable[] result = f().batchMmul(matricesA, matricesB, transposeA, transposeB); return updateVariableNamesAndReferences(result, names); } @@ -290,6 +296,7 @@ public SDVariable concat(int dimension, SDVariable... inputs) { * @see #stack(String, int, SDVariable...) */ public SDVariable concat(String name, int dimension, SDVariable... inputs) { + validateSameType("concat", false, inputs); SDVariable result = f().concat(dimension, inputs); return updateVariableNameAndReference(result, name); } @@ -317,6 +324,7 @@ public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int * @return Output variable */ public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { + validateNumerical("cumprod", in); SDVariable ret = f().cumprod(in, exclusive, reverse, axis); return updateVariableNameAndReference(ret, name); } @@ -344,6 +352,7 @@ public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int. * @return Output variable */ public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { + validateNumerical("cumsum", in); SDVariable ret = f().cumsum(in, exclusive, reverse, axis); return updateVariableNameAndReference(ret, name); } @@ -370,6 +379,7 @@ public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { * @return */ public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("dot", x, y); SDVariable ret = f().dot(x, y, dimensions); return updateVariableNameAndReference(ret, name); } @@ -622,6 +632,7 @@ public SDVariable gt(SDVariable x, double y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(String name, SDVariable x, double y) { + validateNumerical("greater than (gt)", x); SDVariable result = f().gt(x, y); return updateVariableNameAndReference(result, name); } @@ -652,6 +663,7 @@ public SDVariable gt(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("greater than (gt)", x, y); SDVariable result = f().gt(x, y); return updateVariableNameAndReference(result, name); } @@ -680,6 +692,7 @@ public SDVariable gte(SDVariable x, double y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(String name, SDVariable x, double y) { + validateNumerical("greater than or equal (gte)", x); SDVariable result = f().gte(x, y); return updateVariableNameAndReference(result, name); } @@ -710,6 +723,7 @@ public SDVariable gte(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(String name, SDVariable x, SDVariable y) { + validateNumerical("greater than or equal (gte)", x, y); SDVariable result = f().gte(x, y); return updateVariableNameAndReference(result, name); } @@ -758,6 +772,7 @@ public SDVariable invertPermutation(SDVariable input) { * @return 1D inverted permutation */ public SDVariable invertPermutation(String name, SDVariable input) { + validateInteger("invert permutation", input); SDVariable ret = f().invertPermutation(input, false); return updateVariableNameAndReference(ret, name); } @@ -853,6 +868,7 @@ public SDVariable lt(SDVariable x, double y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(String name, SDVariable x, double y) { + validateNumerical("less than (lt)", x); SDVariable result = f().lt(x, y); return updateVariableNameAndReference(result, name); } @@ -883,6 +899,7 @@ public SDVariable lt(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(String name, SDVariable x, SDVariable y) { + validateNumerical("less than (lt)", x, y); SDVariable result = f().lt(x, y); return updateVariableNameAndReference(result, name); } @@ -911,6 +928,7 @@ public SDVariable lte(SDVariable x, double y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(String name, SDVariable x, double y) { + validateNumerical("less than or equal (lte)", x); SDVariable result = f().lte(x, y); return updateVariableNameAndReference(result, name); } @@ -941,6 +959,7 @@ public SDVariable lte(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(String name, SDVariable x, SDVariable y) { + validateNumerical("less than or equal (lte)", x, y); SDVariable result = f().lte(x, y); return updateVariableNameAndReference(result, name); } @@ -1051,6 +1070,7 @@ public SDVariable max(String name, SDVariable x, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("max reduction", x); SDVariable result = f().max(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -1077,6 +1097,7 @@ public SDVariable max(SDVariable first, SDVariable second) { * @return Output variable */ public SDVariable max(String name, SDVariable first, SDVariable second) { + validateNumerical("pairwise maxiumum (max)", first, second); SDVariable result = f().max(first, second); return updateVariableNameAndReference(result, name); } @@ -1119,6 +1140,7 @@ public SDVariable mean(String name, SDVariable x, int... dimension) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimension) { + validateNumerical("mean reduction", x); SDVariable result = f().mean(x, keepDims, dimension); return updateVariableNameAndReference(result, name); } @@ -1173,6 +1195,7 @@ public SDVariable min(String name, SDVariable x, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("min reduction", x); SDVariable result = f().min(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); @@ -1200,6 +1223,7 @@ public SDVariable min(SDVariable first, SDVariable second) { * @return Output variable */ public SDVariable min(String name, SDVariable first, SDVariable second) { + validateNumerical("mean (pairwise)", first, second); SDVariable result = f().min(first, second); return updateVariableNameAndReference(result, name); } @@ -1229,6 +1253,7 @@ public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) { * @return Output variable */ public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) { + validateNumerical("matrix multiplication (mmul)", x, y); SDVariable result = f().mmul(x, y, transpose); return updateVariableNameAndReference(result, name); } @@ -1280,6 +1305,7 @@ public SDVariable neq(SDVariable x, double y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(String name, SDVariable x, double y) { + validateNumerical("not equals (neq)", x); SDVariable result = f().neq(x, y); return updateVariableNameAndReference(result, name); } @@ -1310,6 +1336,7 @@ public SDVariable neq(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(String name, SDVariable x, SDVariable y) { + validateNumerical("not equals (neq)", x, y); SDVariable result = f().neq(x, y); return updateVariableNameAndReference(result, name); } @@ -1344,6 +1371,7 @@ public SDVariable norm1(String name, SDVariable x, int... dimensions) { * @return Output variable */ public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("norm1 reduction", x); SDVariable result = f().norm1(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -1378,6 +1406,7 @@ public SDVariable norm2(String name, SDVariable x, int... dimensions) { * @return Output variable */ public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("norm2 reduction", x); SDVariable result = f().norm2(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -1413,6 +1442,7 @@ public SDVariable normmax(String name, SDVariable x, int... dimensions) { * @return Output variable */ public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("norm max reduction", x); SDVariable result = f().normmax(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -1442,6 +1472,7 @@ public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, d * As per {@link #oneHot(String, SDVariable, int, int, double, double)} but allows configuring the output datatype */ public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { + validateInteger("oneHot", "indices", indices); SDVariable ret = f().onehot(indices, depth, axis, on, off, dataType); return updateVariableNameAndReference(ret, name); } @@ -1517,6 +1548,7 @@ public SDVariable parallel_stack(SDVariable[] values) { * @see #stack(String, int, SDVariable...) */ public SDVariable parallel_stack(String name, SDVariable[] values) { + validateSameType("parallel_stack", false, values); SDVariable ret = f().parallel_stack(values); return updateVariableNameAndReference(ret, name); } @@ -1584,6 +1616,7 @@ public SDVariable prod(String name, SDVariable x, int... dimensions) { * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("product reduction (prod)", x); SDVariable result = f().prod(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -1802,6 +1835,7 @@ public SDVariable reshape(SDVariable x, SDVariable shape) { * @see #reshape(SDVariable, int[]) */ public SDVariable reshape(String name, SDVariable x, SDVariable shape) { + validateInteger("reshape", "shape", shape); SDVariable result = f().reshape(x, shape); return updateVariableNameAndReference(result, name); } @@ -1894,6 +1928,7 @@ public SDVariable scalarFloorMod(SDVariable in, Number value) { * @return Output variable */ public SDVariable scalarFloorMod(String name, SDVariable in, Number value) { + validateNumerical("floorMod", in); SDVariable ret = f().scalarFloorMod(in, value); return updateVariableNameAndReference(ret, name); } @@ -1918,6 +1953,7 @@ public SDVariable scalarMax(SDVariable in, Number value) { * @return Output variable */ public SDVariable scalarMax(String name, SDVariable in, Number value) { + validateNumerical("max", in); SDVariable ret = f().scalarMax(in, value); return updateVariableNameAndReference(ret, name); } @@ -1942,6 +1978,7 @@ public SDVariable scalarMin(SDVariable in, Number value) { * @return Output variable */ public SDVariable scalarMin(String name, SDVariable in, Number value) { + validateNumerical("min", in); SDVariable ret = f().scalarMin(in, value); return updateVariableNameAndReference(ret, name); } @@ -1991,6 +2028,7 @@ public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterAdd", "indices", indices); SDVariable ret = f().scatterAdd(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2016,6 +2054,7 @@ public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterDiv", "indices", indices); SDVariable ret = f().scatterDiv(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2041,6 +2080,7 @@ public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterMax", "indices", indices); SDVariable ret = f().scatterMax(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2066,6 +2106,7 @@ public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterMin", "indices", indices); SDVariable ret = f().scatterMin(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2091,6 +2132,7 @@ public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterMul", "indices", indices); SDVariable ret = f().scatterMul(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2116,6 +2158,7 @@ public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable upda * @return The updated variable */ public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterSub", "indices", indices); SDVariable ret = f().scatterSub(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2142,6 +2185,7 @@ public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable u * @return The updated variable */ public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, SDVariable updates) { + validateInteger("scatterUpdate", "indices", indices); SDVariable ret = f().scatterUpdate(ref, indices, updates); return updateVariableNameAndReference(ret, name); } @@ -2168,6 +2212,8 @@ public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) { * @return Segment max output */ public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds) { + validateNumerical("segmentMax", "data", data); + validateInteger("segmentMax", "segmentIds", segmentIds); SDVariable ret = f().segmentMax(data, segmentIds); return updateVariableNameAndReference(ret, name); } @@ -2193,6 +2239,8 @@ public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) { * @return Segment mean output */ public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds) { + validateNumerical("segmentMean", "data", data); + validateInteger("segmentMean", "segmentIds", segmentIds); SDVariable ret = f().segmentMean(data, segmentIds); return updateVariableNameAndReference(ret, name); } @@ -2218,6 +2266,8 @@ public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) { * @return Segment min output */ public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds) { + validateNumerical("segmentMin", "data", data); + validateInteger("segmentMin", "segmentIds", segmentIds); SDVariable ret = f().segmentMin(data, segmentIds); return updateVariableNameAndReference(ret, name); } @@ -2243,6 +2293,8 @@ public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) { * @return Segment product output */ public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds) { + validateNumerical("segmentProd", "data", data); + validateInteger("segmentProd", "segmentIds", segmentIds); SDVariable ret = f().segmentProd(data, segmentIds); return updateVariableNameAndReference(ret, name); } @@ -2268,6 +2320,8 @@ public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) { * @return Segment sum output */ public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds) { + validateNumerical("segmentSum", "data", data); + validateInteger("segmentSum", "segmentIds", segmentIds); SDVariable ret = f().segmentSum(data, segmentIds); return updateVariableNameAndReference(ret, name); } @@ -2283,6 +2337,7 @@ public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType * @see #sequenceMask(String, SDVariable, SDVariable, DataType) */ public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { + validateInteger("sequenceMask", "lengths", lengths); SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); return updateVariableNameAndReference(ret, name); } @@ -2319,6 +2374,7 @@ public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType d * @return Output variable */ public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, DataType dataType) { + validateInteger("sequenceMask", "lengths", lengths); SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); return updateVariableNameAndReference(ret, name); } @@ -2428,6 +2484,7 @@ public SDVariable squaredNorm(SDVariable x, int... dimensions) { * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} */ public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("squaredNorm", x); SDVariable result = f().squaredNorm(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -2489,6 +2546,7 @@ public SDVariable stack(int axis, SDVariable... values) { * @see #unstack(String[], SDVariable, int, int) */ public SDVariable stack(String name, int axis, SDVariable... values) { + validateSameType("stack", false, values); SDVariable ret = f().stack(values, axis); return updateVariableNameAndReference(ret, name); } @@ -2529,6 +2587,7 @@ public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorre * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { + validateNumerical("standard deviation", x); SDVariable result = f().std(x, biasCorrected, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -2672,6 +2731,7 @@ public SDVariable sum(String name, SDVariable x, int... dimensions) { * of rank (input rank) if keepdims = true */ public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { + validateNumerical("sum reduction", x); SDVariable result = f().sum(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } @@ -2705,6 +2765,7 @@ public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[][] dimensions) { + validateNumerical("tensorMmul", x, y); SDVariable result = f().tensorMmul(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -2782,6 +2843,8 @@ public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int * @return Unsorted segment max output */ public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, int numSegments) { + validateNumerical("unsortedSegmentMax", "data", data); + validateInteger("unsortedSegmentMax", "segmentIds", segmentIds); SDVariable ret = f().unsortedSegmentMax(data, segmentIds, numSegments); return updateVariableNameAndReference(ret, name); } @@ -2807,6 +2870,8 @@ public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, in * @return Unsorted segment mean output */ public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, int numSegments) { + validateNumerical("unsortedSegmentMean", "data", data); + validateInteger("unsortedSegmentMean", "segmentIds", segmentIds); SDVariable ret = f().unsortedSegmentMean(data, segmentIds, numSegments); return updateVariableNameAndReference(ret, name); } @@ -2832,6 +2897,8 @@ public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int * @return Unsorted segment min output */ public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, int numSegments) { + validateNumerical("unsortedSegmentMin", "data", data); + validateInteger("unsortedSegmentMin", "segmentIds", segmentIds); SDVariable ret = f().unsortedSegmentMin(data, segmentIds, numSegments); return updateVariableNameAndReference(ret, name); } @@ -2856,6 +2923,8 @@ public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, in * @return Unsorted segment product output */ public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, int numSegments) { + validateNumerical("unsortedSegmentProd", "data", data); + validateInteger("unsortedSegmentProd", "segmentIds", segmentIds); SDVariable ret = f().unsortedSegmentProd(data, segmentIds, numSegments); return updateVariableNameAndReference(ret, name); } @@ -2904,6 +2973,8 @@ public SDVariable unsortedSegmentSum(@NonNull SDVariable data, @NonNull SDVariab * @return Unsorted segment sum output */ public SDVariable unsortedSegmentSum(String name, @NonNull SDVariable data, @NonNull SDVariable segmentIds, int numSegments) { + validateNumerical("unsortedSegmentSum", "data", data); + validateInteger("unsortedSegmentSum", "segmentIds", segmentIds); SDVariable ret = f().unsortedSegmentSum(data, segmentIds, numSegments); return updateVariableNameAndReference(ret, name); } @@ -2986,6 +3057,7 @@ public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorre * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { + validateNumerical("variance", x); SDVariable result = f().variance(x, biasCorrected, keepDims, dimensions); return updateVariableNameAndReference(result, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index bc03373a8b4e..e8c2f04420b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -4,6 +4,9 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateNumerical; + /** * SameDiff Convolutional Neural Network operations - CNN1d, 2d and 3d ops - as well as related functions.
* Accessible via {@link SameDiff#cnn()}
@@ -40,6 +43,7 @@ public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig * @return Result after applying average pooling on the input */ public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + validateFloatingPoint("avgPooling2d", input); SDVariable ret = f().avgPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } @@ -68,6 +72,7 @@ public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig * @return Result after applying average pooling on the input */ public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + validateFloatingPoint("avgPooling3d", input); SDVariable ret = f().avgPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); } @@ -91,6 +96,7 @@ public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { + validateNumerical("batchToSpace", x); SDVariable ret = f().batchToSpace(x, blocks, crops); return updateVariableNameAndReference(ret, name); } @@ -144,6 +150,8 @@ public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv * @return */ public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { + validateFloatingPoint("conv1d", input); + validateFloatingPoint("conv1d", weights); SDVariable ret = f().conv1d(input, weights, conv1DConfig); return updateVariableNameAndReference(ret, name); } @@ -172,6 +180,9 @@ public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig * @return result of conv2d op */ public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { + validateFloatingPoint("conv2d", "input", layerInput); + validateFloatingPoint("conv2d", "weights", weights); + validateFloatingPoint("conv2d", "bias", bias); SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = weights; @@ -202,6 +213,8 @@ public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { * @return result of convolution 2d operation */ public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { + for(SDVariable v : inputs) + validateNumerical("conv2d", v); SDVariable ret = f().conv2d(inputs, config); return updateVariableNameAndReference(ret, name); } @@ -233,6 +246,9 @@ public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv * @return Conv3d output variable */ public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { + validateFloatingPoint("conv3d", "input", input); + validateFloatingPoint("conv3d", "weights", weights); + validateFloatingPoint("conv3d", "bias", bias); SDVariable[] args; if (bias == null) { args = new SDVariable[]{input, weights}; @@ -297,6 +313,9 @@ public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DCo * @return result of deconv2d op */ public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { + validateFloatingPoint("deconv2d", "input", layerInput); + validateFloatingPoint("deconv2d", "weights", weights); + validateFloatingPoint("deconv2d", "bias", bias); SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = weights; @@ -327,6 +346,8 @@ public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { * @return result of deconv2d op */ public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { + for(SDVariable v : inputs) + validateNumerical("deconv2d", v); SDVariable ret = f().deconv2d(inputs, deconv2DConfig); return updateVariableNameAndReference(ret, name); } @@ -341,6 +362,9 @@ public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deco * @param config Configuration */ public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { + validateFloatingPoint("conv3d", input); + validateFloatingPoint("conv3d", weights); + validateFloatingPoint("conv3d", bias); SDVariable ret = f().deconv3d(input, weights, bias, config); return updateVariableNameAndReference(ret, name); } @@ -404,6 +428,9 @@ public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights * @return result of depthwise conv2d op */ public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { + validateFloatingPoint("depthwiseConv2d", "input", layerInput); + validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); + validateFloatingPoint("depthwiseConv2d", "bias", bias); SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = depthWeights; @@ -436,6 +463,8 @@ public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DC * @return result of depthwise conv2d op */ public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { + for(SDVariable v : inputs) + validateFloatingPoint("depthWiseConv2d", v); SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); return updateVariableNameAndReference(ret, name); } @@ -541,6 +570,7 @@ public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNor */ public SDVariable localResponseNormalization(String name, SDVariable input, LocalResponseNormalizationConfig lrnConfig) { + validateFloatingPoint("local response normalization", input); SDVariable ret = f().localResponseNormalization(input, lrnConfig); return updateVariableNameAndReference(ret, name); } @@ -568,6 +598,7 @@ public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig * @return Result after applying max pooling on the input */ public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + validateNumerical("maxPooling2d", input); SDVariable ret = f().maxPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } @@ -596,6 +627,7 @@ public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig * @return Result after applying max pooling on the input */ public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + validateNumerical("maxPooling3d", input); SDVariable ret = f().maxPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); } @@ -631,6 +663,10 @@ public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights */ public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, SDVariable bias, Conv2DConfig config) { + validateFloatingPoint("separableConv2d", "input", layerInput); + validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); + validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); + validateFloatingPoint("separableConv2d", "bias", bias); SDVariable[] arr = new SDVariable[bias == null ? 3 : 4]; arr[0] = layerInput; arr[1] = depthWeights; @@ -662,6 +698,8 @@ public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { * @return result of separable convolution 2d operation */ public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { + for(SDVariable v : inputs) + validateFloatingPoint("sconv2d", v); SDVariable ret = f().sconv2d(inputs, conv2DConfig); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index 81409c78e3c2..507992c9de27 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -9,6 +9,8 @@ import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; +import static org.nd4j.autodiff.samediff.ops.SDValidation.*; + /** * SameDiff loss functions
* Accessible via {@link SameDiff#loss()} @@ -39,6 +41,8 @@ public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @No */ public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("absolute difference loss", "predictions", predictions); + validateNumerical("absolute difference loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce); @@ -78,6 +82,8 @@ public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNul */ public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, int dimension) { + validateFloatingPoint("cosine distance loss", "predictions", predictions); + validateNumerical("cosine distance loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension); @@ -115,6 +121,8 @@ public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDV */ public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("hinge loss", "predictions", predictions); + validateNumerical("hinge loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossHinge(label, predictions, weights, lossReduce); @@ -157,6 +165,8 @@ public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDV */ public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, double delta) { + validateFloatingPoint("huber loss", "predictions", predictions); + validateNumerical("huber loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta); @@ -190,6 +200,7 @@ public SDVariable l2Loss(@NonNull SDVariable var) { * @return L2 loss */ public SDVariable l2Loss(String name, @NonNull SDVariable var) { + validateNumerical("l2 loss", var); SDVariable result = f().lossL2(var); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -216,6 +227,8 @@ public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVar */ public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) { + validateFloatingPoint("log loss", "predictions", predictions); + validateNumerical("log loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon); @@ -251,6 +264,8 @@ public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SD */ public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("log poisson loss", "predictions", predictions); + validateNumerical("log poisson loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce); @@ -287,6 +302,8 @@ public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNul */ public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("log poisson (full) loss", "predictions", predictions); + validateNumerical("log poisson (full) loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce); @@ -322,6 +339,8 @@ public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable labe * @return Loss variable, scalar output */ public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("main pairwise squared error loss", "predictions", predictions); + validateNumerical("mean pairwise squared error loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce); @@ -351,6 +370,8 @@ public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonN */ public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { + validateFloatingPoint("mean squared error loss", "predictions", predictions); + validateNumerical("mean squared error loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictions.dataType(), 1.0); SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce); @@ -396,6 +417,8 @@ public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @N */ public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictionLogits, SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { + validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits); + validateNumerical("sigmoid cross entropy loss", "labels", label); if (weights == null) weights = sd.scalar(null, predictionLogits.dataType(), 1.0); SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing); @@ -440,6 +463,8 @@ public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @N */ public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable oneHotLabels, @NonNull SDVariable logitPredictions, SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { + validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions); + validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels); if (weights == null) weights = sd.scalar(null, logitPredictions.dataType(), 1.0); SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing); @@ -473,6 +498,8 @@ public SDVariable sparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull * @return Softmax cross entropy */ public SDVariable sparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) { + validateFloatingPoint("sparse softmax cross entropy", "logits (predictions)", logits); + validateInteger("sparse softmax cross entropy", "labels", labels); Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits); SDVariable result = f().lossSparseSoftmaxCrossEntropy(logits, labels); result = updateVariableNameAndReference(result, name); @@ -504,6 +531,8 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable */ public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, SDVariable weights) { + validateFloatingPoint("weighted cross entropy with logits", "inputs", inputs); + validateNumerical("weighted cross entropy with logits", "targets", targets); SDVariable result = f().weightedCrossEntropyWithLogits(targets, inputs, weights); result = updateVariableNameAndReference(result, name); result.markAsLoss(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 09cf15102b45..3b4f2c83fe46 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -13,6 +13,8 @@ import java.util.List; +import static org.nd4j.autodiff.samediff.ops.SDValidation.*; + /** * SameDiff math operations
* Accessible via {@link SameDiff#math()} @@ -42,6 +44,7 @@ public SDVariable abs(SDVariable x) { * @return Output variable */ public SDVariable abs(String name, SDVariable x) { + validateNumerical("abs", x); SDVariable result = f().abs(x); return updateVariableNameAndReference(result, name); } @@ -64,6 +67,7 @@ public SDVariable acos(SDVariable x) { * @return Output variable */ public SDVariable acos(String name, SDVariable x) { + validateNumerical("acos", x); SDVariable result = f().acos(x); return updateVariableNameAndReference(result, name); } @@ -86,6 +90,7 @@ public SDVariable acosh(SDVariable x) { * @return Output variable */ public SDVariable acosh(String name, SDVariable x) { + validateNumerical("acosh", x); SDVariable result = f().acosh(x); return updateVariableNameAndReference(result, name); } @@ -110,6 +115,7 @@ public SDVariable amax(SDVariable in, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amax(String name, SDVariable in, int... dimensions) { + validateNumerical("amax", in); SDVariable ret = f().amax(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -122,6 +128,7 @@ public SDVariable amax(String name, SDVariable in, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amean(SDVariable in, int... dimensions) { + validateNumerical("amean", in); return amean(null, in, dimensions); } @@ -134,6 +141,7 @@ public SDVariable amean(SDVariable in, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amean(String name, SDVariable in, int... dimensions) { + validateNumerical("amean", in); SDVariable ret = f().amean(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -158,6 +166,7 @@ public SDVariable amin(SDVariable in, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amin(String name, SDVariable in, int... dimensions) { + validateNumerical("amin", in); SDVariable ret = f().amin(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -188,6 +197,7 @@ public SDVariable and(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable and(String name, SDVariable x, SDVariable y) { + validateBool("boolean and", x, y); SDVariable result = f().and(x, y); return updateVariableNameAndReference(result, name); } @@ -210,6 +220,7 @@ public SDVariable asin(SDVariable x) { * @return Output variable */ public SDVariable asin(String name, SDVariable x) { + validateNumerical("asin", x); SDVariable result = f().asin(x); return updateVariableNameAndReference(result, name); } @@ -232,6 +243,7 @@ public SDVariable asinh(SDVariable x) { * @return Output variable */ public SDVariable asinh(String name, SDVariable x) { + validateNumerical("asinh", x); SDVariable result = f().asinh(x); return updateVariableNameAndReference(result, name); } @@ -256,6 +268,7 @@ public SDVariable asum(SDVariable in, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable asum(String name, SDVariable in, int... dimensions) { + validateNumerical("asum", in); SDVariable ret = f().asum(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -278,6 +291,7 @@ public SDVariable atan(SDVariable x) { * @return Output variable */ public SDVariable atan(String name, SDVariable x) { + validateNumerical("atan", x); SDVariable result = f().atan(x); return updateVariableNameAndReference(result, name); } @@ -304,6 +318,7 @@ public SDVariable atan2(SDVariable y, SDVariable x) { * @return Output variable */ public SDVariable atan2(String name, SDVariable y, SDVariable x) { + validateNumerical("atan2", y, x); SDVariable ret = f().atan2(y, x); return updateVariableNameAndReference(ret, name); } @@ -326,6 +341,7 @@ public SDVariable atanh(SDVariable x) { * @return Output variable */ public SDVariable atanh(String name, SDVariable x) { + validateNumerical("atanh", x); SDVariable result = f().atanh(x); return updateVariableNameAndReference(result, name); } @@ -350,6 +366,7 @@ public SDVariable ceil(SDVariable x) { * @return Output variable */ public SDVariable ceil(String name, SDVariable x) { + validateFloatingPoint("ceil", x); SDVariable ret = f().ceil(x); return updateVariableNameAndReference(ret, name); } @@ -378,6 +395,7 @@ public SDVariable clipByNorm(SDVariable x, double clipValue) { * @return Output variable */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue) { + validateFloatingPoint("clip by norm", x); SDVariable ret = f().clipByNorm(x, clipValue); return updateVariableNameAndReference(ret, name); } @@ -410,6 +428,7 @@ public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) * @return Output variable */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { + validateFloatingPoint("clip by norm", x); SDVariable ret = f().clipByNorm(x, clipValue, dimensions); return updateVariableNameAndReference(ret, name); } @@ -442,6 +461,7 @@ public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValu * @return Output variable */ public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { + validateNumerical("clip by value", x); SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax); return updateVariableNameAndReference(ret, name); } @@ -471,6 +491,8 @@ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pre * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, DataType dataType) { + validateInteger("confusionMatrix", "labels", labels); + validateInteger("confusionMatrix", "prediction", pred); SDVariable result = f().confusionMatrix(labels, pred, dataType); return updateVariableNameAndReference(result, name); } @@ -498,6 +520,8 @@ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer nu * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) { + validateInteger("confusionMatrix", "labels", labels); + validateInteger("confusionMatrix", "prediction", pred); SDVariable result = f().confusionMatrix(labels, pred, numClasses); return updateVariableNameAndReference(result, name); } @@ -525,6 +549,9 @@ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) { + validateInteger("confusionMatrix", "labels", labels); + validateInteger("confusionMatrix", "prediction", pred); + validateNumerical("confusionMatrix", "weights", weights); SDVariable result = f().confusionMatrix(labels, pred, weights); return updateVariableNameAndReference(result, name); } @@ -553,6 +580,9 @@ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer nu * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { + validateInteger("confusionMatrix", "labels", labels); + validateInteger("confusionMatrix", "prediction", pred); + validateNumerical("confusionMatrix", "weights", weights); SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights); return updateVariableNameAndReference(result, name); } @@ -575,6 +605,7 @@ public SDVariable cos(SDVariable x) { * @return Output variable */ public SDVariable cos(String name, SDVariable x) { + validateNumerical("cos", x); SDVariable result = f().cos(x); return updateVariableNameAndReference(result, name); } @@ -597,6 +628,7 @@ public SDVariable cosh(SDVariable x) { * @return Output variable */ public SDVariable cosh(String name, SDVariable x) { + validateNumerical("cosh", x); SDVariable result = f().cosh(x); return updateVariableNameAndReference(result, name); } @@ -621,6 +653,7 @@ public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) * @return Output variable */ public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("cosine distance", x, y); SDVariable result = f().cosineDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -643,6 +676,7 @@ public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions * @return Output variable */ public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("cosine similarity", x, y); SDVariable cosim = f().cosineSimilarity(x, y, dimensions); return updateVariableNameAndReference(cosim, name); } @@ -667,6 +701,7 @@ public SDVariable countNonZero(SDVariable input, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countNonZero(String name, SDVariable input, int... dimensions) { + validateNumerical("countNonZero", input); SDVariable res = f().countNonZero(input, dimensions); return updateVariableNameAndReference(res, name); } @@ -691,6 +726,7 @@ public SDVariable countZero(SDVariable input, int... dimensions) { * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countZero(String name, SDVariable input, int... dimensions) { + validateNumerical("countNonZero", input); SDVariable res = f().countZero(input, dimensions); return updateVariableNameAndReference(res, name); } @@ -711,6 +747,7 @@ public SDVariable cross(SDVariable a, SDVariable b) { * @return Element-wise cross product */ public SDVariable cross(String name, SDVariable a, SDVariable b) { + validateNumerical("cross", a, b); SDVariable ret = f().cross(a, b); return updateVariableNameAndReference(ret, name); } @@ -733,6 +770,7 @@ public SDVariable cube(SDVariable x) { * @return Output variable */ public SDVariable cube(String name, SDVariable x) { + validateNumerical("cube", x); SDVariable result = f().cube(x); return updateVariableNameAndReference(result, name); } @@ -808,6 +846,7 @@ public SDVariable entropy(SDVariable in, int... dimensions) { * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable entropy(String name, SDVariable in, int... dimensions) { + validateNumerical("entropy reduction", in); SDVariable ret = f().entropy(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -830,6 +869,7 @@ public SDVariable erf(SDVariable x) { * @return Output variable */ public SDVariable erf(String name, SDVariable x) { + validateNumerical("erf (error function)", x); SDVariable ret = f().erf(x); return updateVariableNameAndReference(ret, name); } @@ -852,6 +892,7 @@ public SDVariable erfc(SDVariable x) { * @return Output variable */ public SDVariable erfc(String name, SDVariable x) { + validateNumerical("erfc", x); SDVariable ret = f().erfc(x); return updateVariableNameAndReference(ret, name); } @@ -874,6 +915,7 @@ public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimension * @return Output variable */ public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("euclidean distance", x, y); SDVariable result = f().euclideanDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -896,6 +938,7 @@ public SDVariable exp(SDVariable x) { * @return Output variable */ public SDVariable exp(String name, SDVariable x) { + validateNumerical("exp", x); SDVariable result = f().exp(x); return updateVariableNameAndReference(result, name); } @@ -918,6 +961,7 @@ public SDVariable expm1(SDVariable x) { * @return Output variable */ public SDVariable expm1(String name, SDVariable x) { + validateNumerical("expm1", x); SDVariable result = f().expm1(x); return updateVariableNameAndReference(result, name); } @@ -1122,6 +1166,7 @@ public SDVariable floor(SDVariable x) { * @return Output variable */ public SDVariable floor(String name, SDVariable x) { + validateFloatingPoint("floor", x); SDVariable result = f().floor(x); return updateVariableNameAndReference(result, name); } @@ -1145,6 +1190,7 @@ public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) * @return Output variable */ public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("hamming distance reduction", x, y); SDVariable result = f().hammingDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -1173,6 +1219,7 @@ public SDVariable iamax(String name, SDVariable in, int... dimensions) { * @see SameDiff#argmax(String, SDVariable, boolean, int...) */ public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { + validateNumerical("iamax", in); SDVariable ret = f().iamax(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } @@ -1210,6 +1257,7 @@ public SDVariable iamin(String name, SDVariable in, int... dimensions) { * @see SameDiff#argmin(String, SDVariable, boolean, int...) */ public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { + validateNumerical("iamin", in); SDVariable ret = f().iamin(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } @@ -1245,6 +1293,7 @@ public SDVariable isFinite(SDVariable x) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isFinite(String name, SDVariable x) { + validateFloatingPoint("isFinite", x); SDVariable result = f().isFinite(x); return updateVariableNameAndReference(result, name); } @@ -1271,6 +1320,7 @@ public SDVariable isInfinite(SDVariable x) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isInfinite(String name, SDVariable x) { + validateFloatingPoint("isInfinite", x); SDVariable result = f().isInfinite(x); return updateVariableNameAndReference(result, name); } @@ -1297,6 +1347,7 @@ public SDVariable isMax(SDVariable x) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isMax(String name, SDVariable x) { + validateNumerical("isMax", x); SDVariable ret = f().isMax(x); return updateVariableNameAndReference(ret, name); } @@ -1323,6 +1374,7 @@ public SDVariable isNaN(SDVariable x) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isNaN(String name, SDVariable x) { + validateFloatingPoint("isNaN", x); SDVariable result = f().isNaN(x); return updateVariableNameAndReference(result, name); } @@ -1349,6 +1401,7 @@ public SDVariable isNonDecreasing(SDVariable x) { * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise */ public SDVariable isNonDecreasing(String name, SDVariable x) { + validateNumerical("isNonDecreasing", x); SDVariable result = f().isNonDecreasing(x); return updateVariableNameAndReference(result, name); } @@ -1376,6 +1429,7 @@ public SDVariable isStrictlyIncreasing(SDVariable x) { * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise */ public SDVariable isStrictlyIncreasing(String name, SDVariable x) { + validateNumerical("isStrictlyIncreasing", x); SDVariable result = f().isStrictlyIncreasing(x); return updateVariableNameAndReference(result, name); } @@ -1404,6 +1458,7 @@ public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) * @return Output variable */ public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("Jaccard distance reduction", x, y); SDVariable result = f().jaccardDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -1477,6 +1532,7 @@ public SDVariable log(SDVariable x) { * @return Output variable */ public SDVariable log(String name, SDVariable x) { + validateNumerical("log", x); SDVariable result = f().log(x); return updateVariableNameAndReference(result, name); } @@ -1501,6 +1557,7 @@ public SDVariable log(SDVariable in, double base) { * @return Output variable */ public SDVariable log(String name, SDVariable in, double base) { + validateNumerical("log", in); SDVariable ret = f().log(in, base); return updateVariableNameAndReference(ret, name); } @@ -1523,6 +1580,7 @@ public SDVariable log1p(SDVariable x) { * @return Output variable */ public SDVariable log1p(String name, SDVariable x) { + validateNumerical("log1p", x); SDVariable result = f().log1p(x); return updateVariableNameAndReference(result, name); } @@ -1547,6 +1605,7 @@ public SDVariable logEntropy(SDVariable in, int... dimensions) { * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { + validateNumerical("logEntropy reduction", in); SDVariable ret = f().logEntropy(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -1573,6 +1632,7 @@ public SDVariable logSumExp(SDVariable input, int... dimensions) { * @return Output variable */ public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { + validateNumerical("logSumExp reduction", input); SDVariable ret = f().logSumExp(input, dimensions); return updateVariableNameAndReference(ret, name); } @@ -1596,6 +1656,7 @@ public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimension * @return Output variable */ public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + validateNumerical("manhattan distance", x, y); SDVariable result = f().manhattanDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } @@ -1617,6 +1678,7 @@ public SDVariable matrixDeterminant(SDVariable in) { * @return Matrix determinant variable */ public SDVariable matrixDeterminant(String name, SDVariable in) { + validateNumerical("matrix determinant", in); SDVariable ret = f().matrixDeterminant(in); return updateVariableNameAndReference(ret, name); } @@ -1638,6 +1700,7 @@ public SDVariable matrixInverse(SDVariable in) { * @return Matrix inverse variable */ public SDVariable matrixInverse(String name, SDVariable in) { + validateFloatingPoint("matrix inverse", in); SDVariable ret = f().matrixInverse(in); return updateVariableNameAndReference(ret, name); } @@ -1662,6 +1725,7 @@ public SDVariable mergeAdd(SDVariable... x) { * @return Output variable */ public SDVariable mergeAdd(String name, SDVariable... inputs) { + validateSameType("mergeAdd", true, inputs); SDVariable ret = f().mergeAdd(inputs); return updateVariableNameAndReference(ret, name); } @@ -1686,6 +1750,7 @@ public SDVariable mergeAvg(SDVariable... inputs) { * @return Output variable */ public SDVariable mergeAvg(String name, SDVariable... inputs) { + validateSameType("mergeAvg", true, inputs); SDVariable ret = f().mergeAvg(inputs); return updateVariableNameAndReference(ret, name); } @@ -1709,6 +1774,7 @@ public SDVariable mergeMax(SDVariable... x) { * @return Output variable */ public SDVariable mergeMax(String name, SDVariable... inputs) { + validateSameType("mergeMax", true, inputs); SDVariable ret = f().mergeMax(inputs); return updateVariableNameAndReference(ret, name); } @@ -1754,6 +1820,7 @@ public SDVariable[] meshgrid(List names, SDVariable... inputs) { public SDVariable[] meshgrid(List names, boolean cartesian, SDVariable... inputs) { Preconditions.checkState(names == null || names.size() == inputs.length, "Got %s names but %s inputs", (names == null ? 0 : names.size()), inputs.length); + validateSameType("meshgrid", false, inputs); SDVariable[] ret = f().meshgrid(cartesian, inputs); for (int i = 0; i < ret.length; i++) { ret[i] = updateVariableNameAndReference(ret[i], names == null ? null : names.get(i)); @@ -1777,6 +1844,7 @@ public SDVariable[] moments(SDVariable input, int... axes) { * @return Mean and variance variables */ public SDVariable[] moments(String[] name, SDVariable input, int... axes) { + validateNumerical("moments", input); SDVariable[] res = f().moments(input, axes); return sd.updateVariableNamesAndReferences(res, name); } @@ -1799,6 +1867,7 @@ public SDVariable neg(SDVariable x) { * @return Output variable */ public SDVariable neg(String name, SDVariable x) { + validateNumerical("neg", x); SDVariable result = f().neg(x); return updateVariableNameAndReference(result, name); } @@ -1852,6 +1921,7 @@ public SDVariable or(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable or(String name, SDVariable x, SDVariable y) { + validateBool("or", x, y); SDVariable result = f().or(x, y); return updateVariableNameAndReference(result, name); } @@ -1876,6 +1946,7 @@ public SDVariable pow(SDVariable x, double value) { * @return Output variable */ public SDVariable pow(String name, SDVariable x, double value) { + validateNumerical("pow", x); SDVariable result = f().pow(x, value); return updateVariableNameAndReference(result, name); } @@ -1900,6 +1971,7 @@ public SDVariable pow(SDVariable x, SDVariable y) { * @return Output variable */ public SDVariable pow(String name, SDVariable x, SDVariable y) { + validateNumerical("pow", x, y); SDVariable result = f().pow(x, y); return updateVariableNameAndReference(result, name); } @@ -1922,6 +1994,7 @@ public SDVariable reciprocal(SDVariable a) { * @return Output variable */ public SDVariable reciprocal(String name, SDVariable a) { + validateNumerical("reciprocal", a); SDVariable ret = f().reciprocal(a); return updateVariableNameAndReference(ret, name); } @@ -1946,6 +2019,7 @@ public SDVariable round(SDVariable x) { * @return Output variable */ public SDVariable round(String name, SDVariable x) { + validateFloatingPoint("round", x); SDVariable result = f().round(x); return updateVariableNameAndReference(result, name); } @@ -1968,6 +2042,7 @@ public SDVariable rsqrt(SDVariable x) { * @return Output variable */ public SDVariable rsqrt(String name, SDVariable x) { + validateNumerical("rsqrt", x); SDVariable result = f().rsqrt(x); return updateVariableNameAndReference(result, name); } @@ -2020,6 +2095,7 @@ public SDVariable shannonEntropy(SDVariable in, int... dimensions) { * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { + validateNumerical("shannon entropy reduction", in); SDVariable ret = f().shannonEntropy(in, dimensions); return updateVariableNameAndReference(ret, name); } @@ -2048,6 +2124,7 @@ public SDVariable sign(SDVariable x) { * @return Output variable */ public SDVariable sign(String name, SDVariable x) { + validateNumerical("sign", x); SDVariable result = f().sign(x); return updateVariableNameAndReference(result, name); } @@ -2070,6 +2147,7 @@ public SDVariable sin(SDVariable x) { * @return Output variable */ public SDVariable sin(String name, SDVariable x) { + validateNumerical("sin", x); SDVariable result = f().sin(x); return updateVariableNameAndReference(result, name); } @@ -2092,6 +2170,7 @@ public SDVariable sinh(SDVariable x) { * @return Output variable */ public SDVariable sinh(String name, SDVariable x) { + validateNumerical("sinh", x); SDVariable result = f().sinh(x); return updateVariableNameAndReference(result, name); } @@ -2114,6 +2193,7 @@ public SDVariable sqrt(SDVariable x) { * @return Output variable */ public SDVariable sqrt(String name, SDVariable x) { + validateNumerical("sqrt", x); SDVariable result = f().sqrt(x); return updateVariableNameAndReference(result, name); } @@ -2136,6 +2216,7 @@ public SDVariable square(SDVariable x) { * @return Output variable */ public SDVariable square(String name, SDVariable x) { + validateNumerical("square", x); SDVariable result = f().square(x); return updateVariableNameAndReference(result, name); } @@ -2164,6 +2245,7 @@ public SDVariable step(SDVariable in, double cutoff) { * @return Output variable */ public SDVariable step(String name, SDVariable in, double cutoff) { + validateNumerical("step", in); SDVariable ret = f().step(in, cutoff); return updateVariableNameAndReference(ret, name); } @@ -2200,6 +2282,7 @@ public SDVariable standardize(SDVariable x, int... dimensions) { * @return Output variable */ public SDVariable standardize(String name, SDVariable x, int... dimensions) { + validateNumerical("standardize", x); SDVariable result = f().standardize(x, dimensions); return updateVariableNameAndReference(result, name); } @@ -2222,6 +2305,7 @@ public SDVariable tan(SDVariable x) { * @return Output variable */ public SDVariable tan(String name, SDVariable x) { + validateNumerical("tan", x); SDVariable result = f().tan(x); return updateVariableNameAndReference(result, name); } @@ -2244,6 +2328,7 @@ public SDVariable tanh(SDVariable x) { * @return Output variable */ public SDVariable tanh(String name, SDVariable x) { + validateNumerical("tanh", x); SDVariable result = f().tanh(x); return updateVariableNameAndReference(result, name); } @@ -2265,6 +2350,7 @@ public SDVariable trace(SDVariable in) { * @return Trace */ public SDVariable trace(String name, SDVariable in) { + validateNumerical("trace", in); SDVariable ret = f().trace(in); return updateVariableNameAndReference(ret, name); } @@ -2295,6 +2381,7 @@ public SDVariable xor(SDVariable x, SDVariable y) { * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable xor(String name, SDVariable x, SDVariable y) { + validateBool("xor", x, y); SDVariable result = f().xor(x, y); return updateVariableNameAndReference(result, name); } @@ -2317,6 +2404,7 @@ public SDVariable zeroFraction(SDVariable input) { * @return Reduced array of rank 0 (scalar) */ public SDVariable zeroFraction(String name, SDVariable input) { + validateNumerical("zeroFraction", input); SDVariable res = f().zeroFraction(input); return updateVariableNameAndReference(res, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 3f1dadc1dac9..90c144b79f65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; /** * SameDiff general neural network operations
@@ -40,6 +41,11 @@ public SDVariable batchNorm(SDVariable input, SDVariable mean, public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon, int... axis) { + validateFloatingPoint("batchNorm", "input", input); + validateFloatingPoint("batchNorm", "mean", mean); + validateFloatingPoint("batchNorm", "variance", variance); + validateFloatingPoint("batchNorm", "gamma", gamma); + validateFloatingPoint("batchNorm", "beta", beta); SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon, axis); return updateVariableNameAndReference(res, name); } @@ -82,6 +88,8 @@ public SDVariable biasAdd(SDVariable input, SDVariable bias) { * @return Output variable */ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) { + validateFloatingPoint("biasAdd", "input", input); + validateFloatingPoint("biasAdd", "bias", bias); SDVariable ret = f().biasAdd(input, bias); return updateVariableNameAndReference(ret, name); } @@ -101,6 +109,7 @@ public SDVariable dropout(SDVariable input, double inputRetainProbability) { * @return */ public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { + validateFloatingPoint("dropout", input); SDVariable res = f().dropout(input, inputRetainProbability); return updateVariableNameAndReference(res, name); } @@ -133,6 +142,7 @@ public SDVariable elu(SDVariable x) { * @return Output variable */ public SDVariable elu(String name, SDVariable x) { + validateFloatingPoint("elu", x); SDVariable result = f().elu(x); return updateVariableNameAndReference(result, name); } @@ -157,6 +167,7 @@ public SDVariable eluDerivative(SDVariable x) { * @return Output variable */ public SDVariable eluDerivative(String name, SDVariable x) { + validateFloatingPoint("eluDerivative", x); SDVariable result = f().eluDerivative(x); return updateVariableNameAndReference(result, name); } @@ -183,6 +194,7 @@ public SDVariable gelu(SDVariable x) { * @return Output variable - GELU applied to the input */ public SDVariable gelu(String name, SDVariable x) { + validateFloatingPoint("gelu", x); SDVariable ret = f().gelu(x, false); //Defaults to si return updateVariableNameAndReference(ret, name); } @@ -211,6 +223,7 @@ public SDVariable hardSigmoid(SDVariable in) { * @return Output variable */ public SDVariable hardSigmoid(String name, SDVariable in) { + validateFloatingPoint("hard sigmoid", in); SDVariable ret = f().hardSigmoid(in); return updateVariableNameAndReference(ret, name); } @@ -239,6 +252,7 @@ public SDVariable hardTanh(SDVariable in) { * @return Output variable */ public SDVariable hardTanh(String name, SDVariable in) { + validateFloatingPoint("hard Tanh", in); SDVariable result = f().hardTanh(in); return updateVariableNameAndReference(result, name); } @@ -261,6 +275,7 @@ public SDVariable hardTanhDerivative(SDVariable x) { * @return Output variable */ public SDVariable hardTanhDerivative(String name, SDVariable x) { + validateFloatingPoint("hard Tanh derivative", x); SDVariable result = f().hardTanhDerivative(x); return updateVariableNameAndReference(result, name); } @@ -290,6 +305,7 @@ public SDVariable leakyRelu(SDVariable x, double alpha) { * @return Output variable */ public SDVariable leakyRelu(String name, SDVariable x, double alpha) { + validateFloatingPoint("leaky ReLU", x); SDVariable result = f().leakyRelu(x, alpha); return updateVariableNameAndReference(result, name); } @@ -303,6 +319,7 @@ public SDVariable leakyRelu(String name, SDVariable x, double alpha) { * @return Output variable */ public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { + validateFloatingPoint("leaky ReLU derivative", x); SDVariable result = f().leakyReluDerivative(x, alpha); return updateVariableNameAndReference(result, name); } @@ -325,6 +342,9 @@ public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) * @return Output variable */ public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { + validateFloatingPoint("linear", "input", input); + validateFloatingPoint("linear", "weights", weights); + validateFloatingPoint("linear", "bias", bias); SDVariable res = f().xwPlusB(input, weights, bias); return updateVariableNameAndReference(res, name); } @@ -347,6 +367,7 @@ public SDVariable logSigmoid(SDVariable x) { * @return Output variable */ public SDVariable logSigmoid(String name, SDVariable x) { + validateFloatingPoint("log sigmoid", x); SDVariable ret = f().logSigmoid(x); return updateVariableNameAndReference(ret, name); } @@ -369,6 +390,7 @@ public SDVariable logSoftmax(SDVariable x) { * @return Output variable */ public SDVariable logSoftmax(String name, SDVariable x) { + validateFloatingPoint("log softmax", x); SDVariable ret = f().logSoftmax(x); return updateVariableNameAndReference(ret, name); } @@ -397,6 +419,7 @@ public SDVariable relu(SDVariable x, double cutoff) { * @return Output variable */ public SDVariable relu(String name, SDVariable x, double cutoff) { + validateFloatingPoint("ReLU", x); SDVariable result = f().relu(x, cutoff); return updateVariableNameAndReference(result, name); } @@ -423,6 +446,7 @@ public SDVariable relu6(SDVariable x, double cutoff) { * @return Output variable */ public SDVariable relu6(String name, SDVariable x, double cutoff) { + validateFloatingPoint("ReLU6", x); SDVariable result = f().relu6(x, cutoff); return updateVariableNameAndReference(result, name); } @@ -445,6 +469,9 @@ public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bia * @return Output variable */ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { + validateFloatingPoint("reluLayer", "input", input); + validateFloatingPoint("reluLayer", "weights", weights); + validateFloatingPoint("reluLayer", "bias", bias); SDVariable res = f().reluLayer(input, weights, bias); return updateVariableNameAndReference(res, name); } @@ -473,6 +500,7 @@ public SDVariable selu(SDVariable x) { * @return Output variable */ public SDVariable selu(String name, SDVariable x) { + validateFloatingPoint("selu", x); SDVariable ret = f().selu(x); return updateVariableNameAndReference(ret, name); } @@ -495,6 +523,7 @@ public SDVariable sigmoid(SDVariable x) { * @return Output variable */ public SDVariable sigmoid(String name, SDVariable x) { + validateFloatingPoint("sigmoid", x); SDVariable result = f().sigmoid(x); return updateVariableNameAndReference(result, name); } @@ -519,6 +548,7 @@ public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { * @return Output variable */ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { + validateFloatingPoint("sigmoidDerivative", x); SDVariable result = f().sigmoidDerivative(x, wrt); return updateVariableNameAndReference(result, name); } @@ -540,6 +570,7 @@ public SDVariable softmax(SDVariable x) { * @return Output variable */ public SDVariable softmax(String name, SDVariable x) { + validateFloatingPoint("softmax", x); SDVariable result = f().softmax(x); return updateVariableNameAndReference(result, name); } @@ -553,6 +584,7 @@ public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) { } public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) { + validateFloatingPoint("softmaxDerivative", x); SDVariable result = f().softmaxDerivative(x, wrt, dimension); return updateVariableNameAndReference(result, name); } @@ -575,6 +607,7 @@ public SDVariable softplus(SDVariable x) { * @return Output variable */ public SDVariable softplus(String name, SDVariable x) { + validateFloatingPoint("softplus", x); SDVariable result = f().softplus(x); return updateVariableNameAndReference(result, name); } @@ -597,6 +630,7 @@ public SDVariable softsign(SDVariable x) { * @return Output variable */ public SDVariable softsign(String name, SDVariable x) { + validateFloatingPoint("softsign", x); SDVariable result = f().softsign(x); return updateVariableNameAndReference(result, name); } @@ -619,6 +653,7 @@ public SDVariable softsignDerivative(SDVariable x) { * @return Output varible */ public SDVariable softsignDerivative(String name, SDVariable x) { + validateFloatingPoint("softsignDerivative", x); SDVariable result = f().softsignDerivative(x); return updateVariableNameAndReference(result, name); } @@ -643,6 +678,7 @@ public SDVariable swish(SDVariable x) { * @return Output variable */ public SDVariable swish(String name, SDVariable x) { + validateFloatingPoint("swish", x); SDVariable ret = f().swish(x); return updateVariableNameAndReference(ret, name); } @@ -678,6 +714,9 @@ public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, * @return Output variable */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { + validateFloatingPoint("layerNorm", "input", input); + validateFloatingPoint("layerNorm", "gain", gain); + validateFloatingPoint("layerNorm", "bias", bias); SDVariable result = f().layerNorm(input, gain, bias, dimensions); return updateVariableNameAndReference(result, name); } @@ -704,6 +743,8 @@ public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions * @return Output variable */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) { + validateFloatingPoint("layerNorm", "input", input); + validateFloatingPoint("layerNorm", "gain", gain); SDVariable result = f().layerNorm(input, gain, dimensions); return updateVariableNameAndReference(result, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index bfe1049d8fba..840dc9736249 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -3,6 +3,9 @@ import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; + +import java.util.Arrays; /** * Abstract class for defining categories of operations - such as {@link SDMath} that is available via {@code SameDiff.math()} @@ -24,5 +27,4 @@ protected DifferentialFunctionFactory f() { protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { return sd.updateVariableNameAndReference(varToUpdate, newVarName); } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index 81a8017d7bad..b05c73669ee3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -3,6 +3,8 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; + /** * SameDiff random number generator operations
* Accessible via {@link SameDiff#random()} @@ -35,6 +37,7 @@ public SDVariable bernoulli(double p, SDVariable shape) { * @return New SDVariable */ public SDVariable bernoulli(String name, double p, SDVariable shape) { + validateInteger("bernoulli random", shape); SDVariable ret = f().randomBernoulli(p, shape); return updateVariableNameAndReference(ret, name); } @@ -113,6 +116,7 @@ public SDVariable exponential(double lambda, SDVariable shape) { * @return new SDVaribale */ public SDVariable exponential(String name, double lambda, SDVariable shape) { + validateInteger("exponential random", shape); SDVariable ret = f().randomExponential(lambda, shape); return updateVariableNameAndReference(ret, name); } @@ -159,6 +163,7 @@ public SDVariable normal(double mean, double stddev, SDVariable shape) { * @return New SDVariable */ public SDVariable normal(String name, double mean, double stddev, SDVariable shape) { + validateInteger("normal (Gaussian) random", shape); SDVariable ret = f().randomNormal(mean, stddev, shape); return updateVariableNameAndReference(ret, name); } @@ -229,6 +234,7 @@ public SDVariable uniform(double min, double max, SDVariable shape) { * @return New SDVariable */ public SDVariable uniform(String name, double min, double max, SDVariable shape) { + validateInteger("uniform random", shape); SDVariable ret = f().randomUniform(min, max, shape); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java new file mode 100644 index 000000000000..802cec74c981 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java @@ -0,0 +1,187 @@ +package org.nd4j.autodiff.samediff.ops; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.buffer.DataType; + +import java.util.Arrays; + +public class SDValidation { + + private SDValidation() { + } + + /** + * Validate that the operation is being applied on a numerical SDVariable (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + protected static void validateNumerical(String opName, SDVariable v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-numerical data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a numerical SDVariable (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateNumerical(String opName, String inputName, SDVariable v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + + v.getVarName() + "\" with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on numerical SDVariables (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) { + if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + + v2.getVarName() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type SDVariable + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateInteger(String opName, SDVariable v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type SDVariable + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateInteger(String opName, String inputName, SDVariable v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + + v.getVarName() + "\" with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an floating point type SDVariable + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateFloatingPoint(String opName, SDVariable v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a floating point type SDVariable + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateFloatingPoint(String opName, String inputName, SDVariable v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" + + v.getVarName() + "\" with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type SDVariable + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateBool(String opName, SDVariable v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-boolean point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type SDVariable + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + protected static void validateBool(String opName, String inputName, SDVariable v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" + + v.getVarName() + "\" with non-boolean data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on boolean SDVariables + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + protected static void validateBool(String opName, SDVariable v1, SDVariable v2) { + if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + + v2.getVarName() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on array with the exact same datatypes (which may optionally be + * restricted to numerical SDVariables only (not boolean or utf8)) + * + * @param opName Operation name to print in the exception + * @param numericalOnly If true, the variables must all be the same type, and must be numerical (not boolean/utf8) + * @param vars Variable to perform operation on + */ + protected static void validateSameType(String opName, boolean numericalOnly, SDVariable... vars) { + if (vars.length == 0) + return; + if (vars.length == 1) { + if (numericalOnly) { + validateNumerical(opName, vars[0]); + } + } else { + DataType first = vars[0].dataType(); + if (numericalOnly) + validateNumerical(opName, vars[0]); + for (int i = 1; i < vars.length; i++) { + if (first != vars[i].dataType()) { + String[] names = new String[vars.length]; + DataType[] dtypes = new DataType[vars.length]; + for (int j = 0; j < vars.length; j++) { + names[j] = vars[j].getVarName(); + dtypes[j] = vars[j].dataType(); + } + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" + + " Variable names " + Arrays.toString(names) + ", datatypes " + Arrays.toString(dtypes)); + } + } + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 268906f84349..4f7bd11b8e16 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -369,7 +369,10 @@ public static String validate(OpTestCase testCase) { for (int i = 0; i < outShapes.size(); i++) { val act = outShapes.get(i); val exp = testCase.expShapes().get(i); - if (!Objects.equals(exp, act)) { + if(!Objects.equals(exp.dataType(), act.dataType())){ + return "Shape function check failed for output " + i + ": expected shape " + exp + ", actual shape " + act; + } + if(!Arrays.equals(act.getShape(), exp.getShape())){ return "Shape function check failed for output " + i + ": expected shape " + exp + ", actual shape " + act; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index a1c29af32259..bd70e7a65d8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -247,7 +247,7 @@ public Evaluation(List labels, INDArray costArray) { throw new IllegalArgumentException("Invalid cost array: Cost array values must be positive"); } this.labelsList = labels; - this.costArray = costArray; + this.costArray = costArray == null ? null : costArray.castTo(DataType.FLOAT); this.topN = 1; } @@ -453,7 +453,7 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask, final Lis guessIndex = pClass1.gt(binaryDecisionThreshold); } else if (costArray != null) { //With a cost array: do argmax(cost * probability) instead of just argmax(probability) - guessIndex = Nd4j.argMax(predictions2d.mulRowVector(costArray), 1); + guessIndex = Nd4j.argMax(predictions2d.mulRowVector(costArray.castTo(predictions2d.dataType())), 1); } else { //Standard case: argmax guessIndex = Nd4j.argMax(predictions2d, 1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index 6723f26c73d7..20b6aec44379 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -192,6 +192,7 @@ public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArra if (maskArray != null) { //By multiplying by mask, we keep only those 1s that are actually present + maskArray = maskArray.castTo(truePositives.dataType()); truePositives.muli(maskArray); trueNegatives.muli(maskArray); falsePositives.muli(maskArray); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index f6158a8639b9..90bbd392362c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -180,7 +180,7 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask) { val nClasses = labels2d.size(1); if (rDiagBinPosCount == null) { - DataType dt = predictions.dataType(); + DataType dt = DataType.DOUBLE; //Initialize rDiagBinPosCount = Nd4j.create(DataType.LONG, reliabilityDiagNumBins, nClasses); rDiagBinTotalCount = Nd4j.create(DataType.LONG, reliabilityDiagNumBins, nClasses); @@ -239,7 +239,7 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask) { INDArray numPredictionsCurrBin = currBinBitMask.sum(0); - rDiagBinSumPredictions.getRow(j).addi(maskedProbs.sum(0)); + rDiagBinSumPredictions.getRow(j).addi(maskedProbs.sum(0).castTo(rDiagBinSumPredictions.dataType())); rDiagBinPosCount.getRow(j).addi(isPosLabelForBin.sum(0).castTo(rDiagBinPosCount.dataType())); rDiagBinTotalCount.getRow(j).addi(numPredictionsCurrBin.castTo(rDiagBinTotalCount.dataType())); } @@ -299,14 +299,14 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask) { //Counts for positive class only: values are in the current bin AND it's a positive label INDArray isPosLabelForBin = l.mul(currBinBitMask); - residualPlotByLabelClass.getRow(j).addi(isPosLabelForBin.sum(0)); + residualPlotByLabelClass.getRow(j).addi(isPosLabelForBin.sum(0).castTo(residualPlotByLabelClass.dataType())); int probNewTotalCount = probHistogramOverall.getInt(0, j) + currBinBitMaskProbs.sumNumber().intValue(); probHistogramOverall.putScalar(0, j, probNewTotalCount); INDArray isPosLabelForBinProbs = l.mul(currBinBitMaskProbs); INDArray temp = isPosLabelForBinProbs.sum(0); - probHistogramByLabelClass.getRow(j).addi(temp); + probHistogramByLabelClass.getRow(j).addi(temp.castTo(probHistogramByLabelClass.dataType())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index 8b2850ea2d4a..1bbee9334b9e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -22,6 +22,7 @@ import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.serde.ROCSerializer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -219,14 +220,14 @@ public RocCurve getRocCurve() { INDArray cumSumNeg = isNegative.cumsum(-1); val length = sorted.size(0); - INDArray t = Nd4j.create(new long[]{length + 2, 1}); + INDArray t = Nd4j.create(DataType.DOUBLE, length + 2, 1); t.put(new INDArrayIndex[]{interval(1, length + 1), all()}, sorted.getColumn(0)); - INDArray fpr = Nd4j.create(new long[]{length + 2, 1}); + INDArray fpr = Nd4j.create(DataType.DOUBLE, length + 2, 1); fpr.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumNeg.div(countActualNegative)); - INDArray tpr = Nd4j.create(new long[]{length + 2, 1}); + INDArray tpr = Nd4j.create(DataType.DOUBLE, length + 2, 1); tpr.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumPos.div(countActualPositive)); @@ -400,16 +401,16 @@ as class 0, all others are predicted as class 1 predicted positive at threshold: # values <= threshold, i.e., just i */ - INDArray t = Nd4j.create(new long[]{length + 2, 1}); + INDArray t = Nd4j.create(DataType.DOUBLE, length + 2, 1); t.put(new INDArrayIndex[]{interval(1, length + 1), all()}, sorted.getColumn(0)); - INDArray linspace = Nd4j.linspace(1, length, length, Nd4j.dataType()); - INDArray precision = cumSumPos.div(linspace.reshape(cumSumPos.shape())); - INDArray prec = Nd4j.create(new long[]{length + 2, 1}); + INDArray linspace = Nd4j.linspace(1, length, length, DataType.DOUBLE); + INDArray precision = cumSumPos.castTo(DataType.DOUBLE).div(linspace.reshape(cumSumPos.shape())); + INDArray prec = Nd4j.create(DataType.DOUBLE, length + 2, 1); prec.put(new INDArrayIndex[]{interval(1, length + 1), all()}, precision); //Recall/TPR - INDArray rec = Nd4j.create(new long[]{length + 2, 1}); + INDArray rec = Nd4j.create(DataType.DOUBLE, length + 2, 1); rec.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumPos.div(countActualPositive)); @@ -587,13 +588,13 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask, List= probAndLabel.size(0)) { val newSize = probAndLabel.size(0) + Math.max(exactAllocBlockSize, labels2d.size(0)); - INDArray newProbAndLabel = Nd4j.create(new long[]{newSize, 2}, 'c'); + INDArray newProbAndLabel = Nd4j.create(DataType.DOUBLE, new long[]{newSize, 2}, 'c'); if (exampleCount > 0) { //If statement to handle edge case: no examples, but we need to re-allocate right away newProbAndLabel.get(interval(0, exampleCount), all()).assign( @@ -734,7 +735,7 @@ public void merge(ROC other) { if (this.exampleCount + other.exampleCount > this.probAndLabel.size(0)) { //Allocate new array val newSize = this.probAndLabel.size(0) + Math.max(other.probAndLabel.size(0), exactAllocBlockSize); - INDArray newProbAndLabel = Nd4j.create(newSize, 2); + INDArray newProbAndLabel = Nd4j.create(DataType.DOUBLE, newSize, 2); newProbAndLabel.put(new INDArrayIndex[]{interval(0, exampleCount), all()}, probAndLabel.get(interval(0, exampleCount), all())); probAndLabel = newProbAndLabel; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java index 57214d66072d..b24fc3ccdc59 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java @@ -199,8 +199,8 @@ public void eval(INDArray labels, INDArray predictions, INDArray mask, final Lis } for (int i = 0; i < n; i++) { - INDArray prob = predictions2d.getColumn(i); //Probability of class i - INDArray label = labels2d.getColumn(i); + INDArray prob = predictions2d.getColumn(i, true); //Probability of class i + INDArray label = labels2d.getColumn(i, true); //Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7305 if(prob.rank() == 0) prob = prob.reshape(1,1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index 7b93fc93487d..8c28782bcb54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -20,9 +20,9 @@ import lombok.EqualsAndHashCode; import lombok.val; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; -import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.serde.RowVectorDeserializer; import org.nd4j.linalg.lossfunctions.serde.RowVectorSerializer; @@ -160,17 +160,17 @@ private void initialize(int n) { if (columnNames == null || columnNames.size() != n) { columnNames = createDefaultColumnNames(n); } - exampleCountPerColumn = Nd4j.zeros(n); - labelsSumPerColumn = Nd4j.zeros(n); - sumSquaredErrorsPerColumn = Nd4j.zeros(n); - sumAbsErrorsPerColumn = Nd4j.zeros(n); - currentMean = Nd4j.zeros(n); - - currentPredictionMean = Nd4j.zeros(n); - sumOfProducts = Nd4j.zeros(n); - sumSquaredLabels = Nd4j.zeros(n); - sumSquaredPredicted = Nd4j.zeros(n); - sumLabels = Nd4j.zeros(n); + exampleCountPerColumn = Nd4j.zeros(DataType.DOUBLE, n); + labelsSumPerColumn = Nd4j.zeros(DataType.DOUBLE, n); + sumSquaredErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, n); + sumAbsErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, n); + currentMean = Nd4j.zeros(DataType.DOUBLE, n); + + currentPredictionMean = Nd4j.zeros(DataType.DOUBLE, n); + sumOfProducts = Nd4j.zeros(DataType.DOUBLE, n); + sumSquaredLabels = Nd4j.zeros(DataType.DOUBLE, n); + sumSquaredPredicted = Nd4j.zeros(DataType.DOUBLE, n); + sumLabels = Nd4j.zeros(DataType.DOUBLE, n); initialized = true; } @@ -234,19 +234,19 @@ public void eval(INDArray labels, INDArray predictions, INDArray maskArray) { predictions = predictions.mul(maskArray); } - labelsSumPerColumn.addi(labels.sum(0)); + labelsSumPerColumn.addi(labels.sum(0).castTo(labelsSumPerColumn.dataType())); INDArray error = predictions.sub(labels); INDArray absErrorSum = Nd4j.getExecutioner().exec(new ASum(error, 0)); INDArray squaredErrorSum = error.mul(error).sum(0); - sumAbsErrorsPerColumn.addi(absErrorSum); - sumSquaredErrorsPerColumn.addi(squaredErrorSum); + sumAbsErrorsPerColumn.addi(absErrorSum.castTo(labelsSumPerColumn.dataType())); + sumSquaredErrorsPerColumn.addi(squaredErrorSum.castTo(labelsSumPerColumn.dataType())); - sumOfProducts.addi(labels.mul(predictions).sum(0)); + sumOfProducts.addi(labels.mul(predictions).sum(0).castTo(labelsSumPerColumn.dataType())); - sumSquaredLabels.addi(labels.mul(labels).sum(0)); - sumSquaredPredicted.addi(predictions.mul(predictions).sum(0)); + sumSquaredLabels.addi(labels.mul(labels).sum(0).castTo(labelsSumPerColumn.dataType())); + sumSquaredPredicted.addi(predictions.mul(predictions).sum(0).castTo(labelsSumPerColumn.dataType())); val nRows = labels.size(0); @@ -255,14 +255,14 @@ public void eval(INDArray labels, INDArray predictions, INDArray maskArray) { if (maskArray == null) { newExampleCountPerColumn = exampleCountPerColumn.add(nRows); } else { - newExampleCountPerColumn = exampleCountPerColumn.add(maskArray.sum(0)); + newExampleCountPerColumn = exampleCountPerColumn.add(maskArray.sum(0).castTo(labelsSumPerColumn.dataType())); } - currentMean.muliRowVector(exampleCountPerColumn).addi(labels.sum(0)).diviRowVector(newExampleCountPerColumn); - currentPredictionMean.muliRowVector(exampleCountPerColumn).addi(predictions.sum(0)) + currentMean.muliRowVector(exampleCountPerColumn).addi(labels.sum(0).castTo(labelsSumPerColumn.dataType())).diviRowVector(newExampleCountPerColumn); + currentPredictionMean.muliRowVector(exampleCountPerColumn).addi(predictions.sum(0).castTo(labelsSumPerColumn.dataType())) .divi(newExampleCountPerColumn); exampleCountPerColumn = newExampleCountPerColumn; - sumLabels.addi(labels.sum(0)); + sumLabels.addi(labels.sum(0).castTo(labelsSumPerColumn.dataType())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index c5f963e90132..ace29d65dcc2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2114,13 +2114,10 @@ public INDArray putSlice(int slice, INDArray put) { INDArray view = slice(slice); - if (put.length() == 1) + if (put.length() == 1) { putScalar(slice, put.getDouble(0)); - else if (put.isVector()) - for (int i = 0; i < put.length(); i++) - view.putScalar(i, put.getDouble(i)); - else { - if(!view.equalShapes(put)){ + } else { + if(!(view.isVector() && put.isVector() && view.length() == put.length()) && !view.equalShapes(put)){ throw new IllegalStateException("Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); } @@ -2552,9 +2549,6 @@ public void setData(DataBuffer data) { */ @Override public long slices() { - if (isRowVector()) - return length(); - return size(0); } @@ -2574,7 +2568,7 @@ public INDArray subArray(ShapeOffsetResolution resolution) { int n = shape.length; if (offsets.length != n) - throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets)); + throw new IllegalArgumentException("Invalid offset: number of offsets must be equal to rank " + Arrays.toString(offsets) + ", rank = " + n); if (stride.length != n) throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride)); @@ -3070,6 +3064,8 @@ protected DataBuffer strideOf() { @Override public int stride(int dimension) { int rank = jvmShapeInfo.rank; + Preconditions.checkArgument(dimension < rank, "Cannot get stride for dimension %s from rank %s array: " + + "dimension indices must be in range -rank <= dimension < rank", dimension, rank); if (dimension < 0) return (int) stride()[dimension + rank]; return (int) stride()[dimension]; @@ -3370,6 +3366,7 @@ public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { */ @Override public INDArray mmul(INDArray other) { + Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); // FIXME: for 1D case, we probably want vector output here? long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()}; INDArray result = createUninitialized(this.dataType(), shape, 'f'); @@ -4101,14 +4098,8 @@ public INDArray slice(long slice) { if (slice >= slices) throw new IllegalArgumentException("Illegal slice " + slice); - if (jvmShapeInfo.rank == 0 || isVector()) { - if (slice == 0 || isVector()) { - return createScalarForIndex(slice, true); - } - else { - throw new IllegalArgumentException("Can't slice a 0-d NDArray"); - } - + if (jvmShapeInfo.rank == 0 ) { + throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } @@ -4498,7 +4489,7 @@ public INDArray reshape(char order, boolean enforceView, long... newShape){ Nd4j.getCompressor().autoDecompress(this); // special case for empty reshape - if (this.length() == 1 && (newShape == null || newShape.length == 0)) { + if (this.length() == 1 && (newShape == null || newShape.length == 0) && this.elementWiseStride() == 1) { return Nd4j.create(this.data(), new int[0], new int[0], 0); } @@ -4901,7 +4892,7 @@ public INDArray ravel(char ordering) { */ @Override public INDArray ravel() { - return reshape(1, length()); + return reshape(length()); } /** @@ -4947,6 +4938,14 @@ else if (isColumnVector() && c > 0) return get(NDArrayIndex.all(), NDArrayIndex.point(c)); } + @Override + public INDArray getColumn(long c, boolean keepDim) { + INDArray col = getColumn(c); + if(!keepDim) + return col; + return col.reshape(col.length(), 1); + } + /** * Get whole rows from the passed indices. @@ -4980,29 +4979,21 @@ public INDArray getRows(int[] rindices) { public INDArray get(INDArrayIndex... indexes) { Nd4j.getCompressor().autoDecompress(this); - // besides of backward support for legacy "vectors" which were 2D - // we enforce number of indices provided to be equal to number of dimensions in this array - if (rank() == 2 && jvmShapeInfo.javaShapeInformation[1] == 1 && indexes.length == 1) - indexes = new INDArrayIndex[]{ NDArrayIndex.all(), indexes[0]}; - else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.length == 1) - indexes = new INDArrayIndex[]{indexes[0], NDArrayIndex.all()}; - else { - // we're padding remaining dimensions with all() index - if (indexes.length < this.rank()) { - val newIndexes = new INDArrayIndex[this.rank()]; - for (int e = 0; e < indexes.length; e++) - newIndexes[e] = indexes[e]; - - for (int e = indexes.length; e < newIndexes.length; e++) - newIndexes[e] = NDArrayIndex.all(); + // we're padding remaining dimensions with all() index + if (indexes.length < this.rank()) { + val newIndexes = new INDArrayIndex[this.rank()]; + for (int e = 0; e < indexes.length; e++) + newIndexes[e] = indexes[e]; - indexes = newIndexes; - } + for (int e = indexes.length; e < newIndexes.length; e++) + newIndexes[e] = NDArrayIndex.all(); - // never going to happen :/ - Preconditions.checkArgument(indexes != null && indexes.length >= this.rank(), "Number of indices should be greater or equal to rank of the INDArray"); + indexes = newIndexes; } + // never going to happen :/ + Preconditions.checkArgument(indexes != null && indexes.length >= this.rank(), "Number of indices should be greater or equal to rank of the INDArray"); + if(indexes.length > rank()) { int numNonNewAxis = 0; for(int i = 0; i < indexes.length; i++) { @@ -5039,8 +5030,7 @@ else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.len if (indexes.length < 1) throw new IllegalStateException("Invalid index found of zero length"); - // FIXME: LONG - int[] shape = LongUtils.toInts(resolution.getShapes()); + long[] shape = resolution.getShapes(); int numSpecifiedIndex = 0; for (int i = 0; i < indexes.length; i++) if (indexes[i] instanceof SpecifiedIndex) @@ -5048,7 +5038,7 @@ else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.len if (shape != null && numSpecifiedIndex > 0) { Generator>> gen = SpecifiedIndex.iterate(indexes); - INDArray ret = Nd4j.create(this.dataType(), ArrayUtil.toLongArray(shape), 'c'); + INDArray ret = Nd4j.create(this.dataType(), shape, 'c'); int count = 0; while (true) { try { @@ -5126,6 +5116,14 @@ else if (isRowVector() && r > 0) return result; } + @Override + public INDArray getRow(long r, boolean keepDim) { + INDArray row = getRow(r); + if(!keepDim) + return row; + return row.reshape(1, row.length()); + } + /** * This method allows you to compare INDArray against other INDArray, with variable eps @@ -5154,6 +5152,9 @@ public boolean equalsWithEps(Object o, double eps) { return n.equals(this); } + if (this.rank() != n.rank()) + return false; + if (this.length() != n.length()) return false; @@ -5465,8 +5466,6 @@ public INDArray broadcast(INDArray result) { Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat)); } else Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this},new INDArray[]{result},repeat)); - - //result = Nd4j.tile(this,repeat); } return result; @@ -6055,7 +6054,7 @@ public INDArray detach() { if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { if (!isView()) { Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); @@ -6074,7 +6073,7 @@ public INDArray detach() { if (!isView()) { Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); Nd4j.getMemoryManager().memcpy(buffer, this.data()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 8907559f7942..41750a72b99a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -1517,11 +1517,21 @@ public INDArray getColumn(long i) { return null; } + @Override + public INDArray getColumn(long i, boolean keepDim) { + return null; + } + @Override public INDArray getRow(long i) { return null; } + @Override + public INDArray getRow(long i, boolean keepDim) { + return null; + } + @Override public int columns() { return (int) columns; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 2c0d34f52a32..2835264b9834 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2173,7 +2173,17 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray getColumn(long i); /** - * Returns the specified row. + * Returns the specified column. Throws an exception if its not a matrix (rank 2). + * Returned array will either be 1D (keepDim = false) or 2D (keepDim = true) with shape [length, 1] + * + * @param i the row to get + * @param keepDim If true: return [length, 1] array. Otherwise: return [length] array + * @return the specified row + */ + INDArray getColumn(long i, boolean keepDim); + + /** + * Returns the specified row as a 1D vector. * Throws an exception if its not a matrix * * @param i the row to getScalar @@ -2181,6 +2191,16 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray getRow(long i); + /** + * Returns the specified row. Throws an exception if its not a matrix. + * Returned array will either be 1D (keepDim = false) or 2D (keepDim = true) with shape [1, length] + * + * @param i the row to get + * @param keepDim If true: return [1,length] array. Otherwise: return [length] array + * @return the specified row + */ + INDArray getRow(long i, boolean keepDim); + /** * Returns the number of columns in this matrix (throws exception if not 2d) * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java index 5be2e644efef..93c287444cb9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/BaseReduce3Op.java @@ -97,6 +97,8 @@ public String tensorflowName() { @Override public DataType resultType() { + if(x.dataType().isFPType()) + return x.dataType(); return Nd4j.defaultFloatingPointType(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 20572ed97edf..984d44a255a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -1823,7 +1823,7 @@ public static long elementWiseStride(long[] shape, long[] stride, boolean isFOrd return 1; if (shape.length == 1 && stride.length == 1) - return 1; + return stride[0]; int oldnd; long[] olddims = ArrayUtil.copy(shape); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MultiDataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MultiDataSet.java index d57366701661..63fda447ff51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MultiDataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MultiDataSet.java @@ -354,7 +354,7 @@ private static INDArray getSubsetForExample(INDArray array, int idx) { //So (point,all,all) on a 3d input returns a 2d output. Whereas, we want a 3d [1,x,y] output here switch (array.rank()) { case 2: - return array.get(NDArrayIndex.point(idx), NDArrayIndex.all()); + return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all()); case 3: return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all(), NDArrayIndex.all()); case 4: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 935943e03f8e..446fb6253d47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -210,7 +210,7 @@ public INDArray toFlattened(Collection matrices) { int linearIndex = 0; for (INDArray d : matrices) { INDArray vector = d.reshape(d.length()); - ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, vector); + ret.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, vector); linearIndex += d.length(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 9cd84c4fc3c1..7929e0994053 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -889,7 +889,7 @@ public static INDArray gemm(INDArray a, boolean transposeB) { long cRows = (transposeA ? a.columns() : a.rows()); long cCols = (transposeB ? b.rows() : b.columns()); - INDArray c = Nd4j.createUninitialized(new long[] {cRows, cCols}, 'f'); + INDArray c = Nd4j.createUninitialized(a.dataType(), new long[] {cRows, cCols}, 'f'); return gemm(a, b, c, transposeA, transposeB, 1.0, 0.0); } @@ -1927,7 +1927,7 @@ public static INDArray sortRows(final INDArray in, final int colIdx, final boole if (in.rows() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - INDArray out = Nd4j.create(in.shape()); + INDArray out = Nd4j.create(in.data(), in.shape()); int nRows = (int) in.rows(); ArrayList list = new ArrayList(nRows); for (int i = 0; i < nRows; i++) @@ -3960,9 +3960,7 @@ public static INDArray create(double[] data, int[] shape) { */ public static INDArray create(double[] data, int[] shape, int[] stride, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4030,9 +4028,7 @@ public static INDArray create(double[] data, int rows, int columns, int[] stride */ public static INDArray create(float[] data, int[] shape, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4046,9 +4042,7 @@ public static INDArray create(float[] data, int[] shape, long offset) { public static INDArray create(float[] data, long[] shape, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new long[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4069,9 +4063,7 @@ public static INDArray create(float[] data, long[] shape, long offset) { */ public static INDArray create(double[] data, int[] shape, long offset, char ordering) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4085,9 +4077,7 @@ public static INDArray create(double[] data, int[] shape, long offset, char orde public static INDArray create(double[] data, long[] shape, long offset, char ordering) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new long[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4109,9 +4099,7 @@ public static INDArray create(double[] data, long[] shape, long offset, char ord */ public static INDArray create(float[] data, int[] shape, int[] stride, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4283,9 +4271,7 @@ public static INDArray create(DataType type, long... shape) { */ public static INDArray create(float[] data, int[] shape, int[] stride, char ordering, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4518,9 +4504,7 @@ public static INDArray create(double[] data, int[] shape, int[] stride, long off */ public static INDArray create(double[] data, int[] shape, char ordering) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4542,9 +4526,7 @@ public static INDArray create(double[] data, int[] shape, char ordering) { */ public static INDArray create(float[] data, int[] shape, char ordering) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4601,9 +4583,7 @@ public static INDArray create(double[] data, int rows, int columns, int[] stride */ public static INDArray create(float[] data, int[] shape, int[] stride, long offset, char ordering) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -4787,12 +4767,6 @@ public static INDArray zeros(int rows, int columns, char ordering) { public static INDArray create(@NonNull int[] shape, char ordering) { if(shape.length == 0) return Nd4j.scalar(dataType(), 0.0); - //ensure shapes that wind up being scalar end up with the write shape - if (shape.length == 1 && shape[0] == 0) { - shape = new int[] {1, 1}; - } else if (shape.length == 1) { - shape = new int[] {1, shape[0]}; - } checkShapeValues(shape); @@ -4944,13 +4918,6 @@ public static INDArray createUninitializedDetached(int[] shape, char ordering) { if (shape.length == 0) return scalar(dataType(), 0.0); - //ensure shapes that wind up being scalar end up with the write shape - if (shape.length == 1 && shape[0] == 0) { - shape = new int[] {1, 1}; - } else if (shape.length == 1) { - shape = new int[] {1, shape[0]}; - } - checkShapeValues(shape); INDArray ret = INSTANCE.createUninitializedDetached(shape, ordering); @@ -4968,13 +4935,6 @@ public static INDArray createUninitializedDetached(long[] shape, char ordering) if (shape.length == 0) return scalar(dataType(), 0.0); - //ensure shapes that wind up being scalar end up with the write shape - if (shape.length == 1 && shape[0] == 0) { - shape = new long[] {1, 1}; - } else if (shape.length == 1) { - shape = new long[] {1, shape[0]}; - } - checkShapeValues(shape); INDArray ret = INSTANCE.createUninitializedDetached(shape, ordering); @@ -5032,21 +4992,14 @@ public static INDArray createUninitializedDetached(long[] shape) { * @return */ public static INDArray createUninitialized(int length) { - if (length < 1) - throw new IllegalStateException("INDArray length should be positive value"); - - int[] shape = new int[] {1, length}; - - INDArray ret = INSTANCE.createUninitialized(shape, order()); - logCreationIfNecessary(ret); - return ret; + return createUninitialized((long)length); } public static INDArray createUninitialized(long length) { if (length < 1) throw new IllegalStateException("INDArray length should be positive value"); - long[] shape = new long[] {1, length}; + long[] shape = new long[] {length}; INDArray ret = INSTANCE.createUninitialized(shape, order()); logCreationIfNecessary(ret); @@ -5063,7 +5016,7 @@ public static INDArray createUninitializedDetached(int length) { if (length < 1) throw new IllegalStateException("INDArray length should be positive value"); - int[] shape = new int[] {1, length}; + long[] shape = new long[] {length}; INDArray ret = INSTANCE.createUninitializedDetached(shape, order()); logCreationIfNecessary(ret); @@ -5079,9 +5032,7 @@ public static INDArray createUninitializedDetached(int length) { */ public static INDArray create(double[] data, int[] shape, long offset) { if (shape.length == 1) { - if (shape[0] == data.length) { - shape = new int[] {1, data.length}; - } else + if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); } @@ -5360,7 +5311,7 @@ public static INDArray valueArrayOf(long[] shape, long value, DataType type) { * @return the created ndarray */ public static INDArray valueArrayOf(long num, double value) { - INDArray ret = INSTANCE.valueArrayOf(new long[] {1, num}, value); + INDArray ret = INSTANCE.valueArrayOf(new long[] {num}, value); logCreationIfNecessary(ret); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java index ac68577364ec..cbb869efb6b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java @@ -440,9 +440,12 @@ public static INDArrayIndex[] resolve(long[] shape, INDArrayIndex... intendedInd } protected static INDArrayIndex validate(long size, INDArrayIndex index) { - if ((index instanceof IntervalIndex || index instanceof PointIndex) && size <= index.current() && size > 1) + if ((index instanceof IntervalIndex || index instanceof PointIndex) && size <= index.current()) throw new IllegalArgumentException("NDArrayIndex is out of range. Beginning index: " + index.current() + " must be less than its size: " + size); + if (index instanceof IntervalIndex && index.end() > size) + throw new IllegalArgumentException("NDArrayIndex is out of range. End index: " + index.end() + + " must be less than its size: " + size); if (index instanceof IntervalIndex && size < index.end()) { long begin = ((IntervalIndex) index).begin; index = NDArrayIndex.interval(begin, index.stride(), size); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java index a8af73484384..a7ffdedc222c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java @@ -133,4 +133,9 @@ public int hashCode() { result = 31 * result + (notUsed ? 1 : 0); return result; } + + @Override + public String toString(){ + return "Point(" + point + ")"; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java index 3106b8c2452f..2e6fdf498567 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.util.ArrayUtil; @@ -84,9 +85,10 @@ else if (indexes[i] instanceof NewAxis) newAxis++; else if (indexes[i] instanceof NDArrayIndexAll) numAll++; - } + Preconditions.checkState(pointIndex + interval + numAll <= arr.rank(), "Received more indices than rank of array (%s): %s", arr.rank(), indexes); + if(arr.rank() == 1 && indexes.length == 1){ if(indexes[0] instanceof PointIndex){ @@ -116,17 +118,37 @@ else if (indexes[i] instanceof NDArrayIndexAll) this.offsets = new long[arr.rank()]; return true; } else if (indexes[0] instanceof PointIndex && indexes[1] instanceof NDArrayIndexAll) { - this.shapes = new long[2]; - this.strides = new long[2]; - for (int i = 0; i < 2; i++) { - shapes[i] = 1; - strides[i] = arr.stride(i); + this.shapes = new long[1]; + this.strides = new long[1]; + this.offsets = new long[1]; + if(arr.size(0) == 1){ + //Row vector: [1,x] + shapes[0] = arr.size(1); + strides[0] = arr.stride(1); + this.offset = indexes[0].offset() * strides[0]; + } else { + //Column vector: [x, 1] + shapes[0] = 1; + strides[0] = arr.stride(0); + this.offset = indexes[0].offset() * strides[0]; } - - this.offsets = new long[arr.rank()]; - if(arr.isRowVector()) - this.offset = indexes[0].offset() * strides[1]; - else { + return true; + } else if(indexes[0] instanceof PointIndex && indexes[1] instanceof IntervalIndex){ + IntervalIndex i = (IntervalIndex)indexes[1]; + this.shapes = new long[1]; + this.strides = new long[1]; + this.offsets = new long[1]; + if(arr.size(0) == 1){ + //Row vector: [1,x] + shapes[0] = i.length(); + strides[0] = arr.stride(1) * indexes[1].stride(); + this.offset = indexes[1].offset() * arr.stride(1); + } else { + Preconditions.checkState(i.begin == 0 && i.end == 0, "Cannot get interval index along dimension 1 (begin=%s, end=%s) from array with shape %ndShape", + i.begin, i.end, arr); + //Column vector: [x, 1] + shapes[0] = 1; + strides[0] = arr.stride(0); this.offset = indexes[0].offset() * strides[0]; } return true; @@ -160,14 +182,10 @@ else if (indexes[i] instanceof NDArrayIndexAll) } if (indexes[0] instanceof PointIndex) { if (indexes.length > 1 && indexes[1] instanceof IntervalIndex) { - offset = indexes[1].offset(); - this.shapes = new long[2]; - shapes[0] = 1; - shapes[1] = indexes[1].length(); - this.strides = new long[2]; - strides[0] = 0; - strides[1] = indexes[1].stride(); - this.offsets = new long[2]; + this.shapes = new long[]{indexes[1].length()}; + this.strides = new long[]{indexes[1].stride() * arr.stride(1)}; + this.offsets = new long[]{indexes[1].offset() * arr.stride(0)}; + this.offset = indexes[1].offset() * arr.stride(1); return true; } } else if (indexes[0] instanceof IntervalIndex) { @@ -234,7 +252,7 @@ else if (indexes[i] instanceof NDArrayIndexAll) //specific easy case if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex > 0 && numAll > 0) { - int minDimensions = Math.max(arr.rank() - pointIndex, 2); + int minDimensions = arr.rank()-pointIndex; long[] shape = new long[minDimensions]; Arrays.fill(shape, 1); long[] stride = new long[minDimensions]; @@ -277,7 +295,6 @@ else if (indexes[i] instanceof NDArrayIndexAll) this.offsets = offsets; this.offset = offset; return true; - } //intervals and all @@ -499,39 +516,26 @@ else if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll) } - - //fill in missing strides and shapes while (shapeIndex < shape.length) { - //scalar, should be 1 x 1 rather than the number of columns in the vector - if (Shape.isVector(shape)) { - accumShape.add(1L); - shapeIndex++; - } else - accumShape.add((long) shape[shapeIndex++]); + accumShape.add(shape[shapeIndex++]); } //fill in the rest of the offsets with zero - int delta = (shape.length <= 2 ? shape.length : shape.length - numPointIndexes); +// int delta = (shape.length <= 2 ? shape.length : shape.length - numPointIndexes); + int delta = shape.length - numPointIndexes; boolean needsFilledIn = accumShape.size() != accumStrides.size() && accumOffsets.size() != accumShape.size(); while (accumOffsets.size() < delta && needsFilledIn) accumOffsets.add(0L); - while (accumShape.size() < 2) { - if (Shape.isRowVectorShape(arr.shape())) - accumShape.add(0, 1L); - else - accumShape.add(1L); - } - while (strideIndex < accumShape.size()) { accumStrides.add((long) arr.stride(strideIndex++)); } - /** + /* * For each dimension * where we want to prepend a dimension * we need to add it at the index such that @@ -586,14 +590,7 @@ else if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll) Collections.reverse(accumShape); } - if (arr.isMatrix() && indexes[0] instanceof PointIndex && indexes[1] instanceof IntervalIndex) { - this.shapes = new long[2]; - shapes[0] = 1; - IntervalIndex idx = (IntervalIndex) indexes[1]; - shapes[1] = idx.length(); - - } else - this.shapes = Longs.toArray(accumShape); + this.shapes = Longs.toArray(accumShape); boolean isColumnVector = Shape.isColumnVectorShape(this.shapes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java index 656302b044a7..9ac7e86cde38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java @@ -44,9 +44,9 @@ public static void applyMask(INDArray to, INDArray mask) { //Two possibilities exist: it's *per example* masking, or it's *per output* masking //These cases have different mask shapes. Per example: column vector. Per output: same shape as score array if (mask.isColumnVectorOrScalar()) { - to.muliColumnVector(mask); + to.muliColumnVector(mask.castTo(to.dataType())); } else if (Arrays.equals(to.shape(), mask.shape())) { - to.muli(mask); + to.muli(mask.castTo(to.dataType())); } else { throw new IllegalStateException("Invalid mask array: per-example masking should be a column vector, " + "per output masking arrays should be the same shape as the labels array. Mask shape: " diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java index cfb7745272a5..ca04caf18f68 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java @@ -129,6 +129,7 @@ private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation act if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray scoreArr; if (activationFn instanceof ActivationSoftmax) { @@ -197,6 +198,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); if (clipEps > 0.0) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java index b888a94883e3..12558c7875d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java @@ -55,6 +55,8 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype + /* mean of -(y.dot(yhat)/||y||*||yhat||) */ @@ -105,6 +107,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray yhat = activationFn.getActivation(preOutput.dup(), true); INDArray yL2norm = labels.norm2(1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java index ee4baa2d0325..25e7f3f65d91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java @@ -86,7 +86,7 @@ public LossFMeasure(@JsonProperty("beta") double beta) { @Override public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { - + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype double[] d = computeScoreNumDenom(labels, preOutput, activationFn, mask, average); double numerator = d[0]; double denominator = d[1]; @@ -148,6 +148,7 @@ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivati @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype double[] d = computeScoreNumDenom(labels, preOutput, activationFn, mask, false); double numerator = d[0]; double denominator = d[1]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java index 839b916225f7..cfe718f431f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java @@ -47,6 +47,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype /* y_hat is -1 or 1 hinge loss is max(0,1-y_hat*y) */ @@ -83,6 +84,8 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype + /* gradient is 0 if yhaty is >= 1 else gradient is gradient of the loss function = (1-yhaty) wrt preOutput = -y*derivative_of_yhat wrt preout diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java index 764a3292a226..104f8b4381d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java @@ -51,6 +51,7 @@ private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation act if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); // Clip output and labels to be between Nd4j.EPS_THREsHOLD and 1, i.e. a valid non-zero probability @@ -91,6 +92,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray dLda = labels.div(output).negi(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java index 256e115d80f9..4bcdddbcdee7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java @@ -82,6 +82,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray scoreArr; INDArray output = activationFn.getActivation(preOutput.dup(), true); scoreArr = output.subi(labels); @@ -126,6 +127,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray outSubLabels = output.sub(labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java index 8b46367a7fc6..2744295b9e40 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java @@ -80,6 +80,7 @@ protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation a if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray scoreArr = output.rsubi(labels); scoreArr = scoreArr.muli(scoreArr); @@ -124,6 +125,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray dLda = output.subi(labels).muli(2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java index abbf43a9701d..41e071d9d015 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java @@ -81,6 +81,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray scoreArr; INDArray output = activationFn.getActivation(preOutput.dup(), true); scoreArr = output.rsubi(labels).divi(labels); @@ -126,6 +127,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray actSubPredicted = labels.sub(output); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java index dc5df182302c..11f472126e13 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java @@ -106,6 +106,7 @@ private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation act if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); if(activationFn instanceof ActivationSoftmax && softmaxClipEps > 0.0){ @@ -156,6 +157,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation } INDArray grad; INDArray output = activationFn.getActivation(preOutput.dup(), true); + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype if (activationFn instanceof ActivationSoftmax) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java index 6d136f289be9..4569506b9054 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java @@ -79,6 +79,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray scoreArr; INDArray output = activationFn.getActivation(preOutput.dup(), true); scoreArr = Transforms.log(output.addi(1.0).divi(labels.add(1.0)), false); @@ -123,6 +124,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); INDArray p1 = output.add(1.0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java index df95f5f737c9..f5f24667bc52 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java @@ -214,6 +214,7 @@ public double computeScore(INDArray labels, INDArray preOutput, IActivation acti */ @Override public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), false); MixtureDensityComponents mdc = extractComponents(output); INDArray scoreArr = negativeLogLikelihood(labels, mdc.alpha, mdc.mu, mdc.sigma); @@ -239,6 +240,7 @@ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivati */ @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype long nSamples = labels.size(0); INDArray output = activationFn.getActivation(preOutput.dup(), false); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java index 4441b6b0a794..0bee4031e8f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java @@ -84,6 +84,7 @@ private void calculate(INDArray labels, INDArray preOutput, IActivation activati + " number of outputs (nOut = " + preOutput.size(1) + ") "); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype final INDArray postOutput = activationFn.getActivation(preOutput.dup(), true); final INDArray positive = labels; @@ -93,21 +94,21 @@ private void calculate(INDArray labels, INDArray preOutput, IActivation activati long examples = positive.size(0); for (int i = 0; i < examples; i++) { - final INDArray locCfn = postOutput.getRow(i); + final INDArray locCfn = postOutput.getRow(i, true); final long[] shape = locCfn.shape(); - final INDArray locPositive = positive.getRow(i); - final INDArray locNegative = negative.getRow(i); + final INDArray locPositive = positive.getRow(i, true); + final INDArray locNegative = negative.getRow(i, true); final Double locNormFactor = normFactor.getDouble(i); final int outSetSize = locNegative.sumNumber().intValue(); if(outSetSize == 0 || outSetSize == locNegative.columns()){ if (scoreOutput != null) { - scoreOutput.getRow(i).assign(0); + scoreOutput.getRow(i, true).assign(0); } if (gradientOutput != null) { - gradientOutput.getRow(i).assign(0); + gradientOutput.getRow(i, true).assign(0); } }else { final INDArray operandA = Nd4j.ones(shape[1], shape[0]).mmul(locCfn); @@ -122,15 +123,15 @@ private void calculate(INDArray labels, INDArray preOutput, IActivation activati if (scoreOutput != null) { if (mask != null) { final INDArray perLabel = classificationDifferences.sum(0); - LossUtil.applyMask(perLabel, mask.getRow(i)); - perLabel.sum(scoreOutput.getRow(i), 0); + LossUtil.applyMask(perLabel, mask.getRow(i, true)); + perLabel.sum(scoreOutput.getRow(i, true), 0); } else { - classificationDifferences.sum(scoreOutput.getRow(i), 0, 1); + classificationDifferences.sum(scoreOutput.getRow(i, true), 0, 1); } } if (gradientOutput != null) { - gradientOutput.getRow(i).assign(classificationDifferences.sum(true, 0).addi(classificationDifferences.sum(true,1).transposei().negi())); + gradientOutput.getRow(i, true).assign(classificationDifferences.sum(true, 0).addi(classificationDifferences.sum(true,1).transposei().negi())); } } } @@ -171,6 +172,7 @@ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivati @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype if (labels.size(1) != preOutput.size(1)) { throw new IllegalArgumentException( "Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java index a0418504207c..f0ba9887c370 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java @@ -47,6 +47,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype /* mean of (yhat - y * log(yhat)) */ @@ -86,6 +87,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray yHat = activationFn.getActivation(preOutput.dup(), true); INDArray yDivyhat = labels.div(yHat); INDArray dLda = yDivyhat.rsubi(1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java index 7d8caab0e004..034c4b62fde3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java @@ -48,6 +48,7 @@ public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation acti if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype /* y_hat is -1 or 1 hinge loss is max(0,1-y_hat*y) */ @@ -85,6 +86,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); INDArray bitMaskRowCol = scoreArr.dup(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java index f87edc63ee2d..40a618e2acc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java @@ -56,6 +56,7 @@ private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation act if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray output = activationFn.getActivation(preOutput.dup(), true); @@ -91,7 +92,7 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } - + labels = labels.castTo(preOutput.dataType()); //No-op if already correct dtype INDArray dLda = labels.div(labels.size(1)); if (mask != null && LossUtil.isPerOutputMasking(dLda, mask)) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java index d3aae4c60d8d..b639258c1eb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java @@ -244,36 +244,12 @@ public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray ar return array; } - @Override - public INDArray create(@NonNull T arrayType, @NonNull int[] shape) { - enforceExistsAndActive(arrayType); - return create(arrayType, shape, Nd4j.order()); - } - - @Override - public INDArray create(@NonNull T arrayType, @NonNull long... shape) { - return create(arrayType, Nd4j.dataType(), shape); - } - @Override public INDArray create(@NonNull T arrayType, @NonNull DataType dataType, @NonNull long... shape) { enforceExistsAndActive(arrayType); return create(arrayType, dataType, shape, Nd4j.order()); } - @Override - public INDArray create(@NonNull T arrayType, @NonNull int[] shape, @NonNull char order) { - enforceExistsAndActive(arrayType); - try(MemoryWorkspace ws = notifyScopeBorrowed(arrayType)){ - return Nd4j.create(shape, order); - } - } - - @Override - public INDArray create(@NonNull T arrayType, @NonNull long[] shape, @NonNull char order) { - return create(arrayType, Nd4j.dataType(), shape, order); - } - @Override public INDArray create(@NonNull T arrayType, @NonNull DataType dataType, @NonNull long[] shape, @NonNull char order) { enforceExistsAndActive(arrayType); @@ -282,34 +258,11 @@ public INDArray create(@NonNull T arrayType, @NonNull DataType dataType, @NonNul } } - @Override - public INDArray createUninitialized(@NonNull T arrayType, @NonNull int[] shape) { - return createUninitialized(arrayType, shape, Nd4j.order()); - } - - @Override - public INDArray createUninitialized(@NonNull T arrayType, @NonNull long... shape) { - return createUninitialized(arrayType, shape, Nd4j.order()); - } - @Override public INDArray createUninitialized(T arrayType, DataType dataType, long... shape){ return createUninitialized(arrayType, dataType, shape, Nd4j.order()); } - @Override - public INDArray createUninitialized(@NonNull T arrayType, @NonNull int[] shape, char order) { - enforceExistsAndActive(arrayType); - try(MemoryWorkspace ws = notifyScopeBorrowed(arrayType)){ - return Nd4j.createUninitialized(shape, order); - } - } - - @Override - public INDArray createUninitialized(@NonNull T arrayType, @NonNull long[] shape, char order) { - return createUninitialized(arrayType, Nd4j.dataType(), shape, order); - } - @Override public INDArray createUninitialized(@NonNull T arrayType, @NonNull DataType dataType, @NonNull long[] shape, char order) { enforceExistsAndActive(arrayType); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java index 4bea6dec0ad2..af7359edfb12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java @@ -175,24 +175,6 @@ public interface WorkspaceMgr> { */ INDArray validateArrayLocation(T arrayType, INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) throws ND4JWorkspaceException; - /** - * Create an array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int...)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @return Created arary - */ - INDArray create(T arrayType, int[] shape); - - /** - * Create an array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int...)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @return Created arary - */ - INDArray create(T arrayType, long... shape); - /** * Create an array in the specified array type's workspace (or detached if none is specified). * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int...)}, other than the array location @@ -203,28 +185,6 @@ public interface WorkspaceMgr> { */ INDArray create(T arrayType, DataType dataType, long... shape); - - /** - * Create an array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int[],char)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @param ordering Order of the array - * @return Created arary - */ - INDArray create(T arrayType, int[] shape, char ordering); - - /** - * Create an array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int[],char)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @param ordering Order of the array - * @return Created arary - */ - INDArray create(T arrayType, long[] shape, char ordering); - - /** * Create an array in the specified array type's workspace (or detached if none is specified). * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#create(int[],char)}, other than the array location @@ -236,25 +196,6 @@ public interface WorkspaceMgr> { */ INDArray create(T arrayType, DataType dataType, long[] shape, char ordering); - - /** - * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int)} (int...)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @return Created array - */ - INDArray createUninitialized(T arrayType, int[] shape); - - /** - * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int)} (int...)}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @return Created array - */ - INDArray createUninitialized(T arrayType, long... shape); - /** * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int)} (int...)}, other than the array location @@ -265,26 +206,6 @@ public interface WorkspaceMgr> { */ INDArray createUninitialized(T arrayType, DataType dataType, long... shape); - /** - * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int[], char)}}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @param order Order of the array - * @return Created array - */ - INDArray createUninitialized(T arrayType, int[] shape, char order); - - /** - * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int[], char)}}, other than the array location - * @param arrayType Array type - * @param shape Shape - * @param order Order of the array - * @return Created array - */ - INDArray createUninitialized(T arrayType, long[] shape, char order); - /** * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int[], char)}}, other than the array location diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index d5b8da7f17e5..091010e928b3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -598,21 +598,27 @@ public INDArray concat(int dimension, INDArray... toConcat) { boolean allScalars = true; for (int i = 0; i < toConcat.length; i++) { + Preconditions.checkState(toConcat[i].rank() == outputShape.length, "Encountered different array ranks for concat: input[0].shape()=%ndShape, input[%s].shape()=%ndShape", + toConcat[0], i, toConcat[i]); + if (toConcat[i].isCompressed()) Nd4j.getCompressor().decompressi(toConcat[i]); - Preconditions.checkArgument(toConcat[i].dataType() == toConcat[0].dataType(), "All operands must have same data type"); + Preconditions.checkArgument(toConcat[i].dataType() == toConcat[0].dataType(), "All operands must have same data type: input 0 has type %s, input %s has type %s", + toConcat[0].dataType(), i, toConcat[i].dataType()); allScalars &= toConcat[i].rank() == 0; shapeInfoPointers.put(i, toConcat[i].shapeInfoDataBuffer().addressPointer()); dataPointers.put(i, toConcat[i].data().addressPointer()); sumAlongDim += toConcat[i].size(dimension); - for (int j = 0; j < toConcat[i].rank(); j++) + for (int j = 0; j < toConcat[i].rank(); j++) { + if (j != dimension && toConcat[i].size(j) != outputShape[j]) { throw new IllegalArgumentException( "Illegal concatenation at array " + i + " and shape element " + j); } + } //log.info("Shape[{}]: {}", i, Arrays.toString(toConcat[i].shapeInfoDataBuffer().asInt())); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index db264762f583..6127e5eff106 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; @@ -1200,4 +1201,51 @@ public void exceptionThrown_WhenConf3DInvalid() { .build()); } } + + @Test + public void testLayerNormMixedOrders(){ + Nd4j.getRandom().setSeed(12345); + INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); + INDArray gain = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); + INDArray bias = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); + + INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); + INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); + INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); + INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); + + //F in, F out case + Nd4j.exec(DynamicCustomOp.builder("layer_norm") + .addInputs(input, gain, bias) + .addOutputs(outFF) + .addIntegerArguments(1) //Axis + .build()); + + //C in, C out case + Nd4j.exec(DynamicCustomOp.builder("layer_norm") + .addInputs(input.dup('c'), gain.dup('c'), bias.dup('c')) + .addOutputs(outCC) + .addIntegerArguments(1) //Axis + .build()); + + assertEquals(outFF, outCC); //OK + + //C in, F out case + outFF.assign(0); + Nd4j.exec(DynamicCustomOp.builder("layer_norm") + .addInputs(input.dup('c'), gain.dup('c'), bias.dup('c')) + .addOutputs(outCF) + .addIntegerArguments(1) //Axis + .build()); + assertEquals(outCC, outCF); //Fails here + + //F in, C out case + outFF.assign(0); + Nd4j.exec(DynamicCustomOp.builder("layer_norm") + .addInputs(input, gain, bias) + .addOutputs(outFC) + .addIntegerArguments(1) //Axis + .build()); + assertEquals(outCC, outFC); //Fails here + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 9dc3d54709ff..1282a2bbc0c6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -136,7 +136,7 @@ public void testLoss2d() { //NOTE: both we and TF assume the inputs are normalized predictionsArr.diviColumnVector(predictionsArr.norm2(1)); labelsArr.diviColumnVector(labelsArr.norm2(1)); - expOut = predictionsArr.mul(labelsArr).sum(1).rsub(1.0); + expOut = predictionsArr.mul(labelsArr).sum(1).rsub(1.0).reshape(10,1); loss = sd.loss().cosineDistance("loss", labels, predictions, w, reduction, 1); break; case "hinge": @@ -227,7 +227,7 @@ public void testLoss2d() { loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, lblSmooth2); break; case "mpwse": - expOut = Nd4j.create(labelsArr.size(0)); + expOut = Nd4j.create(labelsArr.size(0), 1); double n = (double) labelsArr.size(1); for(int example = 0; example < labelsArr.size(0); example++){ for(int i = 0; i < labelsArr.size(1); i++){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index a5553715d2dd..d50a6f314340 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -818,7 +818,7 @@ public void testClipByNorm(){ norm2_1 = arr.norm2(1); assertEquals(Nd4j.ones(3), norm2_1); - INDArray scale = Nd4j.create(new double[]{1.1, 1.0, 0.9}, new int[]{3,1}); + INDArray scale = Nd4j.create(new double[]{1.1, 1.0, 0.9}, new int[]{3}); arr.muliColumnVector(scale); norm2_1 = arr.norm2(1); @@ -832,7 +832,7 @@ public void testClipByNorm(){ .build()); INDArray norm2_1b = out.norm2(1); - INDArray exp = Nd4j.create(new double[]{1.0, 1.0, norm2_1.getDouble(2)}, new int[]{3,1}); + INDArray exp = Nd4j.create(new double[]{1.0, 1.0, norm2_1.getDouble(2)}, new int[]{3}); assertEquals(exp, norm2_1b); } @@ -930,7 +930,7 @@ public void testClipByNorm0(){ INDArray norm2_0 = arr.norm2(0); arr.diviRowVector(norm2_0); - INDArray initNorm2 = Nd4j.create(new double[]{2.2, 2.1, 2.0, 1.9}, new int[]{1,4}); //Initial norm2s along dimension 0 + INDArray initNorm2 = Nd4j.create(new double[]{2.2, 2.1, 2.0, 1.9}, new int[]{4}); //Initial norm2s along dimension 0 arr.muliRowVector(initNorm2); norm2_0 = arr.norm2(0); @@ -983,8 +983,6 @@ public void testCumSum(){ {54, 42, 29, 15, 0} }); - INDArray axisArg = Nd4j.scalar(1); //Along dim 1 - for (boolean exclusive : new boolean[]{false, true}) { for (boolean reverse : new boolean[]{false, true}) { @@ -1006,7 +1004,7 @@ public void testCumSum(){ String err = OpValidation.validate(op); if(err != null){ // System.out.println(err); - failing.add(msg); + failing.add(msg + " (" + err + ")"); } } } @@ -1090,11 +1088,11 @@ public void testOneHot1(){ //Because it's on the diagonal, should be the same for all axis args... for( int i=-1; i<=0; i++ ) { - INDArray indicesArr = Nd4j.create(new double[]{0, 1, 2}); + INDArray indicesArr = Nd4j.createFromArray(0, 1, 2); int depth = 3; SameDiff sd = SameDiff.create(); - SDVariable indices = sd.var(indicesArr); + SDVariable indices = sd.constant(indicesArr); SDVariable oneHot = sd.oneHot(indices, depth, i, 1.0, 0.0, DataType.DOUBLE); INDArray exp = Nd4j.eye(3).castTo(DataType.DOUBLE); @@ -1131,10 +1129,10 @@ public void testOneHotOp(){ @Test public void testOneHot2() { - INDArray indicesArr = Nd4j.create(new double[]{0, 2, -1, 1}); + INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); SameDiff sd = SameDiff.create(); - SDVariable indices = sd.var("indices", indicesArr); + SDVariable indices = sd.constant("indices", indicesArr); int depth = 3; int axis = -1; SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.DOUBLE); @@ -1347,8 +1345,8 @@ public void testConfusionMatrix(){ SameDiff sd = SameDiff.create(); - SDVariable labels = sd.var("labels", Nd4j.create(new double[]{1, 2, 4}).castTo(dt)); - SDVariable predictions = sd.var("predictions", Nd4j.create(new double[]{2, 2, 4}).castTo(dt)); + SDVariable labels = sd.constant("labels", Nd4j.createFromArray(1, 2, 4)); + SDVariable predictions = sd.constant("predictions", Nd4j.createFromArray(2, 2, 4)); INDArray exp = Nd4j.create(new double[][]{ {0, 0, 0, 0, 0}, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 75d8ef800fd0..1a2a77ada81c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -807,8 +807,8 @@ public void testAllAny() { String err = OpValidation.validate(new TestCase(sd) .gradientCheck(false) - .expected(all, Nd4j.create(new boolean[]{expAll[i]})) - .expected(any, Nd4j.create(new boolean[]{expAny[i]}))); + .expected(all, Nd4j.scalar(expAll[i])) + .expected(any, Nd4j.scalar(expAny[i]))); assertNull(err); } @@ -866,7 +866,7 @@ public void testIndexAccum() { reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim); if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2}); else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3}); - else exp = Nd4j.create(new double[]{11}); + else exp = Nd4j.scalar(11.0); exp = exp.castTo(DataType.DOUBLE); name = "lastindex"; break; @@ -874,7 +874,7 @@ public void testIndexAccum() { reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim); if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3}); else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4}); - else exp = Nd4j.create(new double[]{12}); + else exp = Nd4j.scalar(12.0); exp = exp.castTo(DataType.DOUBLE); name = "matchConditionCount"; break; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index dc35938bb173..f1030cfba161 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -931,7 +931,7 @@ public void testTransposeOp(){ OpTestCase op = new OpTestCase(new Transpose(arr, out)); INDArray exp = arr.transpose(); - op.expectedOutput(0, exp.dup('c')); + op.expectedOutput(0, exp.dup('f')); String err = OpValidation.validate(op); assertNull(err); } @@ -1518,13 +1518,12 @@ public void testGatherNdSingle() { SDVariable idxs = sameDiff.constant("idxs", arr2); SDVariable result = sameDiff.gatherNd(x, idxs); // build expected output array - INDArray expected = Nd4j.zeros(1, 3); + INDArray expected = Nd4j.zeros(3); for (int i=0; i<3; i++){ INDArray idx = arr2.get(point(i), NDArrayIndex.all()); - expected.get(NDArrayIndex.point(0), point(i)).assign( - arr1.get(point(idx.getInt(0)), + expected.putScalar(i, arr1.get(point(idx.getInt(0)), point(idx.getInt(1)), - point(idx.getInt(2)))); + point(idx.getInt(2))).getDouble(0)); } assertEquals(expected, result.eval()); } @@ -1734,7 +1733,7 @@ public void testStridedSliceShrinkAxisMask() { assertEquals(inArr.get(point(0), all(), all()), slice.getArr()); assertEquals(inArr.get(point(2), all(), all()), slice2.getArr()); - assertEquals(inArr.get(point(1), point(2), interval(1, 5)), slice3.getArr()); + assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr()); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index cba9b3cd082a..ddc6eef5ca6c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -936,7 +936,7 @@ public void testTransforms() { break; case 69: t = sd.rank(in).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), Nd4j.create(new double[]{ia.rank()})).gradientCheck(false); + tc.expectedOutput(t.getVarName(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); break; case 70: t = sd.onesLike(in); @@ -1204,7 +1204,7 @@ public void testPairwiseTransforms() { String msg = "test: " + i + " - " + name; log.info("***** Starting test: {} *****", msg); - SDVariable loss = sd.mean("loss", t); + SDVariable loss = sd.mean("loss", t.castTo(DataType.DOUBLE)); sd.associateArrayWithVariable(ia, in1); sd.associateArrayWithVariable(ib, in2); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 72bdbc8ddfd3..f97aba3a69f8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -181,11 +181,11 @@ public void testSum() { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.FLOAT)).reshape(1, 4); SDVariable x = sameDiff.var("x", arr); - SDVariable result = sameDiff.sum(x, 1); //[1,4].sum(1) == [1,1] + SDVariable result = sameDiff.sum(x, 1); //[1,4].sum(1) == [1] sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - INDArray exp = Nd4j.scalar(arr.sumNumber().floatValue()); + INDArray exp = Nd4j.scalar(arr.sumNumber().floatValue()).reshape(1); INDArray resultArr = result.getArr(); assertEquals(exp, resultArr); } @@ -1047,7 +1047,7 @@ public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVa Map inputsSubset = new HashMap<>(); inputsSubset.put("y", inputs.get("y")); INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub"); - INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}); + INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4,1}); assertEquals(assertion, output); } @@ -1667,7 +1667,7 @@ public void testOnesLikeBackprop() { SDVariable out = sd.sum("oun", ones); INDArray outArr = sd.execAndEndResult(); - assertEquals(Nd4j.valueArrayOf(1, 12.0), outArr); + assertEquals(Nd4j.scalar(12.0), outArr); sd.execBackwards(Collections.emptyMap()); @@ -1769,7 +1769,7 @@ public void testPairwiseBooleanTransforms() { case 6: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); - t = sd.math().or(in1, in2); + t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.or(ia, ib); break; case 7: @@ -1783,13 +1783,13 @@ public void testPairwiseBooleanTransforms() { case 9: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); - t = sd.math().and(in1, in2); + t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.and(ia, ib); break; case 10: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); - t = sd.math().xor(in1, in2); + t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.xor(ia, ib); break; default: @@ -1819,7 +1819,7 @@ public void testBooleanChecks() { INDArray ia = Nd4j.randn(minibatch, nOut); SDVariable in1 = sd.var("in1", ia); - INDArray expOut = Nd4j.create(new boolean[]{true}); + INDArray expOut = Nd4j.scalar(true); SDVariable t; switch (i) { @@ -1974,14 +1974,14 @@ public void testSqueezeExpandChain() { @Test public void testConfusionMatrix() { - INDArray labels = Nd4j.create(new float[]{1, 2, 4}); - INDArray pred = Nd4j.create(new float[]{2, 2, 4}); - INDArray weights = Nd4j.create(new float[]{10, 100, 1000}); + INDArray labels = Nd4j.createFromArray(1, 2, 4); + INDArray pred = Nd4j.createFromArray(2, 2, 4); + INDArray weights = Nd4j.createFromArray(10, 100, 1000); Integer numClasses = 5; SameDiff sd = SameDiff.create(); - SDVariable labelsVar = sd.var("labels", labels); - SDVariable predictionsVar = sd.var("predictions", pred); - SDVariable weightsVar = sd.var("weights", weights); + SDVariable labelsVar = sd.constant("labels", labels); + SDVariable predictionsVar = sd.constant("predictions", pred); + SDVariable weightsVar = sd.constant("weights", weights); sd.math().confusionMatrix("cm", labelsVar, predictionsVar, numClasses, weightsVar); INDArray out = sd.execAndEndResult(); @@ -2267,20 +2267,20 @@ public void testGet() { INDArray arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L); SDVariable x = sd.var(arr); - INDArray expOut1 = arr.get(NDArrayIndex.point(4), NDArrayIndex.point(5)); + INDArray expOut1 = arr.get(NDArrayIndex.point(4), NDArrayIndex.point(5)).reshape(); SDVariable result1 = x.get(SDIndex.point(4), SDIndex.point(5)); assertEquals(expOut1, result1.eval()); - INDArray expOut2 = arr.get(NDArrayIndex.point(4), NDArrayIndex.all()); + INDArray expOut2 = arr.get(NDArrayIndex.point(4), NDArrayIndex.all()).reshape(10); SDVariable result2 = x.get(SDIndex.point(4), SDIndex.all()); assertEquals(expOut2, result2.eval()); - INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)); + INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)).reshape(5,10); SDVariable result3 = x.get(SDIndex.interval(3, 8)); assertEquals(expOut3, result3.eval()); - INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)); + INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5); SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8)); assertEquals(expOut4, result4.eval()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java index 0a5df61150fe..7ec752a14f89 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java @@ -66,7 +66,7 @@ public void testEvaluationCustomBinaryThreshold() { e.eval(labels, probs); e05.eval(labels, probs); - e05v2.eval(labels.getColumn(1), probs.getColumn(1)); //"single output binary" case + e05v2.eval(labels.getColumn(1, true), probs.getColumn(1, true)); //"single output binary" case for (Evaluation e2 : new Evaluation[] {e05, e05v2}) { assertEquals(e.accuracy(), e2.accuracy(), 1e-6); @@ -102,7 +102,7 @@ public void testEvaluationCustomBinaryThreshold() { //Check the same thing, but the single binary output case: Evaluation e025v2 = new Evaluation(0.25); - e025v2.eval(labels.getColumn(1), probs.getColumn(1)); + e025v2.eval(labels.getColumn(1, true), probs.getColumn(1, true)); assertEquals(ex2.accuracy(), e025v2.accuracy(), 1e-6); assertEquals(ex2.f1(), e025v2.f1(), 1e-6); @@ -177,7 +177,7 @@ public void testEvaluationBinaryCustomThreshold() { EvaluationBinary eb05v2 = new EvaluationBinary(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2})); for (int i = 0; i < nExamples; i++) { - eb05v2.eval(labels.getRow(i), probs.getRow(i)); + eb05v2.eval(labels.getRow(i, true), probs.getRow(i, true)); } for (EvaluationBinary eb2 : new EvaluationBinary[] {eb05, eb05v2}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java index 4ee21d6034df..2ccef06bf881 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java @@ -91,39 +91,60 @@ public void testEval() { @Test public void testEval2() { - //Confusion matrix: - //actual 0 20 3 - //actual 1 10 5 - - Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1")); - INDArray predicted0 = Nd4j.create(new double[] {1, 0}, new long[]{1, 2}); - INDArray predicted1 = Nd4j.create(new double[] {0, 1}, new long[]{1, 2}); - INDArray actual0 = Nd4j.create(new double[] {1, 0}, new long[]{1, 2}); - INDArray actual1 = Nd4j.create(new double[] {0, 1}, new long[]{1, 2}); - for (int i = 0; i < 20; i++) { - evaluation.eval(actual0, predicted0); - } + DataType dtypeBefore = Nd4j.defaultFloatingPointType(); + Evaluation first = null; + String sFirst = null; + try { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); + for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + + //Confusion matrix: + //actual 0 20 3 + //actual 1 10 5 + + Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1")); + INDArray predicted0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype); + INDArray predicted1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype); + INDArray actual0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype); + INDArray actual1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype); + for (int i = 0; i < 20; i++) { + evaluation.eval(actual0, predicted0); + } - for (int i = 0; i < 3; i++) { - evaluation.eval(actual0, predicted1); - } + for (int i = 0; i < 3; i++) { + evaluation.eval(actual0, predicted1); + } - for (int i = 0; i < 10; i++) { - evaluation.eval(actual1, predicted0); - } + for (int i = 0; i < 10; i++) { + evaluation.eval(actual1, predicted0); + } - for (int i = 0; i < 5; i++) { - evaluation.eval(actual1, predicted1); - } + for (int i = 0; i < 5; i++) { + evaluation.eval(actual1, predicted1); + } - assertEquals(20, evaluation.truePositives().get(0), 0); - assertEquals(3, evaluation.falseNegatives().get(0), 0); - assertEquals(10, evaluation.falsePositives().get(0), 0); - assertEquals(5, evaluation.trueNegatives().get(0), 0); + assertEquals(20, evaluation.truePositives().get(0), 0); + assertEquals(3, evaluation.falseNegatives().get(0), 0); + assertEquals(10, evaluation.falsePositives().get(0), 0); + assertEquals(5, evaluation.trueNegatives().get(0), 0); - assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6); + assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6); - System.out.println(evaluation.confusionToString()); + String s = evaluation.stats(); + + if(first == null) { + first = evaluation; + sFirst = s; + } else { + assertEquals(first, evaluation); + assertEquals(sFirst, s); + } + } + } + } finally { + Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); + } } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java index d489f5dbf9e7..370499e2b52c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java @@ -20,6 +20,7 @@ import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -45,65 +46,86 @@ public char ordering() { @Test public void testEvaluationBinary() { //Compare EvaluationBinary to Evaluation class - - Nd4j.getRandom().setSeed(12345); - - int nExamples = 50; - int nOut = 4; - int[] shape = {nExamples, nOut}; - - INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); - - INDArray predicted = Nd4j.rand(shape); - INDArray binaryPredicted = predicted.gt(0.5); - - EvaluationBinary eb = new EvaluationBinary(); - eb.eval(labels, predicted); - - System.out.println(eb.stats()); - - double eps = 1e-6; - for (int i = 0; i < nOut; i++) { - INDArray lCol = labels.getColumn(i); - INDArray pCol = predicted.getColumn(i); - INDArray bpCol = binaryPredicted.getColumn(i); - - int countCorrect = 0; - int tpCount = 0; - int tnCount = 0; - for (int j = 0; j < lCol.length(); j++) { - if (lCol.getDouble(j) == bpCol.getDouble(j)) { - countCorrect++; - if (lCol.getDouble(j) == 1) { - tpCount++; - } else { - tnCount++; + DataType dtypeBefore = Nd4j.defaultFloatingPointType(); + EvaluationBinary first = null; + String sFirst = null; + try { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); + for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + + Nd4j.getRandom().setSeed(12345); + + int nExamples = 50; + int nOut = 4; + long[] shape = {nExamples, nOut}; + + INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(lpDtype, shape), 0.5)); + + INDArray predicted = Nd4j.rand(lpDtype, shape); + INDArray binaryPredicted = predicted.gt(0.5); + + EvaluationBinary eb = new EvaluationBinary(); + eb.eval(labels, predicted); + + //System.out.println(eb.stats()); + + double eps = 1e-6; + for (int i = 0; i < nOut; i++) { + INDArray lCol = labels.getColumn(i,true); + INDArray pCol = predicted.getColumn(i,true); + INDArray bpCol = binaryPredicted.getColumn(i,true); + + int countCorrect = 0; + int tpCount = 0; + int tnCount = 0; + for (int j = 0; j < lCol.length(); j++) { + if (lCol.getDouble(j) == bpCol.getDouble(j)) { + countCorrect++; + if (lCol.getDouble(j) == 1) { + tpCount++; + } else { + tnCount++; + } + } + } + double acc = countCorrect / (double) lCol.length(); + + Evaluation e = new Evaluation(); + e.eval(lCol, pCol); + + assertEquals(acc, eb.accuracy(i), eps); + assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps); + assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps); + assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps); + assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps); + assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps); + assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps); + + + assertEquals(tpCount, eb.truePositives(i)); + assertEquals(tnCount, eb.trueNegatives(i)); + + assertEquals((int) e.truePositives().get(1), eb.truePositives(i)); + assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i)); + assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i)); + assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i)); + + assertEquals(nExamples, eb.totalCount(i)); + + String s = eb.stats(); + if(first == null) { + first = eb; + sFirst = s; + } else { + assertEquals(first, eb); + assertEquals(sFirst, s); + } } } } - double acc = countCorrect / (double) lCol.length(); - - Evaluation e = new Evaluation(); - e.eval(lCol, pCol); - - assertEquals(acc, eb.accuracy(i), eps); - assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps); - assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps); - assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps); - assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps); - assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps); - assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps); - - - assertEquals(tpCount, eb.truePositives(i)); - assertEquals(tnCount, eb.trueNegatives(i)); - - assertEquals((int) e.truePositives().get(1), eb.truePositives(i)); - assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i)); - assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i)); - assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i)); - - assertEquals(nExamples, eb.totalCount(i)); + } finally { + Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 99001a17ecc1..2c4de944223e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -17,6 +17,7 @@ package org.nd4j.evaluation; import org.junit.Test; +import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -30,6 +31,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Created by Alex on 05/07/2017. @@ -48,63 +50,92 @@ public char ordering() { @Test public void testReliabilityDiagram() { + DataType dtypeBefore = Nd4j.defaultFloatingPointType(); + EvaluationCalibration first = null; + String sFirst = null; + try { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); + for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { - //Test using 5 bins - format: binary softmax-style output - //Note: no values fall in fourth bin - //[0, 0.2) - INDArray bin0Probs = Nd4j.create(new double[][] {{1.0, 0.0}, {0.9, 0.1}, {0.85, 0.15}}); - INDArray bin0Labels = Nd4j.create(new double[][] {{1.0, 0.0}, {1.0, 0.0}, {0.0, 1.0}}); + //Test using 5 bins - format: binary softmax-style output + //Note: no values fall in fourth bin - //[0.2, 0.4) - INDArray bin1Probs = Nd4j.create(new double[][] {{0.8, 0.2}, {0.7, 0.3}, {0.65, 0.35}}); - INDArray bin1Labels = Nd4j.create(new double[][] {{1.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}}); + //[0, 0.2) + INDArray bin0Probs = Nd4j.create(new double[][]{{1.0, 0.0}, {0.9, 0.1}, {0.85, 0.15}}).castTo(lpDtype); + INDArray bin0Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {1.0, 0.0}, {0.0, 1.0}}).castTo(lpDtype); - //[0.4, 0.6) - INDArray bin2Probs = Nd4j.create(new double[][] {{0.59, 0.41}, {0.5, 0.5}, {0.45, 0.55}}); - INDArray bin2Labels = Nd4j.create(new double[][] {{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}}); + //[0.2, 0.4) + INDArray bin1Probs = Nd4j.create(new double[][]{{0.80, 0.20}, {0.7, 0.3}, {0.65, 0.35}}).castTo(lpDtype); + INDArray bin1Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}}).castTo(lpDtype); - //[0.6, 0.8) - //Empty + //[0.4, 0.6) + INDArray bin2Probs = Nd4j.create(new double[][]{{0.59, 0.41}, {0.5, 0.5}, {0.45, 0.55}}).castTo(lpDtype); + INDArray bin2Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype); - //[0.8, 1.0] - INDArray bin4Probs = Nd4j.create(new double[][] {{0.0, 1.0}, {0.1, 0.9}}); - INDArray bin4Labels = Nd4j.create(new double[][] {{0.0, 1.0}, {0.0, 1.0}}); + //[0.6, 0.8) + //Empty + //[0.8, 1.0] + INDArray bin4Probs = Nd4j.create(new double[][]{{0.0, 1.0}, {0.1, 0.9}}).castTo(lpDtype); + INDArray bin4Labels = Nd4j.create(new double[][]{{0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype); - INDArray probs = Nd4j.vstack(bin0Probs, bin1Probs, bin2Probs, bin4Probs); - INDArray labels = Nd4j.vstack(bin0Labels, bin1Labels, bin2Labels, bin4Labels); - EvaluationCalibration ec = new EvaluationCalibration(5, 5); - ec.eval(labels, probs); - - for (int i = 0; i < 1; i++) { - double[] avgBinProbsClass; - double[] fracPos; - if (i == 0) { - //Class 0: needs to be handled a little differently, due to threshold/edge cases (0.8, etc) - avgBinProbsClass = new double[] {0.05, (0.59 + 0.5 + 0.45) / 3, (0.65 + 0.7) / 2.0, - (0.8 + 0.85 + 0.9 + 1.0) / 4}; - fracPos = new double[] {0.0 / 2.0, 1.0 / 3, 1.0 / 2, 3.0 / 4}; - } else { - avgBinProbsClass = new double[] {bin0Probs.getColumn(i).meanNumber().doubleValue(), - bin1Probs.getColumn(i).meanNumber().doubleValue(), - bin2Probs.getColumn(i).meanNumber().doubleValue(), - bin4Probs.getColumn(i).meanNumber().doubleValue()}; - - fracPos = new double[] {bin0Labels.getColumn(i).sumNumber().doubleValue() / bin0Labels.size(0), - bin1Labels.getColumn(i).sumNumber().doubleValue() / bin1Labels.size(0), - bin2Labels.getColumn(i).sumNumber().doubleValue() / bin2Labels.size(0), - bin4Labels.getColumn(i).sumNumber().doubleValue() / bin4Labels.size(0)}; - } + INDArray probs = Nd4j.vstack(bin0Probs, bin1Probs, bin2Probs, bin4Probs); + INDArray labels = Nd4j.vstack(bin0Labels, bin1Labels, bin2Labels, bin4Labels); + + EvaluationCalibration ec = new EvaluationCalibration(5, 5); + ec.eval(labels, probs); + + for (int i = 0; i < 1; i++) { + double[] avgBinProbsClass; + double[] fracPos; + if (i == 0) { + //Class 0: needs to be handled a little differently, due to threshold/edge cases (0.8, etc) + avgBinProbsClass = new double[]{0.05, (0.59 + 0.5 + 0.45) / 3, (0.65 + 0.7) / 2.0, + (0.8 + 0.85 + 0.9 + 1.0) / 4}; + fracPos = new double[]{0.0 / 2.0, 1.0 / 3, 1.0 / 2, 3.0 / 4}; + } else { + avgBinProbsClass = new double[]{bin0Probs.getColumn(i).meanNumber().doubleValue(), + bin1Probs.getColumn(i).meanNumber().doubleValue(), + bin2Probs.getColumn(i).meanNumber().doubleValue(), + bin4Probs.getColumn(i).meanNumber().doubleValue()}; + + fracPos = new double[]{bin0Labels.getColumn(i).sumNumber().doubleValue() / bin0Labels.size(0), + bin1Labels.getColumn(i).sumNumber().doubleValue() / bin1Labels.size(0), + bin2Labels.getColumn(i).sumNumber().doubleValue() / bin2Labels.size(0), + bin4Labels.getColumn(i).sumNumber().doubleValue() / bin4Labels.size(0)}; + } - org.nd4j.evaluation.curves.ReliabilityDiagram rd = ec.getReliabilityDiagram(i); + org.nd4j.evaluation.curves.ReliabilityDiagram rd = ec.getReliabilityDiagram(i); - double[] x = rd.getMeanPredictedValueX(); - double[] y = rd.getFractionPositivesY(); + double[] x = rd.getMeanPredictedValueX(); + double[] y = rd.getFractionPositivesY(); - assertArrayEquals(avgBinProbsClass, x, 1e-6); - assertArrayEquals(fracPos, y, 1e-6); + assertArrayEquals(avgBinProbsClass, x, 1e-3); + assertArrayEquals(fracPos, y, 1e-3); + + String s = ec.stats(); + if(first == null) { + first = ec; + sFirst = s; + } else { +// assertEquals(first, ec); + assertEquals(sFirst, s); + assertTrue(first.getRDiagBinPosCount().equalsWithEps(ec.getRDiagBinPosCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); //Lower precision due to fload + assertTrue(first.getRDiagBinTotalCount().equalsWithEps(ec.getRDiagBinTotalCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); + assertTrue(first.getRDiagBinSumPredictions().equalsWithEps(ec.getRDiagBinSumPredictions(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); + assertArrayEquals(first.getLabelCountsEachClass(), ec.getLabelCountsEachClass()); + assertArrayEquals(first.getPredictionCountsEachClass(), ec.getPredictionCountsEachClass()); + assertTrue(first.getProbHistogramOverall().equalsWithEps(ec.getProbHistogramOverall(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); + assertTrue(first.getProbHistogramByLabelClass().equalsWithEps(ec.getProbHistogramByLabelClass(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); + } + } + } + } + } finally { + Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java index 9e7a2fb474d5..717cd139693a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java @@ -17,10 +17,13 @@ package org.nd4j.evaluation; import org.junit.Test; +import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.curves.PrecisionRecallCurve; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -46,63 +49,100 @@ public char ordering() { public void testROCBinary() { //Compare ROCBinary to ROC class - Nd4j.getRandom().setSeed(12345); - - int nExamples = 50; - int nOut = 4; - int[] shape = {nExamples, nOut}; - - for (int thresholdSteps : new int[] {30, 0}) { //0 == exact - - INDArray labels = - Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5)); - - INDArray predicted = Nd4j.rand(shape); - INDArray binaryPredicted = predicted.gt(0.5); - - ROCBinary rb = new ROCBinary(thresholdSteps); - - for (int xe = 0; xe < 2; xe++) { - rb.eval(labels, predicted); - - System.out.println(rb.stats()); - - double eps = 1e-6; - for (int i = 0; i < nOut; i++) { - INDArray lCol = labels.getColumn(i); - INDArray pCol = predicted.getColumn(i); - - - ROC r = new ROC(thresholdSteps); - r.eval(lCol, pCol); - - double aucExp = r.calculateAUC(); - double auc = rb.calculateAUC(i); - - assertEquals(aucExp, auc, eps); - - long apExp = r.getCountActualPositive(); - long ap = rb.getCountActualPositive(i); - assertEquals(ap, apExp); - - long anExp = r.getCountActualNegative(); - long an = rb.getCountActualNegative(i); - assertEquals(anExp, an); - - PrecisionRecallCurve pExp = r.getPrecisionRecallCurve(); - PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i); - - assertEquals(pExp, p); + DataType dtypeBefore = Nd4j.defaultFloatingPointType(); + ROCBinary first30 = null; + ROCBinary first0 = null; + String sFirst30 = null; + String sFirst0 = null; + try { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { +// for (DataType globalDtype : new DataType[]{DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); + for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype; + + int nExamples = 50; + int nOut = 4; + long[] shape = {nExamples, nOut}; + + for (int thresholdSteps : new int[]{30, 0}) { //0 == exact + + Nd4j.getRandom().setSeed(12345); + INDArray labels = + Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype); + + Nd4j.getRandom().setSeed(12345); + INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype); + + ROCBinary rb = new ROCBinary(thresholdSteps); + + for (int xe = 0; xe < 2; xe++) { + rb.eval(labels, predicted); + + //System.out.println(rb.stats()); + + double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6; + for (int i = 0; i < nOut; i++) { + INDArray lCol = labels.getColumn(i, true); + INDArray pCol = predicted.getColumn(i, true); + + + ROC r = new ROC(thresholdSteps); + r.eval(lCol, pCol); + + double aucExp = r.calculateAUC(); + double auc = rb.calculateAUC(i); + + assertEquals(msg, aucExp, auc, eps); + + long apExp = r.getCountActualPositive(); + long ap = rb.getCountActualPositive(i); + assertEquals(msg, ap, apExp); + + long anExp = r.getCountActualNegative(); + long an = rb.getCountActualNegative(i); + assertEquals(anExp, an); + + PrecisionRecallCurve pExp = r.getPrecisionRecallCurve(); + PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i); + + assertEquals(msg, pExp, p); + } + + String s = rb.stats(); + + if(thresholdSteps == 0){ + if(first0 == null) { + first0 = rb; + sFirst0 = s; + } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 + assertEquals(msg, sFirst0, s); + assertEquals(first0, rb); + } + } else { + if(first30 == null) { + first30 = rb; + sFirst30 = s; + } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 + assertEquals(msg, sFirst30, s); + assertEquals(first30, rb); + } + } + +// rb.reset(); + rb = new ROCBinary(thresholdSteps); + } + } } - - rb.reset(); } + } finally { + Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } } @Test public void testRocBinaryMerging() { - for (int nSteps : new int[] {30, 0}) { //0 == exact + for (int nSteps : new int[]{30, 0}) { //0 == exact int nOut = 4; int[] shape1 = {30, nOut}; int[] shape2 = {50, nOut}; @@ -133,21 +173,21 @@ public void testRocBinaryMerging() { @Test public void testROCBinaryPerOutputMasking() { - for (int nSteps : new int[] {30, 0}) { //0 == exact + for (int nSteps : new int[]{30, 0}) { //0 == exact //Here: we'll create a test array, then insert some 'masked out' values, and ensure we get the same results - INDArray mask = Nd4j.create(new double[][] {{1, 1, 1}, {0, 1, 1}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}}); + INDArray mask = Nd4j.create(new double[][]{{1, 1, 1}, {0, 1, 1}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}}); - INDArray labels = Nd4j.create(new double[][] {{0, 1, 0}, {1, 1, 0}, {0, 1, 1}, {0, 0, 1}, {1, 1, 1}}); + INDArray labels = Nd4j.create(new double[][]{{0, 1, 0}, {1, 1, 0}, {0, 1, 1}, {0, 0, 1}, {1, 1, 1}}); //Remove the 1 masked value for each column - INDArray labelsExMasked = Nd4j.create(new double[][] {{0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {1, 1, 1}}); + INDArray labelsExMasked = Nd4j.create(new double[][]{{0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {1, 1, 1}}); - INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.4, 0.6}, {0.2, 0.8, 0.4}, {0.6, 0.1, 0.1}, - {0.3, 0.7, 0.2}, {0.8, 0.6, 0.6}}); + INDArray predicted = Nd4j.create(new double[][]{{0.9, 0.4, 0.6}, {0.2, 0.8, 0.4}, {0.6, 0.1, 0.1}, + {0.3, 0.7, 0.2}, {0.8, 0.6, 0.6}}); INDArray predictedExMasked = Nd4j.create( - new double[][] {{0.9, 0.4, 0.6}, {0.6, 0.8, 0.4}, {0.3, 0.7, 0.1}, {0.8, 0.6, 0.6}}); + new double[][]{{0.9, 0.4, 0.6}, {0.6, 0.8, 0.4}, {0.3, 0.7, 0.1}, {0.8, 0.6, 0.6}}); ROCBinary rbMasked = new ROCBinary(nSteps); rbMasked.eval(labels, predicted, mask); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index 581f126393a5..4d01d67eacf5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -17,8 +17,10 @@ package org.nd4j.evaluation; import org.junit.Test; +import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -82,33 +84,56 @@ public void testPerfectPredictions() { @Test public void testKnownValues() { - double[][] labelsD = new double[][] {{1, 2, 3}, {0.1, 0.2, 0.3}, {6, 5, 4}}; - double[][] predictedD = new double[][] {{2.5, 3.2, 3.8}, {2.15, 1.3, -1.2}, {7, 4.5, 3}}; - - double[] expMSE = {2.484166667, 0.966666667, 1.296666667}; - double[] expMAE = {1.516666667, 0.933333333, 1.1}; - double[] expRSE = {0.368813923, 0.246598639, 0.530937216}; - double[] expCorrs = {0.997013483, 0.968619605, 0.915603032}; - double[] expR2 = {0.63118608, 0.75340136 , 0.46906278}; - - INDArray labels = Nd4j.create(labelsD); - INDArray predicted = Nd4j.create(predictedD); - - RegressionEvaluation eval = new RegressionEvaluation(3); - - for (int xe = 0; xe < 2; xe++) { - eval.eval(labels, predicted); - - for (int col = 0; col < 3; col++) { - assertEquals(expMSE[col], eval.meanSquaredError(col), 1e-5); - assertEquals(expMAE[col], eval.meanAbsoluteError(col), 1e-5); - assertEquals(Math.sqrt(expMSE[col]), eval.rootMeanSquaredError(col), 1e-5); - assertEquals(expRSE[col], eval.relativeSquaredError(col), 1e-5); - assertEquals(expCorrs[col], eval.pearsonCorrelation(col), 1e-5); - assertEquals(expR2[col], eval.rSquared(col), 1e-5); - } - eval.reset(); + DataType dtypeBefore = Nd4j.defaultFloatingPointType(); + RegressionEvaluation first = null; + String sFirst = null; + try { + for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); + for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + + double[][] labelsD = new double[][]{{1, 2, 3}, {0.1, 0.2, 0.3}, {6, 5, 4}}; + double[][] predictedD = new double[][]{{2.5, 3.2, 3.8}, {2.15, 1.3, -1.2}, {7, 4.5, 3}}; + + double[] expMSE = {2.484166667, 0.966666667, 1.296666667}; + double[] expMAE = {1.516666667, 0.933333333, 1.1}; + double[] expRSE = {0.368813923, 0.246598639, 0.530937216}; + double[] expCorrs = {0.997013483, 0.968619605, 0.915603032}; + double[] expR2 = {0.63118608, 0.75340136, 0.46906278}; + + INDArray labels = Nd4j.create(labelsD).castTo(lpDtype); + INDArray predicted = Nd4j.create(predictedD).castTo(lpDtype); + + RegressionEvaluation eval = new RegressionEvaluation(3); + + for (int xe = 0; xe < 2; xe++) { + eval.eval(labels, predicted); + + for (int col = 0; col < 3; col++) { + assertEquals(expMSE[col], eval.meanSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + assertEquals(expMAE[col], eval.meanAbsoluteError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + assertEquals(Math.sqrt(expMSE[col]), eval.rootMeanSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + assertEquals(expRSE[col], eval.relativeSquaredError(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + assertEquals(expCorrs[col], eval.pearsonCorrelation(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + assertEquals(expR2[col], eval.rSquared(col), lpDtype == DataType.HALF ? 1e-2 : 1e-4); + } + + String s = eval.stats(); + if(first == null) { + first = eval; + sFirst = s; + } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 + assertEquals(sFirst, s); + assertEquals(first, eval); + } + + eval = new RegressionEvaluation(3); + } + } + } + } finally { + Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index 3e847d2bd04c..7dcfe6a350e2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -229,10 +229,10 @@ public List processSubgraph(SameDiff sd, SubGraph subGraph) { // System.out.println(softmax); // System.out.println(Arrays.toString(softmax.data().asFloat())); - INDArray exp0 = Nd4j.createFromArray(0.99860954f, 0.0013904407f).reshape(1, 2); - INDArray exp1 = Nd4j.createFromArray(0.0005442508f, 0.99945575f).reshape(1, 2); - INDArray exp2 = Nd4j.createFromArray(0.9987967f, 0.0012033002f).reshape(1, 2); - INDArray exp3 = Nd4j.createFromArray(0.97409827f, 0.025901746f).reshape(1, 2); + INDArray exp0 = Nd4j.createFromArray(0.99860954f, 0.0013904407f); + INDArray exp1 = Nd4j.createFromArray(0.0005442508f, 0.99945575f); + INDArray exp2 = Nd4j.createFromArray(0.9987967f, 0.0012033002f); + INDArray exp3 = Nd4j.createFromArray(0.97409827f, 0.025901746f); assertEquals(exp0, softmax.getRow(0)); assertEquals(exp1, softmax.getRow(1)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index 717034306097..e676761fa421 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -116,8 +116,8 @@ public void testIndexingColVec() { assertEquals(i + 1,rowVector.get(NDArrayIndex.point(0), NDArrayIndex.interval(i, j)).getInt(0)); assertEquals(i + 1,colVector.get(NDArrayIndex.interval(i, j), NDArrayIndex.point(0)).getInt(0)); System.out.println("Making sure index interval will not crash with begin/end vals..."); - jj = colVector.get(NDArrayIndex.interval(i, i + 10)); - jj = colVector.get(NDArrayIndex.interval(i, i + 10)); + jj = colVector.get(NDArrayIndex.interval(i, i + 1)); + jj = colVector.get(NDArrayIndex.interval(i, i + 1)); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java index 1e457b10cba8..a9cadc14fc1a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java @@ -129,7 +129,7 @@ public void testRowVectorGemm() { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE); INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE); INDArray result = linspace.mmul(other); - INDArray assertion = Nd4j.create(new double[] {30., 70., 110., 150.}); + INDArray assertion = Nd4j.create(new double[] {30., 70., 110., 150.}, new int[]{1,4}); assertEquals(assertion, result); } @@ -380,7 +380,7 @@ public void testScalar() { INDArray a = Nd4j.scalar(1.0f); assertEquals(true, a.isScalar()); - INDArray n = Nd4j.create(new float[] {1.0f}, new long[] {1, 1}); + INDArray n = Nd4j.create(new float[] {1.0f}, new long[0]); assertEquals(n, a); assertTrue(n.isScalar()); } @@ -477,9 +477,7 @@ public void testAssignOffset() { @Test public void testColumns() { INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); - INDArray column2 = arr.getColumn(0); - //assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, column2.shape())); - INDArray column = Nd4j.create(new double[] {1, 2, 3}, new long[] {1, 3}); + INDArray column = Nd4j.create(new double[] {1, 2, 3}); arr.putColumn(0, column); INDArray firstColumn = arr.getColumn(0); @@ -487,14 +485,14 @@ public void testColumns() { assertEquals(column, firstColumn); - INDArray column1 = Nd4j.create(new double[] {4, 5, 6}, new long[] {1, 3}); + INDArray column1 = Nd4j.create(new double[] {4, 5, 6}); arr.putColumn(1, column1); INDArray testRow1 = arr.getColumn(1); assertEquals(column1, testRow1); INDArray evenArr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); - INDArray put = Nd4j.create(new double[] {5, 6}, new long[] {1, 2}); + INDArray put = Nd4j.create(new double[] {5, 6}); evenArr.putColumn(1, put); INDArray testColumn = evenArr.getColumn(1); assertEquals(put, testColumn); @@ -502,12 +500,12 @@ public void testColumns() { INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}).castTo(DataType.DOUBLE); INDArray column23 = n.getColumn(0); - INDArray column12 = Nd4j.create(new double[] {1, 2}, new long[] {1, 2}); + INDArray column12 = Nd4j.create(new double[] {1, 2}); assertEquals(column23, column12); INDArray column0 = n.getColumn(1); - INDArray column01 = Nd4j.create(new double[] {3, 4}, new long[] {1, 2}); + INDArray column01 = Nd4j.create(new double[] {3, 4}); assertEquals(column0, column01); @@ -540,12 +538,12 @@ public void testPutRow() { INDArray nLast = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}).castTo(DataType.DOUBLE); INDArray row = nLast.getRow(1); - INDArray row1 = Nd4j.create(new double[] {2, 4}, new long[] {1, 2}); + INDArray row1 = Nd4j.create(new double[] {2, 4}); assertEquals(row, row1); INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); - INDArray evenRow = Nd4j.create(new double[] {1, 2}, new long[] {1, 2}); + INDArray evenRow = Nd4j.create(new double[] {1, 2}); arr.putRow(0, evenRow); INDArray firstRow = arr.getRow(0); assertEquals(true, Shape.shapeEquals(new long[] {1, 2}, firstRow.shape())); @@ -553,7 +551,7 @@ public void testPutRow() { assertEquals(evenRow, testRowEven); - INDArray row12 = Nd4j.create(new double[] {5, 6}, new long[] {1, 2}); + INDArray row12 = Nd4j.create(new double[] {5, 6}); arr.putRow(1, row12); assertEquals(true, Shape.shapeEquals(new long[] {1, 2}, arr.getRow(0).shape())); INDArray testRow1 = arr.getRow(1); @@ -561,15 +559,14 @@ public void testPutRow() { INDArray multiSliceTest = Nd4j.create(Nd4j.linspace(1, 16, 16, DataType.DOUBLE).data(), new long[] {4, 2, 2}).castTo(DataType.DOUBLE); - INDArray test = Nd4j.create(new double[] {2, 10}, new long[] {1, 2}); - INDArray test2 = Nd4j.create(new double[] {6, 14}, new long[] {1, 2}); + INDArray test = Nd4j.create(new double[] {2, 10}); + INDArray test2 = Nd4j.create(new double[] {6, 14}); INDArray multiSliceRow1 = multiSliceTest.slice(1).getRow(0); INDArray multiSliceRow2 = multiSliceTest.slice(1).getRow(1); assertEquals(test, multiSliceRow1); assertEquals(test2, multiSliceRow2); - } @@ -601,10 +598,8 @@ public void testMmulF() { INDArray innerProduct = n.mmul(transposed); - INDArray scalar = Nd4j.scalar(385.0); + INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); assertEquals(getFailureMessage(), scalar, innerProduct); - - } @@ -882,7 +877,7 @@ public void testNumVectorsAlongDimension() { @Test public void testBroadCast() { - INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); + INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray broadCasted = n.broadcast(5, 4); for (int i = 0; i < broadCasted.rows(); i++) { assertEquals(n, broadCasted.getRow(i)); @@ -892,9 +887,9 @@ public void testBroadCast() { assertEquals(broadCasted, broadCast2); - INDArray columnBroadcast = n.transpose().broadcast(4, 5); + INDArray columnBroadcast = n.reshape(4,1).broadcast(4, 5); for (int i = 0; i < columnBroadcast.columns(); i++) { - assertEquals(columnBroadcast.getColumn(i), n.transpose()); + assertEquals(columnBroadcast.getColumn(i), n.reshape(4)); } INDArray fourD = Nd4j.create(1, 2, 1, 1); @@ -1065,13 +1060,13 @@ public void testGetColumnGetRow() { INDArray row = Nd4j.ones(1, 5); for (int i = 0; i < 5; i++) { INDArray col = row.getColumn(i); - assertArrayEquals(col.shape(), new long[] {1,1}); + assertArrayEquals(col.shape(), new long[] {1}); } INDArray col = Nd4j.ones(5, 1); for (int i = 0; i < 5; i++) { INDArray row2 = col.getRow(i); - assertArrayEquals(row2.shape(), new long[] {1, 1}); + assertArrayEquals(new long[] {1}, row2.shape()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index f2df3db296a0..6cf581be6970 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -23,10 +23,7 @@ import org.apache.commons.io.FilenameUtils; import org.apache.commons.math3.stat.descriptive.rank.Percentile; import org.apache.commons.math3.util.FastMath; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.*; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.imports.TFGraphs.NodeReader; @@ -183,15 +180,14 @@ public void testDiag() { @Test public void testGetRowEdgeCase() { - INDArray orig = Nd4j.linspace(1,300,300, DataType.DOUBLE).reshape('c', 100, 3); - INDArray col = orig.getColumn(0); + INDArray col = orig.getColumn(0).reshape(100, 1); for( int i = 0; i < 100; i++) { INDArray row = col.getRow(i); INDArray rowDup = row.dup(); - double d = orig.getDouble(i,0); - double d2 = col.getDouble(i, 0); + double d = orig.getDouble(i, 0); + double d2 = col.getDouble(i); double dRowDup = rowDup.getDouble(0); double dRow = row.getDouble(0); @@ -506,7 +502,7 @@ public void testLength() { values2.put(1, 1, 2); - INDArray expected = Nd4j.repeat(Nd4j.scalar(DataType.DOUBLE, 2).reshape(1, 1), 2).reshape(2, 1); + INDArray expected = Nd4j.repeat(Nd4j.scalar(DataType.DOUBLE, 2).reshape(1, 1), 2).reshape(2); val accum = new EuclideanDistance(values, values2); accum.setDimensions(1); @@ -694,7 +690,7 @@ private static INDArray toFlattenedViaIterator(char order, INDArray... toFlatten for (INDArray i : toFlatten) length += i.length(); - INDArray out = Nd4j.create(1, length); + INDArray out = Nd4j.create(length); int i = 0; for (INDArray arr : toFlatten) { NdIndexIterator iter = new NdIndexIterator(order, arr.shape()); @@ -805,8 +801,8 @@ public void testIsMaxAlongDimension() { INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0)); INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1)); - INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}); - INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}); + INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4); + INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4); assertEquals(expAlong0, alongDim0); assertEquals(expAlong1, alongDim1); @@ -818,8 +814,8 @@ public void testIsMaxAlongDimension() { INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0)); INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1)); - INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}); - INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}); + INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1); + INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1); @@ -936,7 +932,7 @@ public void testVStackDifferentOrders() { public void testVStackEdgeCase() { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray vstacked = Nd4j.vstack(arr); - assertEquals(arr, vstacked); + assertEquals(arr.reshape(1,4), vstacked); } @@ -1336,7 +1332,7 @@ public void testToFlattened() { concat.add(arr.dup()); } - INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); + INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, new int[]{1,12}); INDArray flattened = Nd4j.toFlattened(concat); assertEquals(assertion, flattened); @@ -1383,7 +1379,7 @@ public void testSortWithIndicesDescending() { public void testGetFromRowVector() { INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowGet = matrix.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 2)); - assertArrayEquals(new long[] {1, 2}, rowGet.shape()); + assertArrayEquals(new long[] {2}, rowGet.shape()); } @Test @@ -1556,7 +1552,7 @@ public void testScalar() { INDArray a = Nd4j.scalar(1.0f); assertEquals(true, a.isScalar()); - INDArray n = Nd4j.create(new float[] {1.0f}, new long[] {1, 1}); + INDArray n = Nd4j.create(new float[] {1.0f}, new long[0]); assertEquals(n, a); assertTrue(n.isScalar()); } @@ -1606,7 +1602,7 @@ public void testColumns() { INDArray arr = Nd4j.create(new long[] {3, 2}); INDArray column2 = arr.getColumn(0); //assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, column2.shape())); - INDArray column = Nd4j.create(new double[] {1, 2, 3}, new long[] {1, 3}); + INDArray column = Nd4j.create(new double[] {1, 2, 3}, new long[] {3}); arr.putColumn(0, column); INDArray firstColumn = arr.getColumn(0); @@ -1614,7 +1610,7 @@ public void testColumns() { assertEquals(column, firstColumn); - INDArray column1 = Nd4j.create(new double[] {4, 5, 6}, new long[] {1, 3}); + INDArray column1 = Nd4j.create(new double[] {4, 5, 6}, new long[] {3}); arr.putColumn(1, column1); //assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, arr.getColumn(1).shape())); INDArray testRow1 = arr.getColumn(1); @@ -1622,7 +1618,7 @@ public void testColumns() { INDArray evenArr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); - INDArray put = Nd4j.create(new double[] {5, 6}, new long[] {1, 2}); + INDArray put = Nd4j.create(new double[] {5, 6}, new long[] {2}); evenArr.putColumn(1, put); INDArray testColumn = evenArr.getColumn(1); assertEquals(put, testColumn); @@ -1630,12 +1626,12 @@ public void testColumns() { INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray column23 = n.getColumn(0); - INDArray column12 = Nd4j.create(new double[] {1, 3}, new long[] {1, 2}); + INDArray column12 = Nd4j.create(new double[] {1, 3}, new long[] {2}); assertEquals(column23, column12); INDArray column0 = n.getColumn(1); - INDArray column01 = Nd4j.create(new double[] {2, 4}, new long[] {1, 2}); + INDArray column01 = Nd4j.create(new double[] {2, 4}, new long[] {2}); assertEquals(column0, column01); @@ -1669,29 +1665,29 @@ public void testPutRow() { INDArray nLast = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray row = nLast.getRow(1); - INDArray row1 = Nd4j.create(new double[] {3, 4}, new long[] {1, 2}); + INDArray row1 = Nd4j.create(new double[] {3, 4}); assertEquals(row, row1); INDArray arr = Nd4j.create(new long[] {3, 2}); - INDArray evenRow = Nd4j.create(new double[] {1, 2}, new long[] {1, 2}); + INDArray evenRow = Nd4j.create(new double[] {1, 2}); arr.putRow(0, evenRow); INDArray firstRow = arr.getRow(0); - assertEquals(true, Shape.shapeEquals(new long[] {1, 2}, firstRow.shape())); + assertEquals(true, Shape.shapeEquals(new long[] {2}, firstRow.shape())); INDArray testRowEven = arr.getRow(0); assertEquals(evenRow, testRowEven); - INDArray row12 = Nd4j.create(new double[] {5, 6}, new long[] {1, 2}); + INDArray row12 = Nd4j.create(new double[] {5, 6}, new long[] {2}); arr.putRow(1, row12); - assertEquals(true, Shape.shapeEquals(new long[] {1, 2}, arr.getRow(0).shape())); + assertEquals(true, Shape.shapeEquals(new long[] {2}, arr.getRow(0).shape())); INDArray testRow1 = arr.getRow(1); assertEquals(row12, testRow1); INDArray multiSliceTest = Nd4j.create(Nd4j.linspace(1, 16, 16, DataType.DOUBLE).data(), new long[] {4, 2, 2}); - INDArray test = Nd4j.create(new double[] {5, 6}, new long[] {1, 2}); - INDArray test2 = Nd4j.create(new double[] {7, 8}, new long[] {1, 2}); + INDArray test = Nd4j.create(new double[] {5, 6}, new long[] {2}); + INDArray test2 = Nd4j.create(new double[] {7, 8}, new long[] {2}); INDArray multiSliceRow1 = multiSliceTest.slice(1).getRow(0); INDArray multiSliceRow2 = multiSliceTest.slice(1).getRow(1); @@ -1902,7 +1898,7 @@ public void testMmul() { INDArray innerProduct = n.mmul(transposed); - INDArray scalar = Nd4j.scalar(385.0); + INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); assertEquals(getFailureMessage(), scalar, innerProduct); INDArray outerProduct = transposed.mmul(n); @@ -1910,7 +1906,7 @@ public void testMmul() { - INDArray three = Nd4j.create(new double[] {3, 4}, new long[] {1, 2}); + INDArray three = Nd4j.create(new double[] {3, 4}); INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}); INDArray sliceRow = test.slice(0).getRow(1); assertEquals(getFailureMessage(), three, sliceRow); @@ -2123,7 +2119,7 @@ public void testTile() { @Test public void testNegativeOneReshape() { INDArray arr = Nd4j.create(new double[] {0, 1, 2}); - INDArray newShape = arr.reshape(-1, 3); + INDArray newShape = arr.reshape(-1); assertEquals(newShape, arr); } @@ -2155,12 +2151,12 @@ public void test2DArraySlice() { */ for (int i = 0; i < 7; i++) { INDArray slice = array2D.slice(i, 1); - assertTrue(Arrays.equals(slice.shape(), new long[] {5, 1})); + assertArrayEquals(slice.shape(), new long[] {5}); } for (int i = 0; i < 5; i++) { INDArray slice = array2D.slice(i, 0); - assertTrue(Arrays.equals(slice.shape(), new long[] {1, 7})); + assertArrayEquals(slice.shape(), new long[]{7}); } } @@ -2194,7 +2190,7 @@ public void testGetRow() { INDArray arr = Nd4j.ones(10, 4); for (int i = 0; i < 10; i++) { INDArray row = arr.getRow(i); - assertArrayEquals(row.shape(), new long[] {1, 4}); + assertArrayEquals(row.shape(), new long[] {4}); } } @@ -2933,7 +2929,7 @@ public void testElementWiseAdd() { public void testSquareMatrix() { INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2}); INDArray eightFirstTest = n.vectorAlongDimension(0, 2); - INDArray eightFirstAssertion = Nd4j.create(new double[] {1, 2}, new long[] {1, 2}); + INDArray eightFirstAssertion = Nd4j.create(new double[] {1, 2}); assertEquals(eightFirstAssertion, eightFirstTest); INDArray eightFirstTestSecond = n.vectorAlongDimension(1, 2); @@ -2952,7 +2948,7 @@ public void testNumVectorsAlongDimension() { @Test public void testBroadCast() { - INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, 4); + INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray broadCasted = n.broadcast(5, 4); for (int i = 0; i < broadCasted.rows(); i++) { INDArray row = broadCasted.getRow(i); @@ -2963,10 +2959,10 @@ public void testBroadCast() { assertEquals(broadCasted, broadCast2); - INDArray columnBroadcast = n.transpose().broadcast(4, 5); + INDArray columnBroadcast = n.reshape(4,1).broadcast(4, 5); for (int i = 0; i < columnBroadcast.columns(); i++) { INDArray column = columnBroadcast.getColumn(i); - assertEquals(column, n.transpose()); + assertEquals(column, n); } INDArray fourD = Nd4j.create(1, 2, 1, 1); @@ -4156,7 +4152,7 @@ public void testSpecialConcat1() { for (int x = 0; x < 10; x++) { assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1); - assertEquals(arrays.get(x), matrix.getRow(x)); + assertEquals(arrays.get(x), matrix.getRow(x).reshape(1,matrix.size(1))); } } } @@ -4175,7 +4171,7 @@ public void testSpecialConcat2() { for (int x = 0; x < 10; x++) { assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1); - assertEquals(arrays.get(x), matrix.getRow(x)); + assertEquals(arrays.get(x), matrix.getRow(x).reshape(1, matrix.size(1))); } } @@ -4383,7 +4379,7 @@ public void testNewBroadcastComparison1() { val initial = Nd4j.create(3, 5); val mask = Nd4j.create(new double[] {5, 4, 3, 2, 1}); val result = Nd4j.createUninitialized(DataType.BOOL, initial.shape()); - val exp = Nd4j.create(new boolean[] {true, true, true, false, false}).reshape(1, -1); + val exp = Nd4j.create(new boolean[] {true, true, true, false, false}); for (int i = 0; i < initial.columns(); i++) { initial.getColumn(i).assign(i); @@ -5268,7 +5264,7 @@ public void testNativeSort3_1() { @Test public void testNativeSortAlongDimension1() { INDArray array = Nd4j.create(1000, 1000); - INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1); + INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE); INDArray dps = exp1.dup(); Nd4j.shuffle(dps, 0); @@ -5293,7 +5289,7 @@ public void testNativeSortAlongDimension1() { @Test public void testNativeSortAlongDimension3() { INDArray array = Nd4j.create(2000, 2000); - INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE).reshape(1, -1); + INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE); INDArray dps = exp1.dup(); Nd4j.getExecutioner().commit(); @@ -5719,10 +5715,10 @@ public void testGemmStrides() { // Get i-th column vector final INDArray xi = X.get(NDArrayIndex.all(), NDArrayIndex.point(i)); // Build outer product - val trans = xi.transpose(); + val trans = xi; final INDArray outerProduct = xi.mmul(trans); // Build outer product from duplicated column vectors - final INDArray outerProductDuped = xi.dup().mmul(xi.transpose().dup()); + final INDArray outerProductDuped = xi.dup().mmul(xi.dup()); // Matrices should equal //final boolean eq = outerProduct.equalsWithEps(outerProductDuped, 1e-5); //assertTrue(eq); @@ -5934,7 +5930,7 @@ public void testVectorGemv() { val outN = matrix.mmul(vectorN); val outL = matrix.mmul(vectorL); - assertEquals(outL, outN); + assertEquals(outL, outN.reshape(3,1)); assertEquals(1, outN.rank()); } @@ -6806,11 +6802,12 @@ public void testBroadcastInvalid(){ public void testGet(){ //https://github.com/deeplearning4j/deeplearning4j/issues/6133 INDArray m = Nd4j.linspace(0,99,100, DataType.DOUBLE).reshape('c', 10,10); - INDArray exp = Nd4j.create(new double[]{5, 15, 25, 35, 45, 55, 65, 75, 85, 95}, new int[]{10,1}); + INDArray exp = Nd4j.create(new double[]{5, 15, 25, 35, 45, 55, 65, 75, 85, 95}, new int[]{10}); INDArray col = m.getColumn(5); for(int i=0; i<10; i++ ){ - System.out.println(i + "\t" + col.slice(i)); + col.slice(i); +// System.out.println(i + "\t" + col.slice(i)); } //First element: index 5 @@ -7387,7 +7384,7 @@ public void testRepeatStrided() { INDArray array = Nd4j.arange(25).reshape(5, 5); // Get first column (shape 5x1) - INDArray slice = array.get(NDArrayIndex.all(), NDArrayIndex.point(0)); + INDArray slice = array.get(NDArrayIndex.all(), NDArrayIndex.point(0)).reshape(5,1); // Repeat column on sliced array (shape 5x3) INDArray repeatedSlice = slice.repeat(1, (long) 3); @@ -7412,7 +7409,7 @@ public void testGetColumnRowVector(){ INDArray arr = Nd4j.create(1,4); INDArray col = arr.getColumn(0); System.out.println(Arrays.toString(col.shape())); - assertArrayEquals(new long[]{1,1}, col.shape()); + assertArrayEquals(new long[]{1}, col.shape()); } @@ -7527,6 +7524,39 @@ public void testReduceKeepDimsShape(){ assertArrayEquals(new long[]{1, 4}, out2.shape()); } + @Test + public void testSliceRow(){ + double[] data = new double[]{15.0, 16.0}; + INDArray vector = Nd4j.createFromArray(data).reshape(1,2); + INDArray slice = vector.slice(0); + System.out.println(slice.shapeInfoToString()); + assertEquals(vector, slice); + slice.assign(-1); + assertEquals(Nd4j.createFromArray(-1.0, -1.0).reshape(1,2), vector); + } + + @Test + public void testSliceMatrix(){ + INDArray arr = Nd4j.arange(4).reshape(2,2); + System.out.println(arr.slice(0)); + System.out.println(); + System.out.println(arr.slice(1)); + } + + @Test + public void testScalarEq(){ + INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1); + INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1); + INDArray scalarRank0 = Nd4j.scalar(10.0); + + assertNotEquals(scalarRank0, scalarRank2); + assertNotEquals(scalarRank0, scalarRank1); + assertNotEquals(scalarRank1, scalarRank2); + assertEquals(scalarRank0, scalarRank0.dup()); + assertEquals(scalarRank1, scalarRank1.dup()); + assertEquals(scalarRank2, scalarRank2.dup()); + } + /////////////////////////////////////////////////////// protected static void fillJvmArray3D(float[][][] arr) { int cnt = 1; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java index 2b9f5559e0f5..7dcd2285f751 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java @@ -98,7 +98,7 @@ public void testBlasValidation2() { /** * Testing level3 blas */ - @Test(expected = ND4JIllegalStateException.class) + @Test(expected = IllegalStateException.class) public void testBlasValidation3() { INDArray x = Nd4j.create(100, 100); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index 8a923c89cf28..8c325eeac357 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -142,7 +142,7 @@ public void testIndexGetDuplicate() { @Test public void testGetScalar() { - INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); + INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray d = arr.get(NDArrayIndex.point(1)); assertTrue(d.isScalar()); assertEquals(2.0, d.getDouble(0), 1e-1); @@ -161,10 +161,10 @@ public void testVectorIndexing() { INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); int[] index = new int[] {5, 8, 9}; INDArray columnsTest = x.getColumns(index); - assertEquals(Nd4j.create(new double[] {5, 8, 9}), columnsTest); + assertEquals(Nd4j.create(new double[] {5, 8, 9}, new int[]{1,3}), columnsTest); int[] index2 = new int[] {2, 2, 4}; //retrieve the same columns twice INDArray columnsTest2 = x.getColumns(index2); - assertEquals(Nd4j.create(new double[] {2, 2, 4}), columnsTest2); + assertEquals(Nd4j.create(new double[] {2, 2, 4}, new int[]{1,3}), columnsTest2); } @@ -214,6 +214,49 @@ public void testGetIndicesVector() { assertEquals(test, result); } + @Test + public void testGetIndicesVectorView() { + INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5); + INDArray column = matrix.getColumn(0).reshape(1,5); + INDArray test = Nd4j.create(new double[] {6, 11}); + INDArray result = null; //column.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3)); +// assertEquals(test, result); +// + INDArray column3 = matrix.getColumn(2).reshape(1,5); +// INDArray exp = Nd4j.create(new double[] {8, 13}); +// result = column3.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3)); +// assertEquals(exp, result); + + INDArray exp2 = Nd4j.create(new double[] {8, 18}); + result = column3.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)); + assertEquals(exp2, result); + } + + @Test + public void test2dGetPoint(){ + INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); + for( int i=0; i<3; i++ ){ + INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4}); + INDArray row = arr.getRow(i); + INDArray get = arr.get(NDArrayIndex.point(i), NDArrayIndex.all()); + + assertEquals(1, row.rank()); + assertEquals(1, get.rank()); + assertEquals(exp, row); + assertEquals(exp, get); + } + + for( int i=0; i<4; i++ ){ + INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i}); + INDArray col = arr.getColumn(i); + INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i)); + + assertEquals(1, col.rank()); + assertEquals(1, get.rank()); + assertEquals(exp, col); + assertEquals(exp, get); + } + } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 504158390c00..8bd5936bb024 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -163,7 +163,7 @@ public void testGetPointRowVector() { INDArray arr2 = arr.get(point(0), interval(0, 100)); assertEquals(100, arr2.length()); //Returning: length 0 - assertEquals(arr2, Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(1, -1)); + assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2); } @Test @@ -195,15 +195,15 @@ public void testPutRowIndexing() { @Test public void testVectorIndexing2() { - INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).get(interval(1, 2, 3, true)); - INDArray assertion = Nd4j.create(new double[] {2, 4}).reshape(1, -1); + INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true)); + INDArray assertion = Nd4j.create(new double[] {2, 4}); assertEquals(assertion, wholeVector); - INDArray wholeVectorTwo = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).get(interval(1, 2, 4, true)); + INDArray wholeVectorTwo = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 4, true)); assertEquals(assertion, wholeVectorTwo); - INDArray wholeVectorThree = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).get(interval(1, 2, 4, false)); + INDArray wholeVectorThree = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 4, false)); assertEquals(assertion, wholeVectorThree); - INDArray threeFiveAssertion = Nd4j.create(new double[] {3, 5}).reshape(1, -1); - INDArray threeFive = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).get(interval(2, 2, 4, true)); + INDArray threeFiveAssertion = Nd4j.create(new double[] {3, 5}); + INDArray threeFive = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(2, 2, 4, true)); assertEquals(threeFiveAssertion, threeFive); } @@ -235,7 +235,7 @@ public void testIndexFor() { @Test public void testGetScalar() { - INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); + INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray d = arr.get(point(1)); assertTrue(d.isScalar()); assertEquals(2.0, d.getDouble(0), 1e-1); @@ -264,7 +264,7 @@ public void testGetIndices2d() { INDArray secondRow = twoByTwo.getRow(1); INDArray firstAndSecondRow = twoByTwo.getRows(1, 2); INDArray firstRowViaIndexing = twoByTwo.get(interval(0, 1), NDArrayIndex.all()); - assertEquals(firstRow, firstRowViaIndexing); + assertEquals(firstRow.reshape(1,2), firstRowViaIndexing); INDArray secondRowViaIndexing = twoByTwo.get(point(1), NDArrayIndex.all()); assertEquals(secondRow, secondRowViaIndexing); @@ -272,7 +272,7 @@ public void testGetIndices2d() { assertEquals(firstAndSecondRow, firstAndSecondRowTest); INDArray individualElement = twoByTwo.get(interval(1, 2), interval(1, 2)); - assertEquals(Nd4j.create(new double[] {4}), individualElement); + assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/ShapeResolutionTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/ShapeResolutionTestsC.java index 087a60959a06..b5a635167ade 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/ShapeResolutionTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/ShapeResolutionTestsC.java @@ -55,12 +55,12 @@ public void testRowVectorShapeOneZeroOffset() { //row 0 resolution.exec(NDArrayIndex.point(0)); long[] oneIndexShape = ArrayUtil.copy(resolution.getShapes()); - assertArrayEquals(new long[] {1, 2}, oneIndexShape); + assertArrayEquals(new long[] {2}, oneIndexShape); long[] oneIndexOffsets = ArrayUtil.copy(resolution.getOffsets()); - assertArrayEquals(new long[] {0, 0}, oneIndexOffsets); + assertArrayEquals(new long[] {0}, oneIndexOffsets); assertEquals(0, resolution.getOffset()); long[] oneIndexStrides = ArrayUtil.copy(resolution.getStrides()); - assertArrayEquals(new long[] {1, 1}, oneIndexStrides); + assertArrayEquals(new long[] {1}, oneIndexStrides); } @@ -80,10 +80,10 @@ public void testRowVectorShapeOneOneOffset() { //row 0 resolution.exec(NDArrayIndex.point(1)); long[] oneIndexShape = ArrayUtil.copy(resolution.getShapes()); - assertArrayEquals(new long[] {1, 2}, oneIndexShape); + assertArrayEquals(new long[] {2}, oneIndexShape); assertEquals(2, resolution.getOffset()); long[] oneIndexStrides = ArrayUtil.copy(resolution.getStrides()); - assertArrayEquals(new long[] {1, 1}, oneIndexStrides); + assertArrayEquals(new long[] {1}, oneIndexStrides); } @@ -96,12 +96,12 @@ public void testRowVectorShapeTwoOneOffset() { //row 0 resolution.exec(NDArrayIndex.point(1), NDArrayIndex.all()); long[] oneIndexShape = ArrayUtil.copy(resolution.getShapes()); - assertArrayEquals(new long[] {1, 2}, oneIndexShape); + assertArrayEquals(new long[] {2}, oneIndexShape); long[] oneIndexOffsets = ArrayUtil.copy(resolution.getOffsets()); - assertArrayEquals(new long[] {0, 0}, oneIndexOffsets); + assertArrayEquals(new long[] {0}, oneIndexOffsets); assertEquals(2, resolution.getOffset()); long[] oneIndexStrides = ArrayUtil.copy(resolution.getStrides()); - assertArrayEquals(new long[] {1, 1}, oneIndexStrides); + assertArrayEquals(new long[] {1}, oneIndexStrides); } @@ -113,8 +113,8 @@ public void testColumnVectorShapeZeroOffset() { resolution.exec(NDArrayIndex.all(), NDArrayIndex.point(0)); assertEquals(0, resolution.getOffset()); long[] strides = resolution.getStrides(); - assertArrayEquals(new long[] {2, 1}, resolution.getShapes()); - assertArrayEquals(new long[] {2, 1}, strides); + assertArrayEquals(new long[] {2}, resolution.getShapes()); + assertArrayEquals(new long[] {2}, strides); } @Test @@ -124,17 +124,8 @@ public void testColumnVectorShapeOneOffset() { resolution.exec(NDArrayIndex.all(), NDArrayIndex.point(1)); assertEquals(1, resolution.getOffset()); long[] strides = resolution.getStrides(); - assertArrayEquals(new long[] {2, 1}, resolution.getShapes()); - assertArrayEquals(new long[] {2, 1}, strides); - } - - - @Test - public void testPartiallyOutOfRangeIndices() { - INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); - ShapeOffsetResolution resolution = new ShapeOffsetResolution(arr); - resolution.exec(NDArrayIndex.interval(0, 2), NDArrayIndex.interval(1, 4)); - assertArrayEquals(new long[] {2, 1}, resolution.getShapes()); + assertArrayEquals(new long[] {2}, resolution.getShapes()); + assertArrayEquals(new long[] {2}, strides); } @Test @@ -201,7 +192,7 @@ public void testFlatIndexPointInterval() { INDArray value = Nd4j.ones(1, 2).castTo(DataType.DOUBLE); zeros.put(new INDArrayIndex[] {x, y}, value); - INDArray assertion = Nd4j.create(new double[] {0.0, 1.0, 1.0, 0.0}); + INDArray assertion = Nd4j.create(new double[] {0.0, 1.0, 1.0, 0.0}, new int[]{1,4}); assertEquals(assertion, zeros); } @@ -213,7 +204,7 @@ public void testVectorIndexPointPoint() { INDArray value = Nd4j.ones(1, 1).castTo(DataType.DOUBLE); zeros.put(new INDArrayIndex[] {x, y}, value); - INDArray assertion = Nd4j.create(new double[] {0.0, 0.0, 1.0, 0.0}); + INDArray assertion = Nd4j.create(new double[] {0.0, 0.0, 1.0, 0.0}, new int[]{1,4}); assertEquals(assertion, zeros); } @@ -319,6 +310,24 @@ public void testVectorPointIndex(){ assertArrayEquals(new long[]{}, out.shape()); } + @Test + public void testPointIndex(){ + for( int i=0; i<3; i++ ) { + INDArray arr = Nd4j.linspace(DataType.DOUBLE, 1, 3, 1); + INDArray out = arr.get(NDArrayIndex.point(i)); + assertArrayEquals(new long[]{}, out.shape()); + INDArray exp = Nd4j.scalar((double)i+1); + assertEquals(exp, out); + assertTrue(out.isView()); + + INDArray exp2 = Nd4j.linspace(DataType.DOUBLE, 1, 3, 1); + exp2.putScalar(i, 10.0); + out.assign(10.0); + assertEquals(exp2, arr); + + } + } + @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java index 56e5b405381c..383dafce98b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java @@ -124,7 +124,7 @@ public void testNoOpCompression1() { @Test public void testJVMCompression3() { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); + INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}).reshape(1,-1); BasicNDArrayCompressor.getInstance().setDefaultCompression("NOOP"); @@ -164,7 +164,7 @@ public void testThresholdCompressionZ() { log.info("Compressed length: {}", compressed.data().length()); // log.info("Compressed: {}", Arrays.toString(compressed.data().asInt())); - INDArray decompressed = Nd4j.create(initial.length()); + INDArray decompressed = Nd4j.create(1, initial.length()); Nd4j.getExecutioner().thresholdDecode(compressed, decompressed); log.info("Decompressed length: {}", decompressed.lengthLong()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index a637b315135d..1ccf3111a7f7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -336,8 +336,8 @@ public void testScatterUpdate1() { int[] dims = new int[]{1}; int[] indices = new int[]{1, 3}; - val exp0 = Nd4j.create(1, 5).assign(0); - val exp1 = Nd4j.create(1, 5).assign(1); + val exp0 = Nd4j.create(5).assign(0); + val exp1 = Nd4j.create(5).assign(1); ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); Nd4j.getExecutioner().exec(op); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index 148770881b1c..f6098dd3da27 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -79,7 +79,7 @@ public void testViewIterator2(){ for( int i=0; i<10; i++ ){ assertTrue(iter.hasNext()); DataSet d = iter.next(); - INDArray exp = f.getRow(i); + INDArray exp = f.getRow(i, true); assertEquals(exp, d.getFeatures()); assertEquals(exp, d.getLabels()); } @@ -430,7 +430,7 @@ public void testCnnMerge() { assertEquals(first, fMerged.get(interval(0, nExamples1), all(), all(), all())); - assertEquals(second, fMerged.get(interval(nExamples1, nExamples1 + nExamples2, true), all(), all(), all())); + assertEquals(second, fMerged.get(interval(nExamples1, nExamples1 + nExamples2), all(), all(), all())); assertEquals(labels1, lMerged.get(interval(0, nExamples1), all())); assertEquals(labels2, lMerged.get(interval(nExamples1, nExamples1 + nExamples2), all())); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java index ff196dfb6d24..ebe4b3f8c432 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java @@ -57,9 +57,9 @@ public void testMerging2d() { INDArray[] in = new INDArray[nRows]; INDArray[] out = new INDArray[nRows]; for (int i = 0; i < nRows; i++) - in[i] = expIn.getRow(i).dup(); + in[i] = expIn.getRow(i, true).dup(); for (int i = 0; i < nRows; i++) - out[i] = expOut.getRow(i).dup(); + out[i] = expOut.getRow(i, true).dup(); List list = new ArrayList<>(nRows); for (int i = 0; i < nRows; i++) { @@ -100,10 +100,10 @@ public void testMerging2dMultipleInOut() { list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); i++; } else { - INDArray in0 = expIn0.getRow(i).dup(); - INDArray in1 = expIn1.getRow(i).dup(); - INDArray out0 = expOut0.getRow(i).dup(); - INDArray out1 = expOut1.getRow(i).dup(); + INDArray in0 = expIn0.getRow(i, true).dup(); + INDArray in1 = expIn1.getRow(i, true).dup(); + INDArray out0 = expOut0.getRow(i, true).dup(); + INDArray out1 = expOut1.getRow(i, true).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); } } @@ -150,12 +150,12 @@ public void testMerging2dMultipleInOut2() { list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2})); i++; } else { - INDArray in0 = expIn0.getRow(i).dup(); - INDArray in1 = expIn1.getRow(i).dup(); - INDArray in2 = expIn2.getRow(i).dup(); - INDArray out0 = expOut0.getRow(i).dup(); - INDArray out1 = expOut1.getRow(i).dup(); - INDArray out2 = expOut2.getRow(i).dup(); + INDArray in0 = expIn0.getRow(i, true).dup(); + INDArray in1 = expIn1.getRow(i, true).dup(); + INDArray in2 = expIn2.getRow(i, true).dup(); + INDArray out0 = expOut0.getRow(i, true).dup(); + INDArray out1 = expOut1.getRow(i, true).dup(); + INDArray out2 = expOut2.getRow(i, true).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2})); } } @@ -193,12 +193,12 @@ public void testMerging2dMultipleInOut3() { List list = new ArrayList<>(nRows); for (int i = 0; i < nRows; i++) { - INDArray in0 = expIn0.getRow(i).dup(); - INDArray in1 = expIn1.getRow(i).dup(); - INDArray in2 = expIn2.getRow(i).dup(); - INDArray out0 = expOut0.getRow(i).dup(); - INDArray out1 = expOut1.getRow(i).dup(); - INDArray out2 = expOut2.getRow(i).dup(); + INDArray in0 = expIn0.getRow(i, true).dup(); + INDArray in1 = expIn1.getRow(i, true).dup(); + INDArray in2 = expIn2.getRow(i, true).dup(); + INDArray out0 = expOut0.getRow(i, true).dup(); + INDArray out1 = expOut1.getRow(i, true).dup(); + INDArray out2 = expOut2.getRow(i, true).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1, in2}, new INDArray[] {out0, out1, out2})); } @@ -252,8 +252,8 @@ public void testMerging4dMultipleInOut() { NDArrayIndex.all()).dup(); INDArray in1 = expIn1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()).dup(); - INDArray out0 = expOut0.getRow(i).dup(); - INDArray out1 = expOut1.getRow(i).dup(); + INDArray out0 = expOut0.getRow(i, true).dup(); + INDArray out1 = expOut1.getRow(i, true).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); } } @@ -531,7 +531,7 @@ public void testSplit() { assertArrayEquals(new long[] {1, 10, 10}, m.getFeatures(1).shape()); assertArrayEquals(new long[] {1, 5, 10, 10}, m.getFeatures(2).shape()); - assertEquals(features[0].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getFeatures(0)); + assertEquals(features[0].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getFeatures(0)); assertEquals(features[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()), m.getFeatures(1)); assertEquals(features[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), @@ -545,17 +545,17 @@ public void testSplit() { assertArrayEquals(new long[] {1, 10, 10}, m.getLabels(1).shape()); assertArrayEquals(new long[] {1, 5, 10, 10}, m.getLabels(2).shape()); - assertEquals(labels[0].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getLabels(0)); + assertEquals(labels[0].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getLabels(0)); assertEquals(labels[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()), m.getLabels(1)); assertEquals(labels[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), m.getLabels(2)); assertNull(m.getFeaturesMaskArray(0)); - assertEquals(fMask[1].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getFeaturesMaskArray(1)); + assertEquals(fMask[1].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getFeaturesMaskArray(1)); assertNull(m.getLabelsMaskArray(0)); - assertEquals(lMask[1].get(NDArrayIndex.point(i), NDArrayIndex.all()), m.getLabelsMaskArray(1)); + assertEquals(lMask[1].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getLabelsMaskArray(1)); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java index c530f36ee6fa..1612c6efdda2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java @@ -58,8 +58,8 @@ public void testBruteForce3d() { int timeSteps = 15; int samples = 100; //multiplier for the features - INDArray featureScaleA = Nd4j.create(new double[] {1, -2, 3}).reshape(3, 1); - INDArray featureScaleB = Nd4j.create(new double[] {2, 2, 3}).reshape(3, 1); + INDArray featureScaleA = Nd4j.create(new double[] {1, -2, 3}); + INDArray featureScaleB = Nd4j.create(new double[] {2, 2, 3}); Construct3dDataSet caseA = new Construct3dDataSet(featureScaleA, timeSteps, samples, 1); Construct3dDataSet caseB = new Construct3dDataSet(featureScaleB, timeSteps, samples, 1); @@ -352,18 +352,18 @@ public Construct3dDataSet(INDArray featureScale, int timeSteps, int samples, int //calculating stats // The theoretical mean should be the mean of 1,..samples*timesteps float theoreticalMean = origin - 1 + (samples * timeSteps + 1) / 2.0f; - expectedMean = Nd4j.create(new double[] {theoreticalMean, theoreticalMean, theoreticalMean}).reshape(3, 1).castTo(featureScale.dataType()); + expectedMean = Nd4j.create(new double[] {theoreticalMean, theoreticalMean, theoreticalMean}).castTo(featureScale.dataType()); expectedMean.muliColumnVector(featureScale); float stdNaturalNums = (float) Math.sqrt((samples * samples * timeSteps * timeSteps - 1) / 12); - expectedStd = Nd4j.create(new double[] {stdNaturalNums, stdNaturalNums, stdNaturalNums}).reshape(3, 1).castTo(Nd4j.defaultFloatingPointType()); + expectedStd = Nd4j.create(new double[] {stdNaturalNums, stdNaturalNums, stdNaturalNums}).castTo(Nd4j.defaultFloatingPointType()); expectedStd.muliColumnVector(Transforms.abs(featureScale, true)); //preprocessors use the population std so divides by n not (n-1) - expectedStd = expectedStd.dup().muli(Math.sqrt(maxN)).divi(Math.sqrt(maxN)).transpose(); + expectedStd = expectedStd.dup().muli(Math.sqrt(maxN)).divi(Math.sqrt(maxN)); //min max assumes all scaling values are +ve - expectedMin = Nd4j.ones(Nd4j.defaultFloatingPointType(), 3, 1).muliColumnVector(featureScale); - expectedMax = Nd4j.ones(Nd4j.defaultFloatingPointType(),3, 1).muli(samples * timeSteps).muliColumnVector(featureScale); + expectedMin = Nd4j.ones(Nd4j.defaultFloatingPointType(), 3).muliColumnVector(featureScale); + expectedMax = Nd4j.ones(Nd4j.defaultFloatingPointType(),3).muli(samples * timeSteps).muliColumnVector(featureScale); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index 04445cb681e1..bd4bedf667cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -178,7 +178,7 @@ public void test2dAnd4() { public void testSliceAssign1() { INDArray array = Nd4j.zeros(4, 4); - INDArray patch = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f}).reshape(1, -1); + INDArray patch = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f}); INDArray slice = array.slice(1); int[] idx = new int[] {0, 1, 3}; @@ -186,11 +186,11 @@ public void testSliceAssign1() { INDArray subarray = slice.get(range); - System.out.println("Subarray: " + Arrays.toString(subarray.data().asFloat()) + " isView: " + subarray.isView()); + //System.out.println("Subarray: " + Arrays.toString(subarray.data().asFloat()) + " isView: " + subarray.isView()); slice.put(range, patch); - System.out.println("Array after being patched: " + Arrays.toString(array.data().asFloat())); + //System.out.println("Array after being patched: " + Arrays.toString(array.data().asFloat())); assertFalse(BooleanIndexing.and(array, Conditions.equals(0f))); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index dbfa4458ca00..6c4fb96defc9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -180,9 +180,9 @@ public void testEuclideanDistance() { @Test public void testScalarMaxOp() { INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); - INDArray postMax = Nd4j.ones(DataType.DOUBLE, 1, 6); + INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); - assertEquals(getFailureMessage(), scalarMax, postMax); + assertEquals(getFailureMessage(), postMax, scalarMax); } @Test @@ -400,7 +400,7 @@ public void testStridedLog() { INDArray slice = arr.slice(0); Log exp = new Log(slice); opExecutioner.exec(exp); - INDArray assertion = Nd4j.create(Nd4j.createBuffer(new double[] {0.0, 0.6931471824645996, 1.0986123085021973})).reshape(1, -1); + INDArray assertion = Nd4j.create(new double[] {0.0, 0.6931471824645996, 1.0986123085021973}); assertEquals(getFailureMessage(), assertion, slice); } @@ -415,7 +415,7 @@ public void testStridedExp() { expected[i] = (float) Math.exp(slice.getDouble(i)); Exp exp = new Exp(slice); opExecutioner.exec(exp); - assertEquals(getFailureMessage(), Nd4j.create(Nd4j.createBuffer(expected)).reshape(1, -1), slice); + assertEquals(getFailureMessage(), Nd4j.create(expected), slice); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index d995eaee661e..ed6305bcb5bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -79,7 +79,7 @@ public void tearDown() { @Test public void testCrossBackendEquality1() { - int[] shape = {1, 12}; + int[] shape = {12}; double mean = 0; double standardDeviation = 1.0; INDArray exp = Nd4j.create(new double[] {-0.832718168582558, 1.3312306172061867, -0.27101354040045766, 1.0368130323476494, -0.6257379511224601, 0.30653534119847814, 0.28250229228899343, -0.5464191486048424, 0.5182898732953277, 1.463107608378911, 0.5634855878214299, -1.4979616922031507}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java index ad31643de5ee..f11978756755 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java @@ -103,7 +103,6 @@ public void testSixteenSecondDim() { INDArray arr = baseArr.tensorAlongDimension(i, 2); assertEquals("Failed at index " + i, assertions[i], arr); } - } @@ -115,7 +114,7 @@ public void testVectorAlongDimension() { INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2); assertEquals(assertion, vectorDimensionTest); INDArray zeroOne = arr.vectorAlongDimension(0, 1); - assertEquals(zeroOne, Nd4j.create(new float[] {1, 5, 9})); + assertEquals(Nd4j.create(new float[] {1, 5, 9}), zeroOne); INDArray testColumn2Assertion = Nd4j.create(new float[] {13, 17, 21}); INDArray testColumn2 = arr.vectorAlongDimension(1, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index 71f2a99ff7df..06b593df9deb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -187,7 +187,7 @@ public void testOtherReshape() { INDArray slice = nd.slice(1, 0); - INDArray vector = slice.reshape(1, 3); + INDArray vector = slice; for (int i = 0; i < vector.length(); i++) { System.out.println(vector.getDouble(i)); } @@ -197,7 +197,7 @@ public void testOtherReshape() { @Test public void testVectorAlongDimension() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); - INDArray assertion = Nd4j.create(new double[] {3, 4}, new long[] {1, 2}); + INDArray assertion = Nd4j.create(new double[] {3, 4}); INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2); assertEquals(assertion, vectorDimensionTest); val vectorsAlongDimension1 = arr.vectorsAlongDimension(1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 34c4c6ca429e..9b751642322e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -64,13 +64,10 @@ public void testConcatVertically() { INDArray vstack = Nd4j.vstack(slice1, slice2); assertEquals(arr3, vstack); - INDArray col1 = arr2.getColumn(0); - INDArray col2 = arr2.getColumn(1); + INDArray col1 = arr2.getColumn(0).reshape(5, 1); + INDArray col2 = arr2.getColumn(1).reshape(5, 1); INDArray vstacked = Nd4j.vstack(col1, col2); assertEquals(Nd4j.create(10, 1), vstacked); - - - } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java index 2f9d3c3361c7..b67f684c7b55 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java @@ -216,7 +216,7 @@ public void concatGetBug() { assertEquals(first, fMerged.get(NDArrayIndex.interval(0, nExamples1), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all())); - INDArray get = fMerged.get(NDArrayIndex.interval(nExamples1, nExamples1 + nExamples2, true), NDArrayIndex.all(), + INDArray get = fMerged.get(NDArrayIndex.interval(nExamples1, nExamples1 + nExamples2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); assertEquals(second, get.dup()); //Passes assertEquals(second, get); //Fails @@ -228,8 +228,7 @@ public void testShape() { INDArray subarray = ndarray.get(NDArrayIndex.point(0), NDArrayIndex.all()); assertTrue(subarray.isRowVector()); val shape = subarray.shape(); - assertEquals(shape[0], 1); - assertEquals(shape[1], 2); + assertArrayEquals(new long[]{2}, shape); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index e7f8a449eef0..834ad5689fd4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -172,22 +172,22 @@ public void testRowVectorInterval() { } INDArray first10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10a.shape(), new int[] {1, 10}); + assertArrayEquals(first10a.shape(), new int[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10a.getDouble(i) == i); - INDArray first10b = row.get(NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10b.shape(), new int[] {1, 10}); + INDArray first10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); + assertArrayEquals(first10b.shape(), new int[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10b.getDouble(i) == i); INDArray last10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10a.shape(), new int[] {1, 10}); + assertArrayEquals(last10a.shape(), new int[] {10}); for (int i = 0; i < 10; i++) - assertTrue(last10a.getDouble(i) == 20 + i); + assertEquals(i+20, last10a.getDouble(i), 1e-6); - INDArray last10b = row.get(NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10b.shape(), new int[] {1, 10}); + INDArray last10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); + assertArrayEquals(last10b.shape(), new int[] {10}); for (int i = 0; i < 10; i++) assertTrue(last10b.getDouble(i) == 20 + i); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java index 7dbc5b08a808..0e696c884a1d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java @@ -128,8 +128,8 @@ public void testGetRow() { @Test public void testVectorIndexing() { INDArray zeros = Nd4j.create(1, 400000); - INDArray get = zeros.get(NDArrayIndex.interval(0, 300000)); - assertArrayEquals(new long[] {1, 300000}, get.shape()); + INDArray get = zeros.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 300000)); + assertArrayEquals(new long[] {300000}, get.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index b9eddd75ee57..39c4a7b8f215 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -1175,6 +1175,35 @@ public void testBadGenerationLeverageMigrateDetach(){ } } + @Test + public void testDtypeLeverage(){ + + for(DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (DataType arrayDType : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + Nd4j.setDefaultDataTypes(globalDtype, globalDtype); + + WorkspaceConfiguration configOuter = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) + .policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.NONE).build(); + WorkspaceConfiguration configInner = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) + .policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.NONE).build(); + + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configOuter, "ws")) { + INDArray arr = Nd4j.create(arrayDType, 3, 4); + try (MemoryWorkspace wsInner = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configOuter, "wsInner")) { + INDArray leveraged = arr.leverageTo("ws"); + assertTrue(leveraged.isAttached()); + assertEquals(arrayDType, leveraged.dataType()); + + INDArray detached = leveraged.detach(); + assertFalse(detached.isAttached()); + assertEquals(arrayDType, detached.dataType()); + } + } + } + } + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + } + @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 0df9ec2fbedf..94ebbcabfe44 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -219,7 +219,7 @@ public void testViewDetach_1() { (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS109"); INDArray row = Nd4j.linspace(1, 10, 10); - INDArray exp = Nd4j.create(1, 10).assign(2.0); + INDArray exp = Nd4j.create(10).assign(2.0); INDArray result = null; try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS109")) { INDArray matrix = Nd4j.create(10, 10); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 8cbf9a7ba022..d22f4a1e39fe 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1885,12 +1885,6 @@ else if (DataTypeUtil.getDtypeFromContext() == DataType.FLOAT || currentType == else if (DataTypeUtil.getDtypeFromContext() == DataType.HALF && currentType != DataType.INT) elementSize = 2; - if (currentType != DataTypeUtil.getDtypeFromContext() && currentType != DataType.HALF && currentType != DataType.INT - && currentType != DataType.LONG && !(DataTypeUtil.getDtypeFromContext() == DataType.DOUBLE)) { - log.warn("Loading a data stream with opType different from what is set globally. Expect precision loss"); - if (DataTypeUtil.getDtypeFromContext() == DataType.INT) - log.warn("Int to float/double widening UNSUPPORTED!!!"); - } pointerIndexerByCurrentType(currentType); if (currentType != DataType.COMPRESSED)