Skip to content

Commit

Permalink
More tests and implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Dec 21, 2017
1 parent 1cce7b2 commit 8b3e84f
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.deeplearning4j.samediff;

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.samediff.testlayers.SameDiffDense;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.Map;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class SameDiffTest {

@Test
public void testSameDiffDenseBasic(){

int nIn = 3;
int nOut = 4;

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

Map<String,INDArray> pt1 = net.getLayer(0).paramTable();
assertNotNull(pt1);
assertEquals(2, pt1.size());
assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY));
assertNotNull(pt1.get(DefaultParamInitializer.BIAS_KEY));

assertArrayEquals(new int[]{nIn, nOut}, pt1.get(DefaultParamInitializer.WEIGHT_KEY).shape());
assertArrayEquals(new int[]{1, nOut}, pt1.get(DefaultParamInitializer.BIAS_KEY).shape());
}

@Test
public void testSameDiffDenseForward(){

int nIn = 3;
int nOut = 4;

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

Map<String,INDArray> pt1 = net.paramTable();
assertNotNull(pt1);

System.out.println(pt1);

// MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
// .list()
// .layer(new DenseLayer.Builder().activation(Activation.SIGMOID).nIn(nIn).nOut(nOut).build())
// .build();
//
// MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
// net2.init();



}

}
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
package org.deeplearning4j.samediff;

import org.junit.Test;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;

public class SameDiffTest1 {

Expand Down Expand Up @@ -122,12 +128,13 @@ public void test3() {

@Test
public void test4() {
Nd4j.getRandom().setSeed(12345);

SameDiff sd = SameDiff.create();

INDArray iInput = Nd4j.rand(3,4);
INDArray iWeights = Nd4j.rand(4,5);
INDArray iBias = Nd4j.rand(1,5);
INDArray iBias = Nd4j.zeros(1, 5); //Nd4j.rand(1,5);

SDVariable input = sd.var("input", iInput);
SDVariable weights = sd.var("weights", iWeights);
Expand All @@ -138,13 +145,83 @@ public void test4() {
SDVariable out = sd.sigmoid("out", z);


INDArray outArr = out.eval();
// INDArray outArr = out.eval();
Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> m = sd.exec();

for(Map.Entry<SDVariable, DifferentialFunction> e : m.getFirst().entrySet()){
System.out.println(e.getKey().getVarName());
System.out.println(e.getKey().getArr());
}

System.out.println("------------\nAll variable values");

List<SDVariable> variables = sd.variables();
for(SDVariable s : variables){
System.out.println(s.getVarName());
System.out.println(s.getArr());
}

System.out.println("------------");

INDArray exp = iInput.mmul(iWeights).addiRowVector(iBias);

System.out.println(outArr);
System.out.println(Arrays.toString(outArr.dup().data().asFloat()));
System.out.println("Input:");
System.out.println(iInput);
System.out.println("Weights:");
System.out.println(iWeights);
System.out.println("Bias:");
System.out.println(iBias);

System.out.println("------------");

System.out.println("Expected:");
System.out.println(exp);
System.out.println("Actual:");
// System.out.println(outArr);
// System.out.println(Arrays.toString(outArr.dup().data().asFloat()));
}


@Test
public void test5() {
Nd4j.getRandom().setSeed(12345);

SameDiff sd = SameDiff.create();

INDArray iInput = Nd4j.rand(3,4);
INDArray iWeights = Nd4j.rand(4,5);
INDArray iBias = Nd4j.rand(1,5);

SDVariable input = sd.var("input", iInput);
SDVariable weights = sd.var("weights", iWeights);
SDVariable bias = sd.var("bias", iBias);

SDVariable mmul = sd.mmul("mmul", input, weights);
SDVariable z = mmul.add("z", bias);
SDVariable out = sd.sigmoid("out", z);

System.out.println("------------\nAll variable values");

sd.exec();

List<SDVariable> variables = sd.variables();
for(SDVariable s : variables){
System.out.println(s.getVarName());
System.out.println(s.getArr());
System.out.println("Data buffer: " + Arrays.toString(s.getArr().data().asFloat()));
}

System.out.println("------------");

List<String> varNames = variables.stream().map(SDVariable::getVarName).collect(Collectors.toList());
System.out.println("VarNames: " + varNames); //"z" and "out" appear twice

INDArray expMmul = iInput.mmul(iWeights);
INDArray expZ = expMmul.addRowVector(iBias);
INDArray expOut = Transforms.sigmoid(expZ, true);

assertEquals(expMmul, mmul.getArr());
assertEquals(expZ, z.getArr());
assertEquals(expOut, out.getArr());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.deeplearning4j.samediff.testlayers;

import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.BaseSameDiffLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.IActivation;

import java.util.*;

public class SameDiffDense extends BaseSameDiffLayer {

private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY);
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY);
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY);

private final Map<String,int[]> paramShapes;

private int nIn;
private int nOut;

protected SameDiffDense(Builder builder) {
super(builder);

nIn = builder.nIn;
nOut = builder.nOut;

paramShapes = new HashMap<>();
paramShapes.put(DefaultParamInitializer.WEIGHT_KEY, new int[]{nIn, nOut});
paramShapes.put(DefaultParamInitializer.BIAS_KEY, new int[]{1, nOut});
}

@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return null;
}

@Override
public void setNIn(InputType inputType, boolean override) {
if(override){
this.nIn = ((InputType.InputTypeFeedForward)inputType).getSize();
}
}

@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return null;
}

