Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx export support for WordTokenizingTransformer and NgramExtractingTransformer #4451

Merged
merged 3 commits into from
Nov 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -606,16 +606,11 @@ public void SaveAsOnnx(OnnxContext ctx)
ColInfo info = _infos[iinfo];
string inputColumnName = info.InputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(info.Name, false);
continue;
}

if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(_types[iinfo], info.Name)))
{
ctx.RemoveColumn(info.Name, true);
}
var srcVariableName = ctx.GetVariableName(inputColumnName);
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.Name);
SaveAsOnnxCore(ctx, iinfo, info, srcVariableName, dstVariableName);
}
}

Expand Down Expand Up @@ -692,7 +687,7 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke
PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName));
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
var shape = ctx.RetrieveShapeOrNull(srcVariableName);
// Make sure that shape must present for calculating the reduction axes. The shape here is generally not null
Expand All @@ -703,8 +698,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
// default ONNX LabelEncoder just matches the behavior of Bag=false.
var encodedVariableName = _parent._columns[iinfo].OutputCountVector ? ctx.AddIntermediateVariable(null, "encoded", true) : dstVariableName;

string opType = "OneHotEncoder";
var node = ctx.CreateNode(opType, srcVariableName, encodedVariableName, ctx.GetNodeName(opType));
string opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(info.TypeSrc, opType, true);
var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", typeof(long));

opType = "OneHotEncoder";
var node = ctx.CreateNode(opType, castOutput, encodedVariableName, ctx.GetNodeName(opType));
node.AddAttribute("cats_int64s", Enumerable.Range(0, info.TypeSrc.GetItemType().GetKeyCountAsInt32(Host)).Select(x => (long)x));
node.AddAttribute("zeros", true);
if (_parent._columns[iinfo].OutputCountVector)
Expand All @@ -717,7 +717,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 });
reduceNode.AddAttribute("keepdims", 0);
}
return true;
}
}
}
Expand Down
70 changes: 59 additions & 11 deletions src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,22 +768,70 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b

private Delegate MakeGetter<T>(DataViewRow row, int src) => _termMap[src].GetMappingGetter(row);

private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
{
var terms = default(VBuffer<T>);
var map = (TermMap<T>)_termMap[iinfo].Map;
map.GetTerms(ref terms);

var termValues = terms.DenseValues();
var keyMapper = map.GetKeyMapper();

int i = 0;
termIds = new long[map.Count];
foreach (var term in termValues)
{
uint id = 0;
keyMapper(term, ref id);
termIds[i++] = id;
}
return termValues;
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
if (!(info.TypeSrc.GetItemType() is TextDataViewType))
OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);

if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<ReadOnlyMemory<char>>(iinfo, out termIds);
node.AddAttribute("keys_strings", terms);
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Single))
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<float>(iinfo, out termIds);
node.AddAttribute("keys_floats", terms);
}
else
{
// LabelEncoder-2 in ORT v1 only supports the following mappings
// int64-> float
// int64-> string
// float -> int64
// float -> string
// string -> int64
// string -> float
// In ML.NET the output of ValueToKeyMappingTransformer is always an integer type.
// Therefore the only input types we can accept for Onnx conversion are strings and floats handled above.
return false;
}

var terms = default(VBuffer<ReadOnlyMemory<char>>);
TermMap<ReadOnlyMemory<char>> map = (TermMap<ReadOnlyMemory<char>>)_termMap[iinfo].Map;
map.GetTerms(ref terms);
string opType = "LabelEncoder";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("classes_strings", terms.DenseValues());
node.AddAttribute("default_int64", -1);
//default_string needs to be an empty string but there is a BUG in Lotus that
//throws a validation error when default_string is empty. As a work around, set
//default_string to a space.
node.AddAttribute("default_string", " ");
node.AddAttribute("values_int64s", termIds);

// Onnx outputs an Int64, but ML.NET outputs a keytype. So cast it here
InternalDataKind dataKind;
InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind);

opType = "Cast";
var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", dataKind.ToType());

return true;
}

Expand Down
19 changes: 15 additions & 4 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,8 @@ public static NamedOnnxValue CreateScalarNamedOnnxValue<T>(string name, T data)
throw new NotImplementedException($"Not implemented type {typeof(T)}");

if (typeof(T) == typeof(ReadOnlyMemory<char>))
{
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }, false));
}
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }));

return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(new T[] { data }, new int[] { 1, 1 }));
}

Expand All @@ -452,7 +451,19 @@ public static NamedOnnxValue CreateNamedOnnxValue<T>(string name, ReadOnlySpan<T
{
if (!_onnxTypeMap.Contains(typeof(T)))
throw new NotImplementedException($"Not implemented type {typeof(T)}");
return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), shape.Select(x => (int)x).ToArray()));

var dimensions = shape.Select(x => (int)x).ToArray();

if (typeof(T) == typeof(ReadOnlyMemory<char>))
{
string[] stringData = new string[data.Length];
for (int i = 0; i < data.Length; i++)
stringData[i] = data[i].ToString();

return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(stringData, dimensions));
}

return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), dimensions));
}

/// <summary>
Expand Down
158 changes: 157 additions & 1 deletion src/Microsoft.ML.Transforms/Text/NgramTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;

Expand Down Expand Up @@ -124,6 +125,7 @@ private sealed class TransformInfo
public readonly bool[] NonEmptyLevels;
public readonly int NgramLength;
public readonly int SkipLength;
public readonly bool UseAllLengths;
public readonly NgramExtractingEstimator.WeightingCriteria Weighting;

