Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CuDNN improvements #4864

Merged
merged 8 commits into from Mar 29, 2018

Large diffs are not rendered by default.

Expand Up @@ -132,7 +132,6 @@ public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilo
//Epsilons in shape: [miniBatch, depth, outH, outW]
//Epsilons out shape: [miniBatch, depth, inH, inW]


int poolingMode;
switch (poolingType) {
case AVG:
Expand All @@ -145,9 +144,9 @@ public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilo
return null;
}

if (!Shape.strideDescendingCAscendingF(epsilon)) {
if (!Shape.strideDescendingCAscendingF(epsilon) || epsilon.isView()) {
// apparently not supported by cuDNN
epsilon = epsilon.dup();
epsilon = epsilon.dup('c');
}

int[] srcStride = input.stride();
Expand Down
@@ -0,0 +1,25 @@
package org.deeplearning4j;

import org.junit.After;
import org.junit.Before;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;

public class BaseDL4JTest {

public OpExecutioner.ProfilingMode getProfilingMode(){
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}

@Before
public void beforeTest(){
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
}

@After
public void afterTest(){
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
}

}
102 changes: 102 additions & 0 deletions deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java
@@ -0,0 +1,102 @@
package org.deeplearning4j;

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Random;

import static org.junit.Assert.assertEquals;

public class TestUtils {

public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){

try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());

return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
}

public static ComputationGraph testModelSerialization(ComputationGraph net){

try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);

assertEquals(net.getConfiguration(), restored.getConfiguration());
assertEquals(net.params(), restored.params());

return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
}

public static INDArray randomOneHot(int examples, int nOut){
return randomOneHot(examples, nOut, new Random(12345));
}

public static INDArray randomOneHot(int examples, int nOut, long rngSeed){
return randomOneHot(examples, nOut, new Random(rngSeed));
}

public static INDArray randomOneHot(int examples, int nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut);
for( int i=0; i<examples; i++ ){
arr.putScalar(i, rng.nextInt(nOut), 1.0);
}
return arr;
}

public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength){
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random());
}

public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, long rngSeed){
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
}

public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
}
}
return out;
}

public static INDArray randomBernoulli(int... shape) {
return randomBernoulli(0.5, shape);
}

public static INDArray randomBernoulli(double p, int... shape){
INDArray ret = Nd4j.createUninitialized(shape);
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
return ret;
}
}