diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index a6db9f0051..2720cff798 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -11,6 +11,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Model.Pfa; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -152,7 +153,7 @@ private protected override void SaveModel(ModelSaveContext ctx) private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); - private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa + private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa, ISaveAsOnnx { private readonly KeyToValueMappingTransformer _parent; private readonly DataViewType[] _types; @@ -298,6 +299,8 @@ protected KeyToValueMap(Mapper mapper, PrimitiveDataViewType typeVal, int iinfo) public abstract Delegate GetMappingGetter(DataViewRow input); public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken); + + public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName); } private class KeyToValueMap : KeyToValueMap @@ -494,8 +497,65 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken) } return PfaUtils.If(PfaUtils.Call("<", srcToken, 0), defaultToken, PfaUtils.Index(cellRef, srcToken)); } + + public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName) + { + string opType; + + // Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that + // may output a uint32. So cast it here to ensure that the data is treated correctly + opType = "Cast"; + var castNodeOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true); + var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); + castNode.AddAttribute("to", t); + + opType = "LabelEncoder"; + var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType)); + var keys = Array.ConvertAll(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item)); + node.AddAttribute("keys_int64s", keys); + + if (TypeOutput == NumberDataViewType.Int64) + { + long[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToInt64(item)); + node.AddAttribute("values_int64s", values); + } + else if (TypeOutput == NumberDataViewType.Single) + { + float[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToSingle(item)); + node.AddAttribute("values_floats", values); + } + else if (TypeOutput == TextDataViewType.Instance) + { + string[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToString(item)); + node.AddAttribute("values_strings", values); + } + else + return false; + + return true; + } } + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) + { + var info = _parent.ColumnPairs[iinfo]; + var inputColumnName = info.inputColumnName; + + if (!ctx.ContainsColumn(inputColumnName)) + continue; + + var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true); + if (!_kvMaps[iinfo].SaveOnnx(ctx, inputColumnName, dstVariableName)) + { + ctx.RemoveColumn(inputColumnName, true); + } + } + } } } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 1ef88c3889..17a04d7fc7 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1040,6 +1040,42 @@ public void OptionalColumnOnnxTest() Done(); } + [Fact] + private void KeyToValueOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + var dataView = mlContext.Data.LoadFromTextFile(dataPath, + separatorChar: '\t', + hasHeader: true); + + var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelKey", "Label"). + Append(mlContext.Transforms.Conversion.MapKeyToValue("LabelValue", "LabelKey")); + + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "KeyToValue.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedScalarColumns>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); + } + + Done(); + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null }; @@ -1105,6 +1141,34 @@ private void CompareSelectedVectorColumns(string leftColumnName, string right } } + private void CompareSelectedScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + + using (var expectedCursor = left.GetRowCursor(leftColumn)) + using (var actualCursor = right.GetRowCursor(rightColumn)) + { + T expected = default; + VBuffer actual = default; + var expectedGetter = expectedCursor.GetGetter(leftColumn); + var actualGetter = actualCursor.GetGetter>(rightColumn); + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) + { + expectedGetter(ref expected); + actualGetter(ref actual); + var actualVal = actual.GetItemOrDefault(0); + + Assert.Equal(1, actual.Length); + + if (typeof(T) == typeof(ReadOnlyMemory)) + Assert.Equal(expected.ToString(), actualVal.ToString()); + else + Assert.Equal(expected, actualVal); + } + } + } + private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName];