Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -505,19 +505,20 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
// may output a uint32. So cast it here to ensure that the data is treated correctly
opType = "Cast";
var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput");
var srcShape = (int)ctx.RetrieveShapeOrNull(srcVariableName)[1];
var castNodeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, srcShape), "CastNodeOutput");
var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
castNode.AddAttribute("to", t);

var labelEncoderOutput = dstVariableName;
var labelEncoderInput = srcVariableName;
if (TypeOutput == NumberDataViewType.Double || TypeOutput == BooleanDataViewType.Instance)
labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastNodeOutput");
labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, srcShape), "CastNodeOutput");
else if (TypeOutput == NumberDataViewType.Int64 || TypeOutput == NumberDataViewType.UInt16 ||
TypeOutput == NumberDataViewType.Int32 || TypeOutput == NumberDataViewType.Int16 ||
TypeOutput == NumberDataViewType.UInt64 || TypeOutput == NumberDataViewType.UInt32)
labelEncoderOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "CastNodeOutput");
labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, srcShape), "CastNodeOutput");

opType = "LabelEncoder";
var node = ctx.CreateNode(opType, castNodeOutput, labelEncoderOutput, ctx.GetNodeName(opType));
Expand Down
24 changes: 14 additions & 10 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -689,27 +689,31 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
var shape = ctx.RetrieveShapeOrNull(srcVariableName);
// Make sure that shape must present for calculating the reduction axes. The shape here is generally not null
// because inputs and outputs of a transform are declared with shapes.
Contracts.CheckValue(shape, nameof(shape));
var dim = info.TypeSrc.GetValueCount();

string opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, opType);
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, dim), opType);
var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", typeof(long));

opType = "OneHotEncoder";
var isOutputCountVector = _parent._columns[iinfo].OutputCountVector;
var categoryRange = info.TypeSrc.GetItemType().GetKeyCountAsInt32(Host);
var encodedVariableName = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, categoryRange), "encoded");
var typeShape = new VectorDataViewType(NumberDataViewType.Single, dim, categoryRange);

var encodedVariableName = (isOutputCountVector && info.TypeSrc is VectorDataViewType) ?
ctx.AddIntermediateVariable(typeShape, "encoded") : dstVariableName;
var node = ctx.CreateNode(opType, castOutput, encodedVariableName, ctx.GetNodeName(opType));
node.AddAttribute("cats_int64s", Enumerable.Range(1, categoryRange).Select(x => (long)x));
node.AddAttribute("zeros", true);

// OneHotEncoder adds one additional dimension, so we remove it below
opType = "Squeeze";
var reduceNode = ctx.CreateNode(opType, encodedVariableName, dstVariableName, ctx.GetNodeName(opType), "");
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 });
if (_parent._columns[iinfo].OutputCountVector && info.TypeSrc is VectorDataViewType)
{
opType = "ReduceSum";
var reduceNode = ctx.CreateNode(opType, encodedVariableName, dstVariableName, ctx.GetNodeName(opType), "");
reduceNode.AddAttribute("axes", new long[] { 1 });
reduceNode.AddAttribute("keepdims", 0);
}
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,8 @@ private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[] termIds, string srcVariableName, int iinfo,
string opType, string labelEncoderOutput)
{
var castOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "castOutput");
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
castNode.AddAttribute("to", t);
Expand All @@ -799,21 +800,23 @@ private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[]
private void CastInputToFloat<T>(OnnxContext ctx, out OnnxNode node, out long[] termIds, string srcVariableName, int iinfo,
string opType, string labelEncoderOutput)
{
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput");
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castNode.AddAttribute("to", t);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<T>(iinfo, out termIds);
node.AddAttribute("keys_floats", terms.Select(item => Convert.ToSingle(item)));
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
OnnxNode castNode;
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");
var labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, _types[iinfo].GetValueCount()), "LabelEncoderOutput");

var type = info.TypeSrc.GetItemType();
if (type.Equals(TextDataViewType.Instance))
Expand Down Expand Up @@ -876,7 +879,10 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
return false;
}

node.AddAttribute("default_int64", -1);
//Unknown keys should map to 0
node.AddAttribute("default_int64", 0);
node.AddAttribute("default_string", "0");
node.AddAttribute("default_float", 0f);
node.AddAttribute("values_int64s", termIds);

// Onnx outputs an Int64, but ML.NET outputs a keytype. So cast it here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@
},
{
"name": "default_int64",
"i": "-1",
"type": "INT"
},
{
"name": "default_string",
"s": "MA==",
"type": "STRING"
},
{
"name": "default_float",
"type": "FLOAT"
},
{
"name": "values_int64s",
"ints": [
Expand Down Expand Up @@ -94,7 +102,7 @@
"Cast"
],
"output": [
"encoded"
"F21"
],
"name": "OneHotEncoder",
"opType": "OneHotEncoder",
Expand Down Expand Up @@ -123,25 +131,6 @@
],
"domain": "ai.onnx.ml"
},
{
"input": [
"encoded"
],
"output": [
"F21"
],
"name": "Squeeze",
"opType": "Squeeze",
"attribute": [
{
"name": "axes",
"ints": [
"1"
],
"type": "INTS"
}
]
},
{
"input": [
"F21"
Expand Down Expand Up @@ -673,27 +662,6 @@
}
}
},
{
"name": "encoded",
"type": {
"tensorType": {
"elemType": 1,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "1"
},
{
"dimValue": "10"
}
]
}
}
}
},
{
"name": "F22",
"type": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@
},
{
"name": "default_int64",
"i": "-1",
"type": "INT"
},
{
"name": "default_string",
"s": "MA==",
"type": "STRING"
},
{
"name": "default_float",
"type": "FLOAT"
},
{
"name": "values_int64s",
"ints": [
Expand Down Expand Up @@ -92,7 +100,7 @@
"Cast"
],
"output": [
"encoded"
"F21"
],
"name": "OneHotEncoder",
"opType": "OneHotEncoder",
Expand Down Expand Up @@ -120,25 +128,6 @@
],
"domain": "ai.onnx.ml"
},
{
"input": [
"encoded"
],
"output": [
"F21"
],
"name": "Squeeze",
"opType": "Squeeze",
"attribute": [
{
"name": "axes",
"ints": [
"1"
],
"type": "INTS"
}
]
},
{
"input": [
"F1",
Expand Down Expand Up @@ -1022,27 +1011,6 @@
}
}
},
{
"name": "encoded",
"type": {
"tensorType": {
"elemType": 1,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "1"
},
{
"dimValue": "9"
}
]
}
}
}
},
{
"name": "VectorFeaturizerOutput",
"type": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,17 @@
},
{
"name": "default_int64",
"i": "-1",
"type": "INT"
},
{
"name": "default_string",
"s": "MA==",
"type": "STRING"
},
{
"name": "default_float",
"type": "FLOAT"
},
{
"name": "values_int64s",
"ints": [
Expand Down
Loading