From f69e61819c6198f16d44cbe2ade7dbd045b46199 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Fri, 7 Feb 2020 18:15:32 -0800 Subject: [PATCH 1/2] - Cleaned up OnnxContext's initializer interface - Cleaned up column comparison functionality on OnnxConversionTest - Fixed bugs in OptionalColumnTransform's onnx export and added support for boolean initializers --- .../Model/Onnx/OnnxContext.cs | 101 ++++++- .../OnnxContextImpl.cs | 71 ++++- src/Microsoft.ML.OnnxConverter/OnnxUtils.cs | 60 +++- src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs | 6 +- .../OptionalColumnTransform.cs | 29 +- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 281 ++++++++---------- 6 files changed, 352 insertions(+), 196 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index a6542c56a8..b572bf9d15 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using Microsoft.ML.Data; @@ -130,7 +131,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract List RetrieveShapeOrNull(string variableName); /// - /// Call this function can declare a global float + /// Call this function to declare a global bool + /// + /// The boolean value which is going to be added + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(bool value, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global float /// /// The float number which is going to be added /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -139,16 +149,17 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global long + /// Call this function to declare a global float /// - /// The long number which is going to be added into the ONNX graph + /// The float number which is going to be added + /// The type of integer to be added, e.g. typeof(short). Use this for all integer types smaller than Int32 /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true); + public abstract string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global string + /// Call this function to declare a global string /// /// The string which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -157,43 +168,103 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global float tensor + /// Call this function to declare a global long + /// + /// The long number which is going to be added into the ONNX graph + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global double + /// + /// The double number which is going to be added into the ONNX graph + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(double value, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global ulong or uint + /// + /// The long number which is going to be added into the ONNX graph + /// true if value contains a ulong value and false if it contains uint + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global bool tensor + /// + /// The boolean values which are going to be added into the ONNX graph + /// The shape of values + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global float tensor /// /// The floats which are going to be added into the ONNX graph - /// The shape that the floats + /// The shape of values /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global long tensor + /// Call this function to declare a global long tensor + /// + /// The ints which are going to be added into the ONNX graph + /// The type of ints which are going to be added into the ONNX graph, e.g. typeof(short). Use this for adding array initializers of integer types smaller than Int32 + /// The shape of values + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(IEnumerable values, Type type, IEnumerable dims, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global string tensor + /// + /// The strings which are going to be added into the ONNX graph + /// The shape of values + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); + + /// + /// Call this function to declare a global long tensor /// /// The longs which are going to be added into the ONNX graph - /// The shape that the floats + /// The shape of values /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global double tensor + /// Call this function to declare a global double tensor /// /// The doubles which are going to be added into the ONNX graph - /// The shape that the doubles + /// The shape of values /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// - /// Call this function can declare a global string tensor + /// Call this function to declare a global double tensor /// - /// The strings which are going to be added into the ONNX graph - /// The shape that the strings + /// The unsigned integers which are going to be added into the ONNX graph + /// Set to true if values contain ulong values false if they contain uint values + /// The shape of values /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); + public abstract string AddInitializer(IEnumerable values, bool isUint64, IEnumerable dims, string name = null, bool makeUniqueName = true); } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs index 41e05a7053..98866529f4 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs @@ -279,6 +279,13 @@ public override List RetrieveShapeOrNull(string variableName) } /// Adds constant tensor into the graph. + public override string AddInitializer(bool value, string name = null, bool makeUniqueName = true) + { + name = AddVariable(name ?? "bool", makeUniqueName); + _initializers.Add(OnnxUtils.MakeInt32(name, typeof(bool), value ? 1 : 0)); + return name; + } + public override string AddInitializer(float value, string name = null, bool makeUniqueName = true) { name = AddVariable(name ?? "float", makeUniqueName); @@ -286,6 +293,13 @@ public override string AddInitializer(float value, string name = null, bool make return name; } + public override string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true) + { + name = AddVariable(name ?? "int32", makeUniqueName); + _initializers.Add(OnnxUtils.MakeInt32(name, type, value)); + return name; + } + public override string AddInitializer(string value, string name = null, bool makeUniqueName = true) { name = AddVariable(name ?? "string", makeUniqueName); @@ -300,6 +314,31 @@ public override string AddInitializer(long value, string name = null, bool makeU return name; } + public override string AddInitializer(double value, string name = null, bool makeUniqueName = true) + { + name = AddVariable(name ?? "double", makeUniqueName); + _initializers.Add(OnnxUtils.MakeDouble(name, value)); + return name; + } + + public override string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true) + { + name = AddVariable(name ?? "uint64", makeUniqueName); + _initializers.Add(OnnxUtils.MakeUInt(name, isUint64, value)); + return name; + } + + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) + { + _host.CheckValue(values, nameof(values)); + if (dims != null) + _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); + + name = AddVariable(name ?? "bools", makeUniqueName); + _initializers.Add(OnnxUtils.MakeInt32s(name, typeof(bool), values.Select(v => Convert.ToInt32(v)), dims)); + return name; + } + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); @@ -311,6 +350,28 @@ public override string AddInitializer(IEnumerable values, IEnumerable values, Type type, IEnumerable dims, string name = null, bool makeUniqueName = true) + { + _host.CheckValue(values, nameof(values)); + if (dims != null) + _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); + + name = AddVariable(name ?? "int32s", makeUniqueName); + _initializers.Add(OnnxUtils.MakeInt32s(name, type, values, dims)); + return name; + } + + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) + { + _host.CheckValue(values, nameof(values)); + if (dims != null) + _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); + + name = AddVariable(name ?? "strings", makeUniqueName); + _initializers.Add(OnnxUtils.MakeStrings(name, values, dims)); + return name; + } + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); @@ -328,19 +389,19 @@ public override string AddInitializer(IEnumerable values, IEnumerable x * y) == values.Count(), "Number of elements doesn't match tensor size"); - name = AddVariable(name ?? "double", makeUniqueName); - _initializers.Add(OnnxUtils.MakeDouble(name, values, dims)); + name = AddVariable(name ?? "doubles", makeUniqueName); + _initializers.Add(OnnxUtils.MakeDoubles(name, values, dims)); return name; } - public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) + public override string AddInitializer(IEnumerable values, bool isUint64, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); if (dims != null) _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); - name = AddVariable(name ?? "strings", makeUniqueName); - _initializers.Add(OnnxUtils.MakeStrings(name, values, dims)); + name = AddVariable(name ?? "uints", makeUniqueName); + _initializers.Add(OnnxUtils.MakeUInts(name, isUint64, values, dims)); return name; } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index d6b21c2005..9eb993b084 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -410,8 +410,66 @@ public static TensorProto MakeInt64s(string name, IEnumerable values, IEnu return tensor; } + // Make int32 and smaller integer types scalar in ONNX from native C# number + public static TensorProto MakeInt32(string name, Type type, int value) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)ConvertToTensorProtoType(type); + tensor.Int32Data.Add(value); + return tensor; + } + + // Make int32 and smaller integer types vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor. + public static TensorProto MakeInt32s(string name, Type type, IEnumerable values, IEnumerable dims = null) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)ConvertToTensorProtoType(type); + tensor.Int32Data.AddRange(values); + if (dims != null) + tensor.Dims.AddRange(dims); + else + tensor.Dims.Add(values.Count()); + return tensor; + } + + // Make ulong and uint integer types scalar in ONNX from native C# number + public static TensorProto MakeUInt(string name, bool isUint64, ulong value) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint)); + tensor.Uint64Data.Add(value); + return tensor; + } + + // Make ulong and uint integer vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor. + public static TensorProto MakeUInts(string name, bool isUint64, IEnumerable values, IEnumerable dims = null) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint)); + tensor.Uint64Data.AddRange(values); + if (dims != null) + tensor.Dims.AddRange(dims); + else + tensor.Dims.Add(values.Count()); + return tensor; + } + + // Make int32 and smaller integer types scalar in ONNX from native C# number + public static TensorProto MakeDouble(string name, double value) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)TensorProto.Types.DataType.Double; + tensor.DoubleData.Add(value); + return tensor; + } + // Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor. - public static TensorProto MakeDouble(string name, IEnumerable values, IEnumerable dims = null) + public static TensorProto MakeDoubles(string name, IEnumerable values, IEnumerable dims = null) { var tensor = new TensorProto(); tensor.Name = name; diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index ace64905f9..525a8c46fc 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -393,7 +393,9 @@ internal sealed class OnnxUtils typeof(UInt32), typeof(UInt64), typeof(ReadOnlyMemory), - typeof(Boolean) + typeof(Boolean), + typeof(SByte), + typeof(Byte) }; private static Dictionary _typeToKindMap= new Dictionary @@ -408,6 +410,8 @@ internal sealed class OnnxUtils { typeof(UInt64) , InternalDataKind.U8}, { typeof(String) , InternalDataKind.TX}, { typeof(Boolean) , InternalDataKind.BL}, + { typeof(SByte) , InternalDataKind.I1}, + { typeof(Byte) , InternalDataKind.U1}, }; /// diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 18909da5ef..5de451ba7d 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -534,19 +534,28 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, else size = 1; - // REVIEW: - // AddInitializer only supports long, float, double and string. - // Here we are casting ulong to long. Fixing this would involve - // adding additional functions to OnnxContext. - if (type == typeof(float)) + if ((type == typeof(int)) || + (type == typeof(short)) || (type == typeof(ushort)) || + (type == typeof(sbyte)) || (type == typeof(byte))) + ctx.AddInitializer(new int[size], type, new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(uint) || (type == typeof(ulong))) + ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(bool)) + ctx.AddInitializer(new bool[size], new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(long)) + ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(float)) ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false); else if (type == typeof(double)) ctx.AddInitializer(new double[size], new long[] { 1, size }, inputColumnName, false); - else if ((type == typeof(long)) || (type == typeof(int)) || (type == typeof(short)) || (type == typeof(sbyte)) || - (type == typeof(ulong)) || (type == typeof(uint)) || (type == typeof(ushort)) || (type == typeof(byte))) - ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false); - else if (type == typeof(string)) - ctx.AddInitializer(new string[size], new long[] { 1, size }, inputColumnName, false); + else if ((type == typeof(string)) || (columnType is TextDataViewType)) + { + string[] values = new string[size]; + for (int i = 0; i < size; i++) + values[i] = ""; + + ctx.AddInitializer(values, new long[] { 1, size }, inputColumnName, false); + } else return false; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 75f5b52466..40b8edac2b 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -95,7 +95,7 @@ public void SimpleEndToEndOnnxConversionTest() var onnxResult = onnxTransformer.Transform(data); // Step 4: Compare ONNX and ML.NET results. - CompareSelectedR4ScalarColumns("Score", "Score.onnx", transformedData, onnxResult, 1); + CompareSelectedColumns("Score", "Score.onnx", transformedData, onnxResult, 1); } // Step 5: Check ONNX model's text format. This test will be not necessary if Step 3 and Step 4 can run on Linux and @@ -186,7 +186,7 @@ public void KmeansOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); - CompareSelectedR4VectorColumns("Score", "Score.onnx", transformedData, onnxResult, 3); + CompareSelectedColumns("Score", "Score.onnx", transformedData, onnxResult, 3); } // Check ONNX model's text format. We save the produced ONNX model as a text file and compare it against @@ -241,7 +241,7 @@ public void RegressionTrainersOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult, 3); // compare score results + CompareSelectedColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult, 3); // compare score results } // Compare the Onnx graph to a baseline if OnnxRuntime is not supported else @@ -302,8 +302,8 @@ public void BinaryClassificationTrainersOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores - CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels + CompareSelectedColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores + CompareSelectedColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels } } Done(); @@ -337,8 +337,8 @@ public void TestVectorWhiteningOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4VectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); // whitened1 - CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); // whitened2 + CompareSelectedColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); // whitened1 + CompareSelectedColumns(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); // whitened2 } Done(); } @@ -393,9 +393,9 @@ public void PlattCalibratorOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores - CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels - CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities + CompareSelectedColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores + CompareSelectedColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels + CompareSelectedColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities } } Done(); @@ -443,7 +443,7 @@ public void PlattCalibratorOnnxConversionTest2() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); - CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities + CompareSelectedColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities } Done(); } @@ -492,7 +492,7 @@ public void LpNormOnnxConversionTest( var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4VectorColumns(nameof(DataPoint.Features), outputNames[0], transformedData, onnxResult, 3); + CompareSelectedColumns(nameof(DataPoint.Features), outputNames[0], transformedData, onnxResult, 3); } Done(); @@ -947,7 +947,7 @@ public void TokenizingByCharactersOnnxConversionTest(bool useMarkerCharacters) var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedVectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare scores + CompareSelectedColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare scores } Done(); } @@ -956,8 +956,18 @@ public void TokenizingByCharactersOnnxConversionTest(bool useMarkerCharacters) // These are the supported conversions // ML.NET does not allow any conversions between signed and unsigned numeric types // Onnx does not seem to support casting a string to any type - // Though the onnx docs claim support for byte and sbyte, - // CreateNamedOnnxValue in OnnxUtils.cs throws a NotImplementedException for those two + [InlineData(DataKind.SByte, DataKind.SByte)] + [InlineData(DataKind.SByte, DataKind.Int16)] + [InlineData(DataKind.SByte, DataKind.Int32)] + [InlineData(DataKind.SByte, DataKind.Int64)] + [InlineData(DataKind.SByte, DataKind.Single)] + [InlineData(DataKind.SByte, DataKind.Double)] + [InlineData(DataKind.Byte, DataKind.Byte)] + [InlineData(DataKind.Byte, DataKind.UInt16)] + [InlineData(DataKind.Byte, DataKind.UInt32)] + [InlineData(DataKind.Byte, DataKind.UInt64)] + [InlineData(DataKind.Byte, DataKind.Single)] + [InlineData(DataKind.Byte, DataKind.Double)] [InlineData(DataKind.Int16, DataKind.Int16)] [InlineData(DataKind.Int16, DataKind.Int32)] [InlineData(DataKind.Int16, DataKind.Int64)] @@ -1053,7 +1063,7 @@ public void PcaOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4VectorColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); + CompareSelectedColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); } } @@ -1114,7 +1124,7 @@ public void IndicateMissingValuesOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedVectorColumns(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult); + CompareSelectedColumns(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult); } CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle); @@ -1156,7 +1166,7 @@ public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType) var onnxResult = onnxTransformer.Transform(dataView); CompareResults(mlnetResult.Schema[2].Name, outputNames[2], mlnetResult, onnxResult); //compare output values - CompareSelectedVectorColumns(mlnetResult.Schema[1].Name, outputNames[1], mlnetResult, onnxResult); //compare keys + CompareSelectedColumns(mlnetResult.Schema[1].Name, outputNames[1], mlnetResult, onnxResult); //compare keys } Done(); } @@ -1197,7 +1207,7 @@ public void WordTokenizerOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxFilePath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedVectorColumns>(transformedData.Schema[1].Name, outputNames[1], transformedData, onnxResult); + CompareSelectedColumns>(transformedData.Schema[1].Name, outputNames[1], transformedData, onnxResult); } Done(); @@ -1261,25 +1271,38 @@ public void NgramOnnxConversionTest( var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxFilePath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4VectorColumns(transformedData.Schema[transformedData.Schema.Count-1].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3); //comparing Ngrams + CompareSelectedColumns(transformedData.Schema[transformedData.Schema.Count-1].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3); //comparing Ngrams } } Done(); } - [Fact] - public void OptionalColumnOnnxTest() + [Theory] + [InlineData(DataKind.Boolean)] + [InlineData(DataKind.SByte)] + [InlineData(DataKind.Byte)] + [InlineData(DataKind.Int16)] + [InlineData(DataKind.UInt16)] + [InlineData(DataKind.Int32)] + [InlineData(DataKind.UInt32)] + [InlineData(DataKind.Int64)] + [InlineData(DataKind.UInt64)] + [InlineData(DataKind.Single)] + [InlineData(DataKind.Double)] + [InlineData(DataKind.String)] + public void OptionalColumnOnnxTest(DataKind dataKind) { var mlContext = new MLContext(seed: 1); - var samples = new List() - { - new BreastCancerCatFeatureExample() { Label = false, F1 = 0.0f, F2 = "F2"}, - new BreastCancerCatFeatureExample() { Label = true, F1 = 0.1f, F2 = "F2"}, - }; + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", dataKind, 0), + new TextLoader.Column("Thickness", DataKind.Single, 1), + }); + IHostEnvironment env = mlContext as IHostEnvironment; - var dataView = mlContext.Data.LoadFromEnumerable(samples); - var args = new OptionalColumnTransform.Arguments { Columns = new[] { "F1" }, Data = dataView }; + var args = new OptionalColumnTransform.Arguments { Columns = new[] { "Label" }, Data = dataView }; var transform = OptionalColumnTransform.MakeOptional(env, args); var ctx = new OnnxContextImpl(mlContext, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable); @@ -1292,7 +1315,7 @@ public void OptionalColumnOnnxTest() onnxModel = SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null); } - var onnxFileName = "optionalcol.onnx"; + var onnxFileName = $"optionalcol-{dataKind}.onnx"; var onnxModelPath = GetOutputPath(onnxFileName); var onnxTextFileName = "optionalcol.txt"; var onnxTextPath = GetOutputPath(onnxTextFileName); @@ -1305,7 +1328,7 @@ public void OptionalColumnOnnxTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(transform.Model.OutputSchema[2].Name, outputNames[1], outputData, onnxResult); + CompareResults("Label", "Label.onnx", outputData, onnxResult); } Done(); } @@ -1340,7 +1363,7 @@ public void KeyToValueOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedScalarColumns>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); + CompareSelectedColumns>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); } Done(); @@ -1410,8 +1433,8 @@ public void MulticlassTrainersOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedScalarColumns(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels - CompareSelectedR4VectorColumns(transformedData.Schema[6].Name, outputNames[3], transformedData, onnxResult, 4); //compare scores + CompareSelectedColumns(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels + CompareSelectedColumns(transformedData.Schema[6].Name, outputNames[3], transformedData, onnxResult, 4); //compare scores } } Done(); @@ -1445,7 +1468,7 @@ public void CopyColumnsOnnxTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); + CompareSelectedColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); } Done(); } @@ -1493,10 +1516,10 @@ public void FeatureSelectionOnnxTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat.onnx", transformedData, onnxResult); - CompareSelectedR4VectorColumns("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat.onnx", transformedData, onnxResult); - CompareSelectedR4ScalarColumns("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing690.onnx", transformedData, onnxResult); - CompareSelectedR8VectorColumns("VecFeatureSelectMissing690", "VecFeatureSelectMissing690.onnx", transformedData, onnxResult); + CompareSelectedColumns("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat.onnx", transformedData, onnxResult); + CompareSelectedColumns("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat.onnx", transformedData, onnxResult); + CompareSelectedColumns("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing690.onnx", transformedData, onnxResult); + CompareSelectedColumns("VecFeatureSelectMissing690", "VecFeatureSelectMissing690.onnx", transformedData, onnxResult); } Done(); } @@ -1549,10 +1572,10 @@ public void SelectColumnsOnnxTest() Assert.Equal("Thickness.onnx", outputNames[2]); Assert.Equal("Label.onnx", outputNames[3]); - CompareSelectedScalarColumns("Size", "Size.onnx", transformedData, onnxResult); - CompareSelectedScalarColumns("Shape", "Shape.onnx", transformedData, onnxResult); - CompareSelectedScalarColumns("Thickness", "Thickness.onnx", transformedData, onnxResult); - CompareSelectedScalarColumns("Label", "Label.onnx", transformedData, onnxResult); + CompareSelectedColumns("Size", "Size.onnx", transformedData, onnxResult); + CompareSelectedColumns("Shape", "Shape.onnx", transformedData, onnxResult); + CompareSelectedColumns("Thickness", "Thickness.onnx", transformedData, onnxResult); + CompareSelectedColumns("Label", "Label.onnx", transformedData, onnxResult); } onnxFileName = "SelectColumns.txt"; @@ -1564,7 +1587,7 @@ public void SelectColumnsOnnxTest() Done(); } - private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName]; var rightColumn = right.Schema[rightColumnName]; @@ -1573,57 +1596,33 @@ private void CompareResults(string leftColumnName, string rightColumnName, IData Assert.Equal(leftType, rightType); if (leftType == NumberDataViewType.SByte) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Byte) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Int16) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.UInt16) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Int32) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.UInt32) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Int64) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.UInt64) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Single) - CompareSelectedR4VectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right, precision); else if (leftType == NumberDataViewType.Double) - CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + CompareSelectedColumns(leftColumnName, rightColumnName, left, right, precision); + else if (leftType == BooleanDataViewType.Instance) + CompareSelectedColumns(leftColumnName, rightColumnName, left, right); else if (leftType == TextDataViewType.Instance) - CompareSelectedVectorColumns>(leftColumnName, rightColumnName, left, right); - } - - private void CompareSelectedVectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) - { - var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; - - using (var expectedCursor = left.GetRowCursor(leftColumn)) - using (var actualCursor = right.GetRowCursor(rightColumn)) - { - VBuffer expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter>(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); - while (expectedCursor.MoveNext() && actualCursor.MoveNext()) - { - expectedGetter(ref expected); - actualGetter(ref actual); + CompareSelectedColumns>(leftColumnName, rightColumnName, left, right); - Assert.Equal(expected.Length, actual.Length); - for (int i = 0; i < expected.Length; ++i) - if (typeof(T) == typeof(ReadOnlyMemory)) - Assert.Equal(expected.GetItemOrDefault(i).ToString(), actual.GetItemOrDefault(i).ToString()); - else - Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i)); - } - } } - private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) + private void CompareSelectedColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName]; var rightColumn = right.Schema[rightColumnName]; @@ -1631,103 +1630,57 @@ private void CompareSelectedR8VectorColumns(string leftColumnName, string rightC using (var expectedCursor = left.GetRowCursor(leftColumn)) using (var actualCursor = right.GetRowCursor(rightColumn)) { - VBuffer expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter>(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); - while (expectedCursor.MoveNext() && actualCursor.MoveNext()) - { - expectedGetter(ref expected); - actualGetter(ref actual); + T expectedScalar = default; + VBuffer expectedVector = default; - Assert.Equal(expected.Length, actual.Length); - for (int i = 0; i < expected.Length; ++i) - Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision); - } - } - } + ValueGetter expectedScalarGetter = default; + ValueGetter> expectedVectorGetter = default; - private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) - { - var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; + VBuffer actual = default; - using (var expectedCursor = left.GetRowCursor(leftColumn)) - using (var actualCursor = right.GetRowCursor(rightColumn)) - { - VBuffer expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter>(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); + if (leftColumn.Type is VectorDataViewType) + expectedVectorGetter = expectedCursor.GetGetter>(leftColumn); + else + expectedScalarGetter = expectedCursor.GetGetter(leftColumn); + + var actualGetter = actualCursor.GetGetter>(rightColumn); while (expectedCursor.MoveNext() && actualCursor.MoveNext()) { - expectedGetter(ref expected); actualGetter(ref actual); - Assert.Equal(expected.Length, actual.Length); - for (int i = 0; i < expected.Length; ++i) + if (leftColumn.Type is VectorDataViewType) { - // We are using float values. But the Assert.Equal function takes doubles. - // And sometimes the converted doubles are different in their precision. - // So make sure we compare floats - float exp = expected.GetItemOrDefault(i); - float act = actual.GetItemOrDefault(i); - CompareNumbersWithTolerance(exp, act, null, precision); - } - } - } - } + expectedVectorGetter(ref expectedVector); + Assert.Equal(expectedVector.Length, actual.Length); - private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) - { - var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; - - using (var expectedCursor = left.GetRowCursor(leftColumn)) - using (var actualCursor = right.GetRowCursor(rightColumn)) - { - float expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); - while (expectedCursor.MoveNext() && actualCursor.MoveNext()) - { - expectedGetter(ref expected); - actualGetter(ref actual); + for (int i = 0; i < expectedVector.Length; ++i) + CompareScalarValues(expectedVector.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision); + } + else + { + expectedScalarGetter(ref expectedScalar); + Assert.Equal(1, actual.Length); - // Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction. - Assert.Equal(1, actual.Length); - CompareNumbersWithTolerance(expected, actual.GetItemOrDefault(0), null, precision); + var actualVal = actual.GetItemOrDefault(0); + CompareScalarValues(expectedScalar, actualVal, precision); + } } } } - private void CompareSelectedScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + private void CompareScalarValues(T expected, T actual, int precision) { - var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; - - using (var expectedCursor = left.GetRowCursor(leftColumn)) - using (var actualCursor = right.GetRowCursor(rightColumn)) - { - T expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); - while (expectedCursor.MoveNext() && actualCursor.MoveNext()) - { - expectedGetter(ref expected); - actualGetter(ref actual); - var actualVal = actual.GetItemOrDefault(0); - - Assert.Equal(1, actual.Length); - - if (typeof(T) == typeof(ReadOnlyMemory)) - Assert.Equal(expected.ToString(), actualVal.ToString()); - else - Assert.Equal(expected, actualVal); - } - } + if (typeof(T) == typeof(ReadOnlyMemory)) + Assert.Equal(expected.ToString(), actual.ToString()); + else if (typeof(T) == typeof(double)) + Assert.Equal(Convert.ToDouble(expected), Convert.ToDouble(actual), precision); + else if (typeof(T) == typeof(float)) + // We are using float values. But the Assert.Equal function takes doubles. + // And sometimes the converted doubles are different in their precision. + // So make sure we compare floats + CompareNumbersWithTolerance(Convert.ToSingle(expected), Convert.ToSingle(actual), null, precision); + else + Assert.Equal(expected, actual); } private void SaveOnnxModel(ModelProto model, string binaryFormatPath, string textFormatPath) From 81f62ba532dad175e6d7fb3a80ba30e9c4f91b24 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Mon, 10 Feb 2020 11:50:08 -0800 Subject: [PATCH 2/2] Fixed doc issues pointed out by code review --- .../Model/Onnx/OnnxContext.cs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index b572bf9d15..913db98559 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -131,7 +131,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract List RetrieveShapeOrNull(string variableName); /// - /// Call this function to declare a global bool + /// Call this function to declare a global bool scalar /// /// The boolean value which is going to be added /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -140,7 +140,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(bool value, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global float + /// Call this function to declare a global float scalar /// /// The float number which is going to be added /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -149,17 +149,17 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global float + /// Call this function to declare a global integer scalar or smaller types /// /// The float number which is going to be added - /// The type of integer to be added, e.g. typeof(short). Use this for all integer types smaller than Int32 + /// The type of integer to be added, e.g. typeof(short). Use this for all integer types Int32 and smaller /// A string used as a seed to generate this initializer's name in the ONNX graph. /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name public abstract string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global string + /// Call this function to declare a global string scalar /// /// The string which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -168,7 +168,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global long + /// Call this function to declare a global long scalar /// /// The long number which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -177,7 +177,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global double + /// Call this function to declare a global double scalar /// /// The double number which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. @@ -186,7 +186,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(double value, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global ulong or uint + /// Call this function to declare a global ulong or uint scalar /// /// The long number which is going to be added into the ONNX graph /// true if value contains a ulong value and false if it contains uint @@ -216,7 +216,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global long tensor + /// Call this function to declare a global tensor of integer or smaller types /// /// The ints which are going to be added into the ONNX graph /// The type of ints which are going to be added into the ONNX graph, e.g. typeof(short). Use this for adding array initializers of integer types smaller than Int32 @@ -257,7 +257,7 @@ public OnnxNode CreateNode(string opType, string input, string output, string na public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// - /// Call this function to declare a global double tensor + /// Call this function to declare a global ulong tensor /// /// The unsigned integers which are going to be added into the ONNX graph /// Set to true if values contain ulong values false if they contain uint values