diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index d24d7e1c3f..4358fbaad9 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -447,6 +447,16 @@ } ] }, + { + "input": [ + "F3" + ], + "output": [ + "F3.output" + ], + "name": "Identity", + "opType": "Identity" + }, { "input": [ "PredictedLabel" @@ -454,7 +464,7 @@ "output": [ "PredictedLabel.output" ], - "name": "Identity", + "name": "Identity0", "opType": "Identity" }, { @@ -464,7 +474,7 @@ "output": [ "Score.output" ], - "name": "Identity0", + "name": "Identity1", "opType": "Identity" }, { @@ -474,7 +484,7 @@ "output": [ "Probability.output" ], - "name": "Identity1", + "name": "Identity2", "opType": "Identity" } ], @@ -531,9 +541,45 @@ } } } + }, + { + "name": "F3", + "type": { + "tensorType": { + "elemType": 8, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "5" + } + ] + } + } + } } ], "output": [ + { + "name": "F3.output", + "type": { + "tensorType": { + "elemType": 8, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "5" + } + ] + } + } + } + }, { "name": "PredictedLabel.output", "type": { @@ -806,6 +852,24 @@ } } }, + { + "name": "F3.output", + "type": { + "tensorType": { + "elemType": 8, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "5" + } + ] + } + } + } + }, { "name": "PredictedLabel.output", "type": { diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 5163dd59d5..aca9d09c00 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -9,7 +9,6 @@ using System.Runtime.InteropServices; using System.Text.RegularExpressions; using Google.Protobuf; -using Google.Protobuf.WellKnownTypes; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Model.OnnxConverter; @@ -125,6 +124,9 @@ private class BreastCancerCatFeatureExample [LoadColumn(2)] public string F2; + + [LoadColumn(3, 7), VectorType(6)] + public string[] F3; } private class BreastCancerMulticlassExample @@ -1162,6 +1164,37 @@ public void PcaOnnxConversionTest() Done(); } + [Fact] + public void OneHotHashEncodingOnnxConversionTest() + { + var mlContext = new MLContext(); + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath); + var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{ + new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false), + }); + var model = pipe.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "OneHotHashEncoding.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedColumns("Output", "Output", transformedData, onnxResult); + } + Done(); + } + [Theory] [CombinatorialData] // Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.