Skip to content

Commit

Permalink
Make test_BERT_Squad test work
Browse files Browse the repository at this point in the history
Provide a workaround for keras_prelu_ImageNet_small
  • Loading branch information
yuslepukhin committed Apr 4, 2023
1 parent 9fb3811 commit 6587309
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, string nodeName, Nod
if (!((protoDt == metaElementType) ||
(protoDt == TensorElementType.UInt16 &&
(metaElementType == TensorElementType.BFloat16 || metaElementType == TensorElementType.Float16))))
throw new InvalidDataException($"Loaded tensor type: {protoDt} is expected to be equal to: {metaElementType}");
throw new InvalidDataException($"For node: '{nodeName}' metadata expects: '{metaElementType}' but loaded loaded tensor type: '{protoDt}'");

// Tensors within Sequences may have no dimensions as the standard allows
// different dimensions for each tensor element of the sequence
Expand Down Expand Up @@ -124,9 +124,10 @@ internal static NamedOnnxValue CreateNamedOnnxValueFromTensorRawData(string node
return CreateNamedOnnxValueFromRawData<Float16>(nodeName, rawData, sizeof(ushort), intDims);
case TensorElementType.BFloat16:
return CreateNamedOnnxValueFromRawData<BFloat16>(nodeName, rawData, sizeof(ushort), intDims);
case TensorElementType.String:
throw new ArgumentException("For string tensors of type use: CreateNamedOnnxValueFromStringTensor.");
default:
throw new InvalidDataException($"Tensors of type: " + elementType.ToString() +
" not currently supported here, use: CreateNamedOnnxValueFromStringTensor.");
throw new NotImplementedException($"Tensors of type: {elementType} not currently supported by this function");
}
}

Expand Down Expand Up @@ -175,7 +176,7 @@ internal static NamedOnnxValue LoadOnnxValueFromFilePb(string fullFilename, stri
return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta);
}
default:
throw new ArgumentException($"Unable to load value type {nodeMeta.OnnxValueType} not implemented");
throw new NotImplementedException($"Unable to load value type: {nodeMeta.OnnxValueType} not implemented");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ private void TestTensorRTProviderOptions()
// { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." },
{ "test_mnist", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile" },
{ "BERT_Squad", "Could not find an implementation for the nodeMeta bert / embeddings / one_hot:OneHot(9)" },
{ "test_BERT_Squad", "Test tensor data element type does not match metadata: Int64 is expected to be equal to: Float" },
{ "keras_prelu_ImageNet_small", "Unable to match file: input_1.pb to input/output metadata"},

{ "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" },
{ "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" },
{ "tf_resnet_v1_101", "result mismatch when Conv BN Fusion is applied" },
Expand All @@ -274,19 +273,11 @@ private void TestTensorRTProviderOptions()

{ "test_mul_uint8", "Could not find an implementation for Mul(14) node with name" },

{ "test_clip_default_int8_inbounds", "nodeMeta test error"},
{ "test_eyelike_with_dtype", "nodeMeta test error"},
{ "test_cast_STRING_to_FLOAT", "nodeMeta test error"},
{ "test_cast_FLOAT_to_DOUBLE", "nodeMeta test error"},
{ "test_cast_BFLOAT16_to_FLOAT", "nodeMeta test error"},
{ "test_cast_FLOAT_to_BFLOAT16", "nodeMeta test error"},
{ "test_cast_STRING_to_FLOAT", "Output mismatch"},
{ "test_cast_BFLOAT16_to_FLOAT", "Output mismatch"},
{ "test_cast_FLOAT_to_STRING", "Output strings can not be compared exactly"},
{ "test_castlike_STRING_to_FLOAT", "nodeMeta test error"},
{ "test_castlike_STRING_to_FLOAT_expanded", "nodeMeta test error"},
{ "test_castlike_FLOAT16_to_DOUBLE", "nodeMeta test error"},
{ "test_castlike_FLOAT16_to_DOUBLE_expanded", "nodeMeta test error"},
{ "test_castlike_FLOAT_to_DOUBLE", "nodeMeta test error"},
{ "test_castlike_FLOAT_to_DOUBLE_expanded", "nodeMeta test error"},
{ "test_castlike_STRING_to_FLOAT", "Output mismatch"},
{ "test_castlike_STRING_to_FLOAT_expanded", "Output mismatch"},
{ "test_castlike_BFLOAT16_to_FLOAT", "Length is expected to be equal to Count (metadata and expected data mismatch) "},
{ "test_castlike_BFLOAT16_to_FLOAT_expanded", "Length is expected to be equal to Count metadata and expected data mismatch"},
{ "test_castlike_FLOAT_to_BFLOAT16", "Length is expected to be equal to Count. Testdata dims length do not match that of model metadata"},
Expand Down Expand Up @@ -448,13 +439,14 @@ public static IEnumerable<object[]> GetSkippedModelForTest()
}
}

