Skip to content

Commit

Permalink
ND4J indexing fixes + DL4J fix (#6353)
Browse files Browse the repository at this point in the history
* #6327 INDArray.put with SpecifiedIndex

* #6341 - SpecifiedIndex with single value no longer collapses dimensions

* Another indexing fix

* #6343 TransferLearning nOutReplace fix
  • Loading branch information
AlexDBlack committed Sep 4, 2018
1 parent d5c7b3f commit 2c32b5c
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 37 deletions.
Expand Up @@ -22,15 +22,18 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
Expand Down Expand Up @@ -479,4 +482,50 @@ public void testObjectOverrides(){
assertNull(l.getConstraints());
assertEquals(0.0, l.getL2(), 0.0);
}


@Test
public void testTransferLearningSubsequent() {
String inputName = "in";
String outputName = "out";

final String firstConv = "firstConv";
final String secondConv = "secondConv";
final INDArray input = Nd4j.create(6,6,6,6);
final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
.weightInit(new ConstantDistribution(666))
.graphBuilder()
.addInputs(inputName)
.setOutputs(outputName)
.setInputTypes(InputType.inferInputTypes(input))
.addLayer(firstConv, new Convolution2D.Builder(3, 3)
.nOut(10)
.build(), inputName)
.addLayer(secondConv, new Convolution2D.Builder(1, 1)
.nOut(3)
.build(), firstConv)
.addLayer(outputName, new OutputLayer.Builder()
.nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE)
.build(), secondConv)
.build());
graph.init();

final ComputationGraph newGraph = new TransferLearning
.GraphBuilder(graph)
.nOutReplace(firstConv, 7, new ConstantDistribution(333))
.nOutReplace(secondConv, 3, new ConstantDistribution(111))
.removeVertexAndConnections(outputName)
.addLayer(outputName, new OutputLayer.Builder()
.nIn(48).nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE)
.build(), new CnnToFeedForwardPreProcessor(4,4,3), secondConv)
.setOutputs(outputName)
.build();
newGraph.init();

assertEquals("Incorrect # inputs", 7, newGraph.layerInputSize(secondConv));

newGraph.outputSingle(input);
}
}
Expand Up @@ -24,6 +24,7 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
Expand Down Expand Up @@ -631,4 +632,36 @@ public void testObjectOverrides(){
assertEquals(0.0, l.getL2(), 0.0);
}


@Test
public void testTransferLearningSubsequent() {
final INDArray input = Nd4j.create(6,6,6,6);
final MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.weightInit(new ConstantDistribution(666))
.list()
.setInputType(InputType.inferInputTypes(input)[0])
.layer(new Convolution2D.Builder(3, 3).nOut(10).build())
.layer(new Convolution2D.Builder(1, 1).nOut(3).build())
.layer(new OutputLayer.Builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE)
.build()).build());
net.init();

