Skip to content

Commit 145437f

Browse files
committed
fix for key2value
1 parent 40f0298 commit 145437f

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/Microsoft.ML.Data/Transforms/KeyToValue.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,19 +505,20 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
505505
// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
506506
// may output a uint32. So cast it here to ensure that the data is treated correctly
507507
opType = "Cast";
508-
var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput");
508+
var srcShape = (int)ctx.RetrieveShapeOrNull(srcVariableName)[1];
509+
var castNodeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, srcShape), "CastNodeOutput");
509510
var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), "");
510511
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
511512
castNode.AddAttribute("to", t);
512513

513514
var labelEncoderOutput = dstVariableName;
514515
var labelEncoderInput = srcVariableName;
515516
if (TypeOutput == NumberDataViewType.Double || TypeOutput == BooleanDataViewType.Instance)
516-
labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastNodeOutput");
517+
labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, srcShape), "CastNodeOutput");
517518
else if (TypeOutput == NumberDataViewType.Int64 || TypeOutput == NumberDataViewType.UInt16 ||
518519
TypeOutput == NumberDataViewType.Int32 || TypeOutput == NumberDataViewType.Int16 ||
519520
TypeOutput == NumberDataViewType.UInt64 || TypeOutput == NumberDataViewType.UInt32)
520-
labelEncoderOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "CastNodeOutput");
521+
labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, srcShape), "CastNodeOutput");
521522

522523
opType = "LabelEncoder";
523524
var node = ctx.CreateNode(opType, castNodeOutput, labelEncoderOutput, ctx.GetNodeName(opType));

src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,8 @@ private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
787787
private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[] termIds, string srcVariableName, int iinfo,
788788
string opType, string labelEncoderOutput)
789789
{
790-
var castOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "castOutput");
790+
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
791+
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput");
791792
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
792793
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
793794
castNode.AddAttribute("to", t);
@@ -799,7 +800,8 @@ private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[]
799800
private void CastInputToFloat<T>(OnnxContext ctx, out OnnxNode node, out long[] termIds, string srcVariableName, int iinfo,
800801
string opType, string labelEncoderOutput)
801802
{
802-
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput");
803+
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
804+
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "castOutput");
803805
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
804806
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
805807
castNode.AddAttribute("to", t);
@@ -813,7 +815,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
813815
long[] termIds;
814816
string opType = "LabelEncoder";
815817
OnnxNode castNode;
816-
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");
818+
var labelEncoderOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, _types[iinfo].GetValueCount()), "LabelEncoderOutput");
817819

818820
var type = info.TypeSrc.GetItemType();
819821
if (type.Equals(TextDataViewType.Instance))

0 commit comments

Comments
 (0)