Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Api/CustomMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var addedCols = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema);
var addedSchemaShape = SchemaShape.Create(SchemaBuilder.MakeSchema(addedCols));

var result = inputSchema.Columns.ToDictionary(x => x.Name);
var result = inputSchema.ToDictionary(x => x.Name);
var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition);
foreach (var col in inputDef.Columns)
{
Expand All @@ -223,7 +223,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
}
}

foreach (var addedCol in addedSchemaShape.Columns)
foreach (var addedCol in addedSchemaShape)
result[addedCol.Name] = addedCol;

return new SchemaShape(result.Values);
Expand Down
61 changes: 36 additions & 25 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace Microsoft.ML.Core.Data
Expand All @@ -16,13 +17,17 @@ namespace Microsoft.ML.Core.Data
/// This is more relaxed than the proper <see cref="ISchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="ColumnType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: IReadOnlyList<SchemaShape.Column> [](start = 36, length = 35)

This is not right! You can make this a ReadOnlyList with proper indexers and so forth. Or you could expose this Columns thing as you do below. You CANNOT DO BOTH.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right. Thanks a lot for pointing this out.


In reply to: 239226291 [](ancestors = 239226291)

{
public readonly Column[] Columns;
private readonly Column[] _columns;

private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());

public sealed class Column
public int Count => _columns.Count();

public Column this[int index] => _columns[index];

public struct Column
{
public enum VectorKind
{
Expand Down Expand Up @@ -55,13 +60,13 @@ public enum VectorKind
/// </summary>
public readonly SchemaShape Metadata;

public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
[BestFriend]
internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValueOrNull(metadata);
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");

Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");

Name = name;
Expand All @@ -80,9 +85,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey,
/// - The columns of <see cref="Metadata"/> of <paramref name="inputColumn"/> is a superset of our <see cref="Metadata"/> columns.
/// - Each such metadata column is itself compatible with the input metadata column.
/// </summary>
public bool IsCompatibleWith(Column inputColumn)
[BestFriend]
internal bool IsCompatibleWith(Column inputColumn)
{
Contracts.CheckValue(inputColumn, nameof(inputColumn));
Contracts.Check(inputColumn.IsValid, nameof(inputColumn));
if (Name != inputColumn.Name)
return false;
if (Kind != inputColumn.Kind)
Expand All @@ -91,7 +97,7 @@ public bool IsCompatibleWith(Column inputColumn)
return false;
if (IsKey != inputColumn.IsKey)
return false;
foreach (var metaCol in Metadata.Columns)
foreach (var metaCol in Metadata)
{
if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol))
return false;
Expand All @@ -101,7 +107,8 @@ public bool IsCompatibleWith(Column inputColumn)
return true;
}

public string GetTypeString()
[BestFriend]
internal string GetTypeString()
{
string result = ItemType.ToString();
if (IsKey)
Expand All @@ -112,13 +119,20 @@ public string GetTypeString()
result = $"VarVector<{result}>";
return result;
}

/// <summary>
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
/// it means this structure is initialized properly and therefore considered as valid.
/// </summary>
[BestFriend]
internal bool IsValid => Name != null;
}

public SchemaShape(IEnumerable<Column> columns)
{
Contracts.CheckValue(columns, nameof(columns));
Columns = columns.ToArray();
Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null.");
_columns = columns.ToArray();
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
}

