Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] QA, fixes, DL4J net convertDataType methods #7531

Merged
merged 53 commits into from Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
de3da82
Fix BaseNDArray.equalsWithEps issue for scalars of different ranks
AlexDBlack Apr 11, 2019
8856583
#7447 Fix slice on row vector
AlexDBlack Apr 11, 2019
b3a3935
#7483 Remove old deserialization warnings
AlexDBlack Apr 11, 2019
164b567
#6861 SameDiff datatype validation, round 1
AlexDBlack Apr 11, 2019
3a799d6
#6861 SameDiff datatype validation, round 2
AlexDBlack Apr 11, 2019
69cc986
#6861 SameDiff datatype validation, round 3
AlexDBlack Apr 11, 2019
658dcab
More rank 2 minimum shape fixes
AlexDBlack Apr 11, 2019
fc9f9b9
Multiple test fixes after changing rank2 minimum shapes
AlexDBlack Apr 11, 2019
7e4d5be
Test fixes
AlexDBlack Apr 11, 2019
6977ae1
#7520 add MultiLayerNetwork.convertDataType(DataType) + test
AlexDBlack Apr 11, 2019
39e8228
Datatype cleanup and fixes
AlexDBlack Apr 12, 2019
75c85b3
DL4J: Fixes for global (default) vs. network datatypes
AlexDBlack Apr 12, 2019
26ddccc
Fix incorrect datatype when arrays (different to default dtype) are d…
AlexDBlack Apr 12, 2019
8c25dd2
Multiple fixes, improve tests
AlexDBlack Apr 12, 2019
2f9968a
Test
AlexDBlack Apr 12, 2019
17c693d
#7532 New network datatype configuration
AlexDBlack Apr 13, 2019
8ddbb54
Pass network dtype to layer/vertex initialization
AlexDBlack Apr 13, 2019
1075605
Yolo datatype fixes
AlexDBlack Apr 13, 2019
6ae3f97
More fixes, more tests
AlexDBlack Apr 13, 2019
b495a2a
More fixes, more tests
AlexDBlack Apr 13, 2019
6e64e0a
Fix bug in PoolHelperVertex backprop
AlexDBlack Apr 13, 2019
b4c323a
Vertex dtype tests; misc fixes
AlexDBlack Apr 13, 2019
5548436
Fix for BaseReduce3Op dtype
AlexDBlack Apr 13, 2019
4bb0b6c
More fix; finally all layers/vertices/preprocessors tested for dtypes
AlexDBlack Apr 13, 2019
79d7270
Fix slices()
AlexDBlack Apr 13, 2019
2d2a559
Fixes - gradient check dtype issues
AlexDBlack Apr 15, 2019
b2e9552
Pass network dtype when constructing layers
AlexDBlack Apr 15, 2019
ddcc6bf
Pass network dtype when constructing vertices
AlexDBlack Apr 15, 2019
3cc46da
Layer dtype/casting fixes
AlexDBlack Apr 15, 2019
265046d
Various fixes
AlexDBlack Apr 15, 2019
69f079c
Fix Shape.elementWiseStride for 1d view case
AlexDBlack Apr 15, 2019
6a88228
#7092 INDArray.get(point,x)/get(x,point) returns 1d array
AlexDBlack Apr 15, 2019
479ab92
More 1d getRow/getCol fixes
AlexDBlack Apr 15, 2019
7bcef15
Indexing/sub-array fixes
AlexDBlack Apr 15, 2019
9d68d31
More test and indexing fixes
AlexDBlack Apr 15, 2019
4f648e5
More test fixes, add getRow(i,keepDim) and getColumn(i,keepDim)
AlexDBlack Apr 15, 2019
81c3b26
More indexing/test fixes
AlexDBlack Apr 16, 2019
1b8273e
More fixes
AlexDBlack Apr 16, 2019
08dac0b
More fixes
AlexDBlack Apr 16, 2019
b8c9399
More fixes
AlexDBlack Apr 16, 2019
f0be8b4
#7550 Evaluation dtype tests + fixes
AlexDBlack Apr 16, 2019
6b5380d
Nd4j.gemm result dtype fix
AlexDBlack Apr 16, 2019
da891cc
Next round of fixes
AlexDBlack Apr 16, 2019
22964ed
Even more dtype fixes...
AlexDBlack Apr 16, 2019
c38ca8d
Datavec and more DL4J fixes
AlexDBlack Apr 16, 2019
d8aee16
Next round of fixes
AlexDBlack Apr 16, 2019
fc2c378
DL4J cuDNN helpers - dtype improvements/fixes
AlexDBlack Apr 16, 2019
95ae9f0
Another round of fixes
AlexDBlack Apr 17, 2019
3c2f244
Datavec fixes
AlexDBlack Apr 17, 2019
275726d
DL4J Fixes
AlexDBlack Apr 17, 2019
58e9981
Keras/Spark/elementwisevertex fixes
AlexDBlack Apr 17, 2019
a156576
Final (hopefully) fixes
AlexDBlack Apr 17, 2019
ff7683d
Last set of fixes
AlexDBlack Apr 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Expand Up @@ -771,6 +771,8 @@ public IntegerRanges getRange(String name, String defaultValue) {
*/
public Collection<String> getStringCollection(String name) {
String valueString = get(name);
if(valueString == null)
return null;
return Arrays.asList(StringUtils.split(valueString, ","));
}

Expand Down
Expand Up @@ -158,7 +158,7 @@ else if (idx != null)
}

