Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable OnnxTransformer to accept KeyDataViewTypes as if they were UInt32 #4824

Merged
merged 8 commits into from Feb 12, 2020
9 changes: 8 additions & 1 deletion src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Expand Up @@ -389,7 +389,14 @@ private sealed class Mapper : MapperBase
throw Host.Except($"Variable length input columns not supported");

if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType())
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
{
// If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided
// then throw an exception.
// This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32)))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
}

// If the column is one dimension we make sure that the total size of the Onnx shape matches.
// Compute the total size of the known dimensions of the shape.
Expand Down
69 changes: 69 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Expand Up @@ -1450,6 +1450,75 @@ public void CopyColumnsOnnxTest()
Done();
}

[Fact]
public void UseKeyDataViewTypeAsUInt32InOnnxInput()
{
// In this test an onnx model which expect a uin32 input column is applied to a KeyDataViewType input column
// This, is done as needed by NimbusML. For more context see: https://github.com/microsoft/NimbusML/issues/426

// Step 1: Load the Iris Dataset and apply a Value To Key Mapping to it.
// Save the resulting dataview in .idv format eliminating all hidden columns
var mlContext = new MLContext();
var loader = mlContext.Data.CreateTextLoader(
columns: new[]
{
new TextLoader.Column("Label", DataKind.String, 0),
new TextLoader.Column("SepalLength", DataKind.Single, 1),
new TextLoader.Column("SepalWidth", DataKind.Single, 2),
new TextLoader.Column("PetalLength", DataKind.Single, 3),
new TextLoader.Column("PetalWidth", DataKind.Single, 4)
},
hasHeader: false
);

string dataPath = GetDataPath("iris.txt");
var originalData = loader.Load(dataPath);
var pipeline1 = mlContext.Transforms.Conversion.MapValueToKey("Label");
var mappedData = pipeline1.Fit(originalData).Transform(originalData);

string mappedDataPath = GetOutputPath("kdvt-as-uint32-mapped-data.idv");
using (FileStream stream = new FileStream(mappedDataPath, FileMode.Create))
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
mlContext.Data.SaveAsBinary(mappedData, stream, keepHidden: false);

// Step 2: Load back the saved .idv
// This IDataView will have a Label column of type KeyDataViewType
// It's necessary to do this, because if I were to use mappedData directly inside the next
// steps, then when saving the ONNX model, it would actually also save the ValueToKeyTransformer part
// and that wouldn't reproduce the scenario.
IDataView reloadedData = mlContext.Data.LoadFromBinary(mappedDataPath);

// Step 3: Create ONNX model which simply applies Identity to Label column
var pipeline2 = mlContext.Transforms.CopyColumns("Label", "Label");
var model = pipeline2.Fit(reloadedData);

var onnxModelPath = GetOutputPath("onnxmodel1-kdvt-as-uint32.onnx");
using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create))
mlContext.Model.ConvertToOnnx(model, reloadedData, stream);

// Step 4: Get input and output names of model
var onnxProtoBufModel = mlContext.Model.ConvertToOnnxProtobuf(model, reloadedData);
string[] inputNames = onnxProtoBufModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxProtoBufModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();

// Step 5: Apply Onnx Model
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxResult = onnxEstimator.Fit(reloadedData).Transform(reloadedData);

// Step 6: Compare results to an onnx model created using the mappedData IDataView
// Notice that this ONNX model would actually include the steps to do the ValueToKeyTransformer mapping
// And because of this, it can only be applied to reloadedData dataview, despite mappedData was used to create the model.
// If it's tried to apply this model to mappedData or reloadedData, it will throw an exception, since the ONNX model
// will expect a Label input of type string (which only originalData provides).
string onnxModelPath2 = GetOutputPath("onnxmodel2-kdvt-as-uint32.onnx");
using (FileStream stream = new FileStream(onnxModelPath2, FileMode.Create))
mlContext.Model.ConvertToOnnx(model, mappedData, stream);
var onnxEstimator2 = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath2);
var onnxResult2 = onnxEstimator2.Fit(originalData).Transform(originalData);

foreach (var name in outputNames)
CompareResults(name, name, onnxResult, onnxResult2);
}

[Fact]
public void FeatureSelectionOnnxTest()
{
Expand Down