string MatchInputOutputWithFile(string fileName, InferenceSession session, bool input, out NodeMetadata result)
private string MatchInputOutputWithFile(string fileName, InferenceSession session, bool input, out NodeMetadata result)
{
string nodeName = string.Empty;
result = null;
var names = (input) ? session.InputNames : session.OutputNames;
var metadata = (input) ? session.InputMetadata : session.OutputMetadata;
string regEx = (input) ? @"input_(\d{1,}).pb" : @"output_(\d{1,}).pb";
var inpOut = (input) ? "input" : "output";

// Extract the number from the file name, if not try to match the input/output name with the name of the file.
try
Expand All @@ -469,7 +461,7 @@ string MatchInputOutputWithFile(string fileName, InferenceSession session, bool
}
else
{
throw new InvalidDataException($"Filename '{fileName}' input/output number '{num}' is out of range for '{names.Count}' inputs/outputs");
throw new InvalidDataException($"Filename '{fileName}' {inpOut} number '{num}' is out of range for '{names.Count}' {inpOut}(s)");
}
}
catch (Exception)
Expand All @@ -489,6 +481,73 @@ string MatchInputOutputWithFile(string fileName, InferenceSession session, bool
return nodeName;
}

// The numbering of the input files does not match the order of outputs
// listed in the metadata of test_BERT_Squad. Model metadata order:
// "unique_ids_raw_output___9:0", "segment_ids:0", "input_mask:0", "input_ids:0"
// The corr input files are: input_0.pb, input_3.pb, input_2.pb, input_1.pb
// Everything in reverse, but the 0.

// Previously, it worked because our test data has matching
// tensor names that we could match to metadata after we load the tensor.
// But now, we need to know ahead of time what Onnx type we load, and thus match
// metadata with the test data file before loading. Protobuf can happily load whatever
// and give you garbage.

private string MatchBertSquadInputs(string fileName)
{
string nodeName = string.Empty;
switch (fileName)
{
case "input_0.pb":
nodeName = "unique_ids_raw_output___9:0";
break;
case "input_1.pb":
nodeName = "input_ids:0";
break;
case "input_2.pb":
nodeName = "input_mask:0";
break;
case "input_3.pb":
nodeName = "segment_ids:0";
break;
default:
throw new InvalidDataException($"Unhandled input file name: '{fileName}' for test_BERT_Squad");
}
return nodeName;
}

// The model actually has only 3 outputs, but the Zoo version has 4 files are supplied.
// The numbering of the output files does not match the order of outputs
// listed in the metadata.

// Previously, it worked because our CI test data version has matching
// tensor names that we could match to metadata after we load the tensor.
// But now, we need to know ahead of time what Onnx type we load, and thus match
// metadata with the test data file before loading. Protobuf can happily load whatever
// and give you garbage.

// Order in the metadata: unstack:1, unstack:0, unique_ids:0
// The files are in reverse order
private string MatchBertSquadOutputs(string fileName)
{
string nodeName = string.Empty;
switch (fileName)
{
case "output_0.pb": // Int64
nodeName = "unique_ids:0";
break;
case "output_1.pb":
nodeName = "unstack:0";
break;
case "output_2.pb":
nodeName = "unstack:1";
break;
default:
throw new InvalidDataException($"Unhandled output file name: '{fileName}' for test_BERT_Squad");
}
return nodeName;
}

[Theory(DisplayName = "TestPreTrainedModels")]
[MemberData(nameof(GetModelsForTest))]
[MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")]
Expand Down Expand Up @@ -542,13 +601,48 @@ private void TestPreTrainedModels(string opsetDir, string modelName)
var outputContainer = new List<NamedOnnxValue>(outMeta.Count);
foreach (var f in testDataDir.EnumerateFiles("input_*.pb"))
{
var nodeName = MatchInputOutputWithFile(f.Name, session, true, out NodeMetadata nodeMeta);
inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
if (modelName == "keras_prelu_ImageNet_small" && opset == "opset9")
{
// The model has 1 input, match all file names (they are different in each data set)
// to the same input
var nodeName = "p_re_lu_3_input";
var nodeMeta = inMeta[nodeName];
inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
else if (modelName == "test_BERT_Squad" && opset == "opset8")
{
string nodeName = MatchBertSquadInputs(f.Name);
var nodeMeta = inMeta[nodeName];
inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
else
{
var nodeName = MatchInputOutputWithFile(f.Name, session, true, out NodeMetadata nodeMeta);
inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
}
foreach (var f in testDataDir.EnumerateFiles("output_*.pb"))
{
var nodeName = MatchInputOutputWithFile(f.Name, session, false, out NodeMetadata nodeMeta);
outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
if (modelName == "keras_prelu_ImageNet_small" && opset == "opset9")
{
// The model has 1 output, match all file names (they are different in each data set)
// to the same output
var nodeName = "p_re_lu_3/add:0";
var nodeMeta = outMeta[nodeName];
outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
else if (modelName == "test_BERT_Squad" && opset == "opset8")
{
string nodeName = MatchBertSquadOutputs(f.Name);
var nodeMeta = outMeta[nodeName];
outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
else
{
// Otherwise, just match trailing filename number to the input name -> metadata
var nodeName = MatchInputOutputWithFile(f.Name, session, false, out NodeMetadata nodeMeta);
outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta));
}
}

using (var resultCollection = session.Run(inputContainer))
Expand Down Expand Up @@ -576,7 +670,7 @@ private void TestPreTrainedModels(string opsetDir, string modelName)

Assert.Equal(outputValue.ValueType, outputMeta.OnnxValueType);

switch(outputValue.ValueType)
switch (outputValue.ValueType)
{
case OnnxValueType.ONNX_TYPE_TENSOR: // Only Dense tensors now
{
Expand Down

0 comments on commit 6587309

Please sign in to comment.