|
9 | 9 | using System.Runtime.InteropServices;
|
10 | 10 | using System.Text.RegularExpressions;
|
11 | 11 | using Google.Protobuf;
|
12 |
| -using Google.Protobuf.WellKnownTypes; |
13 | 12 | using Microsoft.ML.Data;
|
14 | 13 | using Microsoft.ML.EntryPoints;
|
15 | 14 | using Microsoft.ML.Model.OnnxConverter;
|
@@ -125,6 +124,9 @@ private class BreastCancerCatFeatureExample
|
125 | 124 |
|
126 | 125 | [LoadColumn(2)]
|
127 | 126 | public string F2;
|
| 127 | + |
| 128 | + [LoadColumn(3, 7), VectorType(6)] |
| 129 | + public string[] F3; |
128 | 130 | }
|
129 | 131 |
|
130 | 132 | private class BreastCancerMulticlassExample
|
@@ -1162,6 +1164,37 @@ public void PcaOnnxConversionTest()
|
1162 | 1164 | Done();
|
1163 | 1165 | }
|
1164 | 1166 |
|
| 1167 | + [Fact] |
| 1168 | + public void OneHotHashEncodingOnnxConversionTest() |
| 1169 | + { |
| 1170 | + var mlContext = new MLContext(); |
| 1171 | + string dataPath = GetDataPath("breast-cancer.txt"); |
| 1172 | + |
| 1173 | + var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath); |
| 1174 | + var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{ |
| 1175 | + new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false), |
| 1176 | + }); |
| 1177 | + var model = pipe.Fit(dataView); |
| 1178 | + var transformedData = model.Transform(dataView); |
| 1179 | + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); |
| 1180 | + |
| 1181 | + var onnxFileName = "OneHotHashEncoding.onnx"; |
| 1182 | + var onnxModelPath = GetOutputPath(onnxFileName); |
| 1183 | + SaveOnnxModel(onnxModel, onnxModelPath, null); |
| 1184 | + |
| 1185 | + if (IsOnnxRuntimeSupported()) |
| 1186 | + { |
| 1187 | + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. |
| 1188 | + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 1189 | + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 1190 | + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); |
| 1191 | + var onnxTransformer = onnxEstimator.Fit(dataView); |
| 1192 | + var onnxResult = onnxTransformer.Transform(dataView); |
| 1193 | + CompareSelectedColumns<float>("Output", "Output", transformedData, onnxResult); |
| 1194 | + } |
| 1195 | + Done(); |
| 1196 | + } |
| 1197 | + |
1165 | 1198 | [Theory]
|
1166 | 1199 | [CombinatorialData]
|
1167 | 1200 | // Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
|
|
0 commit comments