Skip to content

Commit

Permalink
More tests/implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Dec 21, 2017
1 parent 8b3e84f commit 707b803
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 26 deletions.
10 changes: 10 additions & 0 deletions deeplearning4j-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
</plugin>
</plugins>
</pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencyManagement>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
package org.deeplearning4j.samediff;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.samediff.testlayers.SameDiffDense;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

@Slf4j
public class SameDiffTest {

@Test
public void testSameDiffDenseBasic(){
public void testSameDiffDenseBasic() {

int nIn = 3;
int nOut = 4;
Expand All @@ -32,7 +38,7 @@ public void testSameDiffDenseBasic(){
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

Map<String,INDArray> pt1 = net.getLayer(0).paramTable();
Map<String, INDArray> pt1 = net.getLayer(0).paramTable();
assertNotNull(pt1);
assertEquals(2, pt1.size());
assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY));
Expand All @@ -43,8 +49,9 @@ public void testSameDiffDenseBasic(){
}

@Test
public void testSameDiffDenseForward(){
public void testSameDiffDenseForward() {

int minibatch = 5;
int nIn = 3;
int nOut = 4;

Expand All @@ -56,21 +63,143 @@ public void testSameDiffDenseForward(){
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

Map<String,INDArray> pt1 = net.paramTable();
Map<String, INDArray> pt1 = net.paramTable();
assertNotNull(pt1);

System.out.println(pt1);

// MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
// .list()
// .layer(new DenseLayer.Builder().activation(Activation.SIGMOID).nIn(nIn).nOut(nOut).build())
// .build();
//
// MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
// net2.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().activation(Activation.SIGMOID).nIn(nIn).nOut(nOut).build())
.build();

MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();

net.params().assign(net2.params());

INDArray in = Nd4j.rand(minibatch, nIn);
INDArray out = net.output(in);
INDArray outExp = net2.output(in);

assertEquals(outExp, out);
}

@Test
public void testShapeResolutionMinus1() {

int nIn = 3;
int nOut = 4;

int minibatch = 3;

// for(boolean useMinus1 : new boolean[]{false, true}) {
for (boolean useMinus1 : new boolean[]{true}) {
log.info("Starting: {}", (useMinus1 ? "minibatch -1" : "minibatch 3"));

int[] inShape;
if (useMinus1) {
inShape = new int[]{-1, nIn};
} else {
inShape = new int[]{minibatch, nIn};
}
int[] wShape = new int[]{nIn, nOut};
int[] bShape = new int[]{1, nOut};

SameDiff sd = SameDiff.create();
SDVariable layerInput = sd.var("in", inShape);
SDVariable weights = sd.var("W", wShape);
SDVariable bias = sd.var("b", bShape);

SDVariable mmul = sd.mmul("mmul", layerInput, weights);
SDVariable z = mmul.add("z", bias);
SDVariable out = sd.sigmoid("out", z);

INDArray in = Nd4j.rand(new int[]{minibatch, nIn});
INDArray w = Nd4j.rand(wShape);
INDArray b = Nd4j.rand(bShape);

Map<String, INDArray> m = new HashMap<>();
m.put("in", in);
m.put("W", w);
m.put("b", b);

sd.associateArrayWithVariable(in, sd.getVariable("in"));
sd.associateArrayWithVariable(w, sd.getVariable("W"));
sd.associateArrayWithVariable(b, sd.getVariable("b"));

// INDArray outArr = sd.execAndEndResult();

sd.addAsPlaceHolder("in");
sd.addAsPlaceHolder("W");
sd.addAsPlaceHolder("b");

sd.execWithPlaceHolder(m);

INDArray outArr = sd.getVariable("out").getArr();

assertArrayEquals(new int[]{minibatch, nOut}, outArr.shape());
}
}

@Test
public void debug() {

int nIn = 3;
int nOut = 4;

int minibatch = 3;

int[] inShape = new int[]{-1, nIn};
int[] wShape = new int[]{nIn, nOut};
int[] bShape = new int[]{1, nOut};

SameDiff sd = SameDiff.create();
SDVariable layerInput = sd.var("in", inShape);
SDVariable weights = sd.var("W", wShape);
SDVariable bias = sd.var("b", bShape);

assertArrayEquals(inShape, layerInput.getShape());
assertArrayEquals(wShape, weights.getShape());

SDVariable mmul = sd.mmul("mmul", layerInput, weights);
SDVariable z = mmul.add("z", bias);
SDVariable out = sd.sigmoid("out", z);

INDArray in = Nd4j.rand(new int[]{minibatch, nIn});
INDArray w = Nd4j.rand(wShape);
INDArray b = Nd4j.rand(bShape);

Map<String, INDArray> m = new HashMap<>();
m.put("in", in);
m.put("W", w);
m.put("b", b);

sd.associateArrayWithVariable(in, sd.getVariable("in"));
sd.associateArrayWithVariable(w, sd.getVariable("W"));
sd.associateArrayWithVariable(b, sd.getVariable("b"));

// INDArray outArr = sd.execAndEndResult();

sd.addAsPlaceHolder("in");
sd.addAsPlaceHolder("W");
sd.addAsPlaceHolder("b");

sd.execWithPlaceHolder(m);

INDArray outArr = sd.getVariable("out").getArr();

assertArrayEquals(new int[]{minibatch, nOut}, outArr.shape());
}

@Test
public void debug2() {
int[] inShape = new int[]{-1, 3};

SameDiff sd = SameDiff.create();
SDVariable layerInput = sd.var("in", inShape);

int[] actShape = layerInput.getShape(); //Getting: [1,3]
assertArrayEquals(inShape, actShape);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ public Map<String, int[]> paramShapes() {
}

@Override
public void defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable) {
public List<String> defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable) {
SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);

SDVariable mmul = sd.mmul("mmul", layerInput, weights);
SDVariable z = mmul.add("z", bias);
SDVariable out = sd.sigmoid("out", z);

return Collections.singletonList("out");
}

public static class Builder extends BaseSameDiffLayer.Builder<Builder> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected BaseSameDiffLayer(Builder builder){

public abstract Map<String,int[]> paramShapes();

public abstract void defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String,SDVariable> paramTable);
public abstract List<String> defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String,SDVariable> paramTable);

//==================================================================================================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class SameDiffLayer extends AbstractLayer<BaseSameDiffLayer> {

private static final String INPUT_KEY = "input";

protected SameDiff sameDiff;
protected String outputKey;
protected List<String> outputKeys;

protected INDArray params;
protected INDArray gradients;
Expand Down Expand Up @@ -56,13 +57,15 @@ public INDArray activate(boolean training) {
doInit();
}

SameDiff sd = sameDiff.getFunction(outputKey);
//Build map:
Map<String, INDArray> map = new HashMap<>(paramTable());
map.put(INPUT_KEY, input);
// Map<String, INDArray> map = new HashMap<>(paramTable());
// map.put(INPUT_KEY, input);

sameDiff.associateArrayWithVariable(input, sameDiff.getVariable(INPUT_KEY));

try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
return sd.execAndEndResult();
INDArray result = sameDiff.execAndEndResult();
return result;
}
}

