diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index b6d49773ec..66be13628a 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -1,7 +1,7 @@
0
- 5
+ 6
0
preview
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 0b6af3cdc9..e880e8c66b 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -10,7 +10,7 @@
2.1.2.2
0.0.0.5
4.5.0
- 0.11.0
+ 0.11.1
1.10.0
diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index f9a1b5ef27..1471c580ba 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -16,6 +16,7 @@
+
diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
index 6f21a1cb01..509a67bb4f 100644
--- a/src/Microsoft.ML.Core/Data/IEstimator.cs
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -28,20 +28,42 @@ public enum VectorKind
VariableVector
}
+ ///
+ /// The column name.
+ ///
public readonly string Name;
+
+ ///
+ /// The type of the column: scalar, fixed vector or variable vector.
+ ///
public readonly VectorKind Kind;
- public readonly DataKind ItemKind;
+
+ ///
+ /// The 'raw' type of column item: must be a primitive type or a structured type.
+ ///
+ public readonly ColumnType ItemType;
+ ///
+ /// The flag whether the column is actually a key. If yes, is representing
+ /// the underlying primitive type.
+ ///
public readonly bool IsKey;
+ ///
+ /// The metadata kinds that are present for this column.
+ ///
public readonly string[] MetadataKinds;
- public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null)
+ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadataKinds);
+ Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
+ Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
+
+ Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
Name = name;
Kind = vecKind;
- ItemKind = itemKind;
+ ItemType = itemType;
IsKey = isKey;
MetadataKinds = metadataKinds ?? new string[0];
}
@@ -51,7 +73,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
/// requirement.
///
/// Namely, it returns true iff:
- /// - The , , , fields match.
+ /// - The , , , fields match.
/// - The of is a superset of our .
///
public bool IsCompatibleWith(Column inputColumn)
@@ -61,7 +83,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (Kind != inputColumn.Kind)
return false;
- if (ItemKind != inputColumn.ItemKind)
+ if (!ItemType.Equals(inputColumn.ItemType))
return false;
if (IsKey != inputColumn.IsKey)
return false;
@@ -72,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn)
public string GetTypeString()
{
- string result = ItemKind.ToString();
+ string result = ItemType.ToString();
if (IsKey)
result = $"Key<{result}>";
if (Kind == VectorKind.Vector)
@@ -110,13 +132,15 @@ public static SchemaShape Create(ISchema schema)
else
vecKind = Column.VectorKind.Scalar;
- var kind = type.ItemType.RawKind;
+ ColumnType itemType = type.ItemType;
+ if (type.ItemType.IsKey)
+ itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
var isKey = type.ItemType.IsKey;
var metadataNames = schema.GetMetadataTypes(iCol)
.Select(kvp => kvp.Key)
.ToArray();
- cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
+ cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames));
}
}
return new SchemaShape(cols.ToArray());
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
new file mode 100644
index 0000000000..29c081ac35
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// The trivial implementation of that already has
+ /// the transformer and returns it on every call to .
+ ///
+ /// Concrete implementations still have to provide the schema propagation mechanism, since
+ /// there is no easy way to infer it from the transformer.
+ ///
+ public abstract class TrivialEstimator : IEstimator
+ where TTransformer : class, ITransformer
+ {
+ protected readonly IHost Host;
+ protected readonly TTransformer Transformer;
+
+ protected TrivialEstimator(IHost host, TTransformer transformer)
+ {
+ Contracts.AssertValue(host);
+
+ Host = host;
+ Host.CheckValue(transformer, nameof(transformer));
+ Transformer = transformer;
+ }
+
+ public TTransformer Fit(IDataView input) => Transformer;
+
+ public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
+ }
+}
diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
index 184c0226bb..3c47f84faa 100644
--- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
@@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var originalColumn = inputSchema.FindColumn(column.Source);
if (originalColumn != null)
{
- var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds);
+ var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds);
resultDic[column.Name] = col;
}
else
diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
new file mode 100644
index 0000000000..b301eac793
--- /dev/null
+++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
@@ -0,0 +1,177 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime.Model;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel
+ {
+ protected readonly IHost Host;
+ protected readonly (string input, string output)[] ColumnPairs;
+
+ protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns)
+ {
+ Contracts.AssertValue(host);
+ host.CheckValue(columns, nameof(columns));
+
+ var newNames = new HashSet();
+ foreach (var column in columns)
+ {
+ host.CheckNonEmpty(column.input, nameof(columns));
+ host.CheckNonEmpty(column.output, nameof(columns));
+
+ if (!newNames.Add(column.output))
+ throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times");
+ }
+
+ Host = host;
+ ColumnPairs = columns;
+ }
+
+ protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx)
+ {
+ Host = host;
+ // *** Binary format ***
+ // int: number of added columns
+ // for each added column
+ // int: id of output column name
+ // int: id of input column name
+
+ int n = ctx.Reader.ReadInt32();
+ ColumnPairs = new (string input, string output)[n];
+ for (int i = 0; i < n; i++)
+ {
+ string output = ctx.LoadNonEmptyString();
+ string input = ctx.LoadNonEmptyString();
+ ColumnPairs[i] = (input, output);
+ }
+ }
+
+ public abstract void Save(ModelSaveContext ctx);
+
+ protected void SaveColumns(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+
+ // *** Binary format ***
+ // int: number of added columns
+ // for each added column
+ // int: id of output column name
+ // int: id of input column name
+
+ ctx.Writer.Write(ColumnPairs.Length);
+ for (int i = 0; i < ColumnPairs.Length; i++)
+ {
+ ctx.SaveNonEmptyString(ColumnPairs[i].output);
+ ctx.SaveNonEmptyString(ColumnPairs[i].input);
+ }
+ }
+
+ private void CheckInput(ISchema inputSchema, int col, out int srcCol)
+ {
+ Contracts.AssertValue(inputSchema);
+ Contracts.Assert(0 <= col && col < ColumnPairs.Length);
+
+ if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input);
+ CheckInputColumn(inputSchema, col, srcCol);
+ }
+
+ protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
+ {
+ // By default, there are no extra checks.
+ }
+
+ protected abstract IRowMapper MakeRowMapper(ISchema schema);
+
+ public ISchema GetOutputSchema(ISchema inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ // Check that all the input columns are present and correct.
+ for (int i = 0; i < ColumnPairs.Length; i++)
+ CheckInput(inputSchema, i, out int col);
+
+ return Transform(new EmptyDataView(Host, inputSchema)).Schema;
+ }
+
+ public IDataView Transform(IDataView input) => MakeDataTransform(input);
+
+ protected RowToRowMapperTransform MakeDataTransform(IDataView input)
+ {
+ Host.CheckValue(input, nameof(input));
+ return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema));
+ }
+
+ protected abstract class MapperBase : IRowMapper
+ {
+ protected readonly IHost Host;
+ protected readonly Dictionary ColMapNewToOld;
+ protected readonly ISchema InputSchema;
+ private readonly OneToOneTransformerBase _parent;
+
+ protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema)
+ {
+ Contracts.AssertValue(host);
+ Contracts.AssertValue(parent);
+ Contracts.AssertValue(inputSchema);
+
+ Host = host;
+ _parent = parent;
+
+ ColMapNewToOld = new Dictionary();
+ for (int i = 0; i < _parent.ColumnPairs.Length; i++)
+ {
+ _parent.CheckInput(inputSchema, i, out int srcCol);
+ ColMapNewToOld.Add(i, srcCol);
+ }
+ InputSchema = inputSchema;
+ }
+ public Func GetDependencies(Func activeOutput)
+ {
+ var active = new bool[InputSchema.ColumnCount];
+ foreach (var pair in ColMapNewToOld)
+ if (activeOutput(pair.Key))
+ active[pair.Value] = true;
+ return col => active[col];
+ }
+
+ public abstract RowMapperColumnInfo[] GetOutputColumns();
+
+ public void Save(ModelSaveContext ctx) => _parent.Save(ctx);
+
+ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer)
+ {
+ Contracts.Assert(input.Schema == InputSchema);
+ var result = new Delegate[_parent.ColumnPairs.Length];
+ var disposers = new Action[_parent.ColumnPairs.Length];
+ for (int i = 0; i < _parent.ColumnPairs.Length; i++)
+ {
+ if (!activeOutput(i))
+ continue;
+ int srcCol = ColMapNewToOld[i];
+ result[i] = MakeGetter(input, i, out disposers[i]);
+ }
+ if (disposers.Any(x => x != null))
+ {
+ disposer = () =>
+ {
+ foreach (var act in disposers)
+ act();
+ };
+ }
+ else
+ disposer = null;
+ return result;
+ }
+
+ protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
index 97c613485f..921309d7ef 100644
--- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
+++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
@@ -16,7 +16,7 @@ public static class ImageAnalytics
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input);
- var xf = new ImageLoaderTransform(h, input, input.Data);
+ var xf = ImageLoaderTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im
public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input);
- var xf = new ImageResizerTransform(h, input, input.Data);
+ var xf = ImageResizerTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I
public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input);
- var xf = new ImagePixelExtractorTransform(h, input, input.Data);
+ var xf = ImagePixelExtractorTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
@@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment
public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input);
- var xf = new ImageGrayscaleTransform(h, input, input.Data);
+ var xf = ImageGrayscaleTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
index 7a267cf1b8..21076b2d86 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
@@ -2,22 +2,31 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Drawing;
-using System.Drawing.Imaging;
-using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-using Microsoft.ML.Runtime.ImageAnalytics;
+using System;
+using System.Collections.Generic;
+using System.Drawing;
+using System.Drawing.Imaging;
+using System.Linq;
+using System.Text;
-[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform),
ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")]
-[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(IDataTransform), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform),
+ ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadModel),
+ ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadRowMapper),
ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
namespace Microsoft.ML.Runtime.ImageAnalytics
@@ -28,7 +37,7 @@ namespace Microsoft.ML.Runtime.ImageAnalytics
/// Transform which takes one or many columns of type in IDataView and
/// convert them to greyscale representation of the same image.
///
- public sealed class ImageGrayscaleTransform : OneToOneTransformBase
+ public sealed class ImageGrayscaleTransform : OneToOneTransformerBase
{
public sealed class Column : OneToOneColumn
{
@@ -69,50 +78,57 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ImageGrayscale";
- // Public constructor corresponding to SignatureDataTransform.
- public ImageGrayscaleTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly();
+
+ public ImageGrayscaleTransform(IHostEnvironment env, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
- Metadata.Seal();
}
- private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ // Factory method for SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
- Host.AssertValue(ctx);
- // *** Binary format ***
- //
- Host.AssertNonEmpty(Infos);
- Metadata.Seal();
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+ env.CheckValue(args.Column, nameof(args.Column));
+
+ return new ImageGrayscaleTransform(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray())
+ .MakeDataTransform(input);
}
- public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model", ch => new ImageGrayscaleTransform(h, ctx, input));
+ return new ImageGrayscaleTransform(host, ctx);
}
+ private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
+ }
+
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
//
- SaveBase(ctx);
- }
-
- protected override ColumnType GetColumnTypeCore(int iinfo)
- {
- Host.Assert(0 <= iinfo & iinfo < Infos.Length);
- return Infos[iinfo].TypeSrc;
+ base.SaveColumns(ctx);
}
private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix(
@@ -125,47 +141,96 @@ protected override ColumnType GetColumnTypeCore(int iinfo)
new float[] {0, 0, 0, 0, 1}
});
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ if (!(inputSchema.GetColumnType(srcCol) is ImageType))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString());
+ }
- var src = default(Bitmap);
- var getSrc = GetSrcGetter(input, iinfo);
+ private sealed class Mapper : MapperBase
+ {
+ private ImageGrayscaleTransform _parent;
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
+ public Mapper(ImageGrayscaleTransform parent, ISchema inputSchema)
+ :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ }
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- dst.Dispose();
-
- getSrc(ref src);
- if (src == null || src.Height <= 0 || src.Width <= 0)
- return;
-
- dst = new Bitmap(src.Width, src.Height);
- ImageAttributes attributes = new ImageAttributes();
- attributes.SetColorMatrix(_grayscaleColorMatrix);
- var srcRectangle = new Rectangle(0, 0, src.Width, src.Height);
- using (var g = Graphics.FromImage(dst))
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent.ColumnPairs.Select((x, idx) => new RowMapperColumnInfo(x.output, InputSchema.GetColumnType(ColMapNewToOld[idx]), null)).ToArray();
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
+
+ var src = default(Bitmap);
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+
+ disposer =
+ () =>
+ {
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
{
- g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes);
- }
- Host.Assert(dst.Width == src.Width && dst.Height == src.Height);
- };
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+
+ dst = new Bitmap(src.Width, src.Height);
+ ImageAttributes attributes = new ImageAttributes();
+ attributes.SetColorMatrix(_grayscaleColorMatrix);
+ var srcRectangle = new Rectangle(0, 0, src.Width, src.Height);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes);
+ }
+ Contracts.Assert(dst.Width == src.Width && dst.Height == src.Height);
+ };
+
+ return del;
+ }
+ }
+ }
+
+ public sealed class ImageGrayscaleEstimator : TrivialEstimator
+ {
+ public ImageGrayscaleEstimator(IHostEnvironment env, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscaleEstimator)), new ImageGrayscaleTransform(env, columns))
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, new ImageType().ToString(), col.GetTypeString());
+
+ result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, col.ItemType, col.IsKey, col.MetadataKinds);
+ }
- return del;
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
index 488c710743..20e1476feb 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
@@ -2,31 +2,37 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Drawing;
-using System.IO;
-using System.Text;
-using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
+using System;
+using System.Collections.Generic;
+using System.Drawing;
+using System.IO;
+using System.Linq;
+using System.Text;
-[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform),
ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")]
-[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(IDataTransform), typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform),
ImageLoaderTransform.UserName, ImageLoaderTransform.LoaderSignature)]
+[assembly: LoadableClass(typeof(ImageLoaderTransform), null, typeof(SignatureLoadModel), "", ImageLoaderTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageLoaderTransform), null, typeof(SignatureLoadRowMapper), "", ImageLoaderTransform.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.ImageAnalytics
{
- // REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of type and loads them as
///
- public sealed class ImageLoaderTransform : OneToOneTransformBase
+ public sealed class ImageLoaderTransform : OneToOneTransformerBase
{
public sealed class Column : OneToOneColumn
{
@@ -61,118 +67,177 @@ public sealed class Arguments : TransformInputBase
internal const string UserName = "Image Loader Transform";
public const string LoaderSignature = "ImageLoaderTransform";
- private static VersionInfo GetVersionInfo()
- {
- return new VersionInfo(
- modelSignature: "IMGLOADT",
- //verWrittenCur: 0x00010001, // Initial
- verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
- verReadableCur: 0x00010002,
- verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
- }
+ public readonly string ImageFolder;
- private readonly ImageType _type;
- private readonly string _imageFolder;
+ public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly();
- private const string RegistrationName = "ImageLoader";
+ public ImageLoaderTransform(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderTransform)), columns)
+ {
+ ImageFolder = imageFolder;
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImageLoaderTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestIsText)
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data)
{
- Host.AssertNonEmpty(Infos);
- _imageFolder = args.ImageFolder;
- Host.Assert(Infos.Length == Utils.Size(args.Column));
- _type = new ImageType();
- Metadata.Seal();
+ return new ImageLoaderTransform(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray())
+ .MakeDataTransform(data);
}
- private ImageLoaderTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, TestIsText)
+ public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+
+ ctx.CheckAtModel(GetVersionInfo());
+ return new ImageLoaderTransform(env.Register(nameof(ImageLoaderTransform)), ctx);
+ }
+ private ImageLoaderTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
// *** Binary format ***
//
- _imageFolder = ctx.Reader.ReadString();
- _type = new ImageType();
- Metadata.Seal();
+ // int: id of image folder
+
+ ImageFolder = ctx.LoadStringOrNull();
}
- public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model", ch => new ImageLoaderTransform(h, ctx, input));
+ if (!inputSchema.GetColumnType(srcCol).IsText)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema.GetColumnType(srcCol).ToString());
}
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
//
- ctx.Writer.Write(_imageFolder);
- SaveBase(ctx);
+ // int: id of image folder
+
+ base.SaveColumns(ctx);
+ ctx.SaveStringOrNull(ImageFolder);
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ private static VersionInfo GetVersionInfo()
{
- Host.Check(0 <= iinfo && iinfo < Infos.Length);
- return _type;
+ return new VersionInfo(
+ modelSignature: "IMGLOADR",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ private sealed class Mapper : MapperBase
{
- Host.AssertValue(ch, nameof(ch));
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
- disposer = null;
-
- var getSrc = GetSrcGetter(input, iinfo);
- DvText src = default;
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- {
- dst.Dispose();
- dst = null;
- }
+ private readonly ImageLoaderTransform _parent;
+ private readonly ImageType _imageType;
- getSrc(ref src);
+ public Mapper(ImageLoaderTransform parent, ISchema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _imageType = new ImageType();
+ _parent = parent;
+ }
- if (src.Length > 0)
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
+
+ disposer = null;
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ DvText src = default;
+ ValueGetter del =
+ (ref Bitmap dst) =>
{
- // Catch exceptions and pass null through. Should also log failures...
- try
+ if (dst != null)
{
- string path = src.ToString();
- if (!string.IsNullOrWhiteSpace(_imageFolder))
- path = Path.Combine(_imageFolder, path);
- dst = new Bitmap(path);
+ dst.Dispose();
+ dst = null;
}
- catch (Exception e)
+
+ getSrc(ref src);
+
+ if (src.Length > 0)
{
- // REVIEW: We catch everything since the documentation for new Bitmap(string)
- // appears to be incorrect. When the file isn't found, it throws an ArgumentException,
- // while the documentation says FileNotFoundException. Not sure what it will throw
- // in other cases, like corrupted file, etc.
-
- // REVIEW : Log failures.
- ch.Info(e.Message);
- ch.Info(e.StackTrace);
- dst = null;
+ // Catch exceptions and pass null through. Should also log failures...
+ try
+ {
+ string path = src.ToString();
+ if (!string.IsNullOrWhiteSpace(_parent.ImageFolder))
+ path = Path.Combine(_parent.ImageFolder, path);
+ dst = new Bitmap(path);
+ }
+ catch (Exception)
+ {
+ // REVIEW: We catch everything since the documentation for new Bitmap(string)
+ // appears to be incorrect. When the file isn't found, it throws an ArgumentException,
+ // while the documentation says FileNotFoundException. Not sure what it will throw
+ // in other cases, like corrupted file, etc.
+
+ // REVIEW : Log failures.
+ dst = null;
+ }
}
- }
- };
- return del;
+ };
+ return del;
+ }
+
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent.ColumnPairs.Select(x => new RowMapperColumnInfo(x.output, _imageType, null)).ToArray();
+ }
+ }
+
+ public sealed class ImageLoaderEstimator : TrivialEstimator
+ {
+ private readonly ImageType _imageType;
+
+ public ImageLoaderEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns)
+ : this(env, new ImageLoaderTransform(env, imageFolder, columns))
+ {
+ }
+
+ public ImageLoaderEstimator(IHostEnvironment env, ImageLoaderTransform transformer)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderEstimator)), transformer)
+ {
+ _imageType = new ImageType();
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var (input, output) in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
+ if (!col.ItemType.IsText || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString());
+
+ result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false);
+ }
+
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
index de0aa98124..0bd5bf7879 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using System.Drawing;
+using System.Linq;
using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
@@ -13,19 +16,24 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform),
ImagePixelExtractorTransform.UserName, "ImagePixelExtractorTransform", "ImagePixelExtractor")]
-[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(IDataTransform), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform),
+ ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadModel),
+ ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadRowMapper),
ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
namespace Microsoft.ML.Runtime.ImageAnalytics
{
- // REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of and convert them into vector representation.
///
- public sealed class ImagePixelExtractorTransform : OneToOneTransformBase
+ public sealed class ImagePixelExtractorTransform : OneToOneTransformerBase
{
public class Column : OneToOneColumn
{
@@ -110,24 +118,28 @@ public class Arguments : TransformInputBase
/// Which color channels are extracted. Note that these values are serialized so should not be modified.
///
[Flags]
- private enum ColorBits : byte
+ public enum ColorBits : byte
{
Alpha = 0x01,
Red = 0x02,
Green = 0x04,
Blue = 0x08,
+ Rgb = Red | Green | Blue,
All = Alpha | Red | Green | Blue
}
- private sealed class ColInfoEx
+ public sealed class ColumnInfo
{
+ public readonly string Input;
+ public readonly string Output;
+
public readonly ColorBits Colors;
public readonly byte Planes;
public readonly bool Convert;
- public readonly Single Offset;
- public readonly Single Scale;
+ public readonly float Offset;
+ public readonly float Scale;
public readonly bool Interleave;
public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } }
@@ -135,8 +147,14 @@ private sealed class ColInfoEx
public bool Green { get { return (Colors & ColorBits.Green) != 0; } }
public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } }
- public ColInfoEx(Column item, Arguments args)
+ internal ColumnInfo(Column item, Arguments args)
{
+ Contracts.CheckValue(item, nameof(item));
+ Contracts.CheckValue(args, nameof(args));
+
+ Input = item.Source ?? item.Name;
+ Output = item.Name;
+
if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; }
if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; }
if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; }
@@ -160,10 +178,57 @@ public ColInfoEx(Column item, Arguments args)
}
}
- public ColInfoEx(ModelLoadContext ctx)
+ public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false)
+ : this(input, output, colors, interleave, false, 1f, 0f)
+ {
+ }
+
+ public ColumnInfo(string input, string output, ColorBits colors = ColorBits.Rgb, bool interleave = false, float scale = 1f, float offset = 0f)
+ : this(input, output, colors, interleave, true, scale, offset)
+ {
+ }
+
+ private ColumnInfo(string input, string output, ColorBits colors, bool interleave, bool convert, float scale, float offset)
+ {
+ Contracts.CheckNonEmpty(input, nameof(input));
+ Contracts.CheckNonEmpty(output, nameof(output));
+
+ Input = input;
+ Output = output;
+ Colors = colors;
+
+ if ((Colors & ColorBits.Alpha) == ColorBits.Alpha) Planes++;
+ if ((Colors & ColorBits.Red) == ColorBits.Red) Planes++;
+ if ((Colors & ColorBits.Green) == ColorBits.Green) Planes++;
+ if ((Colors & ColorBits.Blue) == ColorBits.Blue) Planes++;
+ Contracts.CheckParam(Planes > 0, nameof(colors), "Need to use at least one color plane");
+
+ Interleave = interleave;
+
+ Convert = convert;
+ if (!Convert)
+ {
+ Offset = 0;
+ Scale = 1;
+ }
+ else
+ {
+ Offset = offset;
+ Scale = scale;
+ Contracts.CheckParam(FloatUtils.IsFinite(Offset), nameof(offset));
+ Contracts.CheckParam(FloatUtils.IsFiniteNonZero(Scale), nameof(scale));
+ }
+ }
+
+ internal ColumnInfo(string input, string output, ModelLoadContext ctx)
{
+ Contracts.AssertNonEmpty(input);
+ Contracts.AssertNonEmpty(output);
Contracts.AssertValue(ctx);
+ Input = input;
+ Output = output;
+
// *** Binary format ***
// byte: colors
// byte: convert
@@ -193,7 +258,6 @@ public ColInfoEx(ModelLoadContext ctx)
public void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
-
#if DEBUG
// This code is used in deserialization - assert that it matches what we computed above.
int planes = (int)Colors;
@@ -237,305 +301,368 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ImagePixelExtractor";
- private readonly ColInfoEx[] _exes;
- private readonly VectorType[] _types;
+ private readonly ColumnInfo[] _columns;
+
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
+
+ public ImagePixelExtractorTransform(IHostEnvironment env, string inputColumn, string outputColumn,
+ ColorBits colors = ColorBits.Rgb, bool interleave = false)
+ : this(env, new ColumnInfo(inputColumn, outputColumn, colors, interleave))
+ {
+ }
+
+ public ImagePixelExtractorTransform(IHostEnvironment env, params ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
+ {
+ _columns = columns.ToArray();
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImagePixelExtractorTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input,
- t => t is ImageType ? null : "Expected Image type")
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Contracts.CheckValue(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
+ }
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ // SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ env.CheckValue(args.Column, nameof(args.Column));
+
+ var columns = new ColumnInfo[args.Column.Length];
+ for (int i = 0; i < columns.Length; i++)
{
var item = args.Column[i];
- _exes[i] = new ColInfoEx(item, args);
+ columns[i] = new ColumnInfo(item, args);
}
- _types = ConstructTypes(true);
+ var transformer = new ImagePixelExtractorTransform(env, columns);
+ return new RowToRowMapperTransform(env, input, transformer.MakeRowMapper(input.Schema));
}
- private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
- // *** Binary format ***
- //
- //
- // foreach added column
- // ColInfoEx
- Host.AssertNonEmpty(Infos);
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
- _exes[i] = new ColInfoEx(ctx);
-
- _types = ConstructTypes(false);
+ return new ImagePixelExtractorTransform(host, ctx);
}
- public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
{
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ //
- return h.Apply("Loading Model",
- ch =>
- {
- // *** Binary format ***
- // int: sizeof(Float)
- //
- int cbFloat = ctx.Reader.ReadInt32();
- ch.CheckDecode(cbFloat == sizeof(Single));
- return new ImagePixelExtractorTransform(h, ctx, input);
- });
+ // for each added column
+ // ColumnInfo
+
+ _columns = new ColumnInfo[ColumnPairs.Length];
+ for (int i = 0; i < _columns.Length; i++)
+ _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx);
}
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
- // int: sizeof(Float)
//
- // foreach added column
- // ColInfoEx
- ctx.Writer.Write(sizeof(Single));
- SaveBase(ctx);
-
- Host.Assert(_exes.Length == Infos.Length);
- for (int i = 0; i < _exes.Length; i++)
- _exes[i].Save(ctx);
- }
- private VectorType[] ConstructTypes(bool user)
- {
- var types = new VectorType[Infos.Length];
- for (int i = 0; i < Infos.Length; i++)
- {
- var info = Infos[i];
- var ex = _exes[i];
- Host.Assert(ex.Planes > 0);
+ // for each added column
+ // ColumnInfo
- var type = Source.Schema.GetColumnType(info.Source) as ImageType;
- Host.Assert(type != null);
- if (type.Height <= 0 || type.Width <= 0)
- {
- // REVIEW: Could support this case by making the destination column be variable sized.
- // However, there's no mechanism to communicate the dimensions through with the pixel data.
- string name = Source.Schema.GetColumnName(info.Source);
- throw user ?
- Host.ExceptUserArg(nameof(Arguments.Column), "Column '{0}' does not have known size", name) :
- Host.Except("Column '{0}' does not have known size", name);
- }
- int height = type.Height;
- int width = type.Width;
- Host.Assert(height > 0);
- Host.Assert(width > 0);
- Host.Assert((long)height * width <= int.MaxValue / 4);
-
- if (ex.Interleave)
- types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, height, width, ex.Planes);
- else
- types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, ex.Planes, height, width);
- }
- Metadata.Seal();
- return types;
+ base.SaveColumns(ctx);
+
+ foreach (ColumnInfo info in _columns)
+ info.Save(ctx);
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.Assert(0 <= iinfo & iinfo < Infos.Length);
- return _types[iinfo];
+ var inputColName = _columns[col].Input;
+ var imageType = inputSchema.GetColumnType(srcCol) as ImageType;
+ if (imageType == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema.GetColumnType(srcCol).ToString());
+ if (imageType.Height <= 0 || imageType.Width <= 0)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "known-size image", "unknown-size image");
+ if ((long)imageType.Height * imageType.Width > int.MaxValue / 4)
+ throw Host.Except("Image dimensions are too large");
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ private sealed class Mapper : MapperBase
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ private readonly ImagePixelExtractorTransform _parent;
+ private readonly VectorType[] _types;
- if (_exes[iinfo].Convert)
- return GetGetterCore(input, iinfo, out disposer);
- return GetGetterCore(input, iinfo, out disposer);
- }
+ public Mapper(ImagePixelExtractorTransform parent, ISchema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ _types = ConstructTypes();
+ }
- //REVIEW Rewrite it to where TValue : IConvertible
- private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
- {
- var type = _types[iinfo];
- Host.Assert(type.DimCount == 3);
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent._columns.Select((x, idx) => new RowMapperColumnInfo(x.Output, _types[idx], null)).ToArray();
- var ex = _exes[iinfo];
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
- int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
- int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
- int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
+ if (_parent._columns[iinfo].Convert)
+ return GetGetterCore(input, iinfo, out disposer);
+ return GetGetterCore(input, iinfo, out disposer);
+ }
- int size = type.ValueCount;
- Host.Assert(size > 0);
- Host.Assert(size == planes * height * width);
- int cpix = height * width;
+ //REVIEW Rewrite it to where TValue : IConvertible
+ private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
+ {
+ var type = _types[iinfo];
+ Contracts.Assert(type.DimCount == 3);
- var getSrc = GetSrcGetter(input, iinfo);
- var src = default(Bitmap);
+ var ex = _parent._columns[iinfo];
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
+ int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
+ int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
+ int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
- return
- (ref VBuffer dst) =>
- {
- getSrc(ref src);
- Contracts.AssertValueOrNull(src);
+ int size = type.ValueCount;
+ Contracts.Assert(size > 0);
+ Contracts.Assert(size == planes * height * width);
+ int cpix = height * width;
+
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ var src = default(Bitmap);
- if (src == null)
+ disposer =
+ () =>
{
- dst = new VBuffer(size, 0, dst.Values, dst.Indices);
- return;
- }
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
- Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb);
- Host.Check(src.Height == height && src.Width == width);
+ return
+ (ref VBuffer dst) =>
+ {
+ getSrc(ref src);
+ Contracts.AssertValueOrNull(src);
- var values = dst.Values;
- if (Utils.Size(values) < size)
- values = new TValue[size];
+ if (src == null)
+ {
+ dst = new VBuffer(size, 0, dst.Values, dst.Indices);
+ return;
+ }
- Single offset = ex.Offset;
- Single scale = ex.Scale;
- Host.Assert(scale != 0);
+ Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb);
+ Host.Check(src.Height == height && src.Width == width);
- var vf = values as Single[];
- var vb = values as byte[];
- Host.Assert(vf != null || vb != null);
- bool needScale = offset != 0 || scale != 1;
- Host.Assert(!needScale || vf != null);
+ var values = dst.Values;
+ if (Utils.Size(values) < size)
+ values = new TValue[size];
- bool a = ex.Alpha;
- bool r = ex.Red;
- bool g = ex.Green;
- bool b = ex.Blue;
+ float offset = ex.Offset;
+ float scale = ex.Scale;
+ Contracts.Assert(scale != 0);
- int h = height;
- int w = width;
+ var vf = values as float[];
+ var vb = values as byte[];
+ Contracts.Assert(vf != null || vb != null);
+ bool needScale = offset != 0 || scale != 1;
+ Contracts.Assert(!needScale || vf != null);
- if (ex.Interleave)
- {
- int idst = 0;
- for (int y = 0; y < h; ++y)
- for (int x = 0; x < w; x++)
- {
- var pb = src.GetPixel(y, x);
- if (vb != null)
+ bool a = ex.Alpha;
+ bool r = ex.Red;
+ bool g = ex.Green;
+ bool b = ex.Blue;
+
+ int h = height;
+ int w = width;
+
+ if (ex.Interleave)
+ {
+ int idst = 0;
+ for (int y = 0; y < h; ++y)
+ for (int x = 0; x < w; x++)
{
- if (a) { vb[idst++] = (byte)0; }
- if (r) { vb[idst++] = pb.R; }
- if (g) { vb[idst++] = pb.G; }
- if (b) { vb[idst++] = pb.B; }
+ var pb = src.GetPixel(y, x);
+ if (vb != null)
+ {
+ if (a) { vb[idst++] = (byte)0; }
+ if (r) { vb[idst++] = pb.R; }
+ if (g) { vb[idst++] = pb.G; }
+ if (b) { vb[idst++] = pb.B; }
+ }
+ else if (!needScale)
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = pb.R; }
+ if (g) { vf[idst++] = pb.G; }
+ if (b) { vf[idst++] = pb.B; }
+ }
+ else
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = (pb.R - offset) * scale; }
+ if (g) { vf[idst++] = (pb.B - offset) * scale; }
+ if (b) { vf[idst++] = (pb.G - offset) * scale; }
+ }
}
- else if (!needScale)
+ Contracts.Assert(idst == size);
+ }
+ else
+ {
+ int idstMin = 0;
+ if (ex.Alpha)
+ {
+ // The image only has rgb but we need to supply alpha as well, so fake it up,
+ // assuming that it is 0xFF.
+ if (vf != null)
{
- if (a) { vf[idst++] = 0.0f; }
- if (r) { vf[idst++] = pb.R; }
- if (g) { vf[idst++] = pb.G; }
- if (b) { vf[idst++] = pb.B; }
+ Single v = (0xFF - offset) * scale;
+ for (int i = 0; i < cpix; i++)
+ vf[i] = v;
}
else
{
- if (a) { vf[idst++] = 0.0f; }
- if (r) { vf[idst++] = (pb.R - offset) * scale; }
- if (g) { vf[idst++] = (pb.B - offset) * scale; }
- if (b) { vf[idst++] = (pb.G - offset) * scale; }
+ for (int i = 0; i < cpix; i++)
+ vb[i] = 0xFF;
}
- }
- Host.Assert(idst == size);
- }
- else
- {
- int idstMin = 0;
- if (ex.Alpha)
- {
- // The image only has rgb but we need to supply alpha as well, so fake it up,
- // assuming that it is 0xFF.
- if (vf != null)
- {
- Single v = (0xFF - offset) * scale;
- for (int i = 0; i < cpix; i++)
- vf[i] = v;
- }
- else
- {
- for (int i = 0; i < cpix; i++)
- vb[i] = 0xFF;
- }
- idstMin = cpix;
+ idstMin = cpix;
- // We've preprocessed alpha, avoid it in the
- // scan operation below.
- a = false;
- }
-
- for (int y = 0; y < h; ++y)
- {
- int idstBase = idstMin + y * w;
+ // We've preprocessed alpha, avoid it in the
+ // scan operation below.
+ a = false;
+ }
- // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB.
- if (vb != null)
+ for (int y = 0; y < h; ++y)
{
- for (int x = 0; x < w; x++, idstBase++)
+ int idstBase = idstMin + y * w;
+
+ // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB.
+ if (vb != null)
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vb[idst] = pb.A; idst += cpix; }
- if (r) { vb[idst] = pb.R; idst += cpix; }
- if (g) { vb[idst] = pb.G; idst += cpix; }
- if (b) { vb[idst] = pb.B; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vb[idst] = pb.A; idst += cpix; }
+ if (r) { vb[idst] = pb.R; idst += cpix; }
+ if (g) { vb[idst] = pb.G; idst += cpix; }
+ if (b) { vb[idst] = pb.B; idst += cpix; }
+ }
}
- }
- else if (!needScale)
- {
- for (int x = 0; x < w; x++, idstBase++)
+ else if (!needScale)
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vf[idst] = pb.A; idst += cpix; }
- if (r) { vf[idst] = pb.R; idst += cpix; }
- if (g) { vf[idst] = pb.G; idst += cpix; }
- if (b) { vf[idst] = pb.B; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = pb.A; idst += cpix; }
+ if (r) { vf[idst] = pb.R; idst += cpix; }
+ if (g) { vf[idst] = pb.G; idst += cpix; }
+ if (b) { vf[idst] = pb.B; idst += cpix; }
+ }
}
- }
- else
- {
- for (int x = 0; x < w; x++, idstBase++)
+ else
{
- var pb = src.GetPixel(x, y);
- int idst = idstBase;
- if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; }
- if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; }
- if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; }
- if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; }
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; }
+ if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; }
+ if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; }
+ if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; }
+ }
}
}
}
- }
- dst = new VBuffer(size, values, dst.Indices);
- };
+ dst = new VBuffer(size, values, dst.Indices);
+ };
+ }
+
+ private VectorType[] ConstructTypes()
+ {
+ var types = new VectorType[_parent._columns.Length];
+ for (int i = 0; i < _parent._columns.Length; i++)
+ {
+ var column = _parent._columns[i];
+ Contracts.Assert(column.Planes > 0);
+
+ var type = InputSchema.GetColumnType(ColMapNewToOld[i]) as ImageType;
+ Contracts.Assert(type != null);
+
+ int height = type.Height;
+ int width = type.Width;
+ Contracts.Assert(height > 0);
+ Contracts.Assert(width > 0);
+ Contracts.Assert((long)height * width <= int.MaxValue / 4);
+
+ if (column.Interleave)
+ types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, height, width, column.Planes);
+ else
+ types[i] = new VectorType(column.Convert ? NumberType.Float : NumberType.U1, column.Planes, height, width);
+ }
+ return types;
+ }
+ }
+ }
+
+ public sealed class ImagePixelExtractorEstimator : TrivialEstimator
+ {
+ public ImagePixelExtractorEstimator(IHostEnvironment env, string inputColumn, string outputColumn,
+ ImagePixelExtractorTransform.ColorBits colors = ImagePixelExtractorTransform.ColorBits.Rgb, bool interleave = false)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, inputColumn, outputColumn, colors, interleave))
+ {
+ }
+
+ public ImagePixelExtractorEstimator(IHostEnvironment env, params ImagePixelExtractorTransform.ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractorEstimator)), new ImagePixelExtractorTransform(env, columns))
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.Input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString());
+
+ var itemType = colInfo.Convert ? NumberType.R4 : NumberType.U1;
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, itemType, false);
+ }
+
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
index dd1abc9181..ac11c7fa8d 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using System.Drawing;
+using System.Linq;
using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
@@ -20,13 +23,19 @@
[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel),
+ ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageResizerTransform), null, typeof(SignatureLoadRowMapper),
+ ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.ImageAnalytics
{
// REVIEW: Rewrite as LambdaTransform to simplify.
///
/// Transform which takes one or many columns of and resize them to provided height and width.
///
- public sealed class ImageResizerTransform : OneToOneTransformBase
+ public sealed class ImageResizerTransform : OneToOneTransformerBase
{
public enum ResizingKind : byte
{
@@ -98,23 +107,30 @@ public class Arguments : TransformInputBase
}
///
- /// Extra information for each column (in addition to ColumnInfo).
+ /// Information for each column pair.
///
- private sealed class ColInfoEx
+ public sealed class ColumnInfo
{
+ public readonly string Input;
+ public readonly string Output;
+
public readonly int Width;
public readonly int Height;
public readonly ResizingKind Scale;
public readonly Anchor Anchor;
public readonly ColumnType Type;
- public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor)
+ public ColumnInfo(string input, string output, int width, int height, ResizingKind scale, Anchor anchor)
{
+ Contracts.CheckNonEmpty(input, nameof(input));
+ Contracts.CheckNonEmpty(output, nameof(output));
Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth));
Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight));
Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing));
Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor));
+ Input = input;
+ Output = output;
Width = width;
Height = height;
Scale = scale;
@@ -133,53 +149,87 @@ private static VersionInfo GetVersionInfo()
return new VersionInfo(
modelSignature: "IMGSCALF",
//verWrittenCur: 0x00010001, // Initial
- verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
- verReadableCur: 0x00010002,
- verWeCanReadBack: 0x00010002,
+ //verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verWrittenCur: 0x00010003, // No more sizeof(float)
+ verReadableCur: 0x00010003,
+ verWeCanReadBack: 0x00010003,
loaderSignature: LoaderSignature);
}
private const string RegistrationName = "ImageScaler";
- // This is parallel to Infos.
- private readonly ColInfoEx[] _exes;
+ private readonly ColumnInfo[] _columns;
+
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
+
+ public ImageResizerTransform(IHostEnvironment env, string inputColumn, string outputColumn,
+ int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center)
+ : this(env, new ColumnInfo(inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor))
+ {
+ }
+
+ public ImageResizerTransform(IHostEnvironment env, params ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
+ {
+ _columns = columns.ToArray();
+ }
+
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
+ {
+ Contracts.CheckValue(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
+ }
- // Public constructor corresponding to SignatureDataTransform.
- public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ // Factory method for SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ env.CheckValue(args.Column, nameof(args.Column));
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ var cols = new ColumnInfo[args.Column.Length];
+ for (int i = 0; i < cols.Length; i++)
{
var item = args.Column[i];
- _exes[i] = new ColInfoEx(
+ cols[i] = new ColumnInfo(
+ item.Source ?? item.Name,
+ item.Name,
item.ImageWidth ?? args.ImageWidth,
item.ImageHeight ?? args.ImageHeight,
item.Resizing ?? args.Resizing,
item.CropAnchor ?? args.CropAnchor);
}
- Metadata.Seal();
+
+ return new ImageResizerTransform(env, cols).MakeDataTransform(input);
}
- private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
- Host.AssertValue(ctx);
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(RegistrationName);
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new ImageResizerTransform(host, ctx);
+ }
+
+ private ImageResizerTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
+ {
// *** Binary format ***
- //
//
+
// for each added column
// int: width
// int: height
// byte: scaling kind
- Host.AssertNonEmpty(Infos);
+ // byte: anchor
- _exes = new ColInfoEx[Infos.Length];
- for (int i = 0; i < _exes.Length; i++)
+ _columns = new ColumnInfo[ColumnPairs.Length];
+ for (int i = 0; i < ColumnPairs.Length; i++)
{
int width = ctx.Reader.ReadInt32();
Host.CheckDecode(width > 0);
@@ -189,182 +239,224 @@ private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input)
Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale));
var anchor = (Anchor)ctx.Reader.ReadByte();
Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor));
- _exes[i] = new ColInfoEx(width, height, scale, anchor);
+ _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, width, height, scale, anchor);
}
- Metadata.Seal();
}
- public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- {
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model",
- ch =>
- {
- // *** Binary format ***
- // int: sizeof(Float)
- //
- int cbFloat = ctx.Reader.ReadInt32();
- ch.CheckDecode(cbFloat == sizeof(Single));
- return new ImageResizerTransform(h, ctx, input);
- });
- }
+ // Factory method for SignatureLoadDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
+
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
- // int: sizeof(Float)
//
+
// for each added column
// int: width
// int: height
// byte: scaling kind
- ctx.Writer.Write(sizeof(Single));
- SaveBase(ctx);
+ // byte: anchor
+
+ base.SaveColumns(ctx);
- Host.Assert(_exes.Length == Infos.Length);
- for (int i = 0; i < _exes.Length; i++)
+ foreach (var col in _columns)
{
- var ex = _exes[i];
- ctx.Writer.Write(ex.Width);
- ctx.Writer.Write(ex.Height);
- Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale);
- ctx.Writer.Write((byte)ex.Scale);
- Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor);
- ctx.Writer.Write((byte)ex.Anchor);
+ ctx.Writer.Write(col.Width);
+ ctx.Writer.Write(col.Height);
+ Contracts.Assert((ResizingKind)(byte)col.Scale == col.Scale);
+ ctx.Writer.Write((byte)col.Scale);
+ Contracts.Assert((Anchor)(byte)col.Anchor == col.Anchor);
+ ctx.Writer.Write((byte)col.Anchor);
}
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ protected override IRowMapper MakeRowMapper(ISchema schema)
+ => new Mapper(this, schema);
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.Check(0 <= iinfo && iinfo < Infos.Length);
- return _exes[iinfo].Type;
+ if (!(inputSchema.GetColumnType(srcCol) is ImageType))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema.GetColumnType(srcCol).ToString());
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ private sealed class Mapper : MapperBase
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
-
- var src = default(Bitmap);
- var getSrc = GetSrcGetter(input, iinfo);
- var ex = _exes[iinfo];
-
- disposer =
- () =>
- {
- if (src != null)
- {
- src.Dispose();
- src = null;
- }
- };
-
- ValueGetter del =
- (ref Bitmap dst) =>
- {
- if (dst != null)
- dst.Dispose();
-
- getSrc(ref src);
- if (src == null || src.Height <= 0 || src.Width <= 0)
- return;
- if (src.Height == ex.Height && src.Width == ex.Width)
- {
- dst = src;
- return;
- }
-
- int sourceWidth = src.Width;
- int sourceHeight = src.Height;
- int sourceX = 0;
- int sourceY = 0;
- int destX = 0;
- int destY = 0;
- int destWidth = 0;
- int destHeight = 0;
- float aspect = 0;
- float widthAspect = 0;
- float heightAspect = 0;
-
- widthAspect = (float)ex.Width / sourceWidth;
- heightAspect = (float)ex.Height / sourceHeight;
-
- if (ex.Scale == ResizingKind.IsoPad)
+ private readonly ImageResizerTransform _parent;
+
+ public Mapper(ImageResizerTransform parent, ISchema inputSchema)
+ :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ }
+
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ => _parent._columns.Select(x => new RowMapperColumnInfo(x.Output, x.Type, null)).ToArray();
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
+
+ var src = default(Bitmap);
+ var getSrc = input.GetGetter(ColMapNewToOld[iinfo]);
+ var info = _parent._columns[iinfo];
+
+ disposer =
+ () =>
{
- widthAspect = (float)ex.Width / sourceWidth;
- heightAspect = (float)ex.Height / sourceHeight;
- if (heightAspect < widthAspect)
+ if (src != null)
{
- aspect = heightAspect;
- destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
+ src.Dispose();
+ src = null;
}
- else
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
+ {
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+ if (src.Height == info.Height && src.Width == info.Width)
{
- aspect = widthAspect;
- destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
+ dst = src;
+ return;
}
- destWidth = (int)(sourceWidth * aspect);
- destHeight = (int)(sourceHeight * aspect);
- }
- else
- {
- if (heightAspect < widthAspect)
+ int sourceWidth = src.Width;
+ int sourceHeight = src.Height;
+ int sourceX = 0;
+ int sourceY = 0;
+ int destX = 0;
+ int destY = 0;
+ int destWidth = 0;
+ int destHeight = 0;
+ float aspect = 0;
+ float widthAspect = 0;
+ float heightAspect = 0;
+
+ widthAspect = (float)info.Width / sourceWidth;
+ heightAspect = (float)info.Height / sourceHeight;
+
+ if (info.Scale == ResizingKind.IsoPad)
{
- aspect = widthAspect;
- switch (ex.Anchor)
+ widthAspect = (float)info.Width / sourceWidth;
+ heightAspect = (float)info.Height / sourceHeight;
+ if (heightAspect < widthAspect)
{
- case Anchor.Top:
- destY = 0;
- break;
- case Anchor.Bottom:
- destY = (int)(ex.Height - (sourceHeight * aspect));
- break;
- default:
- destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
- break;
+ aspect = heightAspect;
+ destX = (int)((info.Width - (sourceWidth * aspect)) / 2);
}
+ else
+ {
+ aspect = widthAspect;
+ destY = (int)((info.Height - (sourceHeight * aspect)) / 2);
+ }
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
}
else
{
- aspect = heightAspect;
- switch (ex.Anchor)
+ if (heightAspect < widthAspect)
+ {
+ aspect = widthAspect;
+ switch (info.Anchor)
+ {
+ case Anchor.Top:
+ destY = 0;
+ break;
+ case Anchor.Bottom:
+ destY = (int)(info.Height - (sourceHeight * aspect));
+ break;
+ default:
+ destY = (int)((info.Height - (sourceHeight * aspect)) / 2);
+ break;
+ }
+ }
+ else
{
- case Anchor.Left:
- destX = 0;
- break;
- case Anchor.Right:
- destX = (int)(ex.Width - (sourceWidth * aspect));
- break;
- default:
- destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
- break;
+ aspect = heightAspect;
+ switch (info.Anchor)
+ {
+ case Anchor.Left:
+ destX = 0;
+ break;
+ case Anchor.Right:
+ destX = (int)(info.Width - (sourceWidth * aspect));
+ break;
+ default:
+ destX = (int)((info.Width - (sourceWidth * aspect)) / 2);
+ break;
+ }
}
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
+ }
+ dst = new Bitmap(info.Width, info.Height);
+ var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight);
+ var destRectangle = new Rectangle(destX, destY, destWidth, destHeight);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel);
}
+ Contracts.Assert(dst.Width == info.Width && dst.Height == info.Height);
+ };
- destWidth = (int)(sourceWidth * aspect);
- destHeight = (int)(sourceHeight * aspect);
- }
- dst = new Bitmap(ex.Width, ex.Height);
- var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight);
- var destRectangle = new Rectangle(destX, destY, destWidth, destHeight);
- using (var g = Graphics.FromImage(dst))
- {
- g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel);
- }
- Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height);
- };
+ return del;
+ }
+ }
+ }
+
+ public sealed class ImageResizerEstimator : TrivialEstimator
+ {
+ public ImageResizerEstimator(IHostEnvironment env, string inputColumn, string outputColumn,
+ int imageWidth, int imageHeight, ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center)
+ : this(env, new ImageResizerTransform(env, inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor))
+ {
+ }
+
+ public ImageResizerEstimator(IHostEnvironment env, params ImageResizerTransform.ColumnInfo[] columns)
+ : this(env, new ImageResizerTransform(env, columns))
+ {
+ }
+
+ public ImageResizerEstimator(IHostEnvironment env, ImageResizerTransform transformer)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageResizerEstimator)), transformer)
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in Transformer.Columns)
+ {
+ var col = inputSchema.FindColumn(colInfo.Input);
+
+ if (col == null)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+ if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString());
+
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false);
+ }
- return del;
+ return new SchemaShape(result.Values);
}
}
}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
index 852ea09d9d..fd31302808 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
@@ -35,7 +35,7 @@ public override bool Equals(ColumnType other)
return false;
if (Height != tmp.Height)
return false;
- return Width != tmp.Width;
+ return Width == tmp.Width;
}
public override bool Equals(object other)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 2da2075728..52d2d3aef0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -1407,8 +1407,8 @@ public LinearClassificationTrainer(IHostEnvironment env, Arguments args,
_positiveInstanceWeight = _args.PositiveInstanceWeight;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
};
}
@@ -1426,7 +1426,8 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
- if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8 && labelCol.ItemKind != DataKind.BL)
+
+ if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !labelCol.ItemType.IsBool)
error();
}
@@ -1434,17 +1435,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.BL, false);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
protected override TScalarPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
index f775be92bd..c49593332b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
@@ -57,8 +57,8 @@ public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
_args = args;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false),
- new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true)
};
}
@@ -76,7 +76,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
error();
- if (!labelCol.IsKey && labelCol.ItemKind != DataKind.R4 && labelCol.ItemKind != DataKind.R8)
+ if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8)
error();
}
@@ -84,17 +84,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.U4, true);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
///
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
index 163da10e7b..0b620959c3 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
@@ -61,7 +61,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur
_args = args;
OutputColumns = new[]
{
- new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false)
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false)
};
}
@@ -73,17 +73,17 @@ private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
- return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
- return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, DataKind.R4, false);
+ return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
- return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, DataKind.R4, false);
+ return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}
protected override LinearRegressionPredictor CreatePredictor(VBuffer[] weights, Float[] bias)
diff --git a/test/Microsoft.ML.Benchmarks/Harness/Metrics.cs b/test/Microsoft.ML.Benchmarks/Harness/Metrics.cs
new file mode 100644
index 0000000000..11b670cdcd
--- /dev/null
+++ b/test/Microsoft.ML.Benchmarks/Harness/Metrics.cs
@@ -0,0 +1,110 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Columns;
+using BenchmarkDotNet.Reports;
+using BenchmarkDotNet.Running;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Microsoft.ML.Benchmarks
+{
+ public abstract class WithExtraMetrics
+ {
+ protected abstract IEnumerable GetMetrics();
+
+ ///
+ /// this method is executed after running the benchmrks
+ /// we use it as hack to simply print to console so ExtraMetricColumn can parse the output
+ ///
+ [GlobalCleanup]
+ public void ReportMetrics()
+ {
+ foreach (var metric in GetMetrics())
+ {
+ Console.WriteLine(metric.ToParsableString());
+ }
+ }
+ }
+
+ public class ExtraMetricColumn : IColumn
+ {
+ public string ColumnName => "Extra Metric";
+ public string Id => nameof(ExtraMetricColumn);
+ public string Legend => "Value of the provided extra metric";
+ public bool IsNumeric => true;
+ public bool IsDefault(Summary summary, BenchmarkCase benchmark) => true;
+ public bool IsAvailable(Summary summary) => true;
+ public bool AlwaysShow => true;
+ public ColumnCategory Category => ColumnCategory.Custom;
+ public int PriorityInCategory => 1;
+ public UnitType UnitType => UnitType.Dimensionless;
+ public string GetValue(Summary summary, BenchmarkCase benchmark) => GetValue(summary, benchmark, null);
+ public override string ToString() => ColumnName;
+
+ public string GetValue(Summary summary, BenchmarkCase benchmark, ISummaryStyle style)
+ {
+ if (!summary.HasReport(benchmark))
+ return "-";
+
+ var results = summary[benchmark].ExecuteResults;
+ if (results.Count != 1)
+ return "-";
+
+ var result = results.Single();
+ var buffer = new StringBuilder();
+
+ foreach (var line in result.ExtraOutput)
+ {
+ if (Metric.TryParse(line, out Metric metric))
+ {
+ if (buffer.Length > 0)
+ buffer.Append(", ");
+
+ buffer.Append(metric.ToColumnValue());
+ }
+ }
+
+ return buffer.Length > 0 ? buffer.ToString() : "-";
+ }
+ }
+
+ public struct Metric
+ {
+ private const string Prefix = "// Metric";
+ private const char Separator = '#';
+
+ public string Name { get; }
+ public string Value { get; }
+
+ public Metric(string name, string value) : this()
+ {
+ Name = name;
+ Value = value;
+ }
+
+ public string ToColumnValue()
+ => $"{Name}: {Value}";
+
+ public string ToParsableString()
+ => $"{Prefix} {Separator} {Name} {Separator} {Value}";
+
+ public static bool TryParse(string line, out Metric metric)
+ {
+ metric = default;
+
+ if (!line.StartsWith(Prefix))
+ return false;
+
+ var splitted = line.Split(Separator);
+
+ metric = new Metric(splitted[1].Trim(), splitted[2].Trim());
+
+ return true;
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Benchmarks/Harness/ProjectGenerator.cs b/test/Microsoft.ML.Benchmarks/Harness/ProjectGenerator.cs
new file mode 100644
index 0000000000..7560efe562
--- /dev/null
+++ b/test/Microsoft.ML.Benchmarks/Harness/ProjectGenerator.cs
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using BenchmarkDotNet.Extensions;
+using BenchmarkDotNet.Toolchains;
+using BenchmarkDotNet.Toolchains.CsProj;
+using System;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.ML.Benchmarks.Harness
+{
+ ///
+ /// to avoid side effects of benchmarks affect each other BenchmarkDotNet runs every benchmark in a standalone, dedicated process
+ /// however to do that it needs to be able to create, build and run new executable
+ ///
+ /// the problem with ML.NET is that it has native dependencies, which are NOT copied by MSBuild to the output folder
+ /// in case where A has native dependency and B references A
+ ///
+ /// this is why this class exists: to copy the native dependencies to folder with .exe
+ ///
+ public class ProjectGenerator : CsProjGenerator
+ {
+ public ProjectGenerator(string targetFrameworkMoniker) : base(targetFrameworkMoniker, platform => platform.ToConfig(), null)
+ {
+ }
+
+ protected override void CopyAllRequiredFiles(ArtifactsPaths artifactsPaths)
+ {
+ base.CopyAllRequiredFiles(artifactsPaths);
+
+ CopyMissingNativeDependencies(artifactsPaths);
+ }
+
+ private void CopyMissingNativeDependencies(ArtifactsPaths artifactsPaths)
+ {
+ var foldeWithAutogeneratedExe = Path.GetDirectoryName(artifactsPaths.ExecutablePath);
+ var folderWithNativeDependencies = Path.GetDirectoryName(typeof(ProjectGenerator).Assembly.Location);
+
+ foreach (var nativeDependency in Directory
+ .EnumerateFiles(folderWithNativeDependencies)
+ .Where(fileName => ContainsWithIgnoreCase(fileName, "native")))
+ {
+ File.Copy(
+ sourceFileName: nativeDependency,
+ destFileName: Path.Combine(foldeWithAutogeneratedExe, Path.GetFileName(nativeDependency)),
+ overwrite: true);
+ }
+ }
+
+ bool ContainsWithIgnoreCase(string text, string word) => text != null && text.IndexOf(word, StringComparison.InvariantCultureIgnoreCase) >= 0;
+ }
+}
diff --git a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs
index 4c269e05fc..f96a5d8803 100644
--- a/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs
+++ b/test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs
@@ -4,6 +4,7 @@
using BenchmarkDotNet.Attributes;
using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.KMeans;
@@ -13,21 +14,11 @@ namespace Microsoft.ML.Benchmarks
{
public class KMeansAndLogisticRegressionBench
{
- private static string s_dataPath;
+ private readonly string _dataPath = Program.GetInvariantCultureDataPath("adult.train");
[Benchmark]
- public IPredictor TrainKMeansAndLR() => TrainKMeansAndLRCore();
-
- [GlobalSetup]
- public void Setup()
+ public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
{
- s_dataPath = Program.GetDataPath("adult.train");
- }
-
- private static IPredictor TrainKMeansAndLRCore()
- {
- string dataPath = s_dataPath;
-
using (var env = new TlcEnvironment(seed: 1))
{
// Pipeline
@@ -53,7 +44,7 @@ private static IPredictor TrainKMeansAndLRCore()
new TextLoader.Range() { Min = 10, Max = 12 }
})
}
- }, new MultiFileSource(dataPath));
+ }, new MultiFileSource(_dataPath));
IDataTransform trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments
{
@@ -83,4 +74,4 @@ private static IPredictor TrainKMeansAndLRCore()
}
}
}
-}
+}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj b/test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj
index 5a9f3e7467..dfa673ea82 100644
--- a/test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj
+++ b/test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj
@@ -20,4 +20,16 @@
+
+
+
+ PreserveNewest
+
+
+ PreserveNewest
+
+
+ PreserveNewest
+
+
\ No newline at end of file
diff --git a/test/Microsoft.ML.Benchmarks/Program.cs b/test/Microsoft.ML.Benchmarks/Program.cs
index 0b4e9edc52..5396e17b7b 100644
--- a/test/Microsoft.ML.Benchmarks/Program.cs
+++ b/test/Microsoft.ML.Benchmarks/Program.cs
@@ -6,11 +6,13 @@
using BenchmarkDotNet.Diagnosers;
using BenchmarkDotNet.Jobs;
using BenchmarkDotNet.Running;
-using BenchmarkDotNet.Columns;
-using BenchmarkDotNet.Reports;
-using BenchmarkDotNet.Toolchains.InProcess;
+using BenchmarkDotNet.Toolchains;
+using BenchmarkDotNet.Toolchains.CsProj;
+using BenchmarkDotNet.Toolchains.DotNetCli;
+using Microsoft.ML.Benchmarks.Harness;
+using System.Globalization;
using System.IO;
-using Microsoft.ML.Models;
+using System.Threading;
namespace Microsoft.ML.Benchmarks
{
@@ -28,52 +30,33 @@ static void Main(string[] args)
private static IConfig CreateCustomConfig()
=> DefaultConfig.Instance
.With(Job.Default
+ .WithWarmupCount(1) // for our time consuming benchmarks 1 warmup iteration is enough
.WithMaxIterationCount(20)
- .With(InProcessToolchain.Instance))
- .With(new ClassificationMetricsColumn("AccuracyMacro", "Macro-average accuracy of the model"))
+ .With(CreateToolchain()))
+ .With(new ExtraMetricColumn())
.With(MemoryDiagnoser.Default);
- internal static string GetDataPath(string name)
- => Path.GetFullPath(Path.Combine(_dataRoot, name));
-
- static readonly string _dataRoot;
- static Program()
+ ///
+ /// we need our own toolchain because MSBuild by default does not copy recursive native dependencies to the output
+ ///
+ private static IToolchain CreateToolchain()
{
- var currentAssemblyLocation = new FileInfo(typeof(Program).Assembly.Location);
- var rootDir = currentAssemblyLocation.Directory.Parent.Parent.Parent.Parent.FullName;
- _dataRoot = Path.Combine(rootDir, "test", "data");
+ var csProj = CsProjCoreToolchain.Current.Value;
+ var tfm = NetCoreAppSettings.Current.Value.TargetFrameworkMoniker;
+
+ return new Toolchain(
+ tfm,
+ new ProjectGenerator(tfm),
+ csProj.Builder,
+ csProj.Executor);
}
- }
-
- public class ClassificationMetricsColumn : IColumn
- {
- private readonly string _metricName;
- private readonly string _legend;
- public ClassificationMetricsColumn(string metricName, string legend)
+ internal static string GetInvariantCultureDataPath(string name)
{
- _metricName = metricName;
- _legend = legend;
- }
-
- public string ColumnName => _metricName;
- public string Id => _metricName;
- public string Legend => _legend;
- public bool IsNumeric => true;
- public bool IsDefault(Summary summary, BenchmarkCase benchmark) => true;
- public bool IsAvailable(Summary summary) => true;
- public bool AlwaysShow => true;
- public ColumnCategory Category => ColumnCategory.Custom;
- public int PriorityInCategory => 1;
- public UnitType UnitType => UnitType.Dimensionless;
+ // enforce Neutral Language as "en-us" because the input data files use dot as decimal separator (and it fails for cultures with ",")
+ Thread.CurrentThread.CurrentCulture = CultureInfo.InvariantCulture;
- public string GetValue(Summary summary, BenchmarkCase benchmark, ISummaryStyle style)
- {
- var property = typeof(ClassificationMetrics).GetProperty(_metricName);
- return property.GetValue(StochasticDualCoordinateAscentClassifierBench.s_metrics).ToString();
+ return Path.Combine(Path.GetDirectoryName(typeof(Program).Assembly.Location), "Input", name);
}
- public string GetValue(Summary summary, BenchmarkCase benchmark) => GetValue(summary, benchmark, null);
-
- public override string ToString() => ColumnName;
}
}
diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
index 6e0b856dbd..b0c9235198 100644
--- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
+++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
@@ -11,22 +11,19 @@
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
-using System;
using System.Collections.Generic;
+using System.Globalization;
namespace Microsoft.ML.Benchmarks
{
- public class StochasticDualCoordinateAscentClassifierBench
+ public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics
{
- internal static ClassificationMetrics s_metrics;
- private static PredictionModel s_trainedModel;
- private static string s_dataPath;
- private static string s_sentimentDataPath;
- private static IrisData[][] s_batches;
- private static readonly int[] s_batchSizes = new int[] { 1, 2, 5 };
- private readonly Random r = new Random(0);
- private readonly Consumer _consumer = new Consumer();
- private static readonly IrisData s_example = new IrisData()
+ private readonly string _dataPath = Program.GetInvariantCultureDataPath("iris.txt");
+ private readonly string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");
+ private readonly Consumer _consumer = new Consumer(); // BenchmarkDotNet utility type used to prevent dead code elimination
+
+ private readonly int[] _batchSizes = new int[] { 1, 2, 5 };
+ private readonly IrisData _example = new IrisData()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
@@ -34,71 +31,36 @@ public class StochasticDualCoordinateAscentClassifierBench
PetalWidth = 5.1f,
};
- [GlobalSetup]
- public void Setup()
- {
- s_dataPath = Program.GetDataPath("iris.txt");
- s_sentimentDataPath = Program.GetDataPath("wikipedia-detox-250-line-data.tsv");
- s_trainedModel = TrainCore();
- IrisPrediction prediction = s_trainedModel.Predict(s_example);
-
- var testData = new Data.TextLoader(s_dataPath).CreateFrom(useHeader: true);
- var evaluator = new ClassificationEvaluator();
- s_metrics = evaluator.Evaluate(s_trainedModel, testData);
+ private PredictionModel _trainedModel;
+ private IrisData[][] _batches;
+ private ClassificationMetrics _metrics;
- s_batches = new IrisData[s_batchSizes.Length][];
- for (int i = 0; i < s_batches.Length; i++)
- {
- var batch = new IrisData[s_batchSizes[i]];
- s_batches[i] = batch;
- for (int bi = 0; bi < batch.Length; bi++)
- {
- batch[bi] = s_example;
- }
- }
+ protected override IEnumerable GetMetrics()
+ {
+ if (_metrics != null)
+ yield return new Metric(
+ nameof(ClassificationMetrics.AccuracyMacro),
+ _metrics.AccuracyMacro.ToString("0.##", CultureInfo.InvariantCulture));
}
[Benchmark]
- public PredictionModel TrainIris() => TrainCore();
-
- [Benchmark]
- public float[] PredictIris() => s_trainedModel.Predict(s_example).PredictedLabels;
-
- [Benchmark]
- public void PredictIrisBatchOf1() => Consume(s_trainedModel.Predict(s_batches[0]));
+ public PredictionModel TrainIris() => Train(_dataPath);
- [Benchmark]
- public void PredictIrisBatchOf2() => Consume(s_trainedModel.Predict(s_batches[1]));
-
- [Benchmark]
- public void PredictIrisBatchOf5() => Consume(s_trainedModel.Predict(s_batches[2]));
-
- [Benchmark]
- public IPredictor TrainSentiment() => TrainSentimentCore();
-
- private void Consume(IEnumerable predictions)
- {
- foreach (var prediction in predictions)
- _consumer.Consume(prediction);
- }
-
- private static PredictionModel TrainCore()
+ private PredictionModel Train(string dataPath)
{
var pipeline = new LearningPipeline();
- pipeline.Add(new Data.TextLoader(s_dataPath).CreateFrom(useHeader: true));
- pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
- "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
+ pipeline.Add(new Data.TextLoader(dataPath).CreateFrom(useHeader: true));
+ pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
- PredictionModel model = pipeline.Train();
- return model;
+ return pipeline.Train();
}
- private static IPredictor TrainSentimentCore()
+ [Benchmark]
+ public void TrainSentiment()
{
- var dataPath = s_sentimentDataPath;
using (var env = new TlcEnvironment(seed: 1))
{
// Pipeline
@@ -125,7 +87,7 @@ private static IPredictor TrainSentimentCore()
Type = DataKind.Text
}
}
- }, new MultiFileSource(dataPath));
+ }, new MultiFileSource(_sentimentDataPath));
var text = TextTransform.Create(env,
new TextTransform.Arguments()
@@ -145,7 +107,7 @@ private static IPredictor TrainSentimentCore()
WordFeatureExtractor = null,
}, loader);
- var trans = new WordEmbeddingsTransform(env,
+ var trans = new WordEmbeddingsTransform(env,
new WordEmbeddingsTransform.Arguments()
{
Column = new WordEmbeddingsTransform.Column[1]
@@ -162,32 +124,74 @@ private static IPredictor TrainSentimentCore()
// Train
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 });
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
- return trainer.Train(trainRoles);
+
+ var predicted = trainer.Train(trainRoles);
+ _consumer.Consume(predicted);
}
}
- public class IrisData
+ [GlobalSetup(Targets = new string[] { nameof(PredictIris), nameof(PredictIrisBatchOf1), nameof(PredictIrisBatchOf2), nameof(PredictIrisBatchOf5) })]
+ public void SetupPredictBenchmarks()
{
- [Column("0")]
- public float Label;
+ _trainedModel = Train(_dataPath);
+ _consumer.Consume(_trainedModel.Predict(_example));
- [Column("1")]
- public float SepalLength;
+ var testData = new Data.TextLoader(_dataPath).CreateFrom(useHeader: true);
+ var evaluator = new ClassificationEvaluator();
+ _metrics = evaluator.Evaluate(_trainedModel, testData);
+
+ _batches = new IrisData[_batchSizes.Length][];
+ for (int i = 0; i < _batches.Length; i++)
+ {
+ var batch = new IrisData[_batchSizes[i]];
+ _batches[i] = batch;
+ for (int bi = 0; bi < batch.Length; bi++)
+ {
+ batch[bi] = _example;
+ }
+ }
+ }
- [Column("2")]
- public float SepalWidth;
+ [Benchmark]
+ public float[] PredictIris() => _trainedModel.Predict(_example).PredictedLabels;
- [Column("3")]
- public float PetalLength;
+ [Benchmark]
+ public void PredictIrisBatchOf1() => Consume(_trainedModel.Predict(_batches[0]));
- [Column("4")]
- public float PetalWidth;
- }
+ [Benchmark]
+ public void PredictIrisBatchOf2() => Consume(_trainedModel.Predict(_batches[1]));
+
+ [Benchmark]
+ public void PredictIrisBatchOf5() => Consume(_trainedModel.Predict(_batches[2]));
- public class IrisPrediction
+ private void Consume(IEnumerable predictions)
{
- [ColumnName("Score")]
- public float[] PredictedLabels;
+ foreach (var prediction in predictions)
+ _consumer.Consume(prediction);
}
}
+
+ public class IrisData
+ {
+ [Column("0")]
+ public float Label;
+
+ [Column("1")]
+ public float SepalLength;
+
+ [Column("2")]
+ public float SepalWidth;
+
+ [Column("3")]
+ public float PetalLength;
+
+ [Column("4")]
+ public float PetalWidth;
+ }
+
+ public class IrisPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedLabels;
+ }
}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
index 2e4b598540..b4b1cf65fd 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs
@@ -12,137 +12,92 @@ namespace Microsoft.ML.CpuMath.PerformanceTests
public class AvxPerformanceTests : PerformanceTests
{
[Benchmark]
- public void ManagedAddScalarUPerf()
- {
- AvxIntrinsics.AddScalarU(DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
+ public void AddScalarU()
+ => AvxIntrinsics.AddScalarU(DEFAULT_SCALE, new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedScaleUPerf()
- {
- AvxIntrinsics.ScaleU(DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
+ public void ScaleU()
+ => AvxIntrinsics.ScaleU(DEFAULT_SCALE, new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedScaleSrcUPerf()
- {
- AvxIntrinsics.ScaleSrcU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public void ScaleSrcU()
+ => AvxIntrinsics.ScaleSrcU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedScaleAddUPerf()
- {
- AvxIntrinsics.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
+ public void ScaleAddU()
+ => AvxIntrinsics.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedAddScaleUPerf()
- {
- AvxIntrinsics.AddScaleU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public void AddScaleU()
+ => AvxIntrinsics.AddScaleU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedAddScaleSUPerf()
- {
- AvxIntrinsics.AddScaleSU(DEFAULT_SCALE, new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
- }
+ public void AddScaleSU()
+ => AvxIntrinsics.AddScaleSU(DEFAULT_SCALE, new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
[Benchmark]
- public void ManagedAddScaleCopyUPerf()
- {
- AvxIntrinsics.AddScaleCopyU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN), new Span(result, 0, LEN));
- }
+ public void AddScaleCopyU()
+ => AvxIntrinsics.AddScaleCopyU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN), new Span(result, 0, LEN));
[Benchmark]
- public void ManagedAddUPerf()
- {
- AvxIntrinsics.AddU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public void AddU()
+ => AvxIntrinsics.AddU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedAddSUPerf()
- {
- AvxIntrinsics.AddSU(new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
- }
-
+ public void AddSU()
+ => AvxIntrinsics.AddSU(new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
[Benchmark]
- public void ManagedMulElementWiseUPerf()
- {
- AvxIntrinsics.MulElementWiseU(new Span(src1, 0, LEN), new Span(src2, 0, LEN),
+ public void MulElementWiseU()
+ => AvxIntrinsics.MulElementWiseU(new Span(src1, 0, LEN), new Span(src2, 0, LEN),
new Span(dst, 0, LEN));
- }
[Benchmark]
- public float ManagedSumUPerf()
- {
- return AvxIntrinsics.SumU(new Span(src, 0, LEN));
- }
+ public float SumU()
+ => AvxIntrinsics.SumU(new Span(src, 0, LEN));
[Benchmark]
- public float ManagedSumSqUPerf()
- {
- return AvxIntrinsics.SumSqU(new Span(src, 0, LEN));
- }
+ public float SumSqU()
+ => AvxIntrinsics.SumSqU(new Span(src, 0, LEN));
[Benchmark]
- public float ManagedSumSqDiffUPerf()
- {
- return AvxIntrinsics.SumSqDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
+ public float SumSqDiffU()
+ => AvxIntrinsics.SumSqDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
[Benchmark]
- public float ManagedSumAbsUPerf()
- {
- return AvxIntrinsics.SumAbsU(new Span(src, 0, LEN));
- }
+ public float SumAbsU()
+ => AvxIntrinsics.SumAbsU(new Span(src, 0, LEN));
[Benchmark]
- public float ManagedSumAbsDiffUPerf()
- {
- return AvxIntrinsics.SumAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
+ public float SumAbsDiffU()
+ => AvxIntrinsics.SumAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
[Benchmark]
- public float ManagedMaxAbsUPerf()
- {
- return AvxIntrinsics.MaxAbsU(new Span(src, 0, LEN));
- }
+ public float MaxAbsU()
+ => AvxIntrinsics.MaxAbsU(new Span(src, 0, LEN));
[Benchmark]
- public float ManagedMaxAbsDiffUPerf()
- {
- return AvxIntrinsics.MaxAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
+ public float MaxAbsDiffU()
+ => AvxIntrinsics.MaxAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
[Benchmark]
- public float ManagedDotUPerf()
- {
- return AvxIntrinsics.DotU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public float DotU()
+ => AvxIntrinsics.DotU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public float ManagedDotSUPerf()
- {
- return AvxIntrinsics.DotSU(new Span(src), new Span(dst), new Span(idx, 0, IDXLEN));
- }
+ public float DotSU()
+ => AvxIntrinsics.DotSU(new Span(src), new Span(dst), new Span(idx, 0, IDXLEN));
[Benchmark]
- public float ManagedDist2Perf()
- {
- return AvxIntrinsics.Dist2(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public float Dist2()
+ => AvxIntrinsics.Dist2(new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedSdcaL1UpdateUPerf()
- {
- AvxIntrinsics.SdcaL1UpdateU(DEFAULT_SCALE, new Span(src, 0, LEN), DEFAULT_SCALE, new Span(dst, 0, LEN), new Span(result, 0, LEN));
- }
+ public void SdcaL1UpdateU()
+ => AvxIntrinsics.SdcaL1UpdateU(DEFAULT_SCALE, new Span(src, 0, LEN), DEFAULT_SCALE, new Span(dst, 0, LEN), new Span(result, 0, LEN));
[Benchmark]
- public void ManagedSdcaL1UpdateSUPerf()
- {
- AvxIntrinsics.SdcaL1UpdateSU(DEFAULT_SCALE, new Span(src, 0, IDXLEN), new Span(idx, 0, IDXLEN), DEFAULT_SCALE, new Span(dst), new Span(result));
- }
+ public void SdcaL1UpdateSU()
+ => AvxIntrinsics.SdcaL1UpdateSU(DEFAULT_SCALE, new Span(src, 0, IDXLEN), new Span(idx, 0, IDXLEN), DEFAULT_SCALE, new Span(dst), new Span(result));
}
}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs
new file mode 100644
index 0000000000..b7eb3d233a
--- /dev/null
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs
@@ -0,0 +1,232 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Running;
+using Microsoft.ML.Runtime.Internal.CpuMath;
+
+namespace Microsoft.ML.CpuMath.PerformanceTests
+{
+ public class NativePerformanceTests : PerformanceTests
+ {
+ [Benchmark]
+ public unsafe void AddScalarU()
+ {
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.AddScalarU(DEFAULT_SCALE, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void ScaleU()
+ {
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.ScaleU(DEFAULT_SCALE, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void ScaleSrcU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.ScaleSrcU(DEFAULT_SCALE, psrc, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void ScaleAddU()
+ {
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void AddScaleU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.AddScaleU(DEFAULT_SCALE, psrc, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void AddScaleSU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (int* pidx = idx)
+ {
+ CpuMathNativeUtils.AddScaleSU(DEFAULT_SCALE, psrc, pidx, pdst, IDXLEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void AddScaleCopyU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (float* pres = result)
+ {
+ CpuMathNativeUtils.AddScaleCopyU(DEFAULT_SCALE, psrc, pdst, pres, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void AddU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.AddU(psrc, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void AddSU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (int* pidx = idx)
+ {
+ CpuMathNativeUtils.AddSU(psrc, pidx, pdst, IDXLEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void MulElementWiseU()
+ {
+ fixed (float* psrc1 = src1)
+ fixed (float* psrc2 = src2)
+ fixed (float* pdst = dst)
+ {
+ CpuMathNativeUtils.MulElementWiseU(psrc1, psrc2, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float SumU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.SumU(psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float SumSqU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.SumSqU(psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float SumSqDiffU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.SumSqDiffU(DEFAULT_SCALE, psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float SumAbsU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.SumAbsU(psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float SumAbsDiffU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.SumAbsDiffU(DEFAULT_SCALE, psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float MaxAbsU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.MaxAbsU(psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float MaxAbsDiffU()
+ {
+ fixed (float* psrc = src)
+ {
+ return CpuMathNativeUtils.MaxAbsDiffU(DEFAULT_SCALE, psrc, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float DotU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ {
+ return CpuMathNativeUtils.DotU(psrc, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float DotSU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (int* pidx = idx)
+ {
+ return CpuMathNativeUtils.DotSU(psrc, pdst, pidx, IDXLEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe float Dist2()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ {
+ return CpuMathNativeUtils.Dist2(psrc, pdst, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void SdcaL1UpdateU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (float* pres = result)
+ {
+ CpuMathNativeUtils.SdcaL1UpdateU(DEFAULT_SCALE, psrc, DEFAULT_SCALE, pdst, pres, LEN);
+ }
+ }
+
+ [Benchmark]
+ public unsafe void SdcaL1UpdateSU()
+ {
+ fixed (float* psrc = src)
+ fixed (float* pdst = dst)
+ fixed (float* pres = result)
+ fixed (int* pidx = idx)
+ {
+ CpuMathNativeUtils.SdcaL1UpdateSU(DEFAULT_SCALE, psrc, pidx, DEFAULT_SCALE, pdst, pres, IDXLEN);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/Program.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/Program.cs
index 77656178c8..9ff474a9a1 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/Program.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/Program.cs
@@ -2,6 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using BenchmarkDotNet.Configs;
+using BenchmarkDotNet.Jobs;
+using BenchmarkDotNet.Running;
+using BenchmarkDotNet.Toolchains.InProcess;
+
namespace Microsoft.ML.CpuMath.PerformanceTests
{
class Program
@@ -13,7 +18,7 @@ public static void Main(string[] args)
private static IConfig CreateCustomConfig()
=> DefaultConfig.Instance
- .With(Job.ShortRun
+ .With(Job.Default
.With(InProcessToolchain.Instance));
}
}
diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
index 3188c64db9..21eb840bdc 100644
--- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
+++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs
@@ -12,355 +12,92 @@ namespace Microsoft.ML.CpuMath.PerformanceTests
public class SsePerformanceTests : PerformanceTests
{
[Benchmark]
- public unsafe void NativeAddScalarUPerf()
- {
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.AddScalarU(DEFAULT_SCALE, pdst, LEN);
- }
- }
-
- [Benchmark]
- public void ManagedAddScalarUPerf()
- {
- SseIntrinsics.AddScalarU(DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeScaleUPerf()
- {
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.ScaleU(DEFAULT_SCALE, pdst, LEN);
- }
- }
-
- [Benchmark]
- public void ManagedScaleUPerf()
- {
- SseIntrinsics.ScaleU(DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeScaleSrcUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.ScaleSrcU(DEFAULT_SCALE, psrc, pdst, LEN);
- }
- }
-
- [Benchmark]
- public void ManagedScaleSrcUPerf()
- {
- SseIntrinsics.ScaleSrcU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeScaleAddUPerf()
- {
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, pdst, LEN);
- }
- }
-
- [Benchmark]
- public void ManagedScaleAddUPerf()
- {
- SseIntrinsics.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeAddScaleUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.AddScaleU(DEFAULT_SCALE, psrc, pdst, LEN);
- }
- }
-
- [Benchmark]
- public void ManagedAddScaleUPerf()
- {
- SseIntrinsics.AddScaleU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeAddScaleSUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (int* pidx = idx)
- {
- CpuMathNativeUtils.AddScaleSU(DEFAULT_SCALE, psrc, pidx, pdst, IDXLEN);
- }
- }
-
+ public void AddScalarU()
+ => SseIntrinsics.AddScalarU(DEFAULT_SCALE, new Span(dst, 0, LEN));
+
[Benchmark]
- public void ManagedAddScaleSUPerf()
- {
- SseIntrinsics.AddScaleSU(DEFAULT_SCALE, new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
- }
-
+ public void ScaleU()
+ => SseIntrinsics.ScaleU(DEFAULT_SCALE, new Span(dst, 0, LEN));
+
[Benchmark]
- public unsafe void NativeAddScaleCopyUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (float* pres = result)
- {
- CpuMathNativeUtils.AddScaleCopyU(DEFAULT_SCALE, psrc, pdst, pres, LEN);
- }
- }
+ public void ScaleSrcU()
+ => SseIntrinsics.ScaleSrcU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedAddScaleCopyUPerf()
- {
- SseIntrinsics.AddScaleCopyU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN), new Span(result, 0, LEN));
- }
-
+ public void ScaleAddU()
+ => SseIntrinsics.ScaleAddU(DEFAULT_SCALE, DEFAULT_SCALE, new Span(dst, 0, LEN));
+
[Benchmark]
- public unsafe void NativeAddUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.AddU(psrc, pdst, LEN);
- }
- }
+ public void AddScaleU()
+ => SseIntrinsics.AddScaleU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public void ManagedAddUPerf()
- {
- SseIntrinsics.AddU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
+ public void AddScaleSU()
+ => SseIntrinsics.AddScaleSU(DEFAULT_SCALE, new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
[Benchmark]
- public unsafe void NativeAddSUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (int* pidx = idx)
- {
- CpuMathNativeUtils.AddSU(psrc, pidx, pdst, IDXLEN);
- }
- }
+ public void AddScaleCopyU()
+ => SseIntrinsics.AddScaleCopyU(DEFAULT_SCALE, new Span(src, 0, LEN), new Span(dst, 0, LEN), new Span(result, 0, LEN));
[Benchmark]
- public void ManagedAddSUPerf()
- {
- SseIntrinsics.AddSU(new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
- }
-
+ public void AddU()
+ => SseIntrinsics.AddU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public unsafe void NativeMulElementWiseUPerf()
- {
- fixed (float* psrc1 = src1)
- fixed (float* psrc2 = src2)
- fixed (float* pdst = dst)
- {
- CpuMathNativeUtils.MulElementWiseU(psrc1, psrc2, pdst, LEN);
- }
- }
+ public void AddSU()
+ => SseIntrinsics.AddSU(new Span(src), new Span(idx, 0, IDXLEN), new Span(dst));
[Benchmark]
- public void ManagedMulElementWiseUPerf()
- {
- SseIntrinsics.MulElementWiseU(new Span(src1, 0, LEN), new Span(src2, 0, LEN),
+ public void MulElementWiseU()
+ => SseIntrinsics.MulElementWiseU(new Span(src1, 0, LEN), new Span(src2, 0, LEN),
new Span(dst, 0, LEN));
- }
[Benchmark]
- public unsafe float NativeSumUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.SumU(psrc, LEN);
- }
- }
+ public float SumU()
+ => SseIntrinsics.SumU(new Span(src, 0, LEN));
[Benchmark]
- public float ManagedSumUPerf()
- {
- return SseIntrinsics.SumU(new Span(src, 0, LEN));
- }
-
+ public float SumSqU()
+ => SseIntrinsics.SumSqU(new Span(src, 0, LEN));
+
[Benchmark]
- public unsafe float NativeSumSqUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.SumSqU(psrc, LEN);
- }
- }
-
+ public float SumSqDiffU()
+ => SseIntrinsics.SumSqDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
+
[Benchmark]
- public float ManagedSumSqUPerf()
- {
- return SseIntrinsics.SumSqU(new Span(src, 0, LEN));
- }
+ public float SumAbsU()
+ => SseIntrinsics.SumAbsU(new Span(src, 0, LEN));
[Benchmark]
- public unsafe float NativeSumSqDiffUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.SumSqDiffU(DEFAULT_SCALE, psrc, LEN);
- }
- }
-
+ public float SumAbsDiffU()
+ => SseIntrinsics.SumAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
+
[Benchmark]
- public float ManagedSumSqDiffUPerf()
- {
- return SseIntrinsics.SumSqDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
-
+ public float MaxAbsU()
+ => SseIntrinsics.MaxAbsU(new Span(src, 0, LEN));
+
[Benchmark]
- public unsafe float NativeSumAbsUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.SumAbsU(psrc, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedSumAbsUPerf()
- {
- return SseIntrinsics.SumAbsU(new Span(src, 0, LEN));
- }
-
- [Benchmark]
- public unsafe float NativeSumAbsDiffUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.SumAbsDiffU(DEFAULT_SCALE, psrc, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedSumAbsDiffUPerf()
- {
- return SseIntrinsics.SumAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
-
+ public float MaxAbsDiffU()
+ => SseIntrinsics.MaxAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
+
[Benchmark]
- public unsafe float NativeMaxAbsUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.MaxAbsU(psrc, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedMaxAbsUPerf()
- {
- return SseIntrinsics.MaxAbsU(new Span(src, 0, LEN));
- }
-
- [Benchmark]
- public unsafe float NativeMaxAbsDiffUPerf()
- {
- fixed (float* psrc = src)
- {
- return CpuMathNativeUtils.MaxAbsDiffU(DEFAULT_SCALE, psrc, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedMaxAbsDiffUPerf()
- {
- return SseIntrinsics.MaxAbsDiffU(DEFAULT_SCALE, new Span(src, 0, LEN));
- }
-
+ public float DotU()
+ => SseIntrinsics.DotU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
+
[Benchmark]
- public unsafe float NativeDotUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- {
- return CpuMathNativeUtils.DotU(psrc, pdst, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedDotUPerf()
- {
- return SseIntrinsics.DotU(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe float NativeDotSUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (int* pidx = idx)
- {
- return CpuMathNativeUtils.DotSU(psrc, pdst, pidx, IDXLEN);
- }
- }
-
- [Benchmark]
- public float ManagedDotSUPerf()
- {
- return SseIntrinsics.DotSU(new Span(src), new Span(dst), new Span(idx, 0, IDXLEN));
- }
-
- [Benchmark]
- public unsafe float NativeDist2Perf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- {
- return CpuMathNativeUtils.Dist2(psrc, pdst, LEN);
- }
- }
-
- [Benchmark]
- public float ManagedDist2Perf()
- {
- return SseIntrinsics.Dist2(new Span(src, 0, LEN), new Span(dst, 0, LEN));
- }
-
- [Benchmark]
- public unsafe void NativeSdcaL1UpdateUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (float* pres = result)
- {
- CpuMathNativeUtils.SdcaL1UpdateU(DEFAULT_SCALE, psrc, DEFAULT_SCALE, pdst, pres, LEN);
- }
- }
-
+ public float DotSU()
+ => SseIntrinsics.DotSU(new Span(src), new Span(dst), new Span(idx, 0, IDXLEN));
+
[Benchmark]
- public void ManagedSdcaL1UpdateUPerf()
- {
- SseIntrinsics.SdcaL1UpdateU(DEFAULT_SCALE, new Span(src, 0, LEN), DEFAULT_SCALE, new Span(dst, 0, LEN), new Span(result, 0, LEN));
- }
+ public float Dist2()
+ => SseIntrinsics.Dist2(new Span(src, 0, LEN), new Span(dst, 0, LEN));
[Benchmark]
- public unsafe void NativeSdcaL1UpdateSUPerf()
- {
- fixed (float* psrc = src)
- fixed (float* pdst = dst)
- fixed (float* pres = result)
- fixed (int* pidx = idx)
- {
- CpuMathNativeUtils.SdcaL1UpdateSU(DEFAULT_SCALE, psrc, pidx, DEFAULT_SCALE, pdst, pres, IDXLEN);
- }
- }
+ public void SdcaL1UpdateU()
+ => SseIntrinsics.SdcaL1UpdateU(DEFAULT_SCALE, new Span(src, 0, LEN), DEFAULT_SCALE, new Span(dst, 0, LEN), new Span(result, 0, LEN));
[Benchmark]
- public void ManagedSdcaL1UpdateSUPerf()
- {
- SseIntrinsics.SdcaL1UpdateSU(DEFAULT_SCALE, new Span(src, 0, IDXLEN), new Span(idx, 0, IDXLEN), DEFAULT_SCALE, new Span(dst), new Span(result));
- }
+ public void SdcaL1UpdateSU()
+ => SseIntrinsics.SdcaL1UpdateSU(DEFAULT_SCALE, new Span(src, 0, IDXLEN), new Span(idx, 0, IDXLEN), DEFAULT_SCALE, new Span(dst), new Span(result));
}
}
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index 760cf9c910..d189e627f5 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -5,6 +5,8 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
@@ -17,6 +19,108 @@ namespace Microsoft.ML.Runtime.RunTests
{
public abstract partial class TestDataPipeBase : TestDataViewBase
{
+ ///
+ /// 'Workout test' for an estimator.
+ /// Checks the following traits:
+ /// - the estimator is applicable to the validFitInput, and not applicable to validTransformInput and invalidInput;
+ /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput;
+ /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior.
+ /// - schema propagation for fitted transformer conforms to schema propagation of estimator.
+ ///
+ protected void TestEstimatorCore(IEstimator estimator,
+ IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null)
+ {
+ Contracts.AssertValue(estimator);
+ Contracts.AssertValue(validFitInput);
+ Contracts.AssertValueOrNull(validTransformInput);
+ Contracts.AssertValueOrNull(invalidInput);
+ Action mustFail = (Action action) =>
+ {
+ try
+ {
+ action();
+ Assert.False(true);
+ }
+ catch (ArgumentOutOfRangeException) { }
+ catch (InvalidOperationException) { }
+ };
+
+ // Schema propagation tests for estimator.
+ var outSchemaShape = estimator.GetOutputSchema(SchemaShape.Create(validFitInput.Schema));
+ if (validTransformInput != null)
+ {
+ mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(validTransformInput.Schema)));
+ mustFail(() => estimator.Fit(validTransformInput));
+ }
+
+ if (invalidInput != null)
+ {
+ mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(invalidInput.Schema)));
+ mustFail(() => estimator.Fit(invalidInput));
+ }
+
+ var transformer = estimator.Fit(validFitInput);
+ // Save and reload.
+ string modelPath = GetOutputPath(TestName + "-model.zip");
+ using (var fs = File.Create(modelPath))
+ transformer.SaveTo(Env, fs);
+
+ ITransformer loadedTransformer;
+ using (var fs = File.OpenRead(modelPath))
+ loadedTransformer = TransformerChain.LoadFrom(Env, fs);
+ DeleteOutputPath(modelPath);
+
+ // Run on train data.
+ Action checkOnData = (IDataView data) =>
+ {
+ var schema = transformer.GetOutputSchema(data.Schema);
+
+ // Loaded transformer needs to have the same schema propagation.
+ CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema));
+
+ var scoredTrain = transformer.Transform(data);
+ var scoredTrain2 = loadedTransformer.Transform(data);
+
+ // The schema of the transformed data must match the schema provided by schema propagation.
+ CheckSameSchemas(schema, scoredTrain.Schema);
+
+ // The schema and data of scored dataset must be identical between loaded
+ // and original transformer.
+ // This in turn means that the schema of loaded transformer matches for
+ // Transform and GetOutputSchema calls.
+ CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema);
+ CheckSameValues(scoredTrain, scoredTrain2);
+ };
+
+ checkOnData(validFitInput);
+
+ if (validTransformInput != null)
+ checkOnData(validTransformInput);
+
+ if (invalidInput != null)
+ {
+ mustFail(() => transformer.GetOutputSchema(invalidInput.Schema));
+ mustFail(() => transformer.Transform(invalidInput));
+ mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema));
+ mustFail(() => loadedTransformer.Transform(invalidInput));
+ }
+
+ // Schema verification between estimator and transformer.
+ var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema));
+ CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape);
+ }
+
+ private void CheckSameSchemaShape(SchemaShape first, SchemaShape second)
+ {
+ Assert.True(first.Columns.Length == second.Columns.Length);
+ var sortedCols1 = first.Columns.OrderBy(x => x.Name);
+ var sortedCols2 = second.Columns.OrderBy(x => x.Name);
+
+ Assert.True(sortedCols1.Zip(sortedCols2,
+ (x, y) => x.IsCompatibleWith(y) && y.IsCompatibleWith(x))
+ .All(x => x));
+ }
+
// REVIEW: incorporate the testing for re-apply logic here?
///
/// Create PipeDataLoader from the given args, save it, re-load it, verify that the data of
@@ -878,41 +982,44 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ
{
switch (type.RawKind)
{
- case DataKind.I1:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U1:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I2:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U2:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I4:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U4:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.I8:
- return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
- case DataKind.U8:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
- case DataKind.R4:
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- case DataKind.R8:
- if (exactDoubles)
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
- else
- return GetComparerOne(r1, r2, col, EqualWithEps);
- case DataKind.Text:
- return GetComparerOne(r1, r2, col, DvText.Identical);
- case DataKind.Bool:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.TimeSpan:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.DT:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.DZ:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
- case DataKind.UG:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ case DataKind.I1:
+ return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue);
+ case DataKind.U1:
+ return GetComparerOne