@Override
public List<String> weightKeys() {
return W_KEYS;
}

@Override
public List<String> biasKeys() {
return B_KEYS;
}

@Override
public Map<String, int[]> paramShapes() {
return paramShapes;
}

@Override
public void defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable) {
SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);

SDVariable mmul = sd.mmul("mmul", layerInput, weights);
SDVariable z = mmul.add("z", bias);
SDVariable out = sd.sigmoid("out", z);
}

public static class Builder extends BaseSameDiffLayer.Builder<Builder> {

private int nIn;
private int nOut;

public Builder nIn(int nIn){
this.nIn = nIn;
return this;
}

public Builder nOut(int nOut){
this.nOut = nOut;
return this;
}

@Override
public SameDiffDense build() {
return new SameDiffDense(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public class SameDiffLayer extends AbstractLayer<BaseSameDiffLayer> {
protected SameDiff sameDiff;
protected String outputKey;

protected INDArray params;
protected INDArray gradients;
protected Map<String,INDArray> paramTable;


public SameDiffLayer(NeuralNetConfiguration conf){
super(conf);
Expand Down Expand Up @@ -115,6 +119,65 @@ public double calcL1(boolean backpropParamsOnly) {
return l1Sum;
}

/**Returns the parameters of the neural network as a flattened row vector
* @return the parameters of the neural network
*/
@Override
public INDArray params() {
return params;
}

@Override
public INDArray getParam(String param) {
throw new UnsupportedOperationException("Not supported");
}

@Override
public void setParam(String key, INDArray val) {
throw new UnsupportedOperationException("Not supported");
}

@Override
public void setParams(INDArray params) {
if (params != null) {
throw new UnsupportedOperationException("Not supported");
}
}

protected void setParams(INDArray params, char order) {
throw new UnsupportedOperationException("Not supported");
}

@Override
public void setParamsViewArray(INDArray params) {
this.params = params;
}

@Override
public INDArray getGradientsViewArray() {
return params;
}

@Override
public void setBackpropGradientsViewArray(INDArray gradients) {
this.gradients = gradients;
}

@Override
public void setParamTable(Map<String, INDArray> paramTable) {
this.paramTable = paramTable;
}

@Override
public Map<String, INDArray> paramTable() {
return paramTable(false);
}

@Override
public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
return paramTable;
}

protected void doInit(){
sameDiff = SameDiff.create();
Map<String,INDArray > p = paramTable();
Expand Down
Loading

0 comments on commit 8b3e84f

Please sign in to comment.