protected INDArray makeBOWNDArray(Collection<Integer> indices) {
INDArray counts = Nd4j.zeros(vocabulary.size());
INDArray counts = Nd4j.zeros(1, vocabulary.size());
for (Integer idx : indices)
counts.putScalar(idx, counts.getDouble(idx) + 1);
Nd4j.getExecutioner().commit();
Expand Down
Expand Up @@ -56,7 +56,7 @@ public StringListToIndicesNDArrayTransform(@JsonProperty("columnName") String co

@Override
protected INDArray makeBOWNDArray(Collection<Integer> indices) {
INDArray counts = Nd4j.zeros(indices.size());
INDArray counts = Nd4j.zeros(1, indices.size());
List<Integer> indicesSorted = new ArrayList<>(indices);
Collections.sort(indicesSorted);
for (int i = 0; i < indicesSorted.size(); i++)
Expand Down
Expand Up @@ -306,8 +306,8 @@ private static List<List<Writable>> getClassificationWritableMatrix(DataSet data
List<List<Writable>> writableMatrix = new ArrayList<>();

for (int i = 0; i < dataSet.numExamples(); i++) {
List<Writable> writables = toRecord(dataSet.getFeatures().getRow(i));
writables.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i), 1).getInt(0)));
List<Writable> writables = toRecord(dataSet.getFeatures().getRow(i, true));
writables.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i)).getInt(0)));

writableMatrix.add(writables);
}
Expand Down
Expand Up @@ -56,6 +56,7 @@
import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -1438,7 +1439,8 @@ public void testStringListToCountsNDArrayTransform() throws Exception {

List<Writable> out = t.map(l);

assertEquals(Collections.singletonList(new NDArrayWritable(Nd4j.create(new double[]{2,3,0}, new long[]{1,3}, Nd4j.dataType()))), out);
INDArray exp = Nd4j.create(new double[]{2,3,0}, new long[]{1,3}, Nd4j.dataType());
assertEquals(Collections.singletonList(new NDArrayWritable(exp)), out);

String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
Expand Down
Expand Up @@ -180,11 +180,13 @@ public INDArray asRowVector(Frame image) throws IOException {
}

public INDArray asRowVector(Mat image) throws IOException {
return asMatrix(image).ravel();
INDArray arr = asMatrix(image);
return arr.reshape('c', 1, arr.length());
}

public INDArray asRowVector(org.opencv.core.Mat image) throws IOException {
return asMatrix(image).ravel();
INDArray arr = asMatrix(image);
return arr.reshape('c', 1, arr.length());
}

static Mat convert(PIX pix) {
Expand Down
Expand Up @@ -173,7 +173,7 @@ record = rr.nextRecord();
assertEquals(42, transform.getCurrentImage().getHeight());
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}

ImageTransform transform2 = new ResizeImageTransform(1024, 2048);
Expand All @@ -186,7 +186,7 @@ record = rr.nextRecord();
assertEquals(2048, transform2.getCurrentImage().getHeight());
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}

