Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import test improvements + more import mapping #6145

Merged
merged 19 commits into from
Aug 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,16 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
int tfOpsWithImportTests = 0;
if(logUntestedTFImport)
log.info(" --- Ops with TF Mapping but No TF Import Tests ---");
for(Map.Entry<String, DifferentialFunction> e : tfOpsMap.entrySet()){
String s = e.getKey();
List<String> tfOpsKeys = new ArrayList<>(tfOpsMap.keySet());
Collections.sort(tfOpsKeys);
Set<String> tfIgnored = excludeFromTfImportCoverage();
int tfImportIgnored = 0;
for(String s : tfOpsKeys){
Integer count = tfMappedOpsImportTestCounts.get(s);
if(count == null || count == 0){
if(logUntestedTFImport)
if(tfIgnored.contains(s)){
tfImportIgnored++;
} else if(logUntestedTFImport)
log.info("TF mapped op with no import tests: {}", s);
} else {
tfOpsWithImportTests++;
Expand Down Expand Up @@ -567,7 +572,7 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
int countLibnd4jMapped = countTotalLibnd4jOps - nonMappedLibnd4jOps.size();
String fracLibnd4j = String.format("%.2f", 100.0 * (countLibnd4jMapped / (double)countTotalLibnd4jOps));

String fracTFMappedTested = String.format("%.2f", 100.0 * tfOpsWithImportTests / (double)totalTFMappedOps);
String fracTFMappedTested = String.format("%.2f", 100.0 * tfOpsWithImportTests / (double)(totalTFMappedOps-tfImportIgnored));

log.info("*****************************************************");
log.info("Op Validation: {} of {} classes with adequate tests ({}% coverage)", countAdequate, totalFwd, pc);
Expand All @@ -576,7 +581,7 @@ public static void logCoverageInformation(boolean logAdequatelyTested, boolean l
log.info("({} ops excluded from gradient check coverage)", excludedFromBackpropCoverage.size());
log.info("({} ops excluded from fwd+gradient tests)", excludedFromAllTestCoverage.size());
log.info("TF mapped ops: {} of {} ({}%)", countTfMapped, countTf, fracTfStr);
log.info("SD ops with TF import mapping + test {} of {} ({}%)", tfOpsWithImportTests, totalTFMappedOps, fracTFMappedTested);
log.info("SD ops with TF import mapping + test {} of {} ({}%) - {} ignored for coverage", tfOpsWithImportTests, (totalTFMappedOps-tfImportIgnored), fracTFMappedTested, tfImportIgnored);
log.info("Libnd4j mapped ops: {} of {} ({}%)", countLibnd4jMapped, countTotalLibnd4jOps, fracLibnd4j);
log.info("*****************************************************");
}
Expand Down Expand Up @@ -764,4 +769,35 @@ private static Set<Class> excludedFromGradientCheckCoverage() {
return new HashSet<>(list);
}

/**
* These ops are excluded from TF import test coverage, for various reasons
*/
private static Set<String> excludeFromTfImportCoverage(){
List<String> list = Arrays.asList(
"Reverse", //Can be excluded because "Reverse_v2" is synonym that TF uses with tf.reverse(...); ReverseV2 is also Java op that is synonym for same op
"LogSigmoid", //Not in ops.proto. Have tests for tf.log_sigmoid, but can't test LogSigmoid op directly: tf.log_sigmoid actually just uses "y = -tf.nn.softplus(-x)" - i.e., 3 separate ops :/
"HardSigmoid", //Also implemented as python, NOT a single native op
"SpaceToBatch", //Old name - SpaceToBatchNd is used in practice (inc. for tf.space_to_batch)
"BatchToSpace", //Old name - BatchToSpaceNd is used in practice

//All of the following ops - not available in TF (can't find them) - op mapping is wrong?
//TODO: Check these and remove the import mapping from the Java classes if they are indeed bad
"HardTanh",
"Swish",
"RDiv",
"DivScalar",
"LogX",
"RationalTanh",
"absargmax",
"absargmin",
"entropy_shannon", //This is a thing, but quite different from our op: https://www.tensorflow.org/versions/r1.2/api_docs/python/tf/contrib/bayesflow/entropy/entropy_shannon
"count_zero"



);

return new HashSet<>(list);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ public SameDiff importGraph(GRAPH_TYPE tfGraph) {
//map the names of the nodes while accumulating the vertex ids
//for each variable
for(Map.Entry<String,TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if(dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
DataBuffer.Type dt = dataTypeForTensor(entry.getValue());
if(dt == DataBuffer.Type.UNKNOWN && !unknownTypeNodeImportable(entry.getValue())) {
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('c'));
//mark as place holder for validating resolution later.
if(isPlaceHolder(entry.getValue())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ public interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE> {
*/
DataBuffer.Type dataTypeForTensor(TENSOR_TYPE tensorType);

/**
* If {@link #dataTypeForTensor(Object)} return UNKNOWN we *might* still be able
* to import it. This method will return true if it is importable in spite of unknown type
* @param tensor
* @return
*/
boolean unknownTypeNodeImportable(TENSOR_TYPE tensor);

/**
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ public DataBuffer.Type dataTypeForTensor( onnx.OnnxProto3.TypeProto.Tensor tenso
return nd4jTypeFromOnnxType(tensorProto.getElemType());
}

@Override
public boolean unknownTypeNodeImportable(OnnxProto3.TypeProto.Tensor tensor) {
return false;
}


/**
* Convert an onnx type to the proper nd4j type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ public String getNodeName(String name) {
@Override
public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) {
Map<String,NodeDef> ret = new LinkedHashMap<>();
for(NodeDef nodeDef : graphDef.getNodeList()) {
List<NodeDef> nodeList = graphDef.getNodeList();
for(NodeDef nodeDef : nodeList) {
if(nodeDef.getName().endsWith("/read")) {
continue;
}
Expand Down Expand Up @@ -768,6 +769,19 @@ public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto) {
}
}

@Override
public boolean unknownTypeNodeImportable(NodeDef tensorProto) {
DataType dt = null;
if(tensorProto.containsAttr("dtype")){
dt = tensorProto.getAttrOrThrow("dtype").getType();
} else if(tensorProto.containsAttr("T")){
dt = tensorProto.getAttrOrThrow("T").getType();
} else if(tensorProto.containsAttr("Tidx")){
dt = tensorProto.getAttrOrThrow("Tidx").getType();
}

return dt == DataType.DT_BOOL;
}


@Override
Expand Down Expand Up @@ -961,22 +975,22 @@ public INDArray mapTensorProto(TensorProto tfTensor) {
} else if (tfTensor.getDtype() == DataType.DT_INT64) {
if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getDoubleValCount() < 1)
if (tfTensor.getDoubleValCount() < 1)
return Nd4j.trueScalar(0.0);

double val = (double) tfTensor.getInt64Val(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getInt64ValCount() > 0) {
} else if (tfTensor.getInt64ValCount() > 0) {
double[] jArray = new double[tfTensor.getInt64ValCount()];
for (int e = 0; e < tfTensor.getInt64ValCount(); e++) {
jArray[e] = (double) tfTensor.getInt64Val(e);
jArray[e] = (double) tfTensor.getInt64Val(e);
}

// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0){
} else if (tfTensor.getTensorContent().size() > 0) {
//throw new UnsupportedOperationException("To be implemented yet");
//Mapping INT bytebuffers should be converted to floating point
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
Expand All @@ -997,6 +1011,26 @@ public INDArray mapTensorProto(TensorProto tfTensor) {
//log.debug("Data: {}", Arrays.toString(array.data().asFloat()));
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_BOOL){
if (tfTensor.getBoolValCount() == 1 || ArrayUtil.prod(arrayShape) == 1){
//straight zero case
if(tfTensor.getBoolValCount() < 1)
return Nd4j.trueScalar(0.0);

boolean val = tfTensor.getBoolVal(0);
return Nd4j.trueScalar(val ? 1.0 : 0.0);
} else if (tfTensor.getBoolValCount() > 0) {
float[] jArray = new float[tfTensor.getBoolValCount()];
for (int e = 0; e < tfTensor.getBoolValCount(); e++) {
jArray[e] = tfTensor.getBoolVal(e) ? 1.0f : 0.0f;
}

// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
throw new UnsupportedOperationException("Not yet implemented for DataType.DT_BOOL");
}
} else {
throw new UnsupportedOperationException("Unknown dataType found: [" + tfTensor.getDtype() + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
Expand Down Expand Up @@ -559,29 +560,32 @@ public void profilingHookOut(CustomOp op, long timeStart) {
* @param op
*/
public static void validateDataType(DataBuffer.Type expectedType, Op op) {
if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.x() != null && !Shape.isEmpty(op.x().shapeInfoJava()) && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.x());
}

if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.y() != null && !Shape.isEmpty(op.y().shapeInfoJava()) && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.y());
}

if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
if (op.z() != null && !Shape.isEmpty(op.z().shapeInfoJava()) && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
Nd4j.getCompressor().decompressi(op.z());
}

if (op.x() != null && op.x().data().dataType() != expectedType
&& op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
if (op.x() != null && !Shape.isEmpty(op.x().shapeInfoJava())
&& op.x().data().dataType() != expectedType
&& op.x().data().dataType() != DataBuffer.Type.COMPRESSED)
throw new ND4JIllegalStateException("op.X dataType is [" + op.x().data().dataType()
+ "] instead of expected [" + expectedType + "]");

if (op.z() != null && op.z().data().dataType() != expectedType
if (op.z() != null && !Shape.isEmpty(op.z().shapeInfoJava())
&& op.z().data().dataType() != expectedType
&& op.z().data().dataType() != DataBuffer.Type.COMPRESSED)
throw new ND4JIllegalStateException("op.Z dataType is [" + op.z().data().dataType()
+ "] instead of expected [" + expectedType + "]");

if (op.y() != null && op.y().data().dataType() != expectedType)
if (op.y() != null && !Shape.isEmpty(op.y().shapeInfoJava())
&& op.y().data().dataType() != expectedType)
throw new ND4JIllegalStateException("op.Y dataType is [" + op.y().data().dataType()
+ "] instead of expected [" + expectedType + "]");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.linalg.api.ops.impl.accum;

import org.nd4j.linalg.api.ops.DynamicCustomOp;

/**
* ArgMin function
*
* @author Alex Black
*/
public class ArgMin extends DynamicCustomOp {
@Override
public String opName() {
return "argmin";
}

@Override
public String tensorflowName() {
return "ArgMin";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "ReduceMin";
return "Min";
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, A
sameDiff.addPropertyToResolve(this,arg.getVarName());
}
}
iArguments.clear();
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;

Expand Down Expand Up @@ -80,7 +81,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "argmax";
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;

Expand Down Expand Up @@ -84,7 +85,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "argmin";
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}


Expand Down
Loading