Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/BranchInfo.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project>
<PropertyGroup>
<MajorVersion>0</MajorVersion>
<MinorVersion>5</MinorVersion>
<MinorVersion>6</MinorVersion>
<PatchVersion>0</PatchVersion>
<PreReleaseLabel>preview</PreReleaseLabel>
</PropertyGroup>
Expand Down
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<LightGBMPackageVersion>2.1.2.2</LightGBMPackageVersion>
<MlNetMklDepsPackageVersion>0.0.0.5</MlNetMklDepsPackageVersion>
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
<BenchmarkDotNetVersion>0.11.0</BenchmarkDotNetVersion>
<BenchmarkDotNetVersion>0.11.1</BenchmarkDotNetVersion>
<TensorFlowVersion>1.10.0</TensorFlowVersion>
</PropertyGroup>
</Project>
1 change: 1 addition & 0 deletions src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<ProjectReference Include="..\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
<ProjectReference Include="..\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
<ProjectReference Include="..\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
Expand Down
40 changes: 32 additions & 8 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,42 @@ public enum VectorKind
VariableVector
}

/// <summary>
/// The column name.
/// </summary>
public readonly string Name;

/// <summary>
/// The type of the column: scalar, fixed vector or variable vector.
/// </summary>
public readonly VectorKind Kind;
public readonly DataKind ItemKind;

/// <summary>
/// The 'raw' type of column item: must be a primitive type or a structured type.
/// </summary>
public readonly ColumnType ItemType;
/// <summary>
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
/// the underlying primitive type.
/// </summary>
public readonly bool IsKey;
/// <summary>
/// The metadata kinds that are present for this column.
/// </summary>
public readonly string[] MetadataKinds;

public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null)
public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadataKinds);
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");

Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");

Name = name;
Kind = vecKind;
ItemKind = itemKind;
ItemType = itemType;
IsKey = isKey;
MetadataKinds = metadataKinds ?? new string[0];
}
Expand All @@ -51,7 +73,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
/// requirement.
///
/// Namely, it returns true iff:
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemKind"/>, <see cref="IsKey"/> fields match.
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
/// - The <see cref="MetadataKinds"/> of <paramref name="inputColumn"/> is a superset of our <see cref="MetadataKinds"/>.
/// </summary>
public bool IsCompatibleWith(Column inputColumn)
Expand All @@ -61,7 +83,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (Kind != inputColumn.Kind)
return false;
if (ItemKind != inputColumn.ItemKind)
if (!ItemType.Equals(inputColumn.ItemType))
return false;
if (IsKey != inputColumn.IsKey)
return false;
Expand All @@ -72,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn)

public string GetTypeString()
{
string result = ItemKind.ToString();
string result = ItemType.ToString();
if (IsKey)
result = $"Key<{result}>";
if (Kind == VectorKind.Vector)
Expand Down Expand Up @@ -110,13 +132,15 @@ public static SchemaShape Create(ISchema schema)
else
vecKind = Column.VectorKind.Scalar;

var kind = type.ItemType.RawKind;
ColumnType itemType = type.ItemType;
if (type.ItemType.IsKey)
itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
var isKey = type.ItemType.IsKey;

var metadataNames = schema.GetMetadataTypes(iCol)
.Select(kvp => kvp.Key)
.ToArray();
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames));
}
}
return new SchemaShape(cols.ToArray());
Expand Down
35 changes: 35 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has
/// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>.
///
/// Concrete implementations still have to provide the schema propagation mechanism, since
/// there is no easy way to infer it from the transformer.
/// </summary>
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>
where TTransformer : class, ITransformer
{
protected readonly IHost Host;
protected readonly TTransformer Transformer;

protected TrivialEstimator(IHost host, TTransformer transformer)
{
Contracts.AssertValue(host);

Host = host;
Host.CheckValue(transformer, nameof(transformer));
Transformer = transformer;
}

public TTransformer Fit(IDataView input) => Transformer;

public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var originalColumn = inputSchema.FindColumn(column.Source);
if (originalColumn != null)
{
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds);
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds);
resultDic[column.Name] = col;
}
else
Expand Down
177 changes: 177 additions & 0 deletions src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Model;

