Skip to content

Commit

Permalink
#6520 Fix setLearningRate(double) for the no updater state (SGD, etc)…
Browse files Browse the repository at this point in the history
… case
  • Loading branch information
AlexDBlack committed Oct 8, 2018
1 parent 8ac3647 commit 2d7dade
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 8 deletions.
Expand Up @@ -33,6 +33,7 @@
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
Expand Down Expand Up @@ -140,6 +141,35 @@ public void testChangeLrMLN(){
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
}

@Test
public void testChangeLSGD() {
//Simple test for no updater nets
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Sgd(0.1))
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setLearningRate(1.0);
net.setLearningRate(1, 0.5);
assertEquals(1.0, net.getLearningRate(0), 0.0);
assertEquals(0.5, net.getLearningRate(1), 0.0);


ComputationGraph cg = net.toComputationGraph();
cg.setLearningRate(2.0);
cg.setLearningRate("1", 2.5);
assertEquals(2.0, cg.getLearningRate("0"), 0.0);
assertEquals(2.5, cg.getLearningRate("1"), 0.0);

}

@Test
public void testChangeLrMLNSchedule(){
//First: Set LR for a *single* layer and compare vs. equivalent net config
Expand Down
Expand Up @@ -1140,10 +1140,10 @@ private void copyConfigToLayer(String layerName, Layer layer) {

//Configure updaters:
if(iUpdater != null && bLayer.getIUpdater() == null){
bLayer.setIUpdater(iUpdater);
bLayer.setIUpdater(iUpdater.clone()); //Clone the updater to avoid shared instances - in case of setLearningRate calls later
}
if(biasUpdater != null && bLayer.getBiasUpdater() == null){
bLayer.setBiasUpdater(biasUpdater);
bLayer.setBiasUpdater(biasUpdater.clone()); //Clone the updater to avoid shared instances - in case of setLearningRate calls later
}

if(bLayer.getIUpdater() == null && iUpdater == null && bLayer.initializer().numParams(bLayer) > 0){
Expand Down
Expand Up @@ -306,7 +306,8 @@ public Layer[] getLayers() {
* Get a given layer by name.
*/
public Layer getLayer(String name) {
return verticesMap.get(name).getLayer(); //TODO checks
Preconditions.checkState(verticesMap.containsKey(name), "Layer with name %s does not exist in the network", name);
return verticesMap.get(name).getLayer();
}

/**
Expand Down
Expand Up @@ -3752,7 +3752,7 @@ public void setLearningRate(int layerNumber, ISchedule newLr){
* @return Learning rate for the specified layer, or null
*/
public Double getLearningRate(int layerNumber){
return NetworkUtils.getLearningRate(this, layerIndex);
return NetworkUtils.getLearningRate(this, layerNumber);
}

/**
Expand Down
Expand Up @@ -211,6 +211,13 @@ public BaseMultiLayerUpdater(T network, INDArray updaterState) {
* @param viewArray The new updater state
*/
public void setStateViewArray(INDArray viewArray) {
if(this.updaterStateViewArray == null){
if(viewArray == null)
return; //No op - for example, SGD and NoOp updater - i.e., no stored state
else {
throw new IllegalStateException("Attempting to set updater state view array with null value");
}
}
if (this.updaterStateViewArray.length() != viewArray.length())
throw new IllegalStateException("Invalid input: view arrays differ in length. " + "Expected length "
+ this.updaterStateViewArray.length() + ", got length " + viewArray.length());
Expand Down
Expand Up @@ -1405,7 +1405,6 @@ public INDArray assign(final INDArray arr) {
" with x.shape=%ndShape and y.shape=%ndShape", this, arr );
Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.Set(this, arr, this, length()));
return this;

}

@Override
Expand Down
Expand Up @@ -959,11 +959,12 @@ public static INDArray gemm(INDArray a,
* depending on setting of arguments transposeA and transposeB.<br>
* Note that matrix c MUST be fortran order, have zero offset and have c.data().length == c.length().
* i.e., the result array must not be a view. An exception will be thrown otherwise.<br>
* (Note: some views are allowed, if and only if they have f order and are contiguous in the buffer other than an
* offset. Put another way, they must be f order and have strides identical to a non-view/default array of the same shape)<br>
* Don't use this unless you know about level 3 blas and NDArray storage orders.
* @param a First matrix
* @param b Second matrix
* @param c result matrix. Used in calculation (assuming beta != 0) and result is stored in this. f order,
* zero offset and length == data.length only
* @param c result matrix. Used in calculation (assuming beta != 0) and result is stored in this. f order, and not a view only
* @param transposeA if true: transpose matrix a before mmul
* @param transposeB if true: transpose matrix b before mmul
* @return result, i.e., matrix c is returned for convenience
Expand All @@ -975,7 +976,7 @@ public static INDArray gemm(INDArray a,
boolean transposeB,
double alpha,
double beta) {
Preconditions.checkState(c.length() == 1 || (c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c) && !c.isView()),
Preconditions.checkState(c.length() == 1 || (c.ordering() == 'f' && Shape.hasDefaultStridesForShape(c)), //Note: some views
"C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order " +
"and not a view. C (result) array: [%ndSInfo]", c);
getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta);
Expand Down

0 comments on commit 2d7dade

Please sign in to comment.