diff --git a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs index e3d1a8bd8e..57eb574aad 100644 --- a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs @@ -57,9 +57,20 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] public bool? LoadPredictor; - [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] + /// + /// Entry point API can save either or . + /// is used when the saved model is typed to . + /// + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] public TransformModel Model; + /// + /// Entry point API can save either or . + /// is used when the saved model is typed to . + /// + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Predictor model that needs to be converted to ONNX format.", SortOrder = 12)] + public PredictorModel PredictiveModel; + [Argument(ArgumentType.AtMostOnce, HelpText = "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", SortOrder = 11)] public OnnxVersion OnnxVersion; } @@ -72,6 +83,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly HashSet _inputsToDrop; private readonly HashSet _outputsToDrop; private readonly TransformModel _model; + private readonly PredictorModel _predictiveModel; private const string ProducerName = "ML.NET"; private const long ModelVersion = 0; @@ -96,7 +108,13 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args) _inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(',')); _outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(',')); _domain = args.Domain; + + if (args.Model != null && args.PredictiveModel != null) + throw env.Except(nameof(args.Model) + " and " + nameof(args.PredictiveModel) + + " cannot be specified at the same time when calling ONNX converter. Please check the content of " + nameof(args) + "."); + _model = args.Model; + _predictiveModel = args.PredictiveModel; } private static HashSet CreateDropMap(string[] toDrop) @@ -198,7 +216,7 @@ private void Run(IChannel ch) IDataView view; RoleMappedSchema trainSchema = null; - if (_model == null) + if (_model == null && _predictiveModel == null) { if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { @@ -213,8 +231,16 @@ private void Run(IChannel ch) view = loader; } - else + else if (_model != null) + { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); + } + else + { + view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema)); + rawPred = _predictiveModel.Predictor; + trainSchema = _predictiveModel.GetTrainingSchema(Host); + } // Create the ONNX context for storing global information var assembly = System.Reflection.Assembly.GetExecutingAssembly(); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 5b3d23c4f5..8fef229723 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2275,9 +2275,10 @@ "Name": "Model", "Type": "TransformModel", "Desc": "Model that needs to be converted to ONNX format.", - "Required": true, + "Required": false, "SortOrder": 10.0, - "IsNullable": false + "IsNullable": false, + "Default": null }, { "Name": "OnnxVersion", @@ -2293,6 +2294,15 @@ "SortOrder": 11.0, "IsNullable": false, "Default": "Stable" + }, + { + "Name": "PredictiveModel", + "Type": "PredictorModel", + "Desc": "Predictor model that needs to be converted to ONNX format.", + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": null } ], "Outputs": [] diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index fc67b53e89..af87577042 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -10,6 +10,7 @@ using System.Text.RegularExpressions; using Google.Protobuf; using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.RunTests; using Microsoft.ML.Runtime; @@ -186,7 +187,7 @@ void CommandLineOnnxConversionTest() string modelPath = GetOutputPath("ModelWithLessIO.zip"); var trainingPathArgs = $"data={dataPath} out={modelPath}"; var trainingArgs = " loader=text{col=Label:BL:0 col=F1:R4:1-8 col=F2:TX:9} xf=Cat{col=F2} xf=Concat{col=Features:F1,F2} tr=ft{numberOfThreads=1 numberOfLeaves=8 numberOfTrees=3} seed=1"; - Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs})); + Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs })); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer"); var onnxTextName = "ModelWithLessIO.txt"; @@ -403,6 +404,127 @@ public void MulticlassLogisticRegressionOnnxConversionTest() Done(); } + [Fact] + public void LoadingPredictorModelAndOnnxConversionTest() + { + string dataPath = GetDataPath("iris.txt"); + string modelPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.bin"; + string onnxPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx"; + string onnxJsonPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx.json"; + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'inputFile': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.TextLoader', + 'Inputs': + {{ + 'InputFile': '$inputFile', + 'Arguments': + {{ + 'UseThreads': true, + 'HeaderFile': null, + 'MaxRows': null, + 'AllowQuoting': true, + 'AllowSparse': true, + 'InputSize': null, + 'TrimWhitespace': false, + 'HasHeader': false, + 'Column': + [ + {{'Name':'Sepal_Width','Type':null,'Source':[{{'Min':2,'Max':2,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}}, + {{'Name':'Petal_Length','Type':null,'Source':[{{'Min':3,'Max':4,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}}, + ] + }} + }}, + 'Outputs': + {{ + 'Data': '$training_data' + }} + }}, + {{ + 'Inputs': {{ + 'FeatureColumnName': 'Petal_Length', + 'LabelColumnName': 'Sepal_Width', + 'TrainingData': '$training_data', + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentRegressor', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }} + ], + 'Outputs': {{ + 'output_model': '{1}' + }} + }}", dataPath.Replace("\\", "\\\\"), modelPath.Replace("\\", "\\\\")); + + // Write entry point graph into file so that it can be invoke by graph runner below. + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + // Execute the saved entry point graph to produce a predictive model. + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + // Make entry point graph to conduct ONNX conversion. + inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'model': '{0}' + }}, + 'Nodes': [ + {{ + 'Inputs': {{ + 'Domain': 'com.microsoft.models', + 'Json': '{1}', + 'PredictiveModel': '$model', + 'Onnx': '{2}', + 'OnnxVersion': 'Experimental' + }}, + 'Name': 'Models.OnnxConverter', + 'Outputs': {{}} + }} + ], + 'Outputs': {{}} + }} + ", modelPath.Replace("\\", "\\\\"), onnxJsonPath.Replace("\\", "\\\\"), onnxPath.Replace("\\", "\\\\")); + + // Write entry point graph for ONNX conversion into file so that it can be invoke by graph runner below. + jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + // Onnx converter's assembly is not loaded by default, so we need to register it before calling it. + Env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly); + + // Execute the saved entry point graph to convert the saved model to ONNX format. + args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + // Load the resulted ONNX model from the file so that we can check if the conversion looks good. + var model = new OnnxCSharpToProtoWrapper.ModelProto(); + using (var modelStream = File.OpenRead(onnxPath)) + model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream); + + // Make sure a PredictorModel is loaded by seeing if a predictive model exists. In this the + // predictive model is "LinearRegressor" (converted from StochasticDualCoordinateAscentRegressor + // in the original training entry-point graph. + Assert.Equal("Scaler", model.Graph.Node[0].OpType); + Assert.Equal("LinearRegressor", model.Graph.Node[1].OpType); + + File.Delete(modelPath); + File.Delete(onnxPath); + File.Delete(onnxJsonPath); + + Done(); + } + + [Fact] public void RemoveVariablesInPipelineTest() { @@ -451,7 +573,7 @@ public void RemoveVariablesInPipelineTest() private class SmallSentimentExample { - [LoadColumn(0,3), VectorType(4)] + [LoadColumn(0, 3), VectorType(4)] public string[] Tokens; }