public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf;
Expand All @@ -133,6 +135,7 @@ public TransformInfo(NgramExtractingEstimator.ColumnOptions info)
NgramLength = info.NgramLength;
SkipLength = info.SkipLength;
Weighting = info.Weighting;
UseAllLengths = info.UseAllLengths;
NonEmptyLevels = new bool[NgramLength];
}

Expand Down Expand Up @@ -469,7 +472,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _srcTypes;
private readonly int[] _srcCols;
Expand Down Expand Up @@ -551,6 +554,81 @@ private void GetSlotNames(int iinfo, int size, ref VBuffer<ReadOnlyMemory<char>>
dst = dstEditor.Commit();
}

private IEnumerable<long> GetNgramData(int iinfo, out long[] ngramCounts, out double[] weights, out List<long> indexes)
{
var transformInfo = _parent._transformInfos[iinfo];
var itemType = _srcTypes[iinfo].GetItemType();

Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues());

// Get the key values of the unigrams.
var keyCount = itemType.GetKeyCountAsInt32(Host);

var maxNGramLength = transformInfo.NgramLength;

var pool = _parent._ngramMaps[iinfo];

// the ngrams in ML.NET are sequentially organized. e.g. {a, a|b, b, b|c...}
// in onnx, they need to be separated by type. e.g. {a, b, c, a|b, b|c...}
// since the resulting vectors need to match, we need to create a mapping
// between the two and store it in the node attributes

// create seprate lists to track the ids of 1-grams, 2-grams etc
// because they need to be in adjacent regions in the same list
// when supplied to onnx
// We later concatenate all these separate n-gram lists
var ngramIds = new List<long>[maxNGramLength];
var ngramIndexes = new List<long>[maxNGramLength];
for (int i = 0; i < ngramIds.Length; i++)
{
ngramIds[i] = new List<long>();
ngramIndexes[i] = new List<long>();
//ngramWeights[i] = new List<float>();
}

weights = new double[pool.Count];

uint[] ngram = new uint[maxNGramLength];
for (int i = 0; i < pool.Count; i++)
{
var n = pool.GetById(i, ref ngram);
Host.Assert(n >= 0);

// add the id of each gram to the corresponding ids list
for (int j = 0; j < n; j++)
ngramIds[n - 1].Add(ngram[j]);

// add the indexes to the corresponding list
ngramIndexes[n - 1].Add(i);

if (transformInfo.RequireIdf)
weights[i] = _parent._invDocFreqs[iinfo][i];
else
weights[i] = 1.0f;
}

// initialize the ngramCounts array with start-index of each n-gram
int start = 0;
ngramCounts = new long[maxNGramLength];
for (int i = 0; i < maxNGramLength; i++)
{
ngramCounts[i] = start;
start += ngramIds[i].Count;
}

// concatenate all the lists and return
IEnumerable<long> allNGramIds = ngramIds[0];
indexes = ngramIndexes[0];
for (int i = 1; i < maxNGramLength; i++)
{
allNGramIds = Enumerable.Concat(allNGramIds, ngramIds[i]);
indexes = indexes.Concat(ngramIndexes[i]).ToList();
}

return allNGramIds;
}

private void ComposeNgramString(uint[] ngram, int count, StringBuilder sb, int keyCount, in VBuffer<ReadOnlyMemory<char>> terms)
{
Host.AssertValue(sb);
Expand Down Expand Up @@ -660,6 +738,84 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
}
return del;
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

int numColumns = _parent.ColumnPairs.Length;
for (int iinfo = 0; iinfo < numColumns; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
continue;

string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true);
SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName )
{
VBuffer<ReadOnlyMemory<char>> slotNames = default;
GetSlotNames(iinfo, 0, ref slotNames);

var transformInfo = _parent._transformInfos[iinfo];

// TfIdfVectorizer accepts strings, int32 and int64 tensors.
// But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
// So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
string opType;
var vectorType = _srcTypes[iinfo] as VectorDataViewType;

if ((vectorType != null) &&
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>))))
{
var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;

opType = "Cast";
string castOutput = ctx.AddIntermediateVariable(_srcTypes[iinfo], "CastOutput", true);

var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(dataKind).ToType();
castNode.AddAttribute("to", t);

srcVariableName = castOutput;
}

opType = "TfIdfVectorizer";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("max_gram_length", transformInfo.NgramLength);
node.AddAttribute("max_skip_count", transformInfo.SkipLength);
node.AddAttribute("min_gram_length", transformInfo.UseAllLengths ? 1 : transformInfo.NgramLength);

string mode;
if (transformInfo.RequireIdf)
{
mode = transformInfo.Weighting == NgramExtractingEstimator.WeightingCriteria.Idf ? "IDF" : "TFIDF";
}
else
{
mode = "TF";
}
node.AddAttribute("mode", mode);

long[] ngramCounts;
double[] ngramWeights;
List<long> ngramIndexes;

var ngramIds = GetNgramData(iinfo, out ngramCounts, out ngramWeights, out ngramIndexes);

node.AddAttribute("ngram_counts", ngramCounts);
node.AddAttribute("pool_int64s", ngramIds);
node.AddAttribute("ngram_indexes", ngramIndexes);
node.AddAttribute("weights", ngramWeights);
}

}
}

Expand Down
Loading