Skip to content

Commit d154d19

Browse files
committed
adding test - onehothas
1 parent 59dbdea commit d154d19

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using System.Runtime.InteropServices;
1010
using System.Text.RegularExpressions;
1111
using Google.Protobuf;
12-
using Google.Protobuf.WellKnownTypes;
1312
using Microsoft.ML.Data;
1413
using Microsoft.ML.EntryPoints;
1514
using Microsoft.ML.Model.OnnxConverter;
@@ -125,6 +124,9 @@ private class BreastCancerCatFeatureExample
125124

126125
[LoadColumn(2)]
127126
public string F2;
127+
128+
[LoadColumn(3, 7), VectorType(6)]
129+
public string[] F3;
128130
}
129131

130132
private class BreastCancerMulticlassExample
@@ -1162,6 +1164,37 @@ public void PcaOnnxConversionTest()
11621164
Done();
11631165
}
11641166

1167+
[Fact]
1168+
public void OneHotHashEncodingOnnxConversionTest()
1169+
{
1170+
var mlContext = new MLContext();
1171+
string dataPath = GetDataPath("breast-cancer.txt");
1172+
1173+
var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
1174+
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
1175+
new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
1176+
});
1177+
var model = pipe.Fit(dataView);
1178+
var transformedData = model.Transform(dataView);
1179+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1180+
1181+
var onnxFileName = "OneHotHashEncoding.onnx";
1182+
var onnxModelPath = GetOutputPath(onnxFileName);
1183+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1184+
1185+
if (IsOnnxRuntimeSupported())
1186+
{
1187+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1188+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1189+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1190+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
1191+
var onnxTransformer = onnxEstimator.Fit(dataView);
1192+
var onnxResult = onnxTransformer.Transform(dataView);
1193+
CompareSelectedColumns<float>("Output", "Output", transformedData, onnxResult);
1194+
}
1195+
Done();
1196+
}
1197+
11651198
[Theory]
11661199
[CombinatorialData]
11671200
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.

0 commit comments

Comments
 (0)