namespace Microsoft.ML.Runtime.Data
{
public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel
{
protected readonly IHost Host;
protected readonly (string input, string output)[] ColumnPairs;

protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns)
{
Contracts.AssertValue(host);
host.CheckValue(columns, nameof(columns));

var newNames = new HashSet<string>();
foreach (var column in columns)
{
host.CheckNonEmpty(column.input, nameof(columns));
host.CheckNonEmpty(column.output, nameof(columns));

if (!newNames.Add(column.output))
throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times");
}

Host = host;
ColumnPairs = columns;
}

protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx)
{
Host = host;
// *** Binary format ***
// int: number of added columns
// for each added column
// int: id of output column name
// int: id of input column name

int n = ctx.Reader.ReadInt32();
ColumnPairs = new (string input, string output)[n];
for (int i = 0; i < n; i++)
{
string output = ctx.LoadNonEmptyString();
string input = ctx.LoadNonEmptyString();
ColumnPairs[i] = (input, output);
}
}

public abstract void Save(ModelSaveContext ctx);

protected void SaveColumns(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

// *** Binary format ***
// int: number of added columns
// for each added column
// int: id of output column name
// int: id of input column name

ctx.Writer.Write(ColumnPairs.Length);
for (int i = 0; i < ColumnPairs.Length; i++)
{
ctx.SaveNonEmptyString(ColumnPairs[i].output);
ctx.SaveNonEmptyString(ColumnPairs[i].input);
}
}

private void CheckInput(ISchema inputSchema, int col, out int srcCol)
{
Contracts.AssertValue(inputSchema);
Contracts.Assert(0 <= col && col < ColumnPairs.Length);

if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input);
CheckInputColumn(inputSchema, col, srcCol);
}

protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
// By default, there are no extra checks.
}

protected abstract IRowMapper MakeRowMapper(ISchema schema);

public ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

// Check that all the input columns are present and correct.
for (int i = 0; i < ColumnPairs.Length; i++)
CheckInput(inputSchema, i, out int col);

return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}

public IDataView Transform(IDataView input) => MakeDataTransform(input);

protected RowToRowMapperTransform MakeDataTransform(IDataView input)
{
Host.CheckValue(input, nameof(input));
return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema));
}

protected abstract class MapperBase : IRowMapper
{
protected readonly IHost Host;
protected readonly Dictionary<int, int> ColMapNewToOld;
protected readonly ISchema InputSchema;
private readonly OneToOneTransformerBase _parent;

protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema)
{
Contracts.AssertValue(host);
Contracts.AssertValue(parent);
Contracts.AssertValue(inputSchema);

Host = host;
_parent = parent;

ColMapNewToOld = new Dictionary<int, int>();
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
_parent.CheckInput(inputSchema, i, out int srcCol);
ColMapNewToOld.Add(i, srcCol);
}
InputSchema = inputSchema;
}
public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.ColumnCount];
foreach (var pair in ColMapNewToOld)
if (activeOutput(pair.Key))
active[pair.Value] = true;
return col => active[col];
}

public abstract RowMapperColumnInfo[] GetOutputColumns();

public void Save(ModelSaveContext ctx) => _parent.Save(ctx);

public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer)
{
Contracts.Assert(input.Schema == InputSchema);
var result = new Delegate[_parent.ColumnPairs.Length];
var disposers = new Action[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
if (!activeOutput(i))
continue;
int srcCol = ColMapNewToOld[i];
result[i] = MakeGetter(input, i, out disposers[i]);
}
if (disposers.Any(x => x != null))
{
disposer = () =>
{
foreach (var act in disposers)
act();
};
}
else
disposer = null;
return result;
}

protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer);
}
}
}
8 changes: 4 additions & 4 deletions src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static class ImageAnalytics
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input);
var xf = new ImageLoaderTransform(h, input, input.Data);
var xf = ImageLoaderTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
Expand All @@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im
public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input);
var xf = new ImageResizerTransform(h, input, input.Data);
var xf = ImageResizerTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
Expand All @@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I
public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input);
var xf = new ImagePixelExtractorTransform(h, input, input.Data);
var xf = ImagePixelExtractorTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
Expand All @@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment
public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input);
var xf = new ImageGrayscaleTransform(h, input, input.Data);
var xf = ImageGrayscaleTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, xf, input.Data),
Expand Down
Loading