diff --git a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
index e3d1a8bd8e..57eb574aad 100644
--- a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
+++ b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
@@ -57,9 +57,20 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)]
public bool? LoadPredictor;
- [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
+ ///
+ /// Entry point API can save either or .
+ /// is used when the saved model is typed to .
+ ///
+ [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
public TransformModel Model;
+ ///
+ /// Entry point API can save either or .
+ /// is used when the saved model is typed to .
+ ///
+ [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Predictor model that needs to be converted to ONNX format.", SortOrder = 12)]
+ public PredictorModel PredictiveModel;
+
[Argument(ArgumentType.AtMostOnce, HelpText = "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", SortOrder = 11)]
public OnnxVersion OnnxVersion;
}
@@ -72,6 +83,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase
private readonly HashSet _inputsToDrop;
private readonly HashSet _outputsToDrop;
private readonly TransformModel _model;
+ private readonly PredictorModel _predictiveModel;
private const string ProducerName = "ML.NET";
private const long ModelVersion = 0;
@@ -96,7 +108,13 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args)
_inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(','));
_outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(','));
_domain = args.Domain;
+
+ if (args.Model != null && args.PredictiveModel != null)
+ throw env.Except(nameof(args.Model) + " and " + nameof(args.PredictiveModel) +
+ " cannot be specified at the same time when calling ONNX converter. Please check the content of " + nameof(args) + ".");
+
_model = args.Model;
+ _predictiveModel = args.PredictiveModel;
}
private static HashSet CreateDropMap(string[] toDrop)
@@ -198,7 +216,7 @@ private void Run(IChannel ch)
IDataView view;
RoleMappedSchema trainSchema = null;
- if (_model == null)
+ if (_model == null && _predictiveModel == null)
{
if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
{
@@ -213,8 +231,16 @@ private void Run(IChannel ch)
view = loader;
}
- else
+ else if (_model != null)
+ {
view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
+ }
+ else
+ {
+ view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
+ rawPred = _predictiveModel.Predictor;
+ trainSchema = _predictiveModel.GetTrainingSchema(Host);
+ }
// Create the ONNX context for storing global information
var assembly = System.Reflection.Assembly.GetExecutingAssembly();
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 5b3d23c4f5..8fef229723 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -2275,9 +2275,10 @@
"Name": "Model",
"Type": "TransformModel",
"Desc": "Model that needs to be converted to ONNX format.",
- "Required": true,
+ "Required": false,
"SortOrder": 10.0,
- "IsNullable": false
+ "IsNullable": false,
+ "Default": null
},
{
"Name": "OnnxVersion",
@@ -2293,6 +2294,15 @@
"SortOrder": 11.0,
"IsNullable": false,
"Default": "Stable"
+ },
+ {
+ "Name": "PredictiveModel",
+ "Type": "PredictorModel",
+ "Desc": "Predictor model that needs to be converted to ONNX format.",
+ "Required": false,
+ "SortOrder": 12.0,
+ "IsNullable": false,
+ "Default": null
}
],
"Outputs": []
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index fc67b53e89..af87577042 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -10,6 +10,7 @@
using System.Text.RegularExpressions;
using Google.Protobuf;
using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
@@ -186,7 +187,7 @@ void CommandLineOnnxConversionTest()
string modelPath = GetOutputPath("ModelWithLessIO.zip");
var trainingPathArgs = $"data={dataPath} out={modelPath}";
var trainingArgs = " loader=text{col=Label:BL:0 col=F1:R4:1-8 col=F2:TX:9} xf=Cat{col=F2} xf=Concat{col=Features:F1,F2} tr=ft{numberOfThreads=1 numberOfLeaves=8 numberOfTrees=3} seed=1";
- Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs}));
+ Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs }));
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer");
var onnxTextName = "ModelWithLessIO.txt";
@@ -403,6 +404,127 @@ public void MulticlassLogisticRegressionOnnxConversionTest()
Done();
}
+ [Fact]
+ public void LoadingPredictorModelAndOnnxConversionTest()
+ {
+ string dataPath = GetDataPath("iris.txt");
+ string modelPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.bin";
+ string onnxPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx";
+ string onnxJsonPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx.json";
+
+ string inputGraph = string.Format(@"
+ {{
+ 'Inputs': {{
+ 'inputFile': '{0}'
+ }},
+ 'Nodes': [
+ {{
+ 'Name': 'Data.TextLoader',
+ 'Inputs':
+ {{
+ 'InputFile': '$inputFile',
+ 'Arguments':
+ {{
+ 'UseThreads': true,
+ 'HeaderFile': null,
+ 'MaxRows': null,
+ 'AllowQuoting': true,
+ 'AllowSparse': true,
+ 'InputSize': null,
+ 'TrimWhitespace': false,
+ 'HasHeader': false,
+ 'Column':
+ [
+ {{'Name':'Sepal_Width','Type':null,'Source':[{{'Min':2,'Max':2,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
+ {{'Name':'Petal_Length','Type':null,'Source':[{{'Min':3,'Max':4,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
+ ]
+ }}
+ }},
+ 'Outputs':
+ {{
+ 'Data': '$training_data'
+ }}
+ }},
+ {{
+ 'Inputs': {{
+ 'FeatureColumnName': 'Petal_Length',
+ 'LabelColumnName': 'Sepal_Width',
+ 'TrainingData': '$training_data',
+ }},
+ 'Name': 'Trainers.StochasticDualCoordinateAscentRegressor',
+ 'Outputs': {{
+ 'PredictorModel': '$output_model'
+ }}
+ }}
+ ],
+ 'Outputs': {{
+ 'output_model': '{1}'
+ }}
+ }}", dataPath.Replace("\\", "\\\\"), modelPath.Replace("\\", "\\\\"));
+
+ // Write entry point graph into file so that it can be invoke by graph runner below.
+ var jsonPath = DeleteOutputPath("graph.json");
+ File.WriteAllLines(jsonPath, new[] { inputGraph });
+
+ // Execute the saved entry point graph to produce a predictive model.
+ var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
+ var cmd = new ExecuteGraphCommand(Env, args);
+ cmd.Run();
+
+ // Make entry point graph to conduct ONNX conversion.
+ inputGraph = string.Format(@"
+ {{
+ 'Inputs': {{
+ 'model': '{0}'
+ }},
+ 'Nodes': [
+ {{
+ 'Inputs': {{
+ 'Domain': 'com.microsoft.models',
+ 'Json': '{1}',
+ 'PredictiveModel': '$model',
+ 'Onnx': '{2}',
+ 'OnnxVersion': 'Experimental'
+ }},
+ 'Name': 'Models.OnnxConverter',
+ 'Outputs': {{}}
+ }}
+ ],
+ 'Outputs': {{}}
+ }}
+ ", modelPath.Replace("\\", "\\\\"), onnxJsonPath.Replace("\\", "\\\\"), onnxPath.Replace("\\", "\\\\"));
+
+ // Write entry point graph for ONNX conversion into file so that it can be invoke by graph runner below.
+ jsonPath = DeleteOutputPath("graph.json");
+ File.WriteAllLines(jsonPath, new[] { inputGraph });
+
+ // Onnx converter's assembly is not loaded by default, so we need to register it before calling it.
+ Env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly);
+
+ // Execute the saved entry point graph to convert the saved model to ONNX format.
+ args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
+ cmd = new ExecuteGraphCommand(Env, args);
+ cmd.Run();
+
+ // Load the resulted ONNX model from the file so that we can check if the conversion looks good.
+ var model = new OnnxCSharpToProtoWrapper.ModelProto();
+ using (var modelStream = File.OpenRead(onnxPath))
+ model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);
+
+ // Make sure a PredictorModel is loaded by seeing if a predictive model exists. In this the
+ // predictive model is "LinearRegressor" (converted from StochasticDualCoordinateAscentRegressor
+ // in the original training entry-point graph.
+ Assert.Equal("Scaler", model.Graph.Node[0].OpType);
+ Assert.Equal("LinearRegressor", model.Graph.Node[1].OpType);
+
+ File.Delete(modelPath);
+ File.Delete(onnxPath);
+ File.Delete(onnxJsonPath);
+
+ Done();
+ }
+
+
[Fact]
public void RemoveVariablesInPipelineTest()
{
@@ -451,7 +573,7 @@ public void RemoveVariablesInPipelineTest()
private class SmallSentimentExample
{
- [LoadColumn(0,3), VectorType(4)]
+ [LoadColumn(0, 3), VectorType(4)]
public string[] Tokens;
}