MultiLayerNetwork newGraph = new TransferLearning
.Builder(net)
.fineTuneConfiguration(new FineTuneConfiguration.Builder().build())
.nOutReplace(0, 7, new ConstantDistribution(333))
.nOutReplace(1, 3, new ConstantDistribution(111))
.removeLayersFromOutput(1)
.addLayer(new OutputLayer.Builder()
.nIn(48).nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.setInputPreProcessor(2, new CnnToFeedForwardPreProcessor(4,4,3))
.build();
newGraph.init();

assertEquals("Incorrect # inputs", 7, newGraph.layerInputSize(1));

newGraph.output(input);
}
}
Expand Up @@ -476,6 +476,8 @@ public static class GraphBuilder {
private WorkspaceMode workspaceMode;
private Boolean validateOutputLayerConfig = null;

private Map<String,Integer> nInFromNewConfig = new HashMap<>();

/**
* Computation Graph to tweak for transfer learning
* @param origGraph
Expand Down Expand Up @@ -601,6 +603,14 @@ private GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme,
layerImplF.setDist(dist);
layerImplF.setNOut(nOut);

if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
&& nInFromNewConfig.containsKey(layerName)){
Layer l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer();
if(l instanceof FeedForwardLayer){
layerImplF.setNIn(nInFromNewConfig.get(layerName));
}
}

editedConfigBuilder.removeVertex(layerName, false);
LayerVertex lv = (LayerVertex) origConfig.getVertices().get(layerName);
String[] lvInputs = origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
Expand Down Expand Up @@ -631,6 +641,8 @@ private GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme,
layerImplF.setDist(distNext);
layerImplF.setNIn(nOut);

nInFromNewConfig.put(fanoutVertexName, nOut);

editedConfigBuilder.removeVertex(fanoutVertexName, false);
lv = (LayerVertex) origConfig.getVertices().get(fanoutVertexName);
lvInputs = origConfig.getVertexInputs().get(fanoutVertexName).toArray(new String[0]);
Expand Down
Expand Up @@ -2463,19 +2463,59 @@ else if(indices.isRowVector()) {
@Override
public INDArray put(INDArrayIndex[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this);
if (indices[0] instanceof SpecifiedIndex && element.isVector()) {
indices[0].reset();
int cnt = 0;
while (indices[0].hasNext()) {
long idx = indices[0].next();
// FIXME: LONG
putScalar((int) idx, element.getDouble(cnt));
cnt++;
boolean isSpecifiedIndex = false;
for(INDArrayIndex idx : indices){
if(idx instanceof SpecifiedIndex){
isSpecifiedIndex = true;
break;
}
return this;
} else {
}

if(!isSpecifiedIndex){
return get(indices).assign(element);
} else {
//Can't get a view, so we'll do it in subsets instead
// This is inefficient, but it is correct...
int numSpecified = 0;
List<long[]> specifiedIdxs = new ArrayList<>();
List<Integer> specifiedIdxDims = new ArrayList<>();

INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone
INDArrayIndex[] sourceIndices = indices.clone();
for( int i=0; i<indices.length; i++){
INDArrayIndex idx = indices[i];
if(idx instanceof SpecifiedIndex){
numSpecified++;
long[] idxs = ((SpecifiedIndex) idx).getIndexes();
specifiedIdxs.add(idxs);
specifiedIdxDims.add(i);
} else if(idx instanceof PointIndex){
//Example: [2,3,3].put(point(1), ..., [1,x,y]) -> can't use point(1) on [1,x,y]
sourceIndices[i] = NDArrayIndex.point(0);
}
}
int[] counts = new int[specifiedIdxs.size()];
int[] dims = new int[specifiedIdxDims.size()];
for( int i=0; i<specifiedIdxs.size(); i++ ){
counts[i] = specifiedIdxs.get(i).length;
dims[i] = specifiedIdxDims.get(i);
}

NdIndexIterator iter = new NdIndexIterator(counts);
while(iter.hasNext()){
long[] iterationIdxs = iter.next();
for(int i=0; i<iterationIdxs.length; i++ ){
long[] indicesForDim = specifiedIdxs.get(i);
destinationIndices[dims[i]] = NDArrayIndex.point(indicesForDim[(int)iterationIdxs[i]]);
sourceIndices[dims[i]] = NDArrayIndex.point(iterationIdxs[i]);
}

INDArray sourceView = element.get(sourceIndices);
INDArray destinationView = this.get(destinationIndices);
destinationView.assign(sourceView);
}
}
return this;
}

@Override
Expand Down
Expand Up @@ -240,6 +240,17 @@ public static INDArrayIndex all() {
return new NDArrayIndexAll(true);
}

/**
* Returns an instance of {@link SpecifiedIndex}.
* Note that SpecifiedIndex works differently than the other indexing options, in that it always returns a copy
* of the (subset of) the underlying array, for get operations. This means that INDArray.get(..., indices(x,y,z), ...)
* will be a copy of the relevant subset of the array.
* @param indices Indices to get
*/
public static INDArrayIndex indices(long... indices){
return new SpecifiedIndex(indices);
}


/**
* Represents adding a new dimension
Expand Down Expand Up @@ -313,6 +324,8 @@ public static INDArrayIndex[] resolve(DataBuffer shapeInfo, INDArrayIndex... int
IntervalIndex intervalIndex = (IntervalIndex) intendedIndexes[i];
ret[i] = new SpecifiedIndex(ArrayUtil.range(intervalIndex.begin, intervalIndex.end(),
intervalIndex.stride()));
} else if(intendedIndexes[i] instanceof PointIndex){
ret[i] = intendedIndexes[i];
}
}
}
Expand Down
Expand Up @@ -644,31 +644,6 @@ else if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll)
else
this.offset += ArrayUtil.calcOffsetLong2(accumShape, accumOffsets, accumStrides)
/ Math.max(1, numIntervals);


//collapse singular dimensions with specified index
List<Integer> removeShape = new ArrayList<>();
for (int i = 0; i < Math.min(this.shapes.length, indexes.length); i++) {
if (this.shapes[i] == 1 && indexes[i] instanceof SpecifiedIndex) {
removeShape.add(i);
}
}


if (!removeShape.isEmpty()) {
List<Long> newShape = new ArrayList<>();
List<Long> newStrides = new ArrayList<>();
for (int i = 0; i < this.shapes.length; i++) {
if (!removeShape.contains(i)) {
newShape.add(this.shapes[i]);
newStrides.add(this.strides[i]);
}
}

this.shapes = Longs.toArray(newShape);
this.strides = Longs.toArray(newStrides);
}

}

public void resolveFixedDimensionsCOO(INDArrayIndex... indexes) {
Expand Down
Expand Up @@ -169,6 +169,12 @@ public static class SpecifiedIndexesGenerator implements Generator<Generator<Lis
*/
public SpecifiedIndexesGenerator(INDArrayIndex[] indexes) {
this.indexes = indexes;
for(int i=0; i<indexes.length; i++ ){
//Replace point indices with specified indices
if(indexes[i] instanceof PointIndex){
indexes[i] = new SpecifiedIndex(indexes[i].current());
}
}
}

@Override
Expand Down
Expand Up @@ -65,6 +65,7 @@
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
Expand Down Expand Up @@ -6878,6 +6879,86 @@ public void testStack(){
}
}

@Test
public void testPutSpecifiedIndex(){
long[][] ss = new long[][]{{3,4}, {3,4,5}, {3,4,5,6}};
long[][] st = new long[][]{{4,4}, {4,4,5}, {4,4,5,6}};
long[][] ds = new long[][]{{1,4}, {1,4,5}, {1,4,5,6}};

for( int test=0; test<ss.length; test++ ) {
long[] shapeSource = ss[test];
long[] shapeTarget = st[test];
long[] diffShape = ds[test];

final INDArray source = Nd4j.ones(shapeSource);
final INDArray target = Nd4j.zeros(shapeTarget);

final INDArrayIndex[] targetIndexes = new INDArrayIndex[shapeTarget.length];
Arrays.fill(targetIndexes, NDArrayIndex.all());
int[] arr = new int[(int) shapeSource[0]];
for (int i = 0; i < arr.length; i++) {
arr[i] = i;
}
targetIndexes[0] = new SpecifiedIndex(arr);

// Works
//targetIndexes[0] = NDArrayIndex.interval(0, shapeSource[0]);

target.put(targetIndexes, source);
final INDArray expected = Nd4j.concat(0, Nd4j.ones(shapeSource), Nd4j.zeros(diffShape));
assertEquals("Expected array to be set!", expected, target);
}
}

@Test
public void testPutSpecifiedIndices2d(){

INDArray arr = Nd4j.create(3,4);
INDArray toPut = Nd4j.create(new double[]{1,2,3,4}, new int[]{2,2}, 'c');
INDArrayIndex[] indices = new INDArrayIndex[]{
NDArrayIndex.indices(0,2),
NDArrayIndex.indices(1,3)} ;

INDArray exp = Nd4j.create(new double[][]{
{0,1,0,2},
{0,0,0,0},
{0,3,0,4}});

arr.put(indices, toPut);
assertEquals(exp, arr);
}

@Test
public void testPutSpecifiedIndices3d(){

INDArray arr = Nd4j.create(2,3,4);
INDArray toPut = Nd4j.create(new double[]{1,2,3,4}, new int[]{1,2,2}, 'c');
INDArrayIndex[] indices = new INDArrayIndex[]{
NDArrayIndex.point(1),
NDArrayIndex.indices(0,2),
NDArrayIndex.indices(1,3)} ;

INDArray exp = Nd4j.create(2,3,4);
exp.putScalar(1, 0, 1, 1);
exp.putScalar(1, 0, 3, 2);
exp.putScalar(1, 2, 1, 3);
exp.putScalar(1, 2, 3, 4);

arr.put(indices, toPut);
assertEquals(exp, arr);
}

@Test
public void testSpecifiedIndexArraySize1() {
long[] shape = {2, 2, 2, 2};
INDArray in = Nd4j.create(shape);
INDArrayIndex[] idx1 = new INDArrayIndex[]{NDArrayIndex.all(), new SpecifiedIndex(0), NDArrayIndex.all(), NDArrayIndex.all()};

INDArray arr = in.get(idx1);
long[] expShape = new long[]{2,1,2,2};
assertArrayEquals(expShape, arr.shape());
}

///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
Expand Down
Expand Up @@ -146,8 +146,7 @@ public void testPointPointInterval() {
@Test
public void testIntervalLowerBound() {
INDArray wholeArr = Nd4j.linspace(1, 24, 24).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), new SpecifiedIndex(new int[] {0}),
new SpecifiedIndex(new int[] {0, 2}));
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});

assertEquals(assertion, subarray);
Expand Down

0 comments on commit 2c32b5c

Please sign in to comment.