Skip to content

Commit

Permalink
[WIP] DL4J/SameDiff Misc (#7145)
Browse files Browse the repository at this point in the history
* Small SameDiff fix (variable creation)

* #7140 RRDSI better validation for invalid indices

* GELU tests + polishing

* Deconv3d

* Deconv3d fixes, test

* Switch to FB 1.10.0

* Small deconv3d tweaks

* Javadoc
  • Loading branch information
AlexDBlack committed Feb 12, 2019
1 parent 41e1a88 commit fa9f1f2
Show file tree
Hide file tree
Showing 13 changed files with 565 additions and 36 deletions.
Expand Up @@ -30,6 +30,7 @@
import org.datavec.api.records.reader.impl.ConcatenatingRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.writable.Writable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
Expand Down Expand Up @@ -277,6 +278,12 @@ private void initializeUnderlying(Record next) {

underlyingIsDisjoint = false;
} else if (labelIndex >= 0) {
Preconditions.checkState(labelIndex < next.getRecord().size(),
"Invalid label (from) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", next.getRecord().size()-1, labelIndex);
Preconditions.checkState(labelIndexTo < next.getRecord().size(),
"Invalid label (to) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", next.getRecord().size()-1, labelIndexTo);


//Multiple inputs
int firstFrom = 0;
int firstTo = labelIndex - 1;
Expand Down
Expand Up @@ -431,6 +431,16 @@ public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
return deconv2D.outputVariable();
}

public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
DeConv3D d = new DeConv3D(sameDiff(), input, weights, bias, config);
return d.outputVariable();
}

public SDVariable[] deconv3dDerivative(SDVariable input, SDVariable weights, SDVariable bias, SDVariable grad, DeConv3DConfig config) {
DeConv3DDerivative d = new DeConv3DDerivative(sameDiff(), input, weights, bias, grad, config);
return d.outputVariables();
}

/**
* Conv3d operation.
*
Expand Down Expand Up @@ -1123,8 +1133,12 @@ public SDVariable swishDerivative(SDVariable iX) {
return new SwishDerivative(sameDiff(), iX, false).outputVariable();
}

public SDVariable geluDerivative(SDVariable iX) {
return new GELUDerivative(sameDiff(), iX, false).outputVariable();
public SDVariable gelu(SDVariable iX, boolean precise) {
return new GELU(sameDiff(), iX, false, precise).outputVariable();
}

public SDVariable geluDerivative(SDVariable iX, boolean precise) {
return new GELUDerivative(sameDiff(), iX, false, precise).outputVariable();
}

public SDVariable sign(SDVariable iX) {
Expand Down
Expand Up @@ -1674,7 +1674,8 @@ public INDArray eval() {

@Override
public String toString() {
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + ")";
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType +
(variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")";
}

@Override
Expand Down
Expand Up @@ -398,8 +398,8 @@ public SDVariable invokeGraphOn(SameDiff sameDiff) {
Map<Integer, Integer> thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
val newVar = sameDiff.var(clone);
SDVariable clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
SDVariable newVar = sameDiff.var(clone);
if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
Expand Down Expand Up @@ -2213,24 +2213,32 @@ public SDVariable var(String name, org.nd4j.linalg.api.buffer.DataType dataType,
* {@link NDArraySupplierInitScheme} is used to ensure that if the array is allocated anywhere
* and {@link SameDiff} instance to exist as a copy of the variable.
*
* @param arr
* @param v Variable
* @return
*/
public SDVariable var(@NonNull final SDVariable arr) {
if (variables.containsKey(arr.getVarName()) && variables.get(arr.getVarName()).getVariable().getArr() != null)
return variables.get(arr.getVarName()).getVariable();
public SDVariable var(@NonNull final SDVariable v) {
if (variables.containsKey(v.getVarName()) && variables.get(v.getVarName()).getVariable().getArr() != null)
return variables.get(v.getVarName()).getVariable();

if (arr.getVarName() == null || arr.getVarName().length() < 1)
if (v.getVarName() == null || v.getVarName().length() < 1)
throw new IllegalArgumentException("Name for variable must be defined");

VariableType vt = arr.getVariableType();
WeightInitScheme s = null;
if(vt == VariableType.CONSTANT || vt == VariableType.VARIABLE){
s = new NDArraySupplierInitScheme(arr.getArr());
VariableType vt = v.getVariableType();
NDArraySupplierInitScheme s = null;
switch(vt){
case VARIABLE:
s = new NDArraySupplierInitScheme(v.getArr());
//Intentional fallthrough
case ARRAY:
SDVariable ret = new SDVariable(v.getVarName(), v.getVariableType(), this, v.getShape(), v.dataType(), s);
return addVariable(ret);
case CONSTANT:
return constant(v.getVarName(), v.getArr());
case PLACEHOLDER:
return placeHolder(v.getVarName(), v.dataType(), v.placeholderShape());
default:
throw new RuntimeException("Unknown/not supported variable type: " + vt);
}

SDVariable ret = new SDVariable(arr.getVarName(), arr.getVariableType(), this, arr.getShape(), arr.dataType(), s);
return addVariable(ret);
}

private String getNewVarName() {
Expand Down Expand Up @@ -3258,6 +3266,19 @@ public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deco
return updateVariableNameAndReference(ret, name);
}

/**
* 3D CNN deconvolution operation with or without optional bias
* @param name Name of the output variable
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
*/
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config){
SDVariable ret = f().deconv3d(input, weights, bias, config);
return updateVariableNameAndReference(ret, name);
}


/**
* Convolution 3D operation without bias
Expand Down Expand Up @@ -5035,6 +5056,32 @@ public SDVariable relu6(String name, SDVariable x, double cutoff) {
return updateVariableNameAndReference(result, name);
}

/**
* GELU activation function - Gaussian Error Linear Units<br>
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
* This method uses the sigmoid approximation
*
* @param x Input
* @return Output variable - GELU applied to the input
*/
public SDVariable gelu(SDVariable x) {
return gelu(null, x);
}

/**
* GELU activation function - Gaussian Error Linear Units<br>
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
* This method uses the sigmoid approximation
*
* @param name Name of the output variable. May be null.
* @param x Input
* @return Output variable - GELU applied to the input
*/
public SDVariable gelu(String name, SDVariable x) {
SDVariable ret = f().gelu(x, false); //Defaults to si
return updateVariableNameAndReference(ret, name);
}

/**
* Softmax activation
*
Expand Down
Expand Up @@ -333,6 +333,10 @@ public static Class<?> transformStrictOpClass(int opNum){
return Expm1.class;
case 52:
return ATanh.class;
case 53:
return GELU.class;
case 54:
return GELUDerivative.class;
default:
throw new UnsupportedOperationException("No known transform strict op for op number: " + opNum);
}
Expand Down

0 comments on commit fa9f1f2

Please sign in to comment.