Skip to content

Commit a7fa16c

Browse files
authored
Fixed bugs in OptionalColumnTransform and added bool support (#4815)
* - Cleaned up OnnxContext's initializer interface - Cleaned up column comparison functionality on OnnxConversionTest - Fixed bugs in OptionalColumnTransform's onnx export and added support for boolean initializers * Fixed doc issues pointed out by code review
1 parent 6b6f04b commit a7fa16c

File tree

6 files changed

+352
-196
lines changed

6 files changed

+352
-196
lines changed

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using Microsoft.ML.Data;
78

@@ -130,7 +131,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
130131
public abstract List<long> RetrieveShapeOrNull(string variableName);
131132

132133
/// <summary>
133-
/// Call this function can declare a global float
134+
/// Call this function to declare a global bool scalar
135+
/// </summary>
136+
/// <param name="value">The boolean value which is going to be added</param>
137+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
138+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
139+
/// <returns>The initializer's ONNX name</returns>
140+
public abstract string AddInitializer(bool value, string name = null, bool makeUniqueName = true);
141+
142+
/// <summary>
143+
/// Call this function to declare a global float scalar
134144
/// </summary>
135145
/// <param name="value">The float number which is going to be added</param>
136146
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
@@ -139,16 +149,17 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
139149
public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true);
140150

141151
/// <summary>
142-
/// Call this function can declare a global long
152+
/// Call this function to declare a global integer scalar or smaller types
143153
/// </summary>
144-
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
154+
/// <param name="value">The float number which is going to be added</param>
155+
/// <param name="type">The type of integer to be added, e.g. typeof(short). Use this for all integer types Int32 and smaller</param>
145156
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
146157
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
147158
/// <returns>The initializer's ONNX name</returns>
148-
public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true);
159+
public abstract string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true);
149160

150161
/// <summary>
151-
/// Call this function can declare a global string
162+
/// Call this function to declare a global string scalar
152163
/// </summary>
153164
/// <param name="value">The string which is going to be added into the ONNX graph</param>
154165
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
@@ -157,43 +168,103 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
157168
public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true);
158169

159170
/// <summary>
160-
/// Call this function can declare a global float tensor
171+
/// Call this function to declare a global long scalar
172+
/// </summary>
173+
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
174+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
175+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
176+
/// <returns>The initializer's ONNX name</returns>
177+
public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true);
178+
179+
/// <summary>
180+
/// Call this function to declare a global double scalar
181+
/// </summary>
182+
/// <param name="value">The double number which is going to be added into the ONNX graph</param>
183+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
184+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
185+
/// <returns>The initializer's ONNX name</returns>
186+
public abstract string AddInitializer(double value, string name = null, bool makeUniqueName = true);
187+
188+
/// <summary>
189+
/// Call this function to declare a global ulong or uint scalar
190+
/// </summary>
191+
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
192+
/// <param name="isUint64">true if value contains a ulong value and false if it contains uint </param>
193+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
194+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
195+
/// <returns>The initializer's ONNX name</returns>
196+
public abstract string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true);
197+
198+
/// <summary>
199+
/// Call this function to declare a global bool tensor
200+
/// </summary>
201+
/// <param name="values">The boolean values which are going to be added into the ONNX graph</param>
202+
/// <param name="dims">The shape of values</param>
203+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
204+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
205+
/// <returns>The initializer's ONNX name</returns>
206+
public abstract string AddInitializer(IEnumerable<bool> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
207+
208+
/// <summary>
209+
/// Call this function to declare a global float tensor
161210
/// </summary>
162211
/// <param name="values">The floats which are going to be added into the ONNX graph</param>
163-
/// <param name="dims">The shape that the floats</param>
212+
/// <param name="dims">The shape of values</param>
164213
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
165214
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
166215
/// <returns>The initializer's ONNX name</returns>
167216
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
168217

169218
/// <summary>
170-
/// Call this function can declare a global long tensor
219+
/// Call this function to declare a global tensor of integer or smaller types
220+
/// </summary>
221+
/// <param name="values">The ints which are going to be added into the ONNX graph</param>
222+
/// <param name="type">The type of ints which are going to be added into the ONNX graph, e.g. typeof(short). Use this for adding array initializers of integer types smaller than Int32</param>
223+
/// <param name="dims">The shape of values</param>
224+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
225+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
226+
/// <returns>The initializer's ONNX name</returns>
227+
public abstract string AddInitializer(IEnumerable<int> values, Type type, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
228+
229+
/// <summary>
230+
/// Call this function to declare a global string tensor
231+
/// </summary>
232+
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
233+
/// <param name="dims">The shape of values</param>
234+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
235+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
236+
/// <returns>The initializer's ONNX name</returns>
237+
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
238+
239+
/// <summary>
240+
/// Call this function to declare a global long tensor
171241
/// </summary>
172242
/// <param name="values">The longs which are going to be added into the ONNX graph</param>
173-
/// <param name="dims">The shape that the floats</param>
243+
/// <param name="dims">The shape of values</param>
174244
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
175245
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
176246
/// <returns>The initializer's ONNX name</returns>
177247
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
178248

179249
/// <summary>
180-
/// Call this function can declare a global double tensor
250+
/// Call this function to declare a global double tensor
181251
/// </summary>
182252
/// <param name="values">The doubles which are going to be added into the ONNX graph</param>
183-
/// <param name="dims">The shape that the doubles</param>
253+
/// <param name="dims">The shape of values</param>
184254
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
185255
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
186256
/// <returns>The initializer's ONNX name</returns>
187257
public abstract string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
188258

189259
/// <summary>
190-
/// Call this function can declare a global string tensor
260+
/// Call this function to declare a global ulong tensor
191261
/// </summary>
192-
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
193-
/// <param name="dims">The shape that the strings</param>
262+
/// <param name="values">The unsigned integers which are going to be added into the ONNX graph</param>
263+
/// <param name="isUint64">Set to true if values contain ulong values false if they contain uint values</param>
264+
/// <param name="dims">The shape of values</param>
194265
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
195266
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
196267
/// <returns>The initializer's ONNX name</returns>
197-
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
268+
public abstract string AddInitializer(IEnumerable<ulong> values, bool isUint64, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
198269
}
199270
}

src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,27 @@ public override List<long> RetrieveShapeOrNull(string variableName)
279279
}
280280

