From 568b24e24a084fa7490c087497585dddb5870384 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Wed, 28 Nov 2018 16:06:16 -0800 Subject: [PATCH 1/3] Stop using IColumn in several places. * Stop using IColumn in predicted label scorers. * Stop using IColumn in static schema shape analysis. * Stop using IColumn in linear model statistics creation. --- src/Microsoft.ML.Core/Data/MetadataBuilder.cs | 46 ++++++++++---- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 29 +++++++++ src/Microsoft.ML.Core/Utilities/Utils.cs | 9 +++ .../Scorers/PredictedLabelScorerBase.cs | 47 +++++++------- .../StaticPipe/StaticSchemaShape.cs | 17 +++--- .../Standard/LinearPredictor.cs | 31 +++------- .../MulticlassLogisticRegression.cs | 11 ++-- .../Standard/ModelStatistics.cs | 61 +++++++++---------- 8 files changed, 148 insertions(+), 103 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs index 0b72f53a25..be07f1fc1b 100644 --- a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs +++ b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs @@ -16,11 +16,11 @@ namespace Microsoft.ML.Data /// public sealed class MetadataBuilder { - private readonly List<(string Name, ColumnType Type, Delegate Getter)> _items; + private readonly List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)> _items; public MetadataBuilder() { - _items = new List<(string Name, ColumnType Type, Delegate Getter)>(); + _items = new List<(string Name, ColumnType Type, Delegate Getter, Schema.Metadata Metadata)>(); } /// @@ -40,7 +40,7 @@ public void Add(Schema.Metadata metadata, Func selector) foreach (var column in metadata.Schema) { if (selector(column.Name)) - _items.Add((column.Name, column.Type, metadata.Getters[column.Index])); + _items.Add((column.Name, column.Type, metadata.Getters[column.Index], column.Metadata)); } } @@ -51,13 +51,17 @@ public void Add(Schema.Metadata metadata, Func selector) /// The metadata name. /// The metadata type. /// The getter delegate. - public void Add(string name, ColumnType type, ValueGetter getter) + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void Add(string name, ColumnType type, ValueGetter getter, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); Contracts.CheckValue(getter, nameof(getter)); - Contracts.CheckParam(type.RawType == typeof(TValue), nameof(getter)); - _items.Add((name, type, getter)); + Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type)); + Contracts.CheckValueOrNull(metadata); + + _items.Add((name, type, getter, metadata)); } /// @@ -67,11 +71,31 @@ public void Add(string name, ColumnType type, ValueGetter getter /// The metadata type. /// The getter delegate that provides the value. Note that the type of the getter is still checked /// inside this method. - public void Add(string name, ColumnType type, Delegate getter) + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void Add(string name, ColumnType type, Delegate getter, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); - Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter); + Contracts.CheckValueOrNull(metadata); + Utils.MarshalActionInvoke(AddDelegate, type.RawType, name, type, getter, metadata); + } + + /// + /// Add one metadata column for a primitive value type. + /// + /// The metadata name. + /// The metadata type. + /// The value of the metadata. + /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare + /// except for certain types (for example, slot names for a vector, key values for something of key type). + public void AddRawValue(string name, PrimitiveType type, TValue value, Schema.Metadata metadata = null) + { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(type, nameof(type)); + Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type)); + Contracts.CheckValueOrNull(metadata); + Add(name, type, (ref TValue dst) => dst = value, metadata); } /// @@ -100,11 +124,11 @@ public Schema.Metadata GetMetadata() { var builder = new SchemaBuilder(); foreach (var item in _items) - builder.AddColumn(item.Name, item.Type, null); + builder.AddColumn(item.Name, item.Type, item.Metadata); return new Schema.Metadata(builder.GetSchema(), _items.Select(x => x.Getter).ToArray()); } - private void AddDelegate(string name, ColumnType type, Delegate getter) + private void AddDelegate(string name, ColumnType type, Delegate getter, Schema.Metadata metadata) { Contracts.AssertNonEmpty(name); Contracts.AssertValue(type); @@ -112,7 +136,7 @@ private void AddDelegate(string name, ColumnType type, Delegate getter) var typedGetter = getter as ValueGetter; Contracts.CheckParam(typedGetter != null, nameof(getter)); - _items.Add((name, type, typedGetter)); + _items.Add((name, type, typedGetter, metadata)); } } } diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 9178c1c129..6e2bd5230f 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -494,5 +494,34 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, cols.AddRange(GetTrainerOutputMetadata()); return cols; } + + private sealed class MetadataRow : IRow + { + private readonly Schema.Metadata _metadata; + + public MetadataRow(Schema.Metadata metadata) + { + Contracts.AssertValue(metadata); + _metadata = metadata; + } + + public Schema Schema => _metadata.Schema; + public long Position => 0; + public long Batch => 0; + public ValueGetter GetGetter(int col) => _metadata.GetGetter(col); + public ValueGetter GetIdGetter() => (ref UInt128 dst) => dst = default; + public bool IsColumnActive(int col) => true; + } + + /// + /// Presents a as a an . + /// + /// The metadata to wrap. + /// A row that wraps an input metadata. + public static IRow MetadataAsRow(Schema.Metadata metadata) + { + Contracts.CheckValue(metadata, nameof(metadata)); + return new MetadataRow(metadata); + } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 0fcf55a449..3395bdfe06 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -1076,6 +1076,15 @@ public static void MarshalActionInvoke(Action + /// A four-argument version of . + /// + public static void MarshalActionInvoke(Action act, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4) + { + var meth = MarshalActionInvokeCheckAndCreate(genArg, act); + meth.Invoke(act.Target, new object[] { arg1, arg2, arg3, arg4 }); + } + public static string GetDescription(this Enum value) { Type type = value.GetType(); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index e080e9f7be..8aff6ff69a 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -41,7 +41,7 @@ protected sealed class BindingsImpl : BindingsBase private readonly MetadataUtils.MetadataGetter> _getScoreColumnKind; private readonly MetadataUtils.MetadataGetter> _getScoreValueKind; - private readonly IRow _predColMetadata; + private readonly Schema.Metadata _predColMetadata; private BindingsImpl(Schema input, ISchemaBoundRowMapper mapper, string suffix, string scoreColumnKind, bool user, int scoreColIndex, ColumnType predColType) : base(input, mapper, suffix, user, DefaultColumnNames.PredictedLabel) @@ -59,42 +59,39 @@ private BindingsImpl(Schema input, ISchemaBoundRowMapper mapper, string suffix, // REVIEW: This logic is very specific to multiclass, which is deeply // regrettable, but the class structure as designed and the status of this schema // bearing object makes pushing the logic into the multiclass scorer almost impossible. - if (predColType.IsKey) + if (predColType is KeyType predColKeyType && predColKeyType.Count > 0) { - ColumnType scoreSlotsType = mapper.OutputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreColIndex); - if (scoreSlotsType != null && scoreSlotsType.IsKnownSizeVector && - scoreSlotsType.VectorSize == predColType.KeyCount) + var scoreColMetadata = mapper.OutputSchema[scoreColIndex].Metadata; + + var slotColumn = scoreColMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames); + if (slotColumn?.Type is VectorType slotColVecType && slotColVecType.Size == predColKeyType.Count) { - Contracts.Assert(scoreSlotsType.VectorSize > 0); - IColumn col = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, - scoreSlotsType.RawType, mapper.OutputSchema, scoreColIndex, MetadataUtils.Kinds.SlotNames); - _predColMetadata = RowColumnUtils.GetRow(null, col); + Contracts.Assert(slotColVecType.Size > 0); + _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, slotColVecType.RawType, + scoreColMetadata, slotColumn.Value); } else { - scoreSlotsType = mapper.OutputSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.TrainingLabelValues, scoreColIndex); - if (scoreSlotsType != null && scoreSlotsType.IsKnownSizeVector && - scoreSlotsType.VectorSize == predColType.KeyCount) + var trainLabelColumn = scoreColMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.TrainingLabelValues); + if (trainLabelColumn?.Type is VectorType trainLabelColVecType && trainLabelColVecType.Size == predColKeyType.Count) { - Contracts.Assert(scoreSlotsType.VectorSize > 0); - IColumn col = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, - scoreSlotsType.RawType, mapper.OutputSchema, scoreColIndex, MetadataUtils.Kinds.TrainingLabelValues); - _predColMetadata = RowColumnUtils.GetRow(null, col); + Contracts.Assert(trainLabelColVecType.Size > 0); + _predColMetadata = Utils.MarshalInvoke(KeyValueMetadataFromMetadata, trainLabelColVecType.RawType, + scoreColMetadata, trainLabelColumn.Value); } } } } - private static IColumn KeyValueMetadataFromMetadata(ISchema schema, int col, string metadataName) + private static Schema.Metadata KeyValueMetadataFromMetadata(Schema.Metadata meta, Schema.Column metaCol) { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - var type = schema.GetMetadataTypeOrNull(metadataName, col); - Contracts.AssertValue(type); - Contracts.Assert(type.RawType == typeof(T)); - - ValueGetter getter = (ref T val) => schema.GetMetadata(metadataName, col, ref val); - return RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, type, getter); + Contracts.AssertValue(meta); + Contracts.Assert(0 <= metaCol.Index && metaCol.Index < meta.Schema.ColumnCount); + Contracts.Assert(metaCol.Type.RawType == typeof(T)); + var getter = meta.GetGetter(metaCol.Index); + var builder = new MetadataBuilder(); + builder.Add(MetadataUtils.Kinds.KeyValues, metaCol.Type, meta.GetGetter(metaCol.Index)); + return builder.GetMetadata(); } public static BindingsImpl Create(Schema input, ISchemaBoundRowMapper mapper, string suffix, diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs index 08d6936721..6ffd089f24 100644 --- a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs +++ b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Reflection; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -51,7 +52,7 @@ public static StaticSchemaShape Make(ParameterInfo info) /// /// The context on which to throw exceptions /// The schema to check - public void Check(IExceptionContext ectx, ISchema schema) + public void Check(IExceptionContext ectx, Schema schema) { Contracts.AssertValue(ectx); ectx.AssertValue(schema); @@ -60,7 +61,7 @@ public void Check(IExceptionContext ectx, ISchema schema) { if (!schema.TryGetColumnIndex(pair.Key, out int colIdx)) throw ectx.ExceptParam(nameof(schema), $"Column named '{pair.Key}' was not found"); - var col = RowColumnUtils.GetColumn(schema, colIdx); + var col = schema[colIdx]; var type = GetTypeOrNull(col); if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value))) { @@ -234,9 +235,8 @@ private static bool IsStandardCore(Type t) /// The column /// The .NET type for the static pipelines that should be used to reflect this type, given /// both the characteristics of the as well as one or two crucial pieces of metadata - private static Type GetTypeOrNull(IColumn col) + private static Type GetTypeOrNull(Schema.Column col) { - Contracts.AssertValue(col); var t = col.Type; Type vecType = null; @@ -278,13 +278,14 @@ private static Type GetTypeOrNull(IColumn col) { // Check to see if we have key value metadata of the appropriate type, size, and whatnot. var meta = col.Metadata; - if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.KeyValues, out int kvcol)) + if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.KeyValues, out int kvcolIndex)) { - var kvType = meta.Schema.GetColumnType(kvcol); - if (kvType.VectorSize == kt.Count) + var kvcol = meta.Schema[kvcolIndex]; + var kvType = kvcol.Type; + if (kvType is VectorType kvVecType && kvVecType.Size == kt.Count) { Contracts.Assert(kt.Count > 0); - var subtype = GetTypeOrNull(RowColumnUtils.GetColumn(meta, kvcol)); + var subtype = GetTypeOrNull(kvcol); if (subtype != null && subtype.IsGenericType) { var sgtype = subtype.GetGenericTypeDefinition(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index f8fb515d9f..fd21cc5377 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -20,6 +20,7 @@ using Microsoft.ML.Runtime.Model.Pfa; using Microsoft.ML.Runtime.Numeric; using Newtonsoft.Json.Linq; +using Microsoft.ML.Data; // This is for deserialization from a model repository. [assembly: LoadableClass(typeof(IPredictorProducing), typeof(LinearBinaryPredictor), null, typeof(SignatureLoadModel), @@ -347,27 +348,18 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema) { - var cols = new List(); - var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); - var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, Weight.Length), ref names); - var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol); + var subBuilder = new MetadataBuilder(); + subBuilder.AddSlotNames(Weight.Length, (ref VBuffer> dst) => names.CopyTo(ref dst)); var colType = new VectorType(NumberType.R4, Weight.Length); - - // Add the bias and the weight columns. - var bias = Bias; - cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias)); - var weights = Weight; - cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow)); - return RowColumnUtils.GetRow(null, cols.ToArray()); + var builder = new MetadataBuilder(); + builder.AddRawValue("Bias", NumberType.R4, Bias); + builder.Add("Weights", colType, (ref VBuffer dst) => Weight.CopyTo(ref dst), subBuilder.GetMetadata()); + return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public virtual IRow GetStatsIRowOrNull(RoleMappedSchema schema) - { - return null; - } + public virtual IRow GetStatsIRowOrNull(RoleMappedSchema schema) => null; public abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); @@ -514,13 +506,10 @@ public override IRow GetStatsIRowOrNull(RoleMappedSchema schema) { if (_stats == null) return null; - var cols = new List(); var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); - - // Add the stat columns. - _stats.AddStatsColumns(cols, this, schema, in names); - return RowColumnUtils.GetRow(null, cols.ToArray()); + var meta = _stats.MakeStatisticsMetadata(this, schema, in names); + return MetadataUtils.MetadataAsRow(meta); } public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index b159aaa4be..6d8a411e7c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -440,9 +440,9 @@ public MulticlassLogisticRegressionPredictor(IHostEnvironment env, VBuffer= 2, "numClasses must be at least 2."); + Contracts.CheckParam(numClasses >= 2, nameof(numClasses), "Must be at least 2."); _numClasses = numClasses; - Contracts.Check(numFeatures >= 1, "numFeatures must be positive."); + Contracts.CheckParam(numFeatures >= 1, nameof(numFeatures), "Must be positive."); _numFeatures = numFeatures; Contracts.Check(Utils.Size(weights) == _numClasses); Contracts.Check(Utils.Size(bias) == _numClasses); @@ -992,10 +992,9 @@ public IRow GetStatsIRowOrNull(RoleMappedSchema schema) if (_stats == null) return null; - var cols = new List(); - var names = default(VBuffer>); - _stats.AddStatsColumns(cols, null, schema, in names); - return RowColumnUtils.GetRow(null, cols.ToArray()); + VBuffer> names = default; + var meta = _stats.MakeStatisticsMetadata(null, schema, in names); + return MetadataUtils.MetadataAsRow(meta); } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index e71cb85257..67db04c177 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.CpuMath; @@ -408,54 +409,50 @@ public void SaveSummaryInKeyValuePairs(LinearBinaryPredictor parent, } } - public void AddStatsColumns(List list, LinearBinaryPredictor parent, RoleMappedSchema schema, in VBuffer> names) + internal Schema.Metadata MakeStatisticsMetadata(LinearBinaryPredictor parent, RoleMappedSchema schema, in VBuffer> names) { - _env.AssertValue(list); _env.AssertValueOrNull(parent); _env.AssertValue(schema); - long count = _trainingExampleCount; - list.Add(RowColumnUtils.GetColumn("Count of training examples", NumberType.I8, ref count)); - var dev = _deviance; - list.Add(RowColumnUtils.GetColumn("Residual Deviance", NumberType.R4, ref dev)); - var nullDev = _nullDeviance; - list.Add(RowColumnUtils.GetColumn("Null Deviance", NumberType.R4, ref nullDev)); - var aic = 2 * _paramCount + _deviance; - list.Add(RowColumnUtils.GetColumn("AIC", NumberType.R4, ref aic)); + var builder = new MetadataBuilder(); + + builder.AddRawValue("Count of training examples", NumberType.I8, _trainingExampleCount); + builder.AddRawValue("Residual Deviance", NumberType.R4, _deviance); + builder.AddRawValue("Null Deviance", NumberType.R4, _nullDeviance); + builder.AddRawValue("AIC", NumberType.R4, 2 * _paramCount + _deviance); if (parent == null) - return; + return builder.GetMetadata(); - Single biasStdErr; - Single biasZScore; - Single biasPValue; - if (!TryGetBiasStatistics(parent.Statistics, parent.Bias, out biasStdErr, out biasZScore, out biasPValue)) - return; + if (!TryGetBiasStatistics(parent.Statistics, parent.Bias, out float biasStdErr, out float biasZScore, out float biasPValue)) + return builder.GetMetadata(); var biasEstimate = parent.Bias; - list.Add(RowColumnUtils.GetColumn("BiasEstimate", NumberType.R4, ref biasEstimate)); - list.Add(RowColumnUtils.GetColumn("BiasStandardError", NumberType.R4, ref biasStdErr)); - list.Add(RowColumnUtils.GetColumn("BiasZScore", NumberType.R4, ref biasZScore)); - list.Add(RowColumnUtils.GetColumn("BiasPValue", NumberType.R4, ref biasPValue)); + builder.AddRawValue("BiasEstimate", NumberType.R4, biasEstimate); + builder.AddRawValue("BiasStandardError", NumberType.R4, biasStdErr); + builder.AddRawValue("BiasZScore", NumberType.R4, biasZScore); + builder.AddRawValue("BiasPValue", NumberType.R4, biasPValue); - var weights = default(VBuffer); + var weights = default(VBuffer); parent.GetFeatureWeights(ref weights); - var estimate = default(VBuffer); - var stdErr = default(VBuffer); - var zScore = default(VBuffer); - var pValue = default(VBuffer); + var estimate = default(VBuffer); + var stdErr = default(VBuffer); + var zScore = default(VBuffer); + var pValue = default(VBuffer); ValueGetter>> getSlotNames; GetUnorderedCoefficientStatistics(parent.Statistics, in weights, in names, ref estimate, ref stdErr, ref zScore, ref pValue, out getSlotNames); - var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, stdErr.Length), getSlotNames); - var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol); + var subMetaBuilder = new MetadataBuilder(); + subMetaBuilder.AddSlotNames(stdErr.Length, getSlotNames); + var subMeta = subMetaBuilder.GetMetadata(); var colType = new VectorType(NumberType.R4, stdErr.Length); - list.Add(RowColumnUtils.GetColumn("Estimate", colType, ref estimate, slotNamesRow)); - list.Add(RowColumnUtils.GetColumn("StandardError", colType, ref stdErr, slotNamesRow)); - list.Add(RowColumnUtils.GetColumn("ZScore", colType, ref zScore, slotNamesRow)); - list.Add(RowColumnUtils.GetColumn("PValue", colType, ref pValue, slotNamesRow)); + builder.Add("Estimate", colType, (ref VBuffer dst) => estimate.CopyTo(ref dst), subMeta); + builder.Add("StandardError", colType, (ref VBuffer dst) => stdErr.CopyTo(ref dst), subMeta); + builder.Add("ZScore", colType, (ref VBuffer dst) => zScore.CopyTo(ref dst), subMeta); + builder.Add("PValue", colType, (ref VBuffer dst) => pValue.CopyTo(ref dst), subMeta); + + return builder.GetMetadata(); } private string DecorateProbabilityString(Single probZ) From ecc03082982677487631c2e5233cf5cf3a79092b Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 29 Nov 2018 00:05:21 -0800 Subject: [PATCH 2/3] Stop using IColumn in more places. * Stop using IColumn in FastTree statistics creation. * Stop using IColumn in benchmarking. * Stop using IColumn in many tests. * Add minor conveniences to metadata builder. * Allow metadata builder to have metadata of metadata. * Add appropriate validation in certain places. * Put warnings on an inappropriate method dealing with Batch that should not exist. --- src/Microsoft.ML.Core/Data/MetadataBuilder.cs | 2 +- src/Microsoft.ML.Core/Data/Schema.cs | 7 ++- src/Microsoft.ML.Core/Data/SchemaBuilder.cs | 4 ++ src/Microsoft.ML.Data/Data/IColumn.cs | 29 ---------- src/Microsoft.ML.Data/Data/RowCursorUtils.cs | 24 ++------ src/Microsoft.ML.FastTree/FastTree.cs | 11 ++-- .../Standard/LinearPredictor.cs | 2 +- .../Standard/ModelStatistics.cs | 16 ++--- test/Microsoft.ML.Benchmarks/HashBench.cs | 58 +++++++++++++++---- .../StaticPipeTests.cs | 33 +++++------ .../Transformers/HashTests.cs | 23 ++++---- 11 files changed, 100 insertions(+), 109 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs index be07f1fc1b..06bf090567 100644 --- a/src/Microsoft.ML.Core/Data/MetadataBuilder.cs +++ b/src/Microsoft.ML.Core/Data/MetadataBuilder.cs @@ -89,7 +89,7 @@ public void Add(string name, ColumnType type, Delegate getter, Schema.Metadata m /// The value of the metadata. /// Metadata of the input column. Note that metadata on a metadata column is somewhat rare /// except for certain types (for example, slot names for a vector, key values for something of key type). - public void AddRawValue(string name, PrimitiveType type, TValue value, Schema.Metadata metadata = null) + public void AddPrimitiveValue(string name, PrimitiveType type, TValue value, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); diff --git a/src/Microsoft.ML.Core/Data/Schema.cs b/src/Microsoft.ML.Core/Data/Schema.cs index e21bd41b86..ec45753331 100644 --- a/src/Microsoft.ML.Core/Data/Schema.cs +++ b/src/Microsoft.ML.Core/Data/Schema.cs @@ -194,7 +194,7 @@ public sealed class Metadata /// public Schema Schema { get; } - public static Metadata Empty { get; } = new Metadata(new Schema(Enumerable.Empty()), new Delegate[0]); + public static Metadata Empty { get; } = new Metadata(new Schema(new Column[0]), new Delegate[0]); /// /// Create a metadata row by supplying the schema columns and the getter delegates for all the values. @@ -256,11 +256,12 @@ public void GetValue(string kind, ref TValue value) /// /// This constructor should only be called by . /// - internal Schema(IEnumerable columns) + /// The input columns. The constructed instance takes ownership of the array. + internal Schema(Column[] columns) { Contracts.CheckValue(columns, nameof(columns)); - _columns = columns.ToArray(); + _columns = columns; _nameMap = new Dictionary(); for (int i = 0; i < _columns.Length; i++) { diff --git a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs index 9fb22a7026..28017f1e7d 100644 --- a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs +++ b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using System; using System.Collections.Generic; @@ -32,6 +33,9 @@ public SchemaBuilder() /// The column metadata. public void AddColumn(string name, ColumnType type, Schema.Metadata metadata) { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(type, nameof(type)); + Contracts.CheckValueOrNull(metadata); _items.Add((name, type, metadata)); } diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs index 9d2146ee37..c204d5157a 100644 --- a/src/Microsoft.ML.Data/Data/IColumn.cs +++ b/src/Microsoft.ML.Data/Data/IColumn.cs @@ -203,35 +203,6 @@ public static IRow GetRow(ICounted counted, params IColumn[] columns) return new RowColumnRow(counted, columns); } - /// - /// Given a column, returns a deep-copied memory-materialized version of it. Note that - /// it is acceptable for the column to be inactive: the returned column will likewise - /// be inactive. - /// - /// - /// A memory materialized version of which may be, - /// under appropriate circumstances, the input object itself - public static IColumn CloneColumn(IColumn column) - { - Contracts.CheckValue(column, nameof(column)); - return Utils.MarshalInvoke(CloneColumnCore, column.Type.RawType, column); - } - - private static IColumn CloneColumnCore(IColumn column) - { - Contracts.Assert(column is IValueColumn); - IRow meta = column.Metadata; - if (meta != null) - meta = RowCursorUtils.CloneRow(meta); - - var tcolumn = (IValueColumn)column; - if (!tcolumn.IsActive) - return new InactiveImpl(tcolumn.Name, meta, tcolumn.Type); - T val = default(T); - tcolumn.GetGetter()(ref val); - return GetColumn(tcolumn.Name, tcolumn.Type, ref val, meta); - } - /// /// The implementation for a simple wrapping of an . /// diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index 869f263f3a..57098d2895 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -326,7 +326,10 @@ private static Func GetIsNewGroupDelegateCore(IRow cursor, int col) }; } - public static Func GetIsNewBatchDelegate(IRow cursor, int batchSize) + [Obsolete("The usages of this appear to be based on a total misunderstanding of what Batch actually is. It is a mechanism " + + "to enable sharding and recovery of parallelized data, and has nothing to do with actual data.")] + [BestFriend] + internal static Func GetIsNewBatchDelegate(IRow cursor, int batchSize) { Contracts.CheckParam(batchSize > 0, nameof(batchSize), "Batch size must be > 0"); long lastNewBatchPosition = -1; @@ -454,20 +457,6 @@ public static ValueGetter> GetLabelGetter(ISlotCursor cursor) }; } - /// - /// Returns a row that is a deep in-memory copy of an input row. Note that inactive - /// columns are allowed in this row, and their activity or inactivity will be reflected - /// in the output row. Note that the deep copy includes a copy of the metadata as well. - /// - /// The input row - /// A deep in-memory copy of the input row - public static IRow CloneRow(IRow row) - { - Contracts.CheckValue(row, nameof(row)); - return RowColumnUtils.GetRow(null, - Utils.BuildArray(row.Schema.ColumnCount, c => RowColumnUtils.GetColumn(row, c))); - } - /// /// Fetches the value of the column by name, in the given row. /// Used by the evaluators to retrieve the metrics from the results IDataView. @@ -487,8 +476,7 @@ public static T Fetch(IExceptionContext ectx, IRow row, string name) /// but want to save it somewhere using a .) /// Note that it is not possible for this method to ensure that the input does not /// change, so users of this convenience must take care of what they do with the input row or the data - /// source it came from, while the returned dataview is potentially being used; if this is somehow - /// difficult it may be wise to use to first have a deep copy of the resulting row. + /// source it came from, while the returned dataview is potentially being used. /// /// An environment used to create the host for the resulting data view /// A row, whose columns must all be active @@ -507,7 +495,7 @@ private sealed class OneRowDataView : IDataView private readonly IHost _host; // A channel provider is required for creating the cursor. public Schema Schema => _row.Schema; - public bool CanShuffle { get { return true; } } // The shuffling is even uniformly IID!! :) + public bool CanShuffle => true; // The shuffling is even uniformly IID!! :) public OneRowDataView(IHostEnvironment env, IRow row) { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 1238b8392c..73d032dce5 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -3305,13 +3306,15 @@ public IRow GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names); - var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames, - new VectorType(TextType.Instance, NumFeatures), ref names); - var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol); + var metaBuilder = new MetadataBuilder(); + metaBuilder.AddSlotNames(NumFeatures, names.CopyTo); var weights = default(VBuffer); GetFeatureWeights(ref weights); - return RowColumnUtils.GetRow(null, RowColumnUtils.GetColumn("Gains", new VectorType(NumberType.R4, NumFeatures), ref weights, slotNamesRow)); + var builder = new MetadataBuilder(); + builder.Add>("Gains", new VectorType(NumberType.R4, NumFeatures), weights.CopyTo, metaBuilder.GetMetadata()); + + return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } public IRow GetStatsIRowOrNull(RoleMappedSchema schema) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index fd21cc5377..72692c4016 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -354,7 +354,7 @@ public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema) subBuilder.AddSlotNames(Weight.Length, (ref VBuffer> dst) => names.CopyTo(ref dst)); var colType = new VectorType(NumberType.R4, Weight.Length); var builder = new MetadataBuilder(); - builder.AddRawValue("Bias", NumberType.R4, Bias); + builder.AddPrimitiveValue("Bias", NumberType.R4, Bias); builder.Add("Weights", colType, (ref VBuffer dst) => Weight.CopyTo(ref dst), subBuilder.GetMetadata()); return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index 67db04c177..4626ee5b31 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -416,10 +416,10 @@ internal Schema.Metadata MakeStatisticsMetadata(LinearBinaryPredictor parent, Ro var builder = new MetadataBuilder(); - builder.AddRawValue("Count of training examples", NumberType.I8, _trainingExampleCount); - builder.AddRawValue("Residual Deviance", NumberType.R4, _deviance); - builder.AddRawValue("Null Deviance", NumberType.R4, _nullDeviance); - builder.AddRawValue("AIC", NumberType.R4, 2 * _paramCount + _deviance); + builder.AddPrimitiveValue("Count of training examples", NumberType.I8, _trainingExampleCount); + builder.AddPrimitiveValue("Residual Deviance", NumberType.R4, _deviance); + builder.AddPrimitiveValue("Null Deviance", NumberType.R4, _nullDeviance); + builder.AddPrimitiveValue("AIC", NumberType.R4, 2 * _paramCount + _deviance); if (parent == null) return builder.GetMetadata(); @@ -428,10 +428,10 @@ internal Schema.Metadata MakeStatisticsMetadata(LinearBinaryPredictor parent, Ro return builder.GetMetadata(); var biasEstimate = parent.Bias; - builder.AddRawValue("BiasEstimate", NumberType.R4, biasEstimate); - builder.AddRawValue("BiasStandardError", NumberType.R4, biasStdErr); - builder.AddRawValue("BiasZScore", NumberType.R4, biasZScore); - builder.AddRawValue("BiasPValue", NumberType.R4, biasPValue); + builder.AddPrimitiveValue("BiasEstimate", NumberType.R4, biasEstimate); + builder.AddPrimitiveValue("BiasStandardError", NumberType.R4, biasStdErr); + builder.AddPrimitiveValue("BiasZScore", NumberType.R4, biasZScore); + builder.AddPrimitiveValue("BiasPValue", NumberType.R4, biasPValue); var weights = default(VBuffer); parent.GetFeatureWeights(ref weights); diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index 027411f016..6db0768287 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -17,35 +17,69 @@ namespace Microsoft.ML.Benchmarks { public class HashBench { - private sealed class Counted : ICounted + private sealed class Row : IRow { + public Schema Schema { get; } + public long Position { get; set; } public long Batch => 0; - public ValueGetter GetIdGetter() => (ref UInt128 val) => val = new UInt128((ulong)Position, 0); + + private readonly Delegate _getter; + + public bool IsColumnActive(int col) + { + if (col != 0) + throw new Exception(); + return true; + } + + public ValueGetter GetGetter(int col) + { + if (col != 0) + throw new Exception(); + if (_getter is ValueGetter typedGetter) + return typedGetter; + throw new Exception(); + } + + public static Row Create(ColumnType type, ValueGetter getter) + { + if (type.RawType != typeof(T)) + throw new Exception(); + return new Row(type, getter); + } + + private Row(ColumnType type, Delegate getter) + { + var builder = new SchemaBuilder(); + builder.AddColumn("Foo", type, null); + Schema = builder.GetSchema(); + _getter = getter; + } } private const int Count = 100_000; private readonly IHostEnvironment _env = new MLContext(); - private Counted _counted; + private Row _inRow; private ValueGetter _getter; private ValueGetter> _vecGetter; - private void InitMap(T val, ColumnType type, int hashBits = 20) + private void InitMap(T val, ColumnType type, int hashBits = 20, ValueGetter getter = null) { - var col = RowColumnUtils.GetColumn("Foo", type, ref val); - _counted = new Counted(); - var inRow = RowColumnUtils.GetRow(_counted, col); + if (getter == null) + getter = (ref T dst) => dst = val; + _inRow = Row.Create(type, getter); // One million features is a nice, typical number. var info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: hashBits); var xf = new HashingTransformer(_env, new[] { info }); - var mapper = xf.GetRowToRowMapper(inRow.Schema); + var mapper = xf.GetRowToRowMapper(_inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out int outCol); - var outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + var outRow = mapper.GetRow(_inRow, c => c == outCol, out var _); if (type is VectorType) _vecGetter = outRow.GetGetter>(outCol); else @@ -61,14 +95,14 @@ private void RunScalar() for (int i = 0; i < Count; ++i) { _getter(ref val); - ++_counted.Position; + ++_inRow.Position; } } private void InitDenseVecMap(T[] vals, PrimitiveType itemType, int hashBits = 20) { var vbuf = new VBuffer(vals.Length, vals); - InitMap(vbuf, new VectorType(itemType, vals.Length), hashBits); + InitMap(vbuf, new VectorType(itemType, vals.Length), hashBits, vbuf.CopyTo); } /// @@ -80,7 +114,7 @@ private void RunVector() for (int i = 0; i < Count; ++i) { _vecGetter(ref val); - ++_counted.Position; + ++_inRow.Position; } } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 4b69791e8c..35da6afbc0 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -254,44 +254,37 @@ public void AssertStaticSimpleFailure() hello: c.Text.Scalar))); } - private sealed class MetaCounted : ICounted - { - public long Position => 0; - public long Batch => 0; - public ValueGetter GetIdGetter() => (ref UInt128 v) => v = default; - } - [Fact] public void AssertStaticKeys() { var env = new MLContext(0); - var counted = new MetaCounted(); // We'll test a few things here. First, the case where the key-value metadata is text. var metaValues1 = new VBuffer>(3, new[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() }); - var meta1 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(TextType.Instance, 3), ref metaValues1); - uint value1 = 2; - var col1 = RowColumnUtils.GetColumn("stay", new KeyType(typeof(uint), 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); + var metaBuilder = new MetadataBuilder(); + metaBuilder.AddKeyValues>(3, TextType.Instance, metaValues1.CopyTo); + + var builder = new MetadataBuilder(); + builder.AddPrimitiveValue("stay", new KeyType(typeof(uint), 0, 3), 2u, metaBuilder.GetMetadata()); // Next the case where those values are ints. var metaValues2 = new VBuffer(3, new int[] { 1, 2, 3, 4 }); - var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2); + metaBuilder = new MetadataBuilder(); + metaBuilder.AddKeyValues(3, NumberType.I4, metaValues2.CopyTo); var value2 = new VBuffer(2, 0, null, null); - var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(typeof(byte), 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2)); + builder.Add>("awhile", new VectorType(new KeyType(typeof(byte), 2, 3), 2), value2.CopyTo, metaBuilder.GetMetadata()); // Then the case where a value of that kind exists, but is of not of the right kind, in which case it should not be identified as containing that metadata. - var metaValues3 = (float)2; - var meta3 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, NumberType.R4, ref metaValues3); - var value3 = (ushort)1; - var col3 = RowColumnUtils.GetColumn("and", new KeyType(typeof(ushort), 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3)); + metaBuilder = new MetadataBuilder(); + metaBuilder.AddPrimitiveValue(MetadataUtils.Kinds.KeyValues, NumberType.R4, 2f); + builder.AddPrimitiveValue("and", new KeyType(typeof(ushort), 0, 2), (ushort)1, metaBuilder.GetMetadata()); // Then a final case where metadata of that kind is actaully simply altogether absent. var value4 = new VBuffer(5, 0, null, null); - var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(typeof(uint), 0, 2)), ref value4); + builder.Add>("listen", new VectorType(new KeyType(typeof(uint), 0, 2)), value4.CopyTo); // Finally compose a trivial data view out of all this. - var row = RowColumnUtils.GetRow(counted, col1, col2, col3, col4); - var view = RowCursorUtils.RowAsDataView(env, row); + var view = RowCursorUtils.RowAsDataView(env, MetadataUtils.MetadataAsRow(builder.GetMetadata())); // Whew! I'm glad that's over with. Let us start running the test in ernest. // First let's do a direct match of the types to ensure that works. diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index 417d9ed388..e5845db090 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; @@ -133,19 +134,13 @@ public void TestOldSavingAndLoading() } } - private sealed class Counted : ICounted - { - public long Position => 0; - public long Batch => 0; - public ValueGetter GetIdGetter() => (ref UInt128 val) => val = default; - } - private void HashTestCore(T val, PrimitiveType type, uint expected, uint expectedOrdered, uint expectedOrdered3) { const int bits = 10; - var col = RowColumnUtils.GetColumn("Foo", type, ref val); - var inRow = RowColumnUtils.GetRow(new Counted(), col); + var builder = new MetadataBuilder(); + builder.AddPrimitiveValue("Foo", type, val); + var inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); // First do an unordered hash. var info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits); @@ -174,8 +169,9 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe // at least in the first position, and in the unordered case, the last position. const int vecLen = 5; var denseVec = new VBuffer(vecLen, Utils.CreateArray(vecLen, val)); - col = RowColumnUtils.GetColumn("Foo", new VectorType(type, vecLen), ref denseVec); - inRow = RowColumnUtils.GetRow(new Counted(), col); + builder = new MetadataBuilder(); + builder.Add("Foo", new VectorType(type, vecLen), (ref VBuffer dst) => denseVec.CopyTo(ref dst)); + inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); @@ -207,8 +203,9 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe // Let's now do a sparse vector. var sparseVec = new VBuffer(10, 3, Utils.CreateArray(3, val), new[] { 0, 3, 7 }); - col = RowColumnUtils.GetColumn("Foo", new VectorType(type, vecLen), ref sparseVec); - inRow = RowColumnUtils.GetRow(new Counted(), col); + builder = new MetadataBuilder(); + builder.Add("Foo", new VectorType(type, vecLen), (ref VBuffer dst) => sparseVec.CopyTo(ref dst)); + inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); From f60a727873ca0c6ffc1a4151d95e9bc218e39fb3 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Sat, 1 Dec 2018 09:30:08 -0800 Subject: [PATCH 3/3] Remove IColumn. --- src/Microsoft.ML.Data/Data/IColumn.cs | 638 -------------------------- 1 file changed, 638 deletions(-) delete mode 100644 src/Microsoft.ML.Data/Data/IColumn.cs diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs deleted file mode 100644 index c204d5157a..0000000000 --- a/src/Microsoft.ML.Data/Data/IColumn.cs +++ /dev/null @@ -1,638 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using Microsoft.ML.Data; -using Microsoft.ML.Runtime.Internal.Utilities; - -namespace Microsoft.ML.Runtime.Data -{ - /// - /// This interface is an analogy to that encapsulates the contents of a single - /// column. - /// - /// Note that in the same sense that is not thread safe, implementors of this interface - /// by similar token must not be considered thread safe by users of the interface, and by the same token - /// implementors should feel free to write their implementations with the expectation that only one thread - /// will be calling it at a time. - /// - /// Similarly, in the same sense that an can have its values "change under it" by having - /// the underlying cursor move, so too might this item have its values change under it, and they will if - /// they were directly instantiated from a row. - /// - /// Generally actual implementors of this interface should not implement this directly, but instead implement - /// . - /// - // REVIEW: It is possible we may want to make this ICounted, but let's not start with - // that assumption. The use cases I have in mind are that we'll still, on the side, have an - // IRow lying around. - public interface IColumn - { - /// - /// The name of a column. This string should always be non-empty. - /// - string Name { get; } - - /// - /// The type of the column. - /// - ColumnType Type { get; } - - // REVIEW: This property anticipates a time when we get away with metadata accessors - // altogether, and just have the metadata for a column be represented as a row. - /// - /// The metadata for a column, or null if this column has no metadata. - /// - IRow Metadata { get; } - - /// - /// Whether the column should be considered active or not. - /// - bool IsActive { get; } - - /// - /// The value getter, as a . Implementators should just pass through - /// . - /// - /// The generic getter delegate - Delegate GetGetter(); - } - - /// - /// The type specific interface for a . - /// - /// The type of values in this column. This should agree with the - /// field of . - public interface IValueColumn : IColumn - { - new ValueGetter GetGetter(); - } - - public static class RowColumnUtils - { - /// - /// Exposes a single column in a row. - /// - /// The row to wrap - /// The column to expose - /// A row column instance - public static IColumn GetColumn(IRow row, int col) - { - Contracts.CheckValue(row, nameof(row)); - Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); - - Func func = GetColumnCore; - return Utils.MarshalInvoke(func, row.Schema.GetColumnType(col).RawType, row, col); - } - - private static IColumn GetColumnCore(IRow row, int col) - { - Contracts.AssertValue(row); - Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); - Contracts.Assert(row.Schema.GetColumnType(col).RawType == typeof(T)); - - return new RowWrap(row, col); - } - - /// - /// Exposes a single column in a schema. The column is considered inactive. - /// - /// The schema to get the data for - /// The column to get - /// A column with false - public static IColumn GetColumn(ISchema schema, int col) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); - - Func func = GetColumnCore; - return Utils.MarshalInvoke(func, schema.GetColumnType(col).RawType, schema, col); - } - - private static IColumn GetColumnCore(ISchema schema, int col) - { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); - - return new SchemaWrap(schema, col); - } - - /// - /// Constructs a column out of a value. This will store the input value, not make a copy. - /// - /// The type of the value - /// The column name, which must be non-empty - /// The type of the column, whose raw type must be - /// The value to store in the column - /// Optionally, metadata for the column - /// A column with this value - public static IColumn GetColumn(string name, ColumnType type, ref T value, IRow meta = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch on object type and column type"); - if (type.IsVector) - return Utils.MarshalInvoke(GetColumnVecCore, type.ItemType.RawType, name, type.AsVector, (object)value, meta); - Contracts.CheckParam(type.IsPrimitive, nameof(type), "Type must be either vector or primitive"); - Contracts.CheckValueOrNull(meta); - return Utils.MarshalInvoke(GetColumnOneCore, type.RawType, name, type, (object)value, meta); - } - - private static IColumn GetColumnVecCore(string name, VectorType type, object value, IRow meta) - { - // REVIEW: Ugh. Nasty. Any alternative to boxing? - Contracts.AssertNonEmpty(name); - Contracts.AssertValue(type); - Contracts.Assert(type.IsVector); - Contracts.Assert(type.ItemType.RawType == typeof(T)); - Contracts.Assert(value is VBuffer); - Contracts.AssertValueOrNull(meta); - VBuffer typedVal = (VBuffer)value; - return new ConstVecImpl(name, meta, type, typedVal); - } - - private static IColumn GetColumnOneCore(string name, ColumnType type, object value, IRow meta) - { - Contracts.AssertNonEmpty(name); - Contracts.AssertValue(type); - Contracts.Assert(type.IsPrimitive); - Contracts.Assert(type.RawType == typeof(T)); - Contracts.Assert(value is T); - Contracts.AssertValueOrNull(meta); - T typedVal = (T)value; - return new ConstOneImpl(name, meta, type, typedVal); - } - - /// - /// Constructs a column out of a getter. - /// - /// The type of the value - /// The column name, which must be non-empty - /// The type of the column, whose raw type must be - /// The getter for the column - /// Optionally, metadata for the column - /// A column with this getter - public static IColumn GetColumn(string name, ColumnType type, ValueGetter getter, IRow meta = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch on object type and column type"); - Contracts.CheckValue(getter, nameof(getter)); - Contracts.CheckValueOrNull(meta); - - return new GetterImpl(name, meta, type, getter); - } - - /// - /// Wraps a set of row columns as a row. - /// - /// The counted object that the output row will wrap for its own implementation of - /// , or if null, the output row will yield default values for those implementations, - /// that is, a totally static row - /// A set of row columns - /// A row with items derived from - public static IRow GetRow(ICounted counted, params IColumn[] columns) - { - Contracts.CheckValueOrNull(counted); - Contracts.CheckValue(columns, nameof(columns)); - return new RowColumnRow(counted, columns); - } - - /// - /// The implementation for a simple wrapping of an . - /// - private sealed class RowWrap : IValueColumn - { - private readonly IRow _row; - private readonly int _col; - private MetadataRow _meta; - - public string Name => _row.Schema.GetColumnName(_col); - public ColumnType Type => _row.Schema.GetColumnType(_col); - public bool IsActive => _row.IsColumnActive(_col); - - public IRow Metadata - { - get - { - if (_meta == null) - Interlocked.CompareExchange(ref _meta, new MetadataRow(_row.Schema, _col, x => true), null); - return _meta; - } - } - - public RowWrap(IRow row, int col) - { - Contracts.AssertValue(row); - Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); - Contracts.Assert(row.Schema.GetColumnType(col).RawType == typeof(T)); - - _row = row; - _col = col; - } - - Delegate IColumn.GetGetter() - => GetGetter(); - - public ValueGetter GetGetter() - => _row.GetGetter(_col); - } - - /// - /// The base class for a few implementations that do not "go" anywhere. - /// - private abstract class DefaultCounted : ICounted - { - public long Position => 0; - public long Batch => 0; - public ValueGetter GetIdGetter() - => IdGetter; - - private static void IdGetter(ref UInt128 id) - => id = default; - } - - /// - /// Simple wrapper for a schema column, considered inctive with no getter. - /// - /// The type of the getter - private sealed class SchemaWrap : IValueColumn - { - private readonly ISchema _schema; - private readonly int _col; - private MetadataRow _meta; - - public string Name => _schema.GetColumnName(_col); - public ColumnType Type => _schema.GetColumnType(_col); - public bool IsActive => false; - - public IRow Metadata - { - get - { - if (_meta == null) - Interlocked.CompareExchange(ref _meta, new MetadataRow(_schema, _col, x => true), null); - return _meta; - } - } - - public SchemaWrap(ISchema schema, int col) - { - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); - - _schema = schema; - _col = col; - } - - Delegate IColumn.GetGetter() - => GetGetter(); - - public ValueGetter GetGetter() - => throw Contracts.Except("Column not active"); - } - - /// - /// This class exists to present metadata as stored in an for one particular - /// column as an . This class will cease to be necessary at the point when all - /// metadata implementations are just simple s. - /// - public sealed class MetadataRow : IRow - { - public Schema Schema => _schemaImpl.AsSchema; - - private readonly ISchema _metaSchema; - private readonly int _col; - private readonly SchemaImpl _schemaImpl; - - private readonly KeyValuePair[] _map; - - long ICounted.Position => 0; - long ICounted.Batch => 0; - ValueGetter ICounted.GetIdGetter() - => IdGetter; - - private static void IdGetter(ref UInt128 id) - => id = default; - - private sealed class SchemaImpl : ISchema - { - private readonly MetadataRow _parent; - private readonly Dictionary _nameToCol; - public Schema AsSchema { get; } - - public int ColumnCount { get { return _parent._map.Length; } } - - public SchemaImpl(MetadataRow parent) - { - Contracts.AssertValue(parent); - _parent = parent; - _nameToCol = new Dictionary(ColumnCount); - for (int i = 0; i < _parent._map.Length; ++i) - _nameToCol[_parent._map[i].Key] = i; - - AsSchema = Schema.Create(this); - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._map[col].Key; - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._map[col].Value; - } - - public bool TryGetColumnIndex(string name, out int col) - { - return _nameToCol.TryGetValue(name, out col); - } - - public IEnumerable> GetMetadataTypes(int col) - { - return Enumerable.Empty>(); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - throw MetadataUtils.ExceptGetMetadata(); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return null; - } - } - - public MetadataRow(ISchema schema, int col, Func takeMetadata) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); - Contracts.CheckValue(takeMetadata, nameof(takeMetadata)); - - _metaSchema = schema; - _col = col; - _map = _metaSchema.GetMetadataTypes(_col).Where(x => takeMetadata(x.Key)).ToArray(); - _schemaImpl = new SchemaImpl(this); - } - - public bool IsColumnActive(int col) - { - Contracts.CheckParam(0 <= col && col < _map.Length, nameof(col)); - return true; - } - - public ValueGetter GetGetter(int col) - { - Contracts.CheckParam(0 <= col && col < _map.Length, nameof(col)); - // REVIEW: On type mismatch, this will throw a metadata exception, which is not really - // appropriate. However, since this meant to be a shim anyway, we will tolerate imperfection. - return (ref TValue dst) => _metaSchema.GetMetadata(_map[col].Key, _col, ref dst); - } - } - - /// - /// This is used for a few implementations that need to store their own name, - /// metadata, and type themselves. - /// - private abstract class SimpleColumnBase : IValueColumn - { - public string Name { get; } - public IRow Metadata { get; } - public ColumnType Type { get; } - public abstract bool IsActive { get; } - - public SimpleColumnBase(string name, IRow meta, ColumnType type) - { - Contracts.CheckNonEmpty(name, nameof(name)); - Contracts.CheckValueOrNull(meta); - Contracts.CheckValue(type, nameof(type)); - Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch between CLR type and column type"); - - Name = name; - Metadata = meta; - Type = type; - } - - Delegate IColumn.GetGetter() - { - return GetGetter(); - } - - public abstract ValueGetter GetGetter(); - } - - private sealed class InactiveImpl : SimpleColumnBase - { - public override bool IsActive { get { return false; } } - - public InactiveImpl(string name, IRow meta, ColumnType type) - : base(name, meta, type) - { - } - - public override ValueGetter GetGetter() - { - throw Contracts.Except("Can't get getter for inactive column"); - } - } - - private sealed class ConstOneImpl : SimpleColumnBase - { - private readonly T _value; - - public override bool IsActive => true; - - public ConstOneImpl(string name, IRow meta, ColumnType type, T value) - : base(name, meta, type) - { - Contracts.Assert(type.IsPrimitive); - _value = value; - } - - public override ValueGetter GetGetter() - { - return Getter; - } - - private void Getter(ref T val) - { - val = _value; - } - } - - private sealed class ConstVecImpl : SimpleColumnBase> - { - private readonly VBuffer _value; - - public override bool IsActive { get { return true; } } - - public ConstVecImpl(string name, IRow meta, ColumnType type, VBuffer value) - : base(name, meta, type) - { - _value = value; - } - - public override ValueGetter> GetGetter() - { - return Getter; - } - - private void Getter(ref VBuffer val) - { - _value.CopyTo(ref val); - } - } - - private sealed class GetterImpl : SimpleColumnBase - { - private readonly ValueGetter _getter; - - public override bool IsActive => _getter != null; - - public GetterImpl(string name, IRow meta, ColumnType type, ValueGetter getter) - : base(name, meta, type) - { - Contracts.CheckValueOrNull(getter); - _getter = getter; - } - - public override ValueGetter GetGetter() - { - Contracts.Check(IsActive, "column is not active"); - return _getter; - } - } - - /// - /// An that is an amalgation of multiple implementers. - /// - private sealed class RowColumnRow : IRow - { - private static readonly DefaultCountedImpl _defCount = new DefaultCountedImpl(); - private readonly ICounted _counted; - private readonly IColumn[] _columns; - private readonly SchemaImpl _schema; - - public Schema Schema => _schema.AsSchema; - public long Position => _counted.Position; - public long Batch => _counted.Batch; - - public RowColumnRow(ICounted counted, IColumn[] columns) - { - Contracts.AssertValueOrNull(counted); - Contracts.AssertValue(columns); - _counted = counted ?? _defCount; - _columns = columns; - _schema = new SchemaImpl(this); - } - - public ValueGetter GetGetter(int col) - { - Contracts.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); - var rowCol = _columns[col] as IValueColumn; - if (rowCol == null) - throw Contracts.Except("Invalid TValue: '{0}'", typeof(TValue)); - return rowCol.GetGetter(); - } - - public bool IsColumnActive(int col) - { - Contracts.CheckParam(0 <= col && col < _columns.Length, nameof(col)); - return _columns[col].IsActive; - } - - public ValueGetter GetIdGetter() - { - return _counted.GetIdGetter(); - } - - private sealed class SchemaImpl : ISchema - { - private readonly RowColumnRow _parent; - private readonly Dictionary _nameToIndex; - - public Schema AsSchema { get; } - - public int ColumnCount => _parent._columns.Length; - - public SchemaImpl(RowColumnRow parent) - { - Contracts.AssertValue(parent); - _parent = parent; - _nameToIndex = new Dictionary(); - for (int i = 0; i < _parent._columns.Length; ++i) - _nameToIndex[_parent._columns[i].Name] = i; - AsSchema = Schema.Create(this); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - var meta = _parent._columns[col].Metadata; - int mcol; - if (meta == null || !meta.Schema.TryGetColumnIndex(kind, out mcol)) - throw MetadataUtils.ExceptGetMetadata(); - // REVIEW: Again, since this is a shim, not going to sweat the potential for inappropriate exception message. - meta.GetGetter(mcol)(ref value); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - var meta = _parent._columns[col].Metadata; - int mcol; - if (meta == null || !meta.Schema.TryGetColumnIndex(kind, out mcol)) - return null; - return meta.Schema.GetColumnType(mcol); - } - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - // REVIEW: An IRow can have collisions in names, whereas there is no notion of this in metadata types. - // Since I intend to remove this soon anyway and the number of usages of this will be very low, I am just going - // to tolerate the potential for strangeness here, since it will practically never arise until we reorganize - // the whole thing. - var meta = _parent._columns[col].Metadata; - if (meta == null) - yield break; - var schema = meta.Schema; - for (int i = 0; i < schema.ColumnCount; ++i) - yield return new KeyValuePair(schema.GetColumnName(i), schema.GetColumnType(i)); - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._columns[col].Type; - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _parent._columns[col].Name; - } - - public bool TryGetColumnIndex(string name, out int col) - { - return _nameToIndex.TryGetValue(name, out col); - } - } - - private sealed class DefaultCountedImpl : DefaultCounted - { - } - } - } -}