/// <summary>
Expand Down Expand Up @@ -151,7 +165,8 @@ internal static void GetColumnTypeShape(ColumnType type,
/// <summary>
/// Create a schema shape out of the fully defined schema.
/// </summary>
public static SchemaShape Create(Schema schema)
[BestFriend]
internal static SchemaShape Create(Schema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();
Expand Down Expand Up @@ -179,25 +194,23 @@ public static SchemaShape Create(Schema schema)
/// <summary>
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
/// </summary>
public bool TryFindColumn(string name, out Column column)
[BestFriend]
internal bool TryFindColumn(string name, out Column column)
{
Contracts.CheckValue(name, nameof(name));
column = Columns.FirstOrDefault(x => x.Name == name);
return column != null;
column = _columns.FirstOrDefault(x => x.Name == name);
return column.IsValid;
}

public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}

/// <summary>
/// Exception class for schema validation errors.
/// </summary>
public class SchemaException : Exception
{
}

/// <summary>
/// The 'data reader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
Expand Down Expand Up @@ -246,7 +259,6 @@ public interface ITransformer
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// Throws <see cref="SchemaException"/> if the input schema is not valid for the transformer.
/// </summary>
Schema GetOutputSchema(Schema inputSchema);

Expand Down Expand Up @@ -288,7 +300,6 @@ public interface IEstimator<out TTransformer>
/// <summary>
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the estimator.
/// </summary>
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ public static bool IsNormalized(this Schema schema, int col)
/// of a scalar <see cref="BoolType"/> type, which we assume, if set, should be <c>true</c>.</returns>
public static bool IsNormalized(this SchemaShape.Column col)
{
Contracts.CheckValue(col, nameof(col));
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BoolType.Instance;
Expand All @@ -382,7 +382,7 @@ public static bool IsNormalized(this SchemaShape.Column col)
/// <see cref="Kinds.SlotNames"/> metadata of definite sized vectors of text.</returns>
public static bool HasSlotNames(this SchemaShape.Column col)
{
Contracts.CheckValue(col, nameof(col));
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Core/Utilities/Contracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env)
public static void CheckValueOrNull<T>(T val) where T : class
{
}

/// <summary>
/// This documents that the parameter can legally be null.
/// </summary>
[Conditional("INVARIANT_CHECKS")]
public static void CheckValueOrNull<T>(this IExceptionContext ctx, T val) where T : class
{
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
{
_env = env;
_shape = inputShape;
_colMap = Enumerable.Range(0, _shape.Columns.Length)
.ToDictionary(idx => _shape.Columns[idx].Name, idx => idx);
_colMap = Enumerable.Range(0, _shape.Count)
.ToDictionary(idx => _shape[idx].Name, idx => idx);
}

public int ColumnCount => _shape.Columns.Length;
public int ColumnCount => _shape.Count;

public string GetColumnName(int col)
{
_env.Check(0 <= col && col < ColumnCount);
return _shape.Columns[col].Name;
return _shape[col].Name;
}

public ColumnType GetColumnType(int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var inputCol = _shape[col];
return MakeColumnType(inputCol);
}

Expand All @@ -66,7 +66,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
throw _env.ExceptGetMetadata();
Expand All @@ -89,7 +89,7 @@ public void GetMetadata<TValue>(string kind, int col, ref TValue value)
public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
return null;
Expand All @@ -99,12 +99,12 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col)
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
_env.Check(0 <= col && col < ColumnCount);
var inputCol = _shape.Columns[col];
var inputCol = _shape[col];
var metaShape = inputCol.Metadata;
if (metaShape == null)
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();

return metaShape.Columns.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
return metaShape.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void Check(IExceptionContext ectx, SchemaShape shape)

private static Type GetTypeOrNull(SchemaShape.Column col)
{
Contracts.AssertValue(col);
Contracts.Assert(col.IsValid);

Type vecType = null;

Expand Down
41 changes: 15 additions & 26 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,11 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
private protected TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null)
SchemaShape.Column weight = default)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Host.CheckValue(feature, nameof(feature));
Host.CheckValueOrNull(label);
Host.CheckValueOrNull(weight);
Host.CheckParam(feature.IsValid, nameof(feature), "not initialized properly");

FeatureColumn = feature;
LabelColumn = label;
Expand All @@ -76,7 +74,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)

CheckInputSchema(inputSchema);

var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
var outColumns = inputSchema.ToDictionary(x => x.Name);
foreach (var col in GetOutputColumnsCore(inputSchema))
outColumns[col.Name] = col;

Expand All @@ -102,7 +100,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
if (!FeatureColumn.IsCompatibleWith(featureCol))
throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible");

if (WeightColumn != null)
if (WeightColumn.IsValid)
{
if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
throw Host.Except($"Weight column '{WeightColumn.Name}' is not found");
Expand All @@ -112,7 +110,7 @@ private void CheckInputSchema(SchemaShape inputSchema)

// Special treatment for label column: we allow different types of labels, so the trainers
// may define their own requirements on the label column.
if (LabelColumn != null)
if (LabelColumn.IsValid)
{
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
Expand All @@ -122,8 +120,8 @@ private void CheckInputSchema(SchemaShape inputSchema)

protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.CheckValue(labelCol, nameof(labelCol));
Contracts.AssertValue(LabelColumn);
Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly");
Host.Assert(LabelColumn.IsValid);

if (!LabelColumn.IsCompatibleWith(labelCol))
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
Expand All @@ -133,20 +131,12 @@ protected TTransformer TrainTransformer(IDataView trainSet,
IDataView validationSet = null, IPredictor initPredictor = null)
{
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, trainSet, prefetch: null) : trainSet;
var cachedValid = Info.WantCaching && validationSet != null ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;

var trainRoles = MakeRoles(cachedTrain);
var trainRoleMapped = MakeRoles(cachedTrain);
var validRoleMapped = validationSet == null ? null : MakeRoles(cachedValid);

RoleMappedData validRoles;

if (validationSet == null)
validRoles = null;
else
{
var cachedValid = Info.WantCaching ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;
validRoles = MakeRoles(cachedValid);
}

var pred = TrainModelCore(new TrainContext(trainRoles, validRoles, null, initPredictor));
var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
return MakeTransformer(pred, trainSet.Schema);
}

Expand All @@ -156,7 +146,7 @@ protected TTransformer TrainTransformer(IDataView trainSet,
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);

protected virtual RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name);

IPredictor ITrainer.Train(TrainContext context) => ((ITrainer<TModel>)this).Train(context);
}
Expand All @@ -178,16 +168,15 @@ public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : Tr
public TrainerEstimatorBaseWithGroupId(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null,
SchemaShape.Column groupId = null)
SchemaShape.Column weight = default,
SchemaShape.Column groupId = default)
:base(host, feature, label, weight)
{
Host.CheckValueOrNull(groupId);
GroupIdColumn = groupId;
}

protected override RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);
new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, group: GroupIdColumn.Name, weight: WeightColumn.Name);

}
}
Loading