Expand All @@ -76,16 +79,15 @@ public INDArray preOutput(boolean training) {
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
Gradient g = new DefaultGradient();

SameDiff sd = sameDiff.getFunction(outputKey);
INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
sd.execBackwards();
sameDiff.execBackwards();
for(String s : layerConf().paramKeys() ){
INDArray pg = sd.grad(s).getArr();
INDArray pg = sameDiff.grad(s).getArr();
g.gradientForVariable().put(s, pg);
}

dLdIn = sd.grad(INPUT_KEY).getArr();
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
}

return new Pair<>(g, dLdIn);
Expand Down Expand Up @@ -183,14 +185,24 @@ protected void doInit(){
Map<String,INDArray > p = paramTable();

int[] inputShape = input.shape().clone();
inputShape[0] = -1;
SDVariable inputVar = sameDiff.var(INPUT_KEY, inputShape); //TODO WHAT ABOUT VARIABLE SIZES?
// inputShape[0] = -1; //TODO THIS DOESN'T ENABLE VARIABLE SIZE MINIBATCHES
SDVariable inputVar = sameDiff.var(INPUT_KEY, inputShape);
Map<String,int[]> paramShapes = layerConf().paramShapes();
Map<String,SDVariable> params = new LinkedHashMap<>();
for(String s : layerConf().paramKeys()){
int[] ps = paramShapes.get(s);
params.put(s, sameDiff.var(s, ps));
SDVariable v = sameDiff.var(s, ps);
params.put(s, v);
}
List<String> outputKeys = layerConf().defineLayer(sameDiff, inputVar, params);
if(outputKeys == null || outputKeys.size() != 1){
throw new IllegalStateException("Invalid output keys: " + outputKeys);
}
layerConf().defineLayer(sameDiff, inputVar, params);

for(Map.Entry<String,INDArray> e : p.entrySet()){
sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey()));
}

this.outputKeys = outputKeys;
}
}

0 comments on commit 707b803

Please sign in to comment.