diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 6183bb06c7..a6542c56a8 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -60,6 +60,12 @@ internal abstract class OnnxContext /// IDataView column to stop tracking public abstract void RemoveVariable(string variableName, bool removeColumn); + /// + /// Removes a variable from the input columns list. This function is used only by the ColumnSelectingTransformer. + /// + /// ONNX variable to remove. + public abstract void RemoveInputVariable(string variableName); + /// /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding /// 's column names will map to a variable in the ONNX graph if the intermediate steps diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index 346025ca94..df3665fa90 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -9,6 +9,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -43,6 +44,7 @@ namespace Microsoft.ML.Transforms /// | -- | -- | /// | Does this estimator need to look at the data to train its parameters? | No | /// | Input columns data type | Any | + /// | Exportable to ONNX | Yes | /// /// The resulting /// operates on the schema of a given by dropping or keeping selected columns from the schema. @@ -520,7 +522,7 @@ private sealed class Mapper { private readonly IHost _host; private readonly DataViewSchema _inputSchema; - private readonly int[] _outputToInputMap; + public readonly int[] OutputToInputMap; public DataViewSchema InputSchema => _inputSchema; @@ -531,17 +533,17 @@ public Mapper(ColumnSelectingTransformer transform, DataViewSchema inputSchema) _host = transform._host.Register(nameof(Mapper)); _inputSchema = inputSchema; - _outputToInputMap = BuildOutputToInputMap(transform.SelectColumns, + OutputToInputMap = BuildOutputToInputMap(transform.SelectColumns, transform.KeepColumns, transform.KeepHidden, _inputSchema); - OutputSchema = GenerateOutputSchema(_outputToInputMap, _inputSchema); + OutputSchema = GenerateOutputSchema(OutputToInputMap, _inputSchema); } public int GetInputIndex(int outputIndex) { - _host.Assert(0 <= outputIndex && outputIndex < _outputToInputMap.Length); - return _outputToInputMap[outputIndex]; + _host.Assert(0 <= outputIndex && outputIndex < OutputToInputMap.Length); + return OutputToInputMap[outputIndex]; } private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, @@ -648,7 +650,7 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu public override bool IsColumnActive(DataViewSchema.Column column) => true; } - private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate + private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate, ITransformCanSaveOnnx { private readonly IHost _host; private readonly ColumnSelectingTransformer _transform; @@ -725,6 +727,31 @@ DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource); + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + var droppedCols = new HashSet(Enumerable.Range(0, InputSchema.Count)); + + var outputToInputMap = _mapper.OutputToInputMap; + for(int i = 0; i < outputToInputMap.Length; i++) + { + var srcCol = InputSchema[outputToInputMap[i]]; + var dstCol = OutputSchema[i]; + var srcVariable = ctx.GetVariableName(srcCol.Name); + var dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name, true); + string opType = "Identity"; + ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), ""); + + droppedCols.Remove(srcCol.Index); + } + + foreach (var srcCol in droppedCols) + { + ctx.RemoveInputVariable(InputSchema[srcCol].Name); + } + } } private sealed class Cursor : SynchronizedCursorBase diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs index 8105a81126..41e05a7053 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs @@ -247,6 +247,15 @@ public void AddInputVariable(DataViewType type, string colName) _inputs.Add(OnnxUtils.GetModelArgs(type, colName)); } + public override void RemoveInputVariable(string colName) + { + var variableName = TryGetVariableName(colName); + _host.CheckValue(variableName, nameof(variableName)); + + RemoveVariable(variableName, true); + _inputs.Remove(_inputs.Single(modelArg => modelArg.Name == variableName)); + } + /// /// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found. /// diff --git a/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt b/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt new file mode 100644 index 0000000000..636d1f4da9 --- /dev/null +++ b/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt @@ -0,0 +1,248 @@ +{ + "irVersion": "6", + "producerName": "ML.NET", + "producerVersion": "##VERSION##", + "domain": "machinelearning.dotnet", + "graph": { + "node": [ + { + "input": [ + "Size" + ], + "output": [ + "Size0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Shape" + ], + "output": [ + "Shape0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Thickness" + ], + "output": [ + "Thickness0" + ], + "name": "Identity1", + "opType": "Identity" + }, + { + "input": [ + "Label" + ], + "output": [ + "Label0" + ], + "name": "Identity2", + "opType": "Identity" + }, + { + "input": [ + "Size0" + ], + "output": [ + "Size1" + ], + "name": "Identity3", + "opType": "Identity" + }, + { + "input": [ + "Shape0" + ], + "output": [ + "Shape1" + ], + "name": "Identity4", + "opType": "Identity" + }, + { + "input": [ + "Thickness0" + ], + "output": [ + "Thickness1" + ], + "name": "Identity5", + "opType": "Identity" + }, + { + "input": [ + "Label0" + ], + "output": [ + "Label1" + ], + "name": "Identity6", + "opType": "Identity" + } + ], + "name": "model", + "input": [ + { + "name": "Label", + "type": { + "tensorType": { + "elemType": 9, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Thickness", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Size", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Shape", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + } + ], + "output": [ + { + "name": "Size1", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Shape1", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Thickness1", + "type": { + "tensorType": { + "elemType": 6, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Label1", + "type": { + "tensorType": { + "elemType": 9, + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + } + ] + }, + "opsetImport": [ + { + "domain": "ai.onnx.ml", + "version": "2" + }, + { + "version": "11" + } + ] +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 0ab47d8a27..1bdadba595 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1307,6 +1307,68 @@ public void FeatureSelectionOnnxTest() } + + [Fact] + public void SelectColumnsOnnxTest() + { + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Thickness", DataKind.Int32, 1), + new TextLoader.Column("Size", DataKind.Int32, 2), + new TextLoader.Column("Shape", DataKind.Int32, 3), + new TextLoader.Column("Adhesion", DataKind.Int32, 4), + new TextLoader.Column("EpithelialSize", DataKind.Int32, 5), + new TextLoader.Column("BlandChromatin", DataKind.Int32, 7), + new TextLoader.Column("NormalNucleoli", DataKind.Int32, 8), + new TextLoader.Column("Mitoses", DataKind.Int32, 9), + }); + + var pipeline = mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" }); + + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "selectcolumns.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + + // Verify that onnx output has only the four columns we selected from the input + Assert.Equal(4, outputNames.Length); + Assert.Equal("Size1", outputNames[0]); + Assert.Equal("Shape1", outputNames[1]); + Assert.Equal("Thickness1", outputNames[2]); + Assert.Equal("Label1", outputNames[3]); + + CompareSelectedScalarColumns("Size", "Size1", transformedData, onnxResult); + CompareSelectedScalarColumns("Shape", "Shape1", transformedData, onnxResult); + CompareSelectedScalarColumns("Thickness", "Thickness1", transformedData, onnxResult); + CompareSelectedScalarColumns("Label", "Label1", transformedData, onnxResult); + } + + onnxFileName = "SelectColumns.txt"; + var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms"); + var onnxTextModelPath = GetOutputPath(subDir, onnxFileName); + SaveOnnxModel(onnxModel, null, onnxTextModelPath); + CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1); + + Done(); + } + private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right) { var leftColumn = left.Schema[leftColumnName];