Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SameDiff: Add TensorArray support for new execution, new tests #6976

Merged
merged 9 commits into from Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -41,6 +41,7 @@
import org.nd4j.graph.*;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand All @@ -66,7 +67,7 @@
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.shape.Eye;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayV3;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
Expand All @@ -92,7 +93,6 @@
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.list.compat.TensorList;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
Expand Down Expand Up @@ -167,8 +167,6 @@ public class SameDiff {
@Deprecated //TO BE REMOVED - to Variable
private Map<String, SDVariable> forwardVarForGrad;

private Map<String, TensorList> lists = new HashMap<>(); // Key - node name; Value - TensorList

// counter for auto-naming variables
private int variableId = 0;

Expand Down Expand Up @@ -9747,15 +9745,6 @@ public SameDiff getFunction(String functionName) {
}


public TensorList getListByName(@NonNull String name) {
return lists.get(name);
}

public void putListByName(@NonNull String name, TensorList list) {
lists.put(name, list);
}


/**
* Creates a while statement
*
Expand Down Expand Up @@ -9800,8 +9789,8 @@ public If ifStatement(SameDiffConditional conditional,
}


public TensorArrayV3 tensorArray() {
return new TensorArrayV3(this);
public TensorArray tensorArray(DataType dataType) {
return new TensorArray(this, dataType);
}

/**
Expand Down
Expand Up @@ -30,6 +30,8 @@ public abstract class AbstractSession<T, O> {
protected final SameDiff sameDiff;
@Getter
protected final Map<VarId, T> nodeOutputs = new HashMap<>();
@Getter
protected final Map<VarId, List<T>> tensorArrays = new HashMap<>(); //Stores the outputs for a TensorArray ops
protected final Queue<VarId> availableForExec = new LinkedList<>();
/**
* Contains variables we *might* need to execute in process of getting outputs we want.
Expand Down Expand Up @@ -120,6 +122,7 @@ public Map<String, T> output(@NonNull List<String> variables, Map<String, T> pla
execInputs.clear();
execConstInputs.clear();
nodeOutputs.clear(); //TODO eventually we'll have cache here for later execs... main challenge is detecting in-place array modifications and invalidating old results
tensorArrays.clear();

//Step 1: determine subgraph structure we actually need to execute
//Basic plan: work backwards from the variables we want, based on the graph structure, to work out what
Expand Down Expand Up @@ -239,7 +242,8 @@ public Map<String, T> output(@NonNull List<String> variables, Map<String, T> pla
}
}
} else {
throw new IllegalStateException("Unable to execute variable " + varToExec + " of type " + sameDiff.getVariable(varToExec.getVariable()).getVariableType());
Variable v = sameDiff.getVariables().get(varToExec.getVariable());
throw new IllegalStateException("Unable to execute variable " + varToExec + " of type " + v.getVariable().getVariableType());
}
}

Expand Down Expand Up @@ -627,6 +631,15 @@ protected void addToExecInputs(boolean isConstOrPh, VarId inputVar, VarId forVar
}


protected static VarId lookup(String name, Collection<VarId> varIds){
for(VarId vid : varIds){
if(vid.getVariable().equals(name)){
return vid;
}
}
throw new RuntimeException("Could not find VarId to input " + name);
}

/*
VarId: identifies a variable in a specific frame and frame iteration
Used for 2 places:
Expand Down
Expand Up @@ -2,7 +2,6 @@

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
Expand All @@ -14,9 +13,13 @@
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.*;

Expand Down Expand Up @@ -123,6 +126,182 @@ public INDArray[] getOutputs(DifferentialFunction op, FrameIter outputFrameIter,
Preconditions.checkNotNull(arr, "Input to LoopCond op must not be null");
Preconditions.checkState(arr.isScalar() && arr.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
return new INDArray[]{arr};
} else if(op instanceof BaseTensorOp){
//TensorOps - special cases...
if(op instanceof TensorArray){
//Create a TensorArray
VarId vid = newVarId(op.outputVariable().getVarName(), outputFrameIter);
Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid);
tensorArrays.put(vid, new ArrayList<INDArray>());

// Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0)
return new INDArray[]{Nd4j.scalar(true), Nd4j.scalar(0.0f)};
} else if(op instanceof TensorArrayRead){
//Do lookup and return
//Input 0 is the TensorArray (or dummy variable that represents it)
//Input 1 is the index
SDVariable idxSDV = op.arg(1);
INDArray idxArr = getArray(idxSDV, opInputs);
Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", idxArr);
int i = idxArr.getInt(0);

SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the frame/iteration:
VarId v = lookup(inTensorArray.getVarName(), opInputs);

List<INDArray> list = getTensorArrays().get(v);
Preconditions.checkState(list != null, "Could not find TensorList for %s", v);
Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", i, v);

INDArray out = list.get(i);
return new INDArray[]{out};
} else if(op instanceof TensorArrayWrite) {
//TensorArrayWrite - also has a scalar 0.0 that it returns...

SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the varid (frame/iteration) of the tensor array:
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);

//Input 0 is the TensorArray (or dummy variable that represents it)
//Input 1 is the index
//Input 2 is the value to write

String idxName = op.arg(1).getVarName();
SDVariable idxSDV = sameDiff.getVariable(idxName);
INDArray idxArr = getArray(idxSDV, opInputs);
Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr);
int idx = idxArr.getInt(0);

String inName = op.arg(2).getVarName();
SDVariable inSDV = sameDiff.getVariable(inName);
INDArray arr = getArray(inSDV, opInputs);
Preconditions.checkState(arr != null, "Could not find array for %s", inName);

Preconditions.checkState(tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", tArr);
//TODO is this always safe to insert by index for all execution orders?
List<INDArray> l = tensorArrays.get(tArr); //.set(idx, arr);
while (l.size() <= idx) {
//Can't use set(int, E) if index >= size
l.add(null);
}
l.set(idx, arr);

//Return dummy array
return new INDArray[]{Nd4j.scalar(0.0f)};
} else if(op instanceof TensorArraySize) {
//Index 0 is the TensorArray (or dummy variable that represents it)
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the varid (frame/iteration) of the tensor array:
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);
List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
return new INDArray[]{Nd4j.scalar(DataType.INT, l.size())};
} else if(op instanceof TensorArrayConcat) {
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);
List<INDArray> l = tensorArrays.get(tArr);
//TODO - empty checks. But is size 0 OK?
INDArray concat = Nd4j.concat(0, l.toArray(new INDArray[l.size()]));
return new INDArray[]{concat};
} else if(op instanceof TensorArrayGather) {
//Input 0: the TensorArray
//Input 1: the indices (1d integer vector)

SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);
List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);

String indicesName = op.arg(1).getVarName();
SDVariable indicesSDV = sameDiff.getVariable(indicesName);
INDArray idxArr = getArray(indicesSDV, opInputs);
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName);
Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", idxArr.dataType(), indicesName);

int[] idxArrInt = idxArr.toIntVector();
ArrayList<INDArray> newList = new ArrayList<>();
for (int id : idxArrInt) {
newList.add(l.get(id));
}
INDArray out = Nd4j.pile(newList);
return new INDArray[]{out};
} else if(op instanceof TensorArrayScatter) {
//Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray
//Input 0: the TensorArray
//Input 1: the indices (1d integer vector)
//Input 2: The values to scatter

SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);
List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);

String indicesName = op.arg(1).getVarName();
SDVariable indicesSDV = sameDiff.getVariable(indicesName);
INDArray idxArr = getArray(indicesSDV, opInputs);
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName);
Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName);
int[] idxs = idxArr.toIntVector();

String valuesName = op.arg(2).getVarName();
SDVariable valuesSDV = sameDiff.getVariable(valuesName);
INDArray valuesArr = getArray(valuesSDV, opInputs);

while (l.size() <= idxs.length) { //Can't use set(int, E) if index >= size
l.add(null);
}

INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class);
for (int i = 0; i < idxs.length; i++) {
idx[0] = NDArrayIndex.point(i);
INDArray get = valuesArr.get(idx).dup();
int outIdx = idxs[i];
l.set(outIdx, get);
}

//Return dummy array
return new INDArray[]{Nd4j.scalar(0.0f)};
} else if(op instanceof TensorArraySplit){
//Split values from a rank (N+1)d tensor into sequential indices of the TensorArray
//For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:]
//Input 0: the TensorArray
//Input 1: The values to split
//Input 2: the size of each split (1d integer vector)

SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = lookup(inTensorArray.getVarName(), opInputs);
List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);

String splitName = op.arg(1).getVarName();
INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs);


String sizeName = op.arg(2).getVarName();
SDVariable sizeSDV = sameDiff.getVariable(sizeName);
INDArray sizeArr = getArray(sizeSDV, opInputs);
Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName);
Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", sizeArr.dataType(), sizeName);
int[] sizes = sizeArr.toIntVector();

while (l.size() <= sizes.length) { //Can't use set(int, E) if index >= size
l.add(null);
}

INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class);
int soFar = 0;
for( int i=0; i<sizes.length; i++ ){
idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]);
INDArray sub = splitArr.get(idx).dup();
l.set(i, sub);
soFar += sizes[i];
}
//Return dummy array
return new INDArray[]{Nd4j.scalar(0.0f)};
} else {
throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName());
}

} else if(op instanceof CustomOp){
CustomOp c = (CustomOp)op;
Nd4j.getExecutioner().exec(c);
Expand All @@ -149,13 +328,15 @@ public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameI

DifferentialFunction df = sameDiff.getFunctionById(opName);

//TODO We should clone these - probably - as we don't want them shared between threads/sessions!
//TODO We should clone these ops - probably - as we don't want them shared between threads/sessions!
//But let's only clone them *once* and cache in inference session - not on every exec

Preconditions.checkNotNull(df, "No differential function fond with name %s", opName);

if(df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration ||
df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While){
df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While ||
df instanceof BaseTensorOp){
//Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case
return df;
}

Expand Down Expand Up @@ -296,4 +477,15 @@ public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameI

return df;
}


protected INDArray getArray(SDVariable sdv, Collection<VarId> opInputs){
String n = sdv.getVarName();
if(sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE){
return getConstantOrVariable(n);
} else {
VarId inVarId = lookup(n, opInputs);
return nodeOutputs.get(inVarId);
}
}
}
Expand Up @@ -391,6 +391,9 @@ public String getNodeName(String name) {
if(ret.endsWith("/read")) {
ret = ret.replace("/read","");
}
if(ret.endsWith(":0")){
ret = ret.substring(0, ret.length()-2);
}
return ret;
}

Expand Down
Expand Up @@ -5004,7 +5004,8 @@ public INDArray get(INDArrayIndex... indexes) {
&& indexes[0] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[1] instanceof NDArrayIndexAll
|| isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[0] instanceof NDArrayIndexAll)))
&& indexes[0] instanceof NDArrayIndexAll)) ||
(rank() == 1 && length() == 1 && indexes.length == 1 && indexes[0] instanceof PointIndex && indexes[0].current() == 0)) //Last one: point index on rank 1 size 1
return this;

indexes = NDArrayIndex.resolve(shapeInfoDataBuffer(), indexes);
Expand Down