//Make sure image flip does not break labels and are correct for new image size dimensions:
Expand All @@ -201,7 +201,7 @@ record = rr.nextRecord();
List<Writable> next = rrTransform3.next();
INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}

//Test that doing a downscale with the native image loader directly instead of a transform does not cause an exception:
Expand All @@ -217,7 +217,7 @@ record = rr.nextRecord();

INDArray labelArray = ((NDArrayWritable)next.get(1)).get();
BooleanIndexing.replaceWhere(labelArray, 1, Conditions.notEquals(0));
assertEquals(nonzeroCount[i++], labelArray.ravel().sum(1).getInt(0));
assertEquals(nonzeroCount[i++], labelArray.sum().getInt(0));
}
}

Expand Down
Expand Up @@ -19,6 +19,7 @@
import org.apache.spark.api.java.function.Function;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
Expand Down Expand Up @@ -53,13 +54,13 @@ public INDArray call(List<Writable> c) throws Exception {
}
}

INDArray arr = Nd4j.zeros(length);
INDArray arr = Nd4j.zeros(DataType.FLOAT, 1, length);
int idx = 0;
for (Writable w : c) {
if (w instanceof NDArrayWritable) {
INDArray subArr = ((NDArrayWritable) w).get();
int subLength = subArr.columns();
arr.get(NDArrayIndex.interval(idx, idx + subLength)).assign(subArr);
arr.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, idx + subLength)).assign(subArr);
idx += subLength;
} else {
arr.putScalar(idx++, w.toDouble());
Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.python.PythonTransform;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
Expand Down
Expand Up @@ -56,13 +56,13 @@ public void testAnalysis() throws Exception {

List<List<Writable>> data = new ArrayList<>();
data.add(Arrays.asList((Writable) new IntWritable(0), new DoubleWritable(1.0), new LongWritable(1000),
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 100.0))));
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 100.0))));
data.add(Arrays.asList((Writable) new IntWritable(5), new DoubleWritable(0.0), new LongWritable(2000),
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 200.0))));
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 200.0))));
data.add(Arrays.asList((Writable) new IntWritable(3), new DoubleWritable(10.0), new LongWritable(3000),
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(10, 300.0))));
new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 300.0))));
data.add(Arrays.asList((Writable) new IntWritable(-1), new DoubleWritable(-1.0), new LongWritable(20000),
new Text("B"), new NDArrayWritable(Nd4j.valueArrayOf(10, 400.0))));
new Text("B"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 400.0))));

JavaRDD<List<Writable>> rdd = sc.parallelize(data);

Expand Down
Expand Up @@ -1204,8 +1204,8 @@ public void testRecordReaderDataSetIteratorConcat() {
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3);

DataSet ds = iter.next();
INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
INDArray expL = Nd4j.create(new float[] {0, 1, 0});
INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9});
INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3});

assertEquals(expF, ds.getFeatures());
assertEquals(expL, ds.getLabels());
Expand All @@ -1222,7 +1222,7 @@ public void testRecordReaderDataSetIteratorConcat2() {
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1);

DataSet ds = iter.next();
INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,10});

assertEquals(expF, ds.getFeatures());
}
Expand Down
Expand Up @@ -212,9 +212,9 @@ public void testSplittingCSV() throws Exception {
assertNotNull(lmds[i]);

//Get the subsets of the original iris data
INDArray expIn1 = fds.get(all(), point(0));
INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(1, 2, true));
INDArray expOut1 = fds.get(all(), point(3));
INDArray expIn1 = fds.get(all(), interval(0,0,true));
INDArray expIn2 = fds.get(all(), interval(1, 2, true));
INDArray expOut1 = fds.get(all(), interval(3,3,true));
INDArray expOut2 = lds;

