Skip to content

Commit a9f3b4c

Browse files
Ivanidzo4kaTomFinley
authored andcommitted
Support back compat for ngram hash (#1988)
* Support back compat
1 parent ab7b486 commit a9f3b4c

File tree

3 files changed

+100
-11
lines changed

3 files changed

+100
-11
lines changed

src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,13 @@ private static VersionInfo GetVersionInfo()
165165
// verWrittenCur: 0x00010002, // Invert hash key values, hash fix
166166
verWrittenCur: 0x00010003, // Get rid of writing float size in model context and change saving format
167167
verReadableCur: 0x00010003,
168-
verWeCanReadBack: 0x00010003,
168+
verWeCanReadBack: 0x00010002,
169169
loaderSignature: LoaderSignature,
170170
loaderAssemblyName: typeof(NgramHashingTransformer).Assembly.FullName);
171171
}
172172

173+
private const int VersionTransformer = 0x00010003;
174+
173175
/// <summary>
174176
/// Describes how the transformer handles one pair of mulitple inputs - singular output columns.
175177
/// </summary>
@@ -242,6 +244,7 @@ public ColumnInfo(string[] inputs, string output,
242244
InvertHash = invertHash;
243245
RehashUnigrams = rehashUnigrams;
244246
}
247+
245248
internal ColumnInfo(ModelLoadContext ctx)
246249
{
247250
Contracts.AssertValue(ctx);
@@ -275,6 +278,36 @@ internal ColumnInfo(ModelLoadContext ctx)
275278
AllLengths = ctx.Reader.ReadBoolByte();
276279
}
277280

281+
internal ColumnInfo(ModelLoadContext ctx, string[] inputs, string output)
282+
{
283+
Contracts.AssertValue(ctx);
284+
Contracts.CheckValue(inputs, nameof(inputs));
285+
Contracts.CheckParam(!inputs.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputs),
286+
"Contained some null or empty items");
287+
Inputs = inputs;
288+
Output = output;
289+
// *** Binary format ***
290+
// string Output;
291+
// int: NgramLength
292+
// int: SkipLength
293+
// int: HashBits
294+
// uint: Seed
295+
// byte: Rehash
296+
// byte: Ordered
297+
// byte: AllLengths
298+
NgramLength = ctx.Reader.ReadInt32();
299+
Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
300+
SkipLength = ctx.Reader.ReadInt32();
301+
Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
302+
Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
303+
HashBits = ctx.Reader.ReadInt32();
304+
Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
305+
Seed = ctx.Reader.ReadUInt32();
306+
RehashUnigrams = ctx.Reader.ReadBoolByte();
307+
Ordered = ctx.Reader.ReadBoolByte();
308+
AllLengths = ctx.Reader.ReadBoolByte();
309+
}
310+
278311
internal void Save(ModelSaveContext ctx)
279312
{
280313
Contracts.AssertValue(ctx);
@@ -416,19 +449,56 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
416449
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
417450
=> Create(env, ctx).MakeRowMapper(inputSchema);
418451

419-
private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
452+
private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool loadLegacy = false) :
420453
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer)))
421454
{
422455
Host.CheckValue(ctx, nameof(ctx));
423-
ctx.CheckAtModel(GetVersionInfo());
456+
if (loadLegacy)
457+
{
458+
int cbFloat = ctx.Reader.ReadInt32();
459+
Host.CheckDecode(cbFloat == sizeof(float));
460+
}
424461
var columnsLength = ctx.Reader.ReadInt32();
462+
Contracts.CheckDecode(columnsLength > 0);
425463
var columns = new ColumnInfo[columnsLength];
464+
if (!loadLegacy)
465+
{
466+
// *** Binary format ***
467+
// int number of columns
468+
// columns
469+
for (int i = 0; i < columnsLength; i++)
470+
columns[i] = new ColumnInfo(ctx);
471+
}
472+
else
473+
{
474+
// *** Binary format ***
475+
// int: number of added columns
476+
// for each added column
477+
// int: id of output column name
478+
// int: number of input column names
479+
// int[]: ids of input column names
480+
var outputs = new string[columnsLength];
481+
var inputs = new string[columnsLength][];
482+
for (int i = 0; i < columnsLength; i++)
483+
{
484+
outputs[i] = ctx.LoadNonEmptyString();
426485

427-
// *** Binary format ***
428-
// int number of columns
429-
// columns
430-
for (int i = 0; i < columnsLength; i++)
431-
columns[i] = new ColumnInfo(ctx);
486+
int csrc = ctx.Reader.ReadInt32();
487+
Contracts.CheckDecode(csrc > 0);
488+
inputs[i] = new string[csrc];
489+
for (int j = 0; j < csrc; j++)
490+
{
491+
string src = ctx.LoadNonEmptyString();
492+
inputs[i][j] = src;
493+
}
494+
}
495+
496+
// *** Binary format ***
497+
// int number of columns
498+
// columns
499+
for (int i = 0; i < columnsLength; i++)
500+
columns[i] = new ColumnInfo(ctx, inputs[i], outputs[i]);
501+
}
432502
_columns = columns.ToImmutableArray();
433503
TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes);
434504
}
@@ -469,7 +539,8 @@ private static NgramHashingTransformer Create(IHostEnvironment env, ModelLoadCon
469539
{
470540
Contracts.CheckValue(env, nameof(env));
471541
var host = env.Register(nameof(NgramHashingTransformer));
472-
return new NgramHashingTransformer(host, ctx);
542+
ctx.CheckAtModel(GetVersionInfo());
543+
return new NgramHashingTransformer(host, ctx, ctx.Header.ModelVerWritten < VersionTransformer);
473544
}
474545

475546
private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema);

