Skip to content

Commit

Permalink
Import test improvements + more import mapping (#6145)
Browse files Browse the repository at this point in the history
* Add name to parameterized tests

* Op validation improvements

* More ignores

* Tweaks to op validation reporting

* Fix mmul import

* Fix TF import op names for 3 ops

* Fix Range import edge case

* More import fixes

* Add boolean array type import for TF

* LogicalAnd import fixes

* LogicalOr/Xor

* TFGraphMapper fix; fix Not op extra args

* Fix Equals import

* Fix loading of scalars for TF import tests

* More fixes/ignores

* Fix DefaultOpExecutioner issue with empty arrays

* Import support for: segment max/mean/min/prod/sum, matrix inverse; misc fixes

* Add ShapeN; fixes for Or/Xor; ignore known failing tests for now

* Final tweaks before merge
  • Loading branch information
AlexDBlack committed Aug 16, 2018
1 parent ff90344 commit 1a48d74
Show file tree
Hide file tree
Showing 32 changed files with 623 additions and 79 deletions.
Expand Up @@ -506,11 +506,16 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
int tfOpsWithImportTests = 0;
if(logUntestedTFImport)
log.info(" --- Ops with TF Mapping but No TF Import Tests ---");
for(Map.Entry<String, DifferentialFunction> e : tfOpsMap.entrySet()){
String s = e.getKey();
List<String> tfOpsKeys = new ArrayList<>(tfOpsMap.keySet());
Collections.sort(tfOpsKeys);
Set<String> tfIgnored = excludeFromTfImportCoverage();
int tfImportIgnored = 0;
for(String s : tfOpsKeys){
Integer count = tfMappedOpsImportTestCounts.get(s);
if(count == null || count == 0){
if(logUntestedTFImport)
if(tfIgnored.contains(s)){
tfImportIgnored++;
} else if(logUntestedTFImport)
log.info("TF mapped op with no import tests: {}", s);
} else {
tfOpsWithImportTests++;
Expand Down Expand Up @@ -567,7 +572,7 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
int countLibnd4jMapped = countTotalLibnd4jOps - nonMappedLibnd4jOps.size();
String fracLibnd4j = String.format("%.2f", 100.0 * (countLibnd4jMapped / (double)countTotalLibnd4jOps));

String fracTFMappedTested = String.format("%.2f", 100.0 * tfOpsWithImportTests / (double)totalTFMappedOps);
String fracTFMappedTested = String.format("%.2f", 100.0 * tfOpsWithImportTests / (double)(totalTFMappedOps-tfImportIgnored));

log.info("*****************************************************");
log.info("Op Validation: {} of {} classes with adequate tests ({}% coverage)", countAdequate, totalFwd, pc);
Expand All @@ -576,7 +581,7 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
log.info("({} ops excluded from gradient check coverage)", excludedFromBackpropCoverage.size());
log.info("({} ops excluded from fwd+gradient tests)", excludedFromAllTestCoverage.size());
log.info("TF mapped ops: {} of {} ({}%)", countTfMapped, countTf, fracTfStr);
log.info("SD ops with TF import mapping + test {} of {} ({}%)", tfOpsWithImportTests, totalTFMappedOps, fracTFMappedTested);
log.info("SD ops with TF import mapping + test {} of {} ({}%) - {} ignored for coverage", tfOpsWithImportTests, (totalTFMappedOps-tfImportIgnored), fracTFMappedTested, tfImportIgnored);
log.info("Libnd4j mapped ops: {} of {} ({}%)", countLibnd4jMapped, countTotalLibnd4jOps, fracLibnd4j);
log.info("*****************************************************");
}
Expand Down Expand Up @@ -764,4 +769,35 @@ private static Set<Class> excludedFromGradientCheckCoverage() {
return new HashSet<>(list);
}

/**
* These ops are excluded from TF import test coverage, for various reasons
*/
private static Set<String> excludeFromTfImportCoverage(){
List<String> list = Arrays.asList(
"Reverse", //Can be excluded because "Reverse_v2" is synonym that TF uses with tf.reverse(...); ReverseV2 is also Java op that is synonym for same op
"LogSigmoid", //Not in ops.proto. Have tests for tf.log_sigmoid, but can't test LogSigmoid op directly: tf.log_sigmoid actually just uses "y = -tf.nn.softplus(-x)" - i.e., 3 separate ops :/
"HardSigmoid", //Also implemented as python, NOT a single native op
"SpaceToBatch", //Old name - SpaceToBatchNd is used in practice (inc. for tf.space_to_batch)
"BatchToSpace", //Old name - BatchToSpaceNd is used in practice

//All of the following ops - not available in TF (can't find them) - op mapping is wrong?
//TODO: Check these and remove the import mapping from the Java classes if they are indeed bad
"HardTanh",
"Swish",
"RDiv",
"DivScalar",
"LogX",
"RationalTanh",
"absargmax",
"absargmin",
"entropy_shannon", //This is a thing, but quite different from our op: https://www.tensorflow.org/versions/r1.2/api_docs/python/tf/contrib/bayesflow/entropy/entropy_shannon
"count_zero"



);

return new HashSet<>(list);
}

}
Expand Up @@ -183,7 +183,8 @@ public SameDiff importGraph(GRAPH_TYPE tfGraph) {
//map the names of the nodes while accumulating the vertex ids
//for each variable
for(Map.Entry<String,TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if(dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
DataBuffer.Type dt = dataTypeForTensor(entry.getValue());
if(dt == DataBuffer.Type.UNKNOWN && !unknownTypeNodeImportable(entry.getValue())) {
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('c'));
//mark as place holder for validating resolution later.
if(isPlaceHolder(entry.getValue())) {
Expand Down
Expand Up @@ -208,6 +208,13 @@ public interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE> {
*/
DataBuffer.Type dataTypeForTensor(TENSOR_TYPE tensorType);

/**
* If {@link #dataTypeForTensor(Object)} return UNKNOWN we *might* still be able
* to import it. This method will return true if it is importable in spite of unknown type
* @param tensor
* @return
*/
boolean unknownTypeNodeImportable(TENSOR_TYPE tensor);

/**
*
Expand Down
Expand Up @@ -419,6 +419,11 @@ public DataBuffer.Type dataTypeForTensor( onnx.OnnxProto3.TypeProto.Tensor tenso
return nd4jTypeFromOnnxType(tensorProto.getElemType());
}

@Override
public boolean unknownTypeNodeImportable(OnnxProto3.TypeProto.Tensor tensor) {
return false;
}


/**
* Convert an onnx type to the proper nd4j type
Expand Down
Expand Up @@ -395,7 +395,8 @@ public String getNodeName(String name) {
@Override
public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) {
Map<String,NodeDef> ret = new LinkedHashMap<>();
for(NodeDef nodeDef : graphDef.getNodeList()) {
List<NodeDef> nodeList = graphDef.getNodeList();
for(NodeDef nodeDef : nodeList) {
if(nodeDef.getName().endsWith("/read")) {
continue;
}
Expand Down Expand Up @@ -768,6 +769,19 @@ public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto) {
}
}

@Override
public boolean unknownTypeNodeImportable(NodeDef tensorProto) {
DataType dt = null;
if(tensorProto.containsAttr("dtype")){
dt = tensorProto.getAttrOrThrow("dtype").getType();
} else if(tensorProto.containsAttr("T")){
dt = tensorProto.getAttrOrThrow("T").getType();
} else if(tensorProto.containsAttr("Tidx")){
dt = tensorProto.getAttrOrThrow("Tidx").getType();
}

return dt == DataType.DT_BOOL;
}


@Override
Expand Down Expand Up @@ -961,22 +975,22 @@ public INDArray mapTensorProto(TensorProto tfTensor) {
} else if (tfTensor.getDtype() == DataType.DT_INT64) {
if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getDoubleValCount() < 1)
if (tfTensor.getDoubleValCount() < 1)
return Nd4j.trueScalar(0.0);

double val = (double) tfTensor.getInt64Val(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getInt64ValCount() > 0) {
} else if (tfTensor.getInt64ValCount() > 0) {
double[] jArray = new double[tfTensor.getInt64ValCount()];
for (int e = 0; e < tfTensor.getInt64ValCount(); e++) {
jArray[e] = (double) tfTensor.getInt64Val(e);
jArray[e] = (double) tfTensor.getInt64Val(e);
}

// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0){
} else if (tfTensor.getTensorContent().size() > 0) {
//throw new UnsupportedOperationException("To be implemented yet");
//Mapping INT bytebuffers should be converted to floating point
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
Expand All @@ -997,6 +1011,26 @@ public INDArray mapTensorProto(TensorProto tfTensor) {
//log.debug("Data: {}", Arrays.toString(array.data().asFloat()));
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_BOOL){
if (tfTensor.getBoolValCount() == 1 || ArrayUtil.prod(arrayShape) == 1){
//straight zero case
if(tfTensor.getBoolValCount() < 1)
return Nd4j.trueScalar(0.0);

boolean val = tfTensor.getBoolVal(0);
return Nd4j.trueScalar(val ? 1.0 : 0.0);
} else if (tfTensor.getBoolValCount() > 0) {
float[] jArray = new float[tfTensor.getBoolValCount()];
for (int e = 0; e < tfTensor.getBoolValCount(); e++) {
jArray[e] = tfTensor.getBoolVal(e) ? 1.0f : 0.0f;
}

// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
throw new UnsupportedOperationException("Not yet implemented for DataType.DT_BOOL");
}
} else {
throw new UnsupportedOperationException("Unknown dataType found: [" + tfTensor.getDtype() + "]");
}
Expand Down
Expand Up @@ -28,6 +28,7 @@
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
Expand Down Expand Up @@ -559,29 +560,32 @@ public void profilingHookOut(CustomOp op, long timeStart) {
* @param op
*/
public static void validateDataType(DataBuffer.Type expectedType, Op op) {
if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.x() != null && !Shape.isEmpty(op.x().shapeInfoJava()) && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.x());
}

if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.y() != null && !Shape.isEmpty(op.y().shapeInfoJava()) && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.y());
}

if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.z() != null && !Shape.isEmpty(op.z().shapeInfoJava()) && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.z());
}

if (op.x() != null && op.x().data().dataType() != expectedType
&& op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
if (op.x() != null && !Shape.isEmpty(op.x().shapeInfoJava())
&& op.x().data().dataType() != expectedType
&& op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
throw new ND4JIllegalStateException("op.X dataType is [" + op.x().data().dataType()
+ "] instead of expected [" + expectedType + "]");

if (op.z() != null && op.z().data().dataType() != expectedType
if (op.z() != null && !Shape.isEmpty(op.z().shapeInfoJava())
&& op.z().data().dataType() != expectedType
&& op.z().data().dataType() != DataBuffer.Type.COMPRESSED)
throw new ND4JIllegalStateException("op.Z dataType is [" + op.z().data().dataType()
+ "] instead of expected [" + expectedType + "]");

if (op.y() != null && op.y().data().dataType() != expectedType)
if (op.y() != null && !Shape.isEmpty(op.y().shapeInfoJava())
&& op.y().data().dataType() != expectedType)
throw new ND4JIllegalStateException("op.Y dataType is [" + op.y().data().dataType()
+ "] instead of expected [" + expectedType + "]");

Expand Down
@@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.linalg.api.ops.impl.accum;

import org.nd4j.linalg.api.ops.DynamicCustomOp;

/**
* ArgMin function
*
* @author Alex Black
*/
public class ArgMin extends DynamicCustomOp {
@Override
public String opName() {
return "argmin";
}

@Override
public String tensorflowName() {
return "ArgMin";
}
}
Expand Up @@ -92,7 +92,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "ReduceMin";
return "Min";
}


Expand Down
Expand Up @@ -175,6 +175,8 @@ public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, A
sameDiff.addPropertyToResolve(this,arg.getVarName());
}
}
iArguments.clear();
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
}

@Override
Expand Down
Expand Up @@ -18,6 +18,7 @@

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;

Expand Down Expand Up @@ -80,7 +81,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "argmax";
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}

@Override
Expand Down
Expand Up @@ -18,6 +18,7 @@

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;

Expand Down Expand Up @@ -84,7 +85,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "argmin";
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}


Expand Down

0 comments on commit 1a48d74

Please sign in to comment.