assertEquals(expIn1, fmds[0]);
Expand Down Expand Up @@ -693,14 +693,14 @@ public void testTimeSeriesRandomOffset() {
INDArray f = mds.getFeatures(0);
INDArray l = mds.getLabels(0);

INDArray expF1 = Nd4j.create(new double[] {1.0});
INDArray expL1 = Nd4j.create(new double[] {2.0});
INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1});
INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1});

INDArray expF2 = Nd4j.create(new double[] {10, 20, 30});
INDArray expL2 = Nd4j.create(new double[] {11, 21, 31});
INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3});
INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3});

INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500});
INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501});
INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5});
INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5});

assertEquals(expF1, f.get(point(0), all(),
NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
Expand Down
Expand Up @@ -44,7 +44,7 @@ public void testDSI(){
assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());

assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
assertEquals(Nd4j.ones(3,1), ds.getLabels().sum(1));
assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
}
assertEquals(5, count);
}
Expand Down
Expand Up @@ -168,7 +168,7 @@ public void testTimeTermination() {
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(10000))
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(10, TimeUnit.SECONDS),
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(5, TimeUnit.SECONDS),
new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
.modelSaver(saver).build();
Expand All @@ -184,7 +184,7 @@ public void testTimeTermination() {

assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
result.getTerminationReason());
String expDetails = new MaxTimeIterationTerminationCondition(10, TimeUnit.SECONDS).toString();
String expDetails = new MaxTimeIterationTerminationCondition(5, TimeUnit.SECONDS).toString();
assertEquals(expDetails, result.getTerminationDetails());
}

Expand Down
Expand Up @@ -73,7 +73,7 @@ public void testSerde() {
evalLabel.putScalar(i, i % 3, 1.0);
}
INDArray evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(1));
evalProb.diviColumnVector(evalProb.sum(true,1));
evaluation.eval(evalLabel, evalProb);
roc3.eval(evalLabel, evalProb);
ec.eval(evalLabel, evalProb);
Expand Down
Expand Up @@ -14,6 +14,7 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
Expand Down Expand Up @@ -70,6 +71,7 @@ public void testSelfAttentionLayer() {


MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down Expand Up @@ -134,6 +136,7 @@ public void testLearnedSelfAttentionLayer() {


MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down Expand Up @@ -167,6 +170,7 @@ public void testRecurrentAttentionLayer_differingTimeSteps(){
int layerSize = 8;

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down Expand Up @@ -233,6 +237,7 @@ public void testRecurrentAttentionLayer() {


MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down Expand Up @@ -294,6 +299,7 @@ public void testAttentionVertex() {


ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down Expand Up @@ -360,6 +366,7 @@ public void testAttentionVertexSameInput() {


ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
Expand Down
Expand Up @@ -80,6 +80,7 @@ public void testGradient2dSimple() {

MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3)
Expand Down Expand Up @@ -125,6 +126,7 @@ public void testGradientCnnSimple() {

for(boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
Expand Down Expand Up @@ -201,6 +203,7 @@ public void testGradientBNWithCNNandSubsamplingcCnfigurableProfiler() {
Activation outputActivation = outputActivations[i];

MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
Expand Down Expand Up @@ -310,6 +313,7 @@ public void testGradientBNWithCNNandSubsampling() {
Activation outputActivation = outputActivations[i];

MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
Expand Down Expand Up @@ -419,6 +423,7 @@ public void testGradientDense() {

MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.updater(new NoOp())
Expand Down Expand Up @@ -495,6 +500,7 @@ public void testGradient2dFixedGammaBeta() {

for(boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build())
Expand Down Expand Up @@ -540,6 +546,7 @@ public void testGradientCnnFixedGammaBeta() {

for(boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
Expand Down Expand Up @@ -584,6 +591,7 @@ public void testBatchNormCompGraphSimple() {
for(boolean useLogStd : new boolean[]{true, false}) {

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp())
.dataType(DataType.DOUBLE)
.weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
.setInputTypes(InputType.convolutional(height, width, channels))
.addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in")
Expand Down Expand Up @@ -655,6 +663,7 @@ public void testGradientBNWithCNNandSubsamplingCompGraph() {
Activation outputActivation = outputActivations[i];

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder()
Expand Down