test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML;
88
using Microsoft.ML.Data;
99
using Microsoft.ML.Data.IO;
10+
using Microsoft.ML.Model;
1011
using Microsoft.ML.RunTests;
1112
using Microsoft.ML.StaticPipe;
1213
using Microsoft.ML.Tools;
@@ -172,7 +173,7 @@ public void StopWordsRemoverFromFactory()
172173
string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
173174
var data = TextLoader.Create(ML, new TextLoader.Arguments()
174175
{
175-
Column = new []
176+
Column = new[]
176177
{
177178
new TextLoader.Column("Text", DataKind.TX, 1)
178179
}
@@ -212,7 +213,7 @@ public void WordBagWorkout()
212213
.Read(sentimentDataPath);
213214

214215
var est = new WordBagEstimator(Env, "text", "bag_of_words").
215-
Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash:-1));
216+
Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash: -1));
216217

217218
// The following call fails because of the following issue
218219
// https://github.com/dotnet/machinelearning/issues/969
@@ -269,6 +270,23 @@ public void NgramWorkout()
269270
Done();
270271
}
271272

273+
[Fact]
274+
void TestNgramCompatColumns()
275+
{
276+
string dropModelPath = GetDataPath("backcompat/ngram.zip");
277+
string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
278+
var data = TextLoader.CreateReader(ML, ctx => (
279+
Sentiment: ctx.LoadBool(0),
280+
SentimentText: ctx.LoadText(1)), hasHeader: true)
281+
.Read(sentimentDataPath);
282+
using (FileStream fs = File.OpenRead(dropModelPath))
283+
{
284+
var result = ModelFileUtils.LoadTransforms(Env, data.AsDynamic, fs);
285+
var featureColumn = result.Schema.GetColumnOrNull("Features");
286+
Assert.NotNull(featureColumn);
287+
}
288+
}
289+
272290
[Fact]
273291
public void LdaWorkout()
274292
{

test/data/backcompat/ngram.zip

273 KB
Binary file not shown.

0 commit comments

Comments
 (0)