Skip to content

Commit

Permalink
Cleanup, fixes, improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed May 5, 2018
1 parent f4ec2f3 commit e10434b
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 78 deletions.
Expand Up @@ -24,6 +24,7 @@
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
Expand Down Expand Up @@ -1519,14 +1520,19 @@ public void testTopoSortSaving(){

ComputationGraph cg = new ComputationGraph(conf);
cg.init();

GraphIndices indices = cg.calculateIndices();

int[] order = cg.topologicalSortOrder();
List<String> strOrder = cg.getConfiguration().getTopologicalOrderStr();
INDArray[] out1 = cg.output(in);

//Check it's the same after loading:
System.out.println("-----------");
ComputationGraph cg2 = TestUtils.testModelSerialization(cg);
int[] order2 = cg2.topologicalSortOrder();
List<String> strOrder2 = cg.getConfiguration().getTopologicalOrderStr();
assertArrayEquals(order, order2);
assertEquals(strOrder, strOrder2);

INDArray[] out2 = cg2.output(in);
assertArrayEquals(out1, out2);
Expand All @@ -1540,47 +1546,51 @@ public void testTopoSortSaving(){
cg3.setParams(cg2.params());

int[] order3 = cg3.topologicalSortOrder();
List<String> strOrder3 = cg.getConfiguration().getTopologicalOrderStr();
INDArray[] out3 = cg3.output(in);
assertArrayEquals(order, order3);
assertEquals(strOrder, strOrder3);
assertArrayEquals(out1, out3);


// //Now, change the order, and ensure the net is the same... note that we can do [l0, l1, l2] in any order
//
// List<List<String>> someValidOrders = new ArrayList<>();
// List<int[]> someValidOrderIdxs = new ArrayList<>();
// someValidOrders.add(Arrays.asList("in1", "in2", "l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7"));
// someValidOrderIdxs.add(new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
// someValidOrders.add(Arrays.asList("in1", "in2", "l1", "l0", "l2", "l3", "l4", "l5", "l6", "l7"));
// someValidOrderIdxs.add(new int[]{0, 1, 3, 2, 4, 5, 6, 7, 8, 9});
// someValidOrders.add(Arrays.asList("in1", "in2", "l2", "l1", "l0", "l3", "l4", "l5", "l6", "l7"));
// someValidOrderIdxs.add(new int[]{0, 1, 4, 3, 2, 5, 6, 7, 8, 9});
// someValidOrders.add(Arrays.asList("in1", "in2", "l2", "l5", "l0", "l1", "l3", "l4", "l7", "l6"));
// someValidOrderIdxs.add(new int[]{0, 1, 4, 7, 2, 3, 5, 6, 9, 8});
//
// for( int i=0; i<someValidOrders.size(); i++ ){
// List<String> l = someValidOrders.get(i);
// int[] arr = someValidOrderIdxs.get(i);
//
// ComputationGraphConfiguration conf2 = conf.clone();
// conf2.setTopologicalOrderStr(l);
// conf2.setTopologicalOrder(arr);
//
// ComputationGraph g = new ComputationGraph(conf2);
// g.setParamTable(cg.paramTable());
// g.init();
// int[] origOrder = g.topologicalSortOrder();
//
// INDArray[] out4 = g.output(in);
// assertArrayEquals(out1, out4);
//
// ComputationGraph g2 = TestUtils.testModelSerialization(g);
// int[] loadedOrder = g2.topologicalSortOrder();
//
// assertArrayEquals(origOrder, loadedOrder);
//
// INDArray[] out5 = g.output(in);
// assertArrayEquals(out1, out5);
// }
//Now, change the order, and ensure the net is the same... note that we can do [l0, l1, l2] in any order
List<List<String>> someValidOrders = new ArrayList<>();
someValidOrders.add(Arrays.asList("in1", "in2", "l0", "l1-merge", "l1", "l2", "l3", "l4", "l5", "l6-merge", "l6", "l7"));
someValidOrders.add(Arrays.asList("in1", "in2", "l1-merge", "l1", "l0", "l2", "l3", "l4", "l5", "l6-merge", "l6", "l7"));
someValidOrders.add(Arrays.asList("in1", "in2", "l2", "l1-merge", "l1", "l0", "l3", "l4", "l5", "l6-merge", "l6", "l7"));
someValidOrders.add(Arrays.asList("in1", "in2", "l2", "l5", "l0", "l1-merge", "l1", "l3", "l4", "l7", "l6-merge", "l6"));

for(List<String> l : someValidOrders){
assertEquals(strOrder.size(), l.size());
}

for( int i=0; i<someValidOrders.size(); i++ ){
List<String> l = someValidOrders.get(i);
int[] arr = new int[l.size()];
int j=0;
for(String s : l){
arr[j++] = indices.getNameToIdx().get(s);
}

ComputationGraphConfiguration conf2 = conf.clone();
conf2.setTopologicalOrderStr(l);
conf2.setTopologicalOrder(arr);

ComputationGraph g = new ComputationGraph(conf2);
g.init();
g.setParamTable(cg.paramTable());
int[] origOrder = g.topologicalSortOrder();

INDArray[] out4 = g.output(in);
assertArrayEquals(out1, out4);

ComputationGraph g2 = TestUtils.testModelSerialization(g);
int[] loadedOrder = g2.topologicalSortOrder();

assertArrayEquals(origOrder, loadedOrder);

INDArray[] out5 = g.output(in);
assertArrayEquals(out1, out5);
}
}
}
Expand Up @@ -35,6 +35,7 @@
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
Expand Down Expand Up @@ -154,6 +155,9 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
* (and hence also backward pass, which is the opposite to this) is conducted in the network.
*/
protected int[] topologicalOrder;

protected GraphIndices graphIndices;

