Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,13 @@ private static VersionInfo GetVersionInfo()
// verWrittenCur: 0x00010002, // Invert hash key values, hash fix
verWrittenCur: 0x00010003, // Get rid of writing float size in model context and change saving format
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(NgramHashingTransformer).Assembly.FullName);
}

private const int VersionTransformer = 0x00010003;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VersionTransformer [](start = 26, length = 18)

it would be more intuitite, i think, if called: CurrentVersion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble with naming it current version is that it is not a descriptive name, and it could change in the future. it is better to describe what changed then. The reason is, if we think about how these are used, they are used in the conditional checks during loading to do this or that. So, if we had some model format change in such a way that we added some information "Foo," then we should name that VersionFoo or something, so that the inevitable conditional test, if (header.Version >= VersionFoo) or what have you, makes sense. That would not if it was named CurrentVersion. It would also lead to bugs, since people might say, "hey, the current version changed, I'll change this." But it is absolutely essential that they do not, since we are using that field for a very specific test. So it is absolutely essential that it not be named something like CurrentVersion.


/// <summary>
/// Describes how the transformer handles one pair of mulitple inputs - singular output columns.
/// </summary>
Expand Down Expand Up @@ -242,6 +244,7 @@ public ColumnInfo(string[] inputs, string output,
InvertHash = invertHash;
RehashUnigrams = rehashUnigrams;
}

internal ColumnInfo(ModelLoadContext ctx)
{
Contracts.AssertValue(ctx);
Expand Down Expand Up @@ -275,6 +278,36 @@ internal ColumnInfo(ModelLoadContext ctx)
AllLengths = ctx.Reader.ReadBoolByte();
}

internal ColumnInfo(ModelLoadContext ctx, string[] inputs, string output)
{
Contracts.AssertValue(ctx);
Contracts.CheckValue(inputs, nameof(inputs));
Contracts.CheckParam(!inputs.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputs),
"Contained some null or empty items");
Inputs = inputs;
Output = output;
// *** Binary format ***
// string Output;
// int: NgramLength
// int: SkipLength
// int: HashBits
// uint: Seed
// byte: Rehash
// byte: Ordered
// byte: AllLengths
NgramLength = ctx.Reader.ReadInt32();
Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
SkipLength = ctx.Reader.ReadInt32();
Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
HashBits = ctx.Reader.ReadInt32();
Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
Seed = ctx.Reader.ReadUInt32();
RehashUnigrams = ctx.Reader.ReadBoolByte();
Ordered = ctx.Reader.ReadBoolByte();
AllLengths = ctx.Reader.ReadBoolByte();
}

internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
Expand Down Expand Up @@ -416,19 +449,56 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);

private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool loadLegacy = false) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer)))
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
if (loadLegacy)
{
int cbFloat = ctx.Reader.ReadInt32();
Host.CheckDecode(cbFloat == sizeof(float));
}
var columnsLength = ctx.Reader.ReadInt32();
Contracts.CheckDecode(columnsLength > 0);
var columns = new ColumnInfo[columnsLength];
if (!loadLegacy)
{
// *** Binary format ***
// int number of columns
// columns
for (int i = 0; i < columnsLength; i++)
columns[i] = new ColumnInfo(ctx);
}
else
{
// *** Binary format ***
// int: number of added columns
// for each added column
// int: id of output column name
// int: number of input column names
// int[]: ids of input column names
var outputs = new string[columnsLength];
var inputs = new string[columnsLength][];
for (int i = 0; i < columnsLength; i++)
{
outputs[i] = ctx.LoadNonEmptyString();

// *** Binary format ***
// int number of columns
// columns
for (int i = 0; i < columnsLength; i++)
columns[i] = new ColumnInfo(ctx);
int csrc = ctx.Reader.ReadInt32();
Contracts.CheckDecode(csrc > 0);
inputs[i] = new string[csrc];
for (int j = 0; j < csrc; j++)
{
string src = ctx.LoadNonEmptyString();
inputs[i][j] = src;
}
}

// *** Binary format ***
// int number of columns
// columns
for (int i = 0; i < columnsLength; i++)
columns[i] = new ColumnInfo(ctx, inputs[i], outputs[i]);
}
_columns = columns.ToImmutableArray();
TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes);
}
Expand Down Expand Up @@ -469,7 +539,8 @@ private static NgramHashingTransformer Create(IHostEnvironment env, ModelLoadCon
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(nameof(NgramHashingTransformer));
return new NgramHashingTransformer(host, ctx);
ctx.CheckAtModel(GetVersionInfo());
return new NgramHashingTransformer(host, ctx, ctx.Header.ModelVerWritten < VersionTransformer);
}

private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema);
Expand Down
22 changes: 20 additions & 2 deletions test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Model;
using Microsoft.ML.RunTests;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.Tools;
Expand Down Expand Up @@ -172,7 +173,7 @@ public void StopWordsRemoverFromFactory()
string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var data = TextLoader.Create(ML, new TextLoader.Arguments()
{
Column = new []
Column = new[]
{
new TextLoader.Column("Text", DataKind.TX, 1)
}
Expand Down Expand Up @@ -212,7 +213,7 @@ public void WordBagWorkout()
.Read(sentimentDataPath);

var est = new WordBagEstimator(Env, "text", "bag_of_words").
Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash:-1));
Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash: -1));

// The following call fails because of the following issue
// https://github.com/dotnet/machinelearning/issues/969
Expand Down Expand Up @@ -269,6 +270,23 @@ public void NgramWorkout()
Done();
}

[Fact]
void TestNgramCompatColumns()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new test! Should we be testing the specific output values too? (or is this somewhere I'm not seeing it?)

{
string dropModelPath = GetDataPath("backcompat/ngram.zip");
string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"wikipedia-detox-250-line-data.tsv" [](start = 51, length = 35)

should we stick to TestDatasets.Sentiment.trainFilename

var data = TextLoader.CreateReader(ML, ctx => (
Sentiment: ctx.LoadBool(0),
SentimentText: ctx.LoadText(1)), hasHeader: true)
.Read(sentimentDataPath);
using (FileStream fs = File.OpenRead(dropModelPath))
{
var result = ModelFileUtils.LoadTransforms(Env, data.AsDynamic, fs);
var featureColumn = result.Schema.GetColumnOrNull("Features");
Assert.NotNull(featureColumn);
}
}

[Fact]
public void LdaWorkout()
{
Expand Down
Binary file added test/data/backcompat/ngram.zip
Binary file not shown.