281281
/// Adds constant tensor into the graph.
282+
public override string AddInitializer(bool value, string name = null, bool makeUniqueName = true)
283+
{
284+
name = AddVariable(name ?? "bool", makeUniqueName);
285+
_initializers.Add(OnnxUtils.MakeInt32(name, typeof(bool), value ? 1 : 0));
286+
return name;
287+
}
288+
282289
public override string AddInitializer(float value, string name = null, bool makeUniqueName = true)
283290
{
284291
name = AddVariable(name ?? "float", makeUniqueName);
285292
_initializers.Add(OnnxUtils.MakeFloat(name, value));
286293
return name;
287294
}
288295

296+
public override string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true)
297+
{
298+
name = AddVariable(name ?? "int32", makeUniqueName);
299+
_initializers.Add(OnnxUtils.MakeInt32(name, type, value));
300+
return name;
301+
}
302+
289303
public override string AddInitializer(string value, string name = null, bool makeUniqueName = true)
290304
{
291305
name = AddVariable(name ?? "string", makeUniqueName);
@@ -300,6 +314,31 @@ public override string AddInitializer(long value, string name = null, bool makeU
300314
return name;
301315
}
302316

317+
public override string AddInitializer(double value, string name = null, bool makeUniqueName = true)
318+
{
319+
name = AddVariable(name ?? "double", makeUniqueName);
320+
_initializers.Add(OnnxUtils.MakeDouble(name, value));
321+
return name;
322+
}
323+
324+
public override string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true)
325+
{
326+
name = AddVariable(name ?? "uint64", makeUniqueName);
327+
_initializers.Add(OnnxUtils.MakeUInt(name, isUint64, value));
328+
return name;
329+
}
330+
331+
public override string AddInitializer(IEnumerable<bool> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
332+
{
333+
_host.CheckValue(values, nameof(values));
334+
if (dims != null)
335+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
336+
337+
name = AddVariable(name ?? "bools", makeUniqueName);
338+
_initializers.Add(OnnxUtils.MakeInt32s(name, typeof(bool), values.Select(v => Convert.ToInt32(v)), dims));
339+
return name;
340+
}
341+
303342
public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
304343
{
305344
_host.CheckValue(values, nameof(values));
@@ -311,6 +350,28 @@ public override string AddInitializer(IEnumerable<float> values, IEnumerable<lon
311350
return name;
312351
}
313352

353+
public override string AddInitializer(IEnumerable<int> values, Type type, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
354+
{
355+
_host.CheckValue(values, nameof(values));
356+
if (dims != null)
357+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
358+
359+
name = AddVariable(name ?? "int32s", makeUniqueName);
360+
_initializers.Add(OnnxUtils.MakeInt32s(name, type, values, dims));
361+
return name;
362+
}
363+
364+
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
365+
{
366+
_host.CheckValue(values, nameof(values));
367+
if (dims != null)
368+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
369+
370+
name = AddVariable(name ?? "strings", makeUniqueName);
371+
_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
372+
return name;
373+
}
374+
314375
public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
315376
{
316377
_host.CheckValue(values, nameof(values));
@@ -328,19 +389,19 @@ public override string AddInitializer(IEnumerable<double> values, IEnumerable<lo
328389
if (dims != null)
329390
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
330391

331-
name = AddVariable(name ?? "double", makeUniqueName);
332-
_initializers.Add(OnnxUtils.MakeDouble(name, values, dims));
392+
name = AddVariable(name ?? "doubles", makeUniqueName);
393+
_initializers.Add(OnnxUtils.MakeDoubles(name, values, dims));
333394
return name;
334395
}
335396

336-
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
397+
public override string AddInitializer(IEnumerable<ulong> values, bool isUint64, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
337398
{
338399
_host.CheckValue(values, nameof(values));
339400
if (dims != null)
340401
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
341402

342-
name = AddVariable(name ?? "strings", makeUniqueName);
343-
_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
403+
name = AddVariable(name ?? "uints", makeUniqueName);
404+
_initializers.Add(OnnxUtils.MakeUInts(name, isUint64, values, dims));
344405
return name;
345406
}
346407

src/Microsoft.ML.OnnxConverter/OnnxUtils.cs

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,66 @@ public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnu
410410
return tensor;
411411
}
412412

413+
// Make int32 and smaller integer types scalar in ONNX from native C# number
414+
public static TensorProto MakeInt32(string name, Type type, int value)
415+
{
416+
var tensor = new TensorProto();
417+
tensor.Name = name;
418+
tensor.DataType = (int)ConvertToTensorProtoType(type);
419+
tensor.Int32Data.Add(value);
420+
return tensor;
421+
}
422+
423+
// Make int32 and smaller integer types vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
424+
public static TensorProto MakeInt32s(string name, Type type, IEnumerable<int> values, IEnumerable<long> dims = null)
425+
{
426+
var tensor = new TensorProto();
427+
tensor.Name = name;
428+
tensor.DataType = (int)ConvertToTensorProtoType(type);
429+
tensor.Int32Data.AddRange(values);
430+
if (dims != null)
431+
tensor.Dims.AddRange(dims);
432+
else
433+
tensor.Dims.Add(values.Count());
434+
return tensor;
435+
}
436+
437+
// Make ulong and uint integer types scalar in ONNX from native C# number
438+
public static TensorProto MakeUInt(string name, bool isUint64, ulong value)
439+
{
440+
var tensor = new TensorProto();
441+
tensor.Name = name;
442+
tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint));
443+
tensor.Uint64Data.Add(value);
444+
return tensor;
445+
}
446+
447+
// Make ulong and uint integer vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
448+
public static TensorProto MakeUInts(string name, bool isUint64, IEnumerable<ulong> values, IEnumerable<long> dims = null)
449+
{
450+
var tensor = new TensorProto();
451+
tensor.Name = name;
452+
tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint));
453+
tensor.Uint64Data.AddRange(values);
454+
if (dims != null)
455+
tensor.Dims.AddRange(dims);
456+
else
457+
tensor.Dims.Add(values.Count());
458+
return tensor;
459+
}
460+
461+
// Make int32 and smaller integer types scalar in ONNX from native C# number
462+
public static TensorProto MakeDouble(string name, double value)
463+
{
464+
var tensor = new TensorProto();
465+
tensor.Name = name;
466+
tensor.DataType = (int)TensorProto.Types.DataType.Double;
467+
tensor.DoubleData.Add(value);
468+
return tensor;
469+
}
470+
413471
// Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
414-
public static TensorProto MakeDouble(string name, IEnumerable<double> values, IEnumerable<long> dims = null)
472+
public static TensorProto MakeDoubles(string name, IEnumerable<double> values, IEnumerable<long> dims = null)
415473
{
416474
var tensor = new TensorProto();
417475
tensor.Name = name;

0 commit comments

Comments
 (0)