/**
* A list of layers. Each of these layers is present in a GraphVertex, but are here for easy reference.
* This array also defines the order in which the getLayer(int) method returns layers.
Expand Down Expand Up @@ -430,8 +434,8 @@ public void init(INDArray parameters, boolean cloneParametersArray) {
// }

//First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
Indexes indexes = calculateIndexes();
topologicalOrder = indexes.getTopologicalSortOrder();
GraphIndices indices = calculateIndices();
topologicalOrder = indices.getTopologicalSortOrder();

//Initialization: create the GraphVertex objects, based on configuration structure
Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();
Expand Down Expand Up @@ -462,7 +466,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) {
numParamsForVertex[i] = 0; //No parameters for input vertices
}
for(; i<topologicalOrder.length; i++ ){
String name = indexes.getIdxToName().get(i);
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i];
Expand Down Expand Up @@ -517,7 +521,7 @@ public void init(INDArray parameters, boolean cloneParametersArray) {
List<String> variables = defaultConfiguration.variables(false);
i = configuration.getNetworkInputs().size();
for(; i<topologicalOrder.length; i++ ){
String name = indexes.getIdxToName().get(i);
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);

GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber],
Expand Down Expand Up @@ -545,20 +549,6 @@ public void init(INDArray parameters, boolean cloneParametersArray) {
}
layers = tempLayerList.toArray(new Layer[numLayers]);

/////////////////////
if(!allNamesReverse.equals(indexes.getNameToIdx())){
throw new RuntimeException();
}

//Also validate Vertex[]:
for( int x=0; x<vertices.length; x++ ){
String name = vertices[x].getVertexName();
String expName = indexes.getIdxToName().get(x);
Preconditions.checkState(expName.equals(name), "%s - %s, %s", x, expName, name);
}

/////////////////////

//Create the lookup table, so we can find vertices easily by name
verticesMap = new HashMap<>();
for (GraphVertex gv : vertices) {
Expand Down Expand Up @@ -673,6 +663,8 @@ public void initGradientsView() {
if (!initCalled)
init();

GraphIndices indices = calculateIndices();

//Go through layers, and work out total number of parameters. Then allocate full parameters array
int numParams = 0;
int[] numParamsForVertex = new int[topologicalOrder.length];
Expand All @@ -681,12 +673,11 @@ public void initGradientsView() {
numParamsForVertex[i] = 0; //No parameters for input vertices
}
Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();
for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap
.entrySet()) {
org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
for (; i < topologicalOrder.length; i++) {
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i];
i++;
}

if(numParams > 0) {
Expand Down Expand Up @@ -1075,14 +1066,7 @@ public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArra
clearLayersStates();
}

@Data
@AllArgsConstructor
@Builder
private static class Indexes {
private int[] topologicalSortOrder;
private Map<String,Integer> nameToIdx;
private Map<Integer,String> idxToName;
}


/**
* Calculate a topological sort order for the vertices in the graph.
Expand All @@ -1094,12 +1078,12 @@ private static class Indexes {
* Specifically, gradients/params/forward pass are executed on vertex[topologicalSortOrder[i]], for i=0..nVertices-1
*/
public int[] topologicalSortOrder() {
return calculateIndexes().topologicalSortOrder;
return calculateIndices().getTopologicalSortOrder();
}

public Indexes calculateIndexes(){
// if (topologicalOrder != null)
// return topologicalOrder;
public GraphIndices calculateIndices(){
if(graphIndices != null)
return graphIndices;


//Get cached topological sort order from config, if present
Expand All @@ -1113,12 +1097,12 @@ public Indexes calculateIndexes(){
m2.put(t[i], s.get(i));
}

System.out.println("RETURNING CACHED TOPO SORT");
return Indexes.builder()
graphIndices = GraphIndices.builder()
.topologicalSortOrder(t)
.nameToIdx(m1)
.idxToName(m2)
.build();
return graphIndices;
}


Expand Down Expand Up @@ -1229,12 +1213,12 @@ public Indexes calculateIndexes(){
configuration.setTopologicalOrder(out);
configuration.setTopologicalOrderStr(s);

// return out;
return Indexes.builder()
graphIndices = GraphIndices.builder()
.topologicalSortOrder(out)
.nameToIdx(vertexNamesMap2)
.idxToName(vertexNamesMap)
.build();
return graphIndices;
}

@Override
Expand Down Expand Up @@ -3046,8 +3030,24 @@ public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
}

@Override
public void setParamTable(Map<String, INDArray> paramTable) {
throw new UnsupportedOperationException("Not implemented");
public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
Map<String,INDArray> current = paramTable();
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
for(String s : current.keySet()){
INDArray arrCurrent = current.get(s);
INDArray arrNew = paramTable.get(s);
int[] shapeCurrent = arrCurrent.shape();
int[] shapeNew = arrNew.shape();
Preconditions.checkState(Arrays.equals(shapeCurrent, shapeNew), "Cannot set parameters: shape array for " +
"parameter \"%s\" does not match existing shape: parameter shape = %s, new param shape = %s", s, shapeCurrent, arrNew);
}

for(String s : current.keySet()) {
INDArray arrCurrent = current.get(s);
INDArray arrNew = paramTable.get(s);
arrCurrent.assign(arrNew);
}
}

@Override
Expand Down
@@ -0,0 +1,21 @@
package org.deeplearning4j.nn.graph.util;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

import java.util.Map;

/**
* Simple helper class for ComputationGraph topological sort
*
* @author Alex Black
*/
@Data
@AllArgsConstructor
@Builder
public class GraphIndices {
private int[] topologicalSortOrder;
private Map<String,Integer> nameToIdx;
private Map<Integer,String> idxToName;
}

0 comments on commit e10434b

Please sign in to comment.