diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index 047ebdcd05..f4f35f5dd5 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -165,11 +165,13 @@ private static VersionInfo GetVersionInfo() // verWrittenCur: 0x00010002, // Invert hash key values, hash fix verWrittenCur: 0x00010003, // Get rid of writing float size in model context and change saving format verReadableCur: 0x00010003, - verWeCanReadBack: 0x00010003, + verWeCanReadBack: 0x00010002, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(NgramHashingTransformer).Assembly.FullName); } + private const int VersionTransformer = 0x00010003; + /// /// Describes how the transformer handles one pair of mulitple inputs - singular output columns. /// @@ -242,6 +244,7 @@ public ColumnInfo(string[] inputs, string output, InvertHash = invertHash; RehashUnigrams = rehashUnigrams; } + internal ColumnInfo(ModelLoadContext ctx) { Contracts.AssertValue(ctx); @@ -275,6 +278,36 @@ internal ColumnInfo(ModelLoadContext ctx) AllLengths = ctx.Reader.ReadBoolByte(); } + internal ColumnInfo(ModelLoadContext ctx, string[] inputs, string output) + { + Contracts.AssertValue(ctx); + Contracts.CheckValue(inputs, nameof(inputs)); + Contracts.CheckParam(!inputs.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputs), + "Contained some null or empty items"); + Inputs = inputs; + Output = output; + // *** Binary format *** + // string Output; + // int: NgramLength + // int: SkipLength + // int: HashBits + // uint: Seed + // byte: Rehash + // byte: Ordered + // byte: AllLengths + NgramLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); + SkipLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); + Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength); + HashBits = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(1 <= HashBits && HashBits <= 30); + Seed = ctx.Reader.ReadUInt32(); + RehashUnigrams = ctx.Reader.ReadBoolByte(); + Ordered = ctx.Reader.ReadBoolByte(); + AllLengths = ctx.Reader.ReadBoolByte(); + } + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -416,19 +449,56 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx) : + private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool loadLegacy = false) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer))) { Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); + if (loadLegacy) + { + int cbFloat = ctx.Reader.ReadInt32(); + Host.CheckDecode(cbFloat == sizeof(float)); + } var columnsLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(columnsLength > 0); var columns = new ColumnInfo[columnsLength]; + if (!loadLegacy) + { + // *** Binary format *** + // int number of columns + // columns + for (int i = 0; i < columnsLength; i++) + columns[i] = new ColumnInfo(ctx); + } + else + { + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: number of input column names + // int[]: ids of input column names + var outputs = new string[columnsLength]; + var inputs = new string[columnsLength][]; + for (int i = 0; i < columnsLength; i++) + { + outputs[i] = ctx.LoadNonEmptyString(); - // *** Binary format *** - // int number of columns - // columns - for (int i = 0; i < columnsLength; i++) - columns[i] = new ColumnInfo(ctx); + int csrc = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(csrc > 0); + inputs[i] = new string[csrc]; + for (int j = 0; j < csrc; j++) + { + string src = ctx.LoadNonEmptyString(); + inputs[i][j] = src; + } + } + + // *** Binary format *** + // int number of columns + // columns + for (int i = 0; i < columnsLength; i++) + columns[i] = new ColumnInfo(ctx, inputs[i], outputs[i]); + } _columns = columns.ToImmutableArray(); TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes); } @@ -469,7 +539,8 @@ private static NgramHashingTransformer Create(IHostEnvironment env, ModelLoadCon { Contracts.CheckValue(env, nameof(env)); var host = env.Register(nameof(NgramHashingTransformer)); - return new NgramHashingTransformer(host, ctx); + ctx.CheckAtModel(GetVersionInfo()); + return new NgramHashingTransformer(host, ctx, ctx.Header.ModelVerWritten < VersionTransformer); } private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 4989287a64..0c33dc549b 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -7,6 +7,7 @@ using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; +using Microsoft.ML.Model; using Microsoft.ML.RunTests; using Microsoft.ML.StaticPipe; using Microsoft.ML.Tools; @@ -172,7 +173,7 @@ public void StopWordsRemoverFromFactory() string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var data = TextLoader.Create(ML, new TextLoader.Arguments() { - Column = new [] + Column = new[] { new TextLoader.Column("Text", DataKind.TX, 1) } @@ -212,7 +213,7 @@ public void WordBagWorkout() .Read(sentimentDataPath); var est = new WordBagEstimator(Env, "text", "bag_of_words"). - Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash:-1)); + Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash: -1)); // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 @@ -269,6 +270,23 @@ public void NgramWorkout() Done(); } + [Fact] + void TestNgramCompatColumns() + { + string dropModelPath = GetDataPath("backcompat/ngram.zip"); + string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var data = TextLoader.CreateReader(ML, ctx => ( + Sentiment: ctx.LoadBool(0), + SentimentText: ctx.LoadText(1)), hasHeader: true) + .Read(sentimentDataPath); + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, data.AsDynamic, fs); + var featureColumn = result.Schema.GetColumnOrNull("Features"); + Assert.NotNull(featureColumn); + } + } + [Fact] public void LdaWorkout() { diff --git a/test/data/backcompat/ngram.zip b/test/data/backcompat/ngram.zip new file mode 100644 index 0000000000..d9184cad3c Binary files /dev/null and b/test/data/backcompat/ngram.zip differ