-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1cce7b2
commit 8b3e84f
Showing
5 changed files
with
321 additions
and
5 deletions.
There are no files selected for viewing
76 changes: 76 additions & 0 deletions
76
deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/SameDiffTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package org.deeplearning4j.samediff; | ||
|
||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.params.DefaultParamInitializer; | ||
import org.deeplearning4j.samediff.testlayers.SameDiffDense; | ||
import org.junit.Test; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertArrayEquals; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotNull; | ||
|
||
public class SameDiffTest { | ||
|
||
@Test | ||
public void testSameDiffDenseBasic(){ | ||
|
||
int nIn = 3; | ||
int nOut = 4; | ||
|
||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||
.list() | ||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build()) | ||
.build(); | ||
|
||
MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||
net.init(); | ||
|
||
Map<String,INDArray> pt1 = net.getLayer(0).paramTable(); | ||
assertNotNull(pt1); | ||
assertEquals(2, pt1.size()); | ||
assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY)); | ||
assertNotNull(pt1.get(DefaultParamInitializer.BIAS_KEY)); | ||
|
||
assertArrayEquals(new int[]{nIn, nOut}, pt1.get(DefaultParamInitializer.WEIGHT_KEY).shape()); | ||
assertArrayEquals(new int[]{1, nOut}, pt1.get(DefaultParamInitializer.BIAS_KEY).shape()); | ||
} | ||
|
||
@Test | ||
public void testSameDiffDenseForward(){ | ||
|
||
int nIn = 3; | ||
int nOut = 4; | ||
|
||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||
.list() | ||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build()) | ||
.build(); | ||
|
||
MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||
net.init(); | ||
|
||
Map<String,INDArray> pt1 = net.paramTable(); | ||
assertNotNull(pt1); | ||
|
||
System.out.println(pt1); | ||
|
||
// MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() | ||
// .list() | ||
// .layer(new DenseLayer.Builder().activation(Activation.SIGMOID).nIn(nIn).nOut(nOut).build()) | ||
// .build(); | ||
// | ||
// MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); | ||
// net2.init(); | ||
|
||
|
||
|
||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/testlayers/SameDiffDense.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package org.deeplearning4j.samediff.testlayers; | ||
|
||
import org.deeplearning4j.nn.conf.InputPreProcessor; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.deeplearning4j.nn.conf.layers.Layer; | ||
import org.deeplearning4j.nn.conf.layers.samediff.BaseSameDiffLayer; | ||
import org.deeplearning4j.nn.params.DefaultParamInitializer; | ||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.linalg.activations.IActivation; | ||
|
||
import java.util.*; | ||
|
||
public class SameDiffDense extends BaseSameDiffLayer { | ||
|
||
private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY); | ||
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY); | ||
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY); | ||
|
||
private final Map<String,int[]> paramShapes; | ||
|
||
private int nIn; | ||
private int nOut; | ||
|
||
protected SameDiffDense(Builder builder) { | ||
super(builder); | ||
|
||
nIn = builder.nIn; | ||
nOut = builder.nOut; | ||
|
||
paramShapes = new HashMap<>(); | ||
paramShapes.put(DefaultParamInitializer.WEIGHT_KEY, new int[]{nIn, nOut}); | ||
paramShapes.put(DefaultParamInitializer.BIAS_KEY, new int[]{1, nOut}); | ||
} | ||
|
||
@Override | ||
public InputType getOutputType(int layerIndex, InputType inputType) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void setNIn(InputType inputType, boolean override) { | ||
if(override){ | ||
this.nIn = ((InputType.InputTypeFeedForward)inputType).getSize(); | ||
} | ||
} | ||
|
||
@Override | ||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<String> weightKeys() { | ||
return W_KEYS; | ||
} | ||
|
||
@Override | ||
public List<String> biasKeys() { | ||
return B_KEYS; | ||
} | ||
|
||
@Override | ||
public Map<String, int[]> paramShapes() { | ||
return paramShapes; | ||
} | ||
|
||
@Override | ||
public void defineLayer(SameDiff sd, SDVariable layerInput, Map<String, SDVariable> paramTable) { | ||
SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); | ||
SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY); | ||
|
||
SDVariable mmul = sd.mmul("mmul", layerInput, weights); | ||
SDVariable z = mmul.add("z", bias); | ||
SDVariable out = sd.sigmoid("out", z); | ||
} | ||
|
||
public static class Builder extends BaseSameDiffLayer.Builder<Builder> { | ||
|
||
private int nIn; | ||
private int nOut; | ||
|
||
public Builder nIn(int nIn){ | ||
this.nIn = nIn; | ||
return this; | ||
} | ||
|
||
public Builder nOut(int nOut){ | ||
this.nOut = nOut; | ||
return this; | ||
} | ||
|
||
@Override | ||
public SameDiffDense build() { | ||
return new SameDiffDense(this); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.