diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index fdb26ec26d..959461a104 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -15,6 +15,7 @@ using Microsoft.ML.Data.IO; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms.Text; @@ -343,7 +344,7 @@ private static Stream GetResourceFileStreamOrNull(StopWordsRemovingEstimator.Lan return assembly.GetManifestResourceStream($"{assembly.GetName().Name}.Text.StopWords.{lang.ToString()}.txt"); } - private sealed class Mapper : MapperBase + private sealed class Mapper : MapperBase, ISaveAsOnnx { private readonly DataViewType[] _types; private readonly StopWordsRemovingTransformer _parent; @@ -351,6 +352,8 @@ private sealed class Mapper : MapperBase private readonly bool?[] _resourcesExist; private readonly Dictionary _colMapNewToOld; + public bool CanSaveOnnx(OnnxContext ctx) => true; + public Mapper(StopWordsRemovingTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { @@ -438,6 +441,45 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func); + UpdateLanguage(ref langToUse, null, ref lang); + + var words = StopWords[iinfo].Select(item => Convert.ToString(item.Value)); + node.AddAttribute("stopwords", StopWords[iinfo].Select(item => Convert.ToString(item.Value))); + + opType = "Unsqueeze"; + squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput"); + node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + } + private void UpdateLanguage(ref StopWordsRemovingEstimator.Language langToUse, ValueGetter> getLang, ref ReadOnlyMemory langTxt) { if (getLang != null) @@ -490,7 +532,7 @@ private protected override Func GetDependenciesCore(Func a /// | Does this estimator need to look at the data to train its parameters? | No | /// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) | /// | Output column data type | Variable-sized vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) | - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// The resulting creates a new column, named as specified in the output column name parameter, /// and fills it with a vector of words containing all of the words in the input column **except the predefined list of stopwords for the specified language. @@ -1016,11 +1058,13 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewType[] _types; private readonly CustomStopWordsRemovingTransformer _parent; + public bool CanSaveOnnx(OnnxContext ctx) => true; + public Mapper(CustomStopWordsRemovingTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema) { @@ -1084,6 +1128,43 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func Convert.ToString(item.Value))); + + opType = "Unsqueeze"; + squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput"); + node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + } } } @@ -1098,7 +1179,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func creates a new column, named as specified by the output column name parameter, and /// fills it with a vector of words containing all of the words in the input column except those given by the stopwords parameter. diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index ca1b6d0ae0..dfd6aa2136 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -974,8 +974,8 @@ public void OneHotHashEncodingOnnxConversionTest() var mlContext = new MLContext(); string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); - var dataView = ML.Data.LoadFromTextFile(dataPath); - var pipeline = ML.Transforms.Categorical.OneHotHashEncoding(new[]{ + var dataView = mlContext.Data.LoadFromTextFile(dataPath); + var pipeline = mlContext.Transforms.Categorical.OneHotHashEncoding(new[]{ new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false), }); var onnxFileName = "OneHotHashEncoding.onnx"; @@ -1343,6 +1343,54 @@ public void NgramOnnxConversionTest( Done(); } + [Fact] + public void CustomStopWordsRemovingEstimatorOnnxTest() + { + var mlContext = new MLContext(); + + var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text") + .Append(mlContext.Transforms.Text.RemoveStopWords( + "WordsWithoutStopWords", "Words", stopwords: + new[] { "cat", "sat", "on" })); + + var samples = new List() + { + new TextData(){ Text = "cat sat on mat" }, + new TextData(){ Text = "mat not fit cat" }, + new TextData(){ Text = "a cat think mat bad" }, + }; + var dataView = mlContext.Data.LoadFromEnumerable(samples); + var onnxFileName = $"CustomStopWordsRemovingEstimator.onnx"; + + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords")}); + + Done(); + } + + [Fact] + public void StopWordsRemovingEstimatorOnnxTest() + { + var mlContext = new MLContext(); + + var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text") + .Append(mlContext.Transforms.Text.RemoveDefaultStopWords( + "WordsWithoutStopWords", "Words", language: + StopWordsRemovingEstimator.Language.English)); + + var samples = new List() + { + new TextData(){ Text = "a go cat sat on mat" }, + new TextData(){ Text = "a mat not fit go cat" }, + new TextData(){ Text = "cat think mat bad a" }, + }; + var dataView = mlContext.Data.LoadFromEnumerable(samples); + var onnxFileName = $"StopWordsRemovingEstimator.onnx"; + + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords") }); + + Done(); + } + [Theory] [InlineData(DataKind.Boolean)] [InlineData(DataKind.SByte)]