diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 6183bb06c7..a6542c56a8 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -60,6 +60,12 @@ internal abstract class OnnxContext
/// IDataView column to stop tracking
public abstract void RemoveVariable(string variableName, bool removeColumn);
+ ///
+ /// Removes a variable from the input columns list. This function is used only by the ColumnSelectingTransformer.
+ ///
+ /// ONNX variable to remove.
+ public abstract void RemoveInputVariable(string variableName);
+
///
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
/// 's column names will map to a variable in the ONNX graph if the intermediate steps
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
index 346025ca94..df3665fa90 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
@@ -43,6 +44,7 @@ namespace Microsoft.ML.Transforms
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input columns data type | Any |
+ /// | Exportable to ONNX | Yes |
///
/// The resulting
/// operates on the schema of a given by dropping or keeping selected columns from the schema.
@@ -520,7 +522,7 @@ private sealed class Mapper
{
private readonly IHost _host;
private readonly DataViewSchema _inputSchema;
- private readonly int[] _outputToInputMap;
+ public readonly int[] OutputToInputMap;
public DataViewSchema InputSchema => _inputSchema;
@@ -531,17 +533,17 @@ public Mapper(ColumnSelectingTransformer transform, DataViewSchema inputSchema)
_host = transform._host.Register(nameof(Mapper));
_inputSchema = inputSchema;
- _outputToInputMap = BuildOutputToInputMap(transform.SelectColumns,
+ OutputToInputMap = BuildOutputToInputMap(transform.SelectColumns,
transform.KeepColumns,
transform.KeepHidden,
_inputSchema);
- OutputSchema = GenerateOutputSchema(_outputToInputMap, _inputSchema);
+ OutputSchema = GenerateOutputSchema(OutputToInputMap, _inputSchema);
}
public int GetInputIndex(int outputIndex)
{
- _host.Assert(0 <= outputIndex && outputIndex < _outputToInputMap.Length);
- return _outputToInputMap[outputIndex];
+ _host.Assert(0 <= outputIndex && outputIndex < OutputToInputMap.Length);
+ return OutputToInputMap[outputIndex];
}
private static int[] BuildOutputToInputMap(IEnumerable selectedColumns,
@@ -648,7 +650,7 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu
public override bool IsColumnActive(DataViewSchema.Column column) => true;
}
- private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate
+ private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate, ITransformCanSaveOnnx
{
private readonly IHost _host;
private readonly ColumnSelectingTransformer _transform;
@@ -725,6 +727,31 @@ DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource);
+
+ public bool CanSaveOnnx(OnnxContext ctx) => true;
+
+ public void SaveAsOnnx(OnnxContext ctx)
+ {
+ var droppedCols = new HashSet(Enumerable.Range(0, InputSchema.Count));
+
+ var outputToInputMap = _mapper.OutputToInputMap;
+ for(int i = 0; i < outputToInputMap.Length; i++)
+ {
+ var srcCol = InputSchema[outputToInputMap[i]];
+ var dstCol = OutputSchema[i];
+ var srcVariable = ctx.GetVariableName(srcCol.Name);
+ var dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name, true);
+ string opType = "Identity";
+ ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");
+
+ droppedCols.Remove(srcCol.Index);
+ }
+
+ foreach (var srcCol in droppedCols)
+ {
+ ctx.RemoveInputVariable(InputSchema[srcCol].Name);
+ }
+ }
}
private sealed class Cursor : SynchronizedCursorBase
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
index 8105a81126..41e05a7053 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
@@ -247,6 +247,15 @@ public void AddInputVariable(DataViewType type, string colName)
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
}
+ public override void RemoveInputVariable(string colName)
+ {
+ var variableName = TryGetVariableName(colName);
+ _host.CheckValue(variableName, nameof(variableName));
+
+ RemoveVariable(variableName, true);
+ _inputs.Remove(_inputs.Single(modelArg => modelArg.Name == variableName));
+ }
+
///
/// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found.
///
diff --git a/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt b/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt
new file mode 100644
index 0000000000..636d1f4da9
--- /dev/null
+++ b/test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt
@@ -0,0 +1,248 @@
+{
+ "irVersion": "6",
+ "producerName": "ML.NET",
+ "producerVersion": "##VERSION##",
+ "domain": "machinelearning.dotnet",
+ "graph": {
+ "node": [
+ {
+ "input": [
+ "Size"
+ ],
+ "output": [
+ "Size0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Shape"
+ ],
+ "output": [
+ "Shape0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Thickness"
+ ],
+ "output": [
+ "Thickness0"
+ ],
+ "name": "Identity1",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Label"
+ ],
+ "output": [
+ "Label0"
+ ],
+ "name": "Identity2",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Size0"
+ ],
+ "output": [
+ "Size1"
+ ],
+ "name": "Identity3",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Shape0"
+ ],
+ "output": [
+ "Shape1"
+ ],
+ "name": "Identity4",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Thickness0"
+ ],
+ "output": [
+ "Thickness1"
+ ],
+ "name": "Identity5",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Label0"
+ ],
+ "output": [
+ "Label1"
+ ],
+ "name": "Identity6",
+ "opType": "Identity"
+ }
+ ],
+ "name": "model",
+ "input": [
+ {
+ "name": "Label",
+ "type": {
+ "tensorType": {
+ "elemType": 9,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Thickness",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Size",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Shape",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ }
+ ],
+ "output": [
+ {
+ "name": "Size1",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Shape1",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Thickness1",
+ "type": {
+ "tensorType": {
+ "elemType": 6,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ },
+ {
+ "name": "Label1",
+ "type": {
+ "tensorType": {
+ "elemType": 9,
+ "shape": {
+ "dim": [
+ {
+ "dimValue": "1"
+ },
+ {
+ "dimValue": "1"
+ }
+ ]
+ }
+ }
+ }
+ }
+ ]
+ },
+ "opsetImport": [
+ {
+ "domain": "ai.onnx.ml",
+ "version": "2"
+ },
+ {
+ "version": "11"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index 0ab47d8a27..1bdadba595 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -1307,6 +1307,68 @@ public void FeatureSelectionOnnxTest()
}
+
+ [Fact]
+ public void SelectColumnsOnnxTest()
+ {
+ var mlContext = new MLContext(seed: 1);
+
+ string dataPath = GetDataPath("breast-cancer.txt");
+
+ var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
+ new TextLoader.Column("Label", DataKind.Boolean, 0),
+ new TextLoader.Column("Thickness", DataKind.Int32, 1),
+ new TextLoader.Column("Size", DataKind.Int32, 2),
+ new TextLoader.Column("Shape", DataKind.Int32, 3),
+ new TextLoader.Column("Adhesion", DataKind.Int32, 4),
+ new TextLoader.Column("EpithelialSize", DataKind.Int32, 5),
+ new TextLoader.Column("BlandChromatin", DataKind.Int32, 7),
+ new TextLoader.Column("NormalNucleoli", DataKind.Int32, 8),
+ new TextLoader.Column("Mitoses", DataKind.Int32, 9),
+ });
+
+ var pipeline = mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" });
+
+ var model = pipeline.Fit(dataView);
+ var transformedData = model.Transform(dataView);
+ var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
+
+ var onnxFileName = "selectcolumns.onnx";
+ var onnxModelPath = GetOutputPath(onnxFileName);
+
+ SaveOnnxModel(onnxModel, onnxModelPath, null);
+
+ if (IsOnnxRuntimeSupported())
+ {
+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
+ string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
+ string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
+ var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
+ var onnxTransformer = onnxEstimator.Fit(dataView);
+ var onnxResult = onnxTransformer.Transform(dataView);
+
+ // Verify that onnx output has only the four columns we selected from the input
+ Assert.Equal(4, outputNames.Length);
+ Assert.Equal("Size1", outputNames[0]);
+ Assert.Equal("Shape1", outputNames[1]);
+ Assert.Equal("Thickness1", outputNames[2]);
+ Assert.Equal("Label1", outputNames[3]);
+
+ CompareSelectedScalarColumns("Size", "Size1", transformedData, onnxResult);
+ CompareSelectedScalarColumns("Shape", "Shape1", transformedData, onnxResult);
+ CompareSelectedScalarColumns("Thickness", "Thickness1", transformedData, onnxResult);
+ CompareSelectedScalarColumns("Label", "Label1", transformedData, onnxResult);
+ }
+
+ onnxFileName = "SelectColumns.txt";
+ var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms");
+ var onnxTextModelPath = GetOutputPath(subDir, onnxFileName);
+ SaveOnnxModel(onnxModel, null, onnxTextModelPath);
+ CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1);
+
+ Done();
+ }
+
private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];