From 98d9f5b54e3f71775e7a9a74ec8cbec38fcc1270 Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Thu, 2 Jul 2020 11:28:23 -0700 Subject: [PATCH 1/3] StopWordsRemoving transformer export to onnx --- .../Text/StopWordsRemovingTransformer.cs | 89 ++++++++++++++++++- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 55 +++++++++++- 2 files changed, 138 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index fdb26ec26d..805789e986 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -15,6 +15,7 @@ using Microsoft.ML.Data.IO; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms.Text; @@ -343,7 +344,7 @@ private static Stream GetResourceFileStreamOrNull(StopWordsRemovingEstimator.Lan return assembly.GetManifestResourceStream($"{assembly.GetName().Name}.Text.StopWords.{lang.ToString()}.txt"); } - private sealed class Mapper : MapperBase + private sealed class Mapper : MapperBase, ISaveAsOnnx { private readonly DataViewType[] _types; private readonly StopWordsRemovingTransformer _parent; @@ -351,6 +352,8 @@ private sealed class Mapper : MapperBase private readonly bool?[] _resourcesExist; private readonly Dictionary _colMapNewToOld; + public bool CanSaveOnnx(OnnxContext ctx) => true; + public Mapper(StopWordsRemovingTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { @@ -438,6 +441,45 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func); + UpdateLanguage(ref langToUse, null, ref lang); + + var words = StopWords[(int)0].Select(item => Convert.ToString(item.Value)); + node.AddAttribute("stopwords", StopWords[(int)0].Select(item => Convert.ToString(item.Value))); + + opType = "Unsqueeze"; + squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput"); + node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + } + private void UpdateLanguage(ref StopWordsRemovingEstimator.Language langToUse, ValueGetter> getLang, ref ReadOnlyMemory langTxt) { if (getLang != null) @@ -490,7 +532,7 @@ private protected override Func GetDependenciesCore(Func a /// | Does this estimator need to look at the data to train its parameters? | No | /// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) | /// | Output column data type | Variable-sized vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) | - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// The resulting creates a new column, named as specified in the output column name parameter, /// and fills it with a vector of words containing all of the words in the input column **except the predefined list of stopwords for the specified language. @@ -1016,11 +1058,13 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewType[] _types; private readonly CustomStopWordsRemovingTransformer _parent; + public bool CanSaveOnnx(OnnxContext ctx) => true; + public Mapper(CustomStopWordsRemovingTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema) { @@ -1084,6 +1128,43 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func Convert.ToString(item.Value))); + + opType = "Unsqueeze"; + squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput"); + node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + } } } @@ -1098,7 +1179,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func creates a new column, named as specified by the output column name parameter, and /// fills it with a vector of words containing all of the words in the input column except those given by the stopwords parameter. diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index ca1b6d0ae0..d4dc580aae 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -974,8 +974,8 @@ public void OneHotHashEncodingOnnxConversionTest() var mlContext = new MLContext(); string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); - var dataView = ML.Data.LoadFromTextFile(dataPath); - var pipeline = ML.Transforms.Categorical.OneHotHashEncoding(new[]{ + var dataView = mlContext.Data.LoadFromTextFile(dataPath); + var pipeline = mlContext.Transforms.Categorical.OneHotHashEncoding(new[]{ new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false), }); var onnxFileName = "OneHotHashEncoding.onnx"; @@ -1343,6 +1343,57 @@ public void NgramOnnxConversionTest( Done(); } + [Fact] + public void CustomStopWordsRemovingEstimatorOnnxTest() + { + var mlContext = new MLContext(); + + var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text") + .Append(mlContext.Transforms.Text.RemoveStopWords( + "WordsWithoutStopWords", "Words", stopwords: + new[] { "cat", "sat", "on" })); + + var samples = new List() + { + new TextData(){ Text = "cat sat on mat" }, + new TextData(){ Text = "mat not fit cat" }, + new TextData(){ Text = "a cat think mat bad" }, + }; + + var dataView = mlContext.Data.LoadFromEnumerable(samples); + + var onnxFileName = $"CustomStopWordsRemovingEstimator.onnx"; + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords")}); + + Done(); + } + + [Fact] + public void StopWordsRemovingEstimatorOnnxTest() + { + var mlContext = new MLContext(); + + var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text") + .Append(mlContext.Transforms.Text.RemoveDefaultStopWords( + "WordsWithoutStopWords", "Words", language: + StopWordsRemovingEstimator.Language.English)); + + var samples = new List() + { + new TextData(){ Text = "a go cat sat on mat" }, + new TextData(){ Text = "a mat not fit go cat" }, + new TextData(){ Text = "cat think mat bad a" }, + }; + + var dataView = mlContext.Data.LoadFromEnumerable(samples); + + var onnxFileName = $"StopWordsRemovingEstimator.onnx"; + + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords") }); + + Done(); + } + [Theory] [InlineData(DataKind.Boolean)] [InlineData(DataKind.SByte)] From 809d77f24d981cf484dccaa9648781d1d2d200ad Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Thu, 2 Jul 2020 12:04:53 -0700 Subject: [PATCH 2/3] format changes --- .../Text/StopWordsRemovingTransformer.cs | 4 ++-- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 805789e986..14931e6c7d 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -1145,8 +1145,8 @@ public void SaveAsOnnx(OnnxContext ctx) } } - // Note: Since StringNormalizer only accepts inputs of [C] or [1,C], we squeeze the batch dimension which - // may exceed 1 + // Note: Since StringNormalizer only accepts inputs of shape [C] or [1,C], we temporarily squeeze the + // batch dimension which may exceed 1 private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { var opType = "Squeeze"; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index d4dc580aae..dfd6aa2136 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1359,10 +1359,9 @@ public void CustomStopWordsRemovingEstimatorOnnxTest() new TextData(){ Text = "mat not fit cat" }, new TextData(){ Text = "a cat think mat bad" }, }; - var dataView = mlContext.Data.LoadFromEnumerable(samples); - var onnxFileName = $"CustomStopWordsRemovingEstimator.onnx"; + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords")}); Done(); @@ -1384,9 +1383,7 @@ public void StopWordsRemovingEstimatorOnnxTest() new TextData(){ Text = "a mat not fit go cat" }, new TextData(){ Text = "cat think mat bad a" }, }; - var dataView = mlContext.Data.LoadFromEnumerable(samples); - var onnxFileName = $"StopWordsRemovingEstimator.onnx"; TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords") }); From f55873cde9a3b0ea64b0bb140228254e44e21ca7 Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Fri, 10 Jul 2020 15:46:30 -0700 Subject: [PATCH 3/3] adding types --- .../Text/StopWordsRemovingTransformer.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 14931e6c7d..959461a104 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -459,20 +459,20 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { var opType = "Squeeze"; - var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); + var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true); var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", new long[] { 0 }); opType = "StringNormalizer"; - var stringNormalizerOutput = ctx.AddIntermediateVariable(null, "StringNormalizerOutput", true); + var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true); node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), ""); var langToUse = _parent._columns[iinfo].Language; var lang = default(ReadOnlyMemory); UpdateLanguage(ref langToUse, null, ref lang); - var words = StopWords[(int)0].Select(item => Convert.ToString(item.Value)); - node.AddAttribute("stopwords", StopWords[(int)0].Select(item => Convert.ToString(item.Value))); + var words = StopWords[iinfo].Select(item => Convert.ToString(item.Value)); + node.AddAttribute("stopwords", StopWords[iinfo].Select(item => Convert.ToString(item.Value))); opType = "Unsqueeze"; squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput"); @@ -1150,12 +1150,12 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { var opType = "Squeeze"; - var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); + var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true); var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", new long[] { 0 }); opType = "StringNormalizer"; - var stringNormalizerOutput = ctx.AddIntermediateVariable(null, "StringNormalizerOutput", true); + var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true); node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), ""); var words = _parent._stopWordsMap.ToList(); node.AddAttribute("stopwords", words.Select(item => Convert.ToString(item.Value)));