diff --git a/src/Microsoft.ML.Api/CustomMappingTransformer.cs b/src/Microsoft.ML.Api/CustomMappingTransformer.cs index 3e4102aad3..963996d099 100644 --- a/src/Microsoft.ML.Api/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Api/CustomMappingTransformer.cs @@ -111,7 +111,7 @@ public Mapper(CustomMappingTransformer parent, Schema inputSchema) _typedSrc = TypedCursorable.Create(_host, emptyDataView, false, _parent.InputSchemaDefinition); } - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { disposer = null; // If no outputs are active, we short-circuit to empty array of getters. @@ -147,7 +147,7 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac return result; } - private Delegate GetDstGetter(IRow input, int colIndex, Action refreshAction) + private Delegate GetDstGetter(Row input, int colIndex, Action refreshAction) { var getter = input.GetGetter(colIndex); ValueGetter combinedGetter = (ref T dst) => diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 1d05d67f3b..d68882f805 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -125,22 +125,20 @@ protected override TRow GetCurrentRowObject() } /// - /// A row that consumes items of type , and provides an . This + /// A row that consumes items of type , and provides an . This /// is in contrast to which consumes a data view row and publishes them as the output type. /// /// The input data type. - public abstract class InputRowBase : IRow + public abstract class InputRowBase : Row where TRow : class { private readonly int _colCount; private readonly Delegate[] _getters; protected readonly IHost Host; - public long Batch => 0; + public override long Batch => 0; - public Schema Schema { get; } - - public abstract long Position { get; } + public override Schema Schema { get; } public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func predicate) { @@ -332,7 +330,7 @@ private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colT protected abstract TRow GetCurrentRowObject(); - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { CheckColumnInRange(col); return _getters[col] != null; @@ -344,7 +342,7 @@ private void CheckColumnInRange(int columnIndex) throw Host.Except("Column index must be between 0 and {0}", _colCount); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { if (!IsColumnActive(col)) throw Host.Except("Column {0} is not active in the cursor", col); @@ -355,8 +353,6 @@ public ValueGetter GetGetter(int col) throw Host.Except("Invalid TValue in GetGetter for column #{0}: '{1}'", col, typeof(TValue)); return fn; } - - public abstract ValueGetter GetIdGetter(); } /// @@ -400,16 +396,37 @@ protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefiniti public abstract long? GetRowCount(); - public abstract IRowCursor GetRowCursor(Func predicate, Random rand = null); + public abstract RowCursor GetRowCursor(Func predicate, Random rand = null); - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { consolidator = null; return new[] { GetRowCursor(predicate, rand) }; } - public abstract class DataViewCursorBase : InputRowBase, IRowCursor + public sealed class WrappedCursor : RowCursor + { + private readonly DataViewCursorBase _toWrap; + + public WrappedCursor(DataViewCursorBase toWrap) => _toWrap = toWrap; + + public override CursorState State => _toWrap.State; + public override long Position => _toWrap.Position; + public override long Batch => _toWrap.Batch; + public override Schema Schema => _toWrap.Schema; + + public override void Dispose() => _toWrap.Dispose(); + public override ValueGetter GetGetter(int col) + => _toWrap.GetGetter(col); + public override ValueGetter GetIdGetter() => _toWrap.GetIdGetter(); + public override RowCursor GetRootCursor() => this; + public override bool IsColumnActive(int col) => _toWrap.IsColumnActive(col); + public override bool MoveMany(long count) => _toWrap.MoveMany(count); + public override bool MoveNext() => _toWrap.MoveNext(); + } + + public abstract class DataViewCursorBase : InputRowBase { // There is no real concept of multiple inheritance and for various reasons it was better to // descend from the row class as opposed to wrapping it, so much of this class is regrettably @@ -524,14 +541,6 @@ protected virtual bool MoveManyCore(long count) /// . /// protected abstract bool MoveNextCore(); - - /// - /// Returns a cursor that can be used for invoking , , - /// , and , with results identical to calling - /// those on this cursor. Generally, if the root cursor is not the same as this cursor, using - /// the root cursor will be faster. - /// - public ICursor GetRootCursor() => this; } } @@ -561,10 +570,10 @@ public override bool CanShuffle return _data.Count; } - public override IRowCursor GetRowCursor(Func predicate, Random rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); - return new Cursor(Host, "ListDataView", this, predicate, rand); + return new WrappedCursor(new Cursor(Host, "ListDataView", this, predicate, rand)); } private sealed class Cursor : DataViewCursorBase @@ -660,9 +669,9 @@ public override bool CanShuffle return (_data as ICollection)?.Count; } - public override IRowCursor GetRowCursor(Func predicate, Random rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { - return new Cursor(Host, this, predicate); + return new WrappedCursor (new Cursor(Host, this, predicate)); } /// @@ -677,7 +686,7 @@ public void SetData(IEnumerable data) _data = data; } - private class Cursor : DataViewCursorBase + private sealed class Cursor : DataViewCursorBase { private readonly IEnumerator _enumerator; private TRow _currentRow; @@ -731,15 +740,9 @@ public SingleRowLoopDataView(IHostEnvironment env, InternalSchemaDefinition sche { } - public override bool CanShuffle - { - get { return false; } - } + public override bool CanShuffle => false; - public override long? GetRowCount() - { - return null; - } + public override long? GetRowCount() => null; public void SetCurrentRowObject(TRow value) { @@ -747,10 +750,10 @@ public void SetCurrentRowObject(TRow value) _current = value; } - public override IRowCursor GetRowCursor(Func predicate, Random rand = null) + public override RowCursor GetRowCursor(Func predicate, Random rand = null) { Contracts.Assert(_current != null, "The current object must be set prior to cursoring"); - return new Cursor(Host, this, predicate); + return new WrappedCursor (new Cursor(Host, this, predicate)); } private sealed class Cursor : DataViewCursorBase @@ -773,10 +776,7 @@ public override ValueGetter GetIdGetter() }; } - protected override TRow GetCurrentRowObject() - { - return _currentRow; - } + protected override TRow GetCurrentRowObject() => _currentRow; protected override bool MoveNextCore() { diff --git a/src/Microsoft.ML.Api/StatefulFilterTransform.cs b/src/Microsoft.ML.Api/StatefulFilterTransform.cs index a8341639ff..f07ef7daad 100644 --- a/src/Microsoft.ML.Api/StatefulFilterTransform.cs +++ b/src/Microsoft.ML.Api/StatefulFilterTransform.cs @@ -108,7 +108,7 @@ private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -120,7 +120,7 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) return new Cursor(this, input, predicate); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Contracts.CheckValue(predicate, nameof(predicate)); Contracts.CheckParam(n >= 0, nameof(n)); @@ -144,13 +144,13 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) return new StatefulFilterTransform(env, this, newSource); } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly StatefulFilterTransform _parent; - private readonly IRowCursor _input; + private readonly RowCursor _input; // This is used to serve getters for the columns we produce. - private readonly IRow _appendedRow; + private readonly Row _appendedRow; private readonly TSrc _src; private readonly TDst _dst; @@ -163,7 +163,7 @@ public override long Batch get { return _input.Batch; } } - public Cursor(StatefulFilterTransform parent, IRowCursor input, Func predicate) + public Cursor(StatefulFilterTransform parent, RowCursor input, Func predicate) : base(parent.Host) { Ch.AssertValue(input); @@ -240,9 +240,9 @@ private void RunLambda(out bool isRowAccepted) isRowAccepted = _parent._filterFunc(_src, _dst, _state); } - public Schema Schema => _parent._bindings.Schema; + public override Schema Schema => _parent._bindings.Schema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Contracts.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); bool isSrc; @@ -252,7 +252,7 @@ public bool IsColumnActive(int col) return _appendedRow.IsColumnActive(iCol); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); bool isSrc; diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index c13d277305..4607697662 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -13,11 +13,12 @@ namespace Microsoft.ML.Runtime.Api { /// - /// This interface is an with 'strongly typed' binding. + /// This interface is an with 'strongly typed' binding. /// It can populate the user-supplied object's fields with the values of the current row. /// /// The user-defined type that is being populated while cursoring. - public interface IRowReadableAs : IRow + [BestFriend] + internal interface IRowReadableAs where TRow : class { /// @@ -28,11 +29,11 @@ public interface IRowReadableAs : IRow } /// - /// This interface is an with 'strongly typed' binding. + /// This interface is an with 'strongly typed' binding. /// It can accept values of type and present the value as a row. /// /// The user-defined type that provides the values while cursoring. - public interface IRowBackedBy : IRow + internal interface IRowBackedBy where TRow : class { /// @@ -48,9 +49,10 @@ public interface IRowBackedBy : IRow /// It can populate the user-supplied object's fields with the values of the current row. /// /// The user-defined type that is being populated while cursoring. - public interface IRowCursor : IRowReadableAs, ICursor + public abstract class RowCursor : RowCursor, IRowReadableAs where TRow : class { + public abstract void FillValues(TRow row); } /// @@ -63,13 +65,13 @@ public interface ICursorable /// /// Get a new cursor. /// - IRowCursor GetCursor(); + RowCursor GetCursor(); /// /// Get a new randomized cursor. /// /// The random seed to use. - IRowCursor GetRandomizedCursor(int randomSeed); + RowCursor GetRandomizedCursor(int randomSeed); } /// @@ -163,7 +165,7 @@ private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) /// /// Create and return a new cursor. /// - public IRowCursor GetCursor() + public RowCursor GetCursor() { return GetCursor(x => false); } @@ -172,14 +174,14 @@ public IRowCursor GetCursor() /// Create and return a new randomized cursor. /// /// The random seed to use. - public IRowCursor GetRandomizedCursor(int randomSeed) + public RowCursor GetRandomizedCursor(int randomSeed) { return GetCursor(x => false, randomSeed); } - public IRowReadableAs GetRow(IRow input) + public IRowReadableAs GetRow(Row input) { - return new TypedRow(this, input); + return new RowImplementation(new TypedRow(this, input)); } /// @@ -188,14 +190,14 @@ public IRowReadableAs GetRow(IRow input) /// Predicate that denotes which additional columns to include in the cursor, /// in addition to the columns that are needed for populating the object. /// The random seed to use. If null, the cursor will be non-randomized. - public IRowCursor GetCursor(Func additionalColumnsPredicate, int? randomSeed = null) + public RowCursor GetCursor(Func additionalColumnsPredicate, int? randomSeed = null) { _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate)); Random rand = randomSeed.HasValue ? RandomUtils.Create(randomSeed.Value) : null; var cursor = _data.GetRowCursor(GetDependencies(additionalColumnsPredicate), rand); - return new TypedCursor(this, cursor); + return new RowCursorImplementation(new TypedCursor(this, cursor)); } public Func GetDependencies(Func additionalColumnsPredicate) @@ -211,7 +213,7 @@ public Func GetDependencies(Func additionalColumnsPredicat /// in addition to the columns that are needed for populating the object. /// Number of cursors to create /// Random generator to use - public IRowCursor[] GetCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetCursorSet(out IRowCursorConsolidator consolidator, Func additionalColumnsPredicate, int n, Random rand) { _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate)); @@ -226,8 +228,8 @@ public IRowCursor[] GetCursorSet(out IRowCursorConsolidator consolidator, _host.AssertNonEmpty(inputs); return inputs - .Select(rc => (IRowCursor)(new TypedCursor(this, rc))) - .ToArray(); + .Select(rc => (RowCursor)(new RowCursorImplementation(new TypedCursor(this, rc)))) + .ToArray(); } /// @@ -251,10 +253,10 @@ public static TypedCursorable Create(IHostEnvironment env, IDataView data, return new TypedCursorable(env, data, ignoreMissingColumns, outSchema); } - private abstract class TypedRowBase : IRowReadableAs + private abstract class TypedRowBase { protected readonly IChannel Ch; - private readonly IRow _input; + private readonly Row _input; private readonly Action[] _setters; public long Batch => _input.Batch; @@ -263,7 +265,7 @@ private abstract class TypedRowBase : IRowReadableAs public Schema Schema => _input.Schema; - public TypedRowBase(TypedCursorable parent, IRow input, string channelMessage) + public TypedRowBase(TypedCursorable parent, Row input, string channelMessage) { Contracts.AssertValue(parent); Contracts.AssertValue(parent._host); @@ -280,17 +282,14 @@ public TypedRowBase(TypedCursorable parent, IRow input, string channelMess _setters[i] = GenerateSetter(_input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]); } - public ValueGetter GetIdGetter() - { - return _input.GetIdGetter(); - } + public ValueGetter GetIdGetter() => _input.GetIdGetter(); - private Action GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) + private Action GenerateSetter(Row input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) { var colType = input.Schema.GetColumnType(index); var fieldType = column.OutputType; var genericType = fieldType; - Func> del; + Func> del; if (fieldType.IsArray) { Ch.Assert(colType.IsVector); @@ -349,7 +348,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower // than the 'direct' getter. We don't have good indication of this to the user, and the selection // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). - private Action CreateConvertingVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) + private Action CreateConvertingVBufferSetter(Row input, int col, Delegate poke, Delegate peek, Func convert) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -371,7 +370,7 @@ private Action CreateConvertingVBufferSetter(IRow input, int c }; } - private Action CreateDirectVBufferSetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateDirectVBufferSetter(Row input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -409,7 +408,7 @@ private Action CreateDirectVBufferSetter(IRow input, int col, Delega }; } - private static Action CreateConvertingActionSetter(IRow input, int col, Delegate poke, Func convert) + private static Action CreateConvertingActionSetter(Row input, int col, Delegate poke, Func convert) { var getter = input.GetGetter(col); var typedPoke = poke as Poke; @@ -423,7 +422,7 @@ private static Action CreateConvertingActionSetter(IRow input, }; } - private static Action CreateDirectSetter(IRow input, int col, Delegate poke, Delegate peek) + private static Action CreateDirectSetter(Row input, int col, Delegate poke, Delegate peek) { // Awkward to have a parameter that's always null, but slightly more convenient for generalizing the setter. Contracts.Assert(peek == null); @@ -438,7 +437,7 @@ private static Action CreateDirectSetter(IRow input, int col, Delega }; } - private Action CreateVBufferToVBufferSetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateVBufferToVBufferSetter(Row input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke>; @@ -473,18 +472,55 @@ public ValueGetter GetGetter(int col) private sealed class TypedRow : TypedRowBase { - public TypedRow(TypedCursorable parent, IRow input) + public TypedRow(TypedCursorable parent, Row input) : base(parent, input, "Row") { } } - private sealed class TypedCursor : TypedRowBase, IRowCursor + private sealed class RowImplementation : IRowReadableAs + { + private readonly TypedRow _row; + + public RowImplementation(TypedRow row) => _row = row; + + public long Position => _row.Position; + public long Batch => _row.Batch; + public Schema Schema => _row.Schema; + public void FillValues(TRow row) => _row.FillValues(row); + public ValueGetter GetGetter(int col) => _row.GetGetter(col); + public ValueGetter GetIdGetter() => _row.GetIdGetter(); + public bool IsColumnActive(int col) => _row.IsColumnActive(col); + } + + private sealed class RowCursorImplementation : RowCursor { - private readonly IRowCursor _input; + private readonly TypedCursor _cursor; + + public RowCursorImplementation(TypedCursor cursor) => _cursor = cursor; + + public override CursorState State => _cursor.State; + public override long Position => _cursor.Position; + public override long Batch => _cursor.Batch; + public override Schema Schema => _cursor.Schema; + + public override void Dispose() { } + + public override void FillValues(TRow row) => _cursor.FillValues(row); + public override ValueGetter GetGetter(int col) => _cursor.GetGetter(col); + public override ValueGetter GetIdGetter() => _cursor.GetIdGetter(); + public override RowCursor GetRootCursor() => _cursor.GetRootCursor(); + public override bool IsColumnActive(int col) => _cursor.IsColumnActive(col); + public override bool MoveMany(long count) => _cursor.MoveMany(count); + public override bool MoveNext() => _cursor.MoveNext(); + } + + private sealed class TypedCursor : TypedRowBase + { + private readonly RowCursor _input; private bool _disposed; - public TypedCursor(TypedCursorable parent, IRowCursor input) + public TypedCursor(TypedCursorable parent, RowCursor input) : base(parent, input, "Cursor") { _input = input; @@ -496,7 +532,7 @@ public override void FillValues(TRow row) base.FillValues(row); } - public CursorState State { get { return _input.State; } } + public CursorState State => _input.State; public void Dispose() { @@ -508,20 +544,9 @@ public void Dispose() } } - public bool MoveNext() - { - return _input.MoveNext(); - } - - public bool MoveMany(long count) - { - return _input.MoveMany(count); - } - - public ICursor GetRootCursor() - { - return _input.GetRootCursor(); - } + public bool MoveNext() => _input.MoveNext(); + public bool MoveMany(long count) => _input.MoveMany(count); + public RowCursor GetRootCursor() => _input.GetRootCursor(); } } diff --git a/src/Microsoft.ML.Core/Data/ICursor.cs b/src/Microsoft.ML.Core/Data/ICursor.cs deleted file mode 100644 index 476f07245c..0000000000 --- a/src/Microsoft.ML.Core/Data/ICursor.cs +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Float = System.Single; - -using System; - -namespace Microsoft.ML.Runtime.Data -{ - /// - /// This is a base interface for an and . It contains only the - /// positional properties, no behavioral methods, and no data. - /// - public interface ICounted - { - /// - /// This is incremented for ICursor when the underlying contents changes, giving clients a way to detect change. - /// Generally it's -1 when the object is in an invalid state. In particular, for an , this is -1 - /// when the is or . - /// - /// Note that this position is not position within the underlying data, but position of this cursor only. - /// If one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's - /// first valid entry would always have position 0. - /// - long Position { get; } - - /// - /// This provides a means for reconciling multiple streams of counted things. Generally, in each stream, - /// batch numbers should be non-decreasing. Furthermore, any given batch number should only appear in one - /// of the streams. Order is determined by batch number. The reconciler ensures that each stream (that is - /// still active) has at least one item available, then takes the item with the smallest batch number. - /// - /// Note that there is no suggestion that the batches for a particular entry will be consistent from - /// cursoring to cursoring, except for the consistency in resulting in the same overall ordering. The same - /// entry could have different batch numbers from one cursoring to another. There is also no requirement - /// that any given batch number must appear, at all. - /// - long Batch { get; } - - /// - /// A getter for a 128-bit ID value. It is common for objects to serve multiple - /// instances to iterate over what is supposed to be the same data, for example, in a - /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor - /// will produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID - /// exists for applications that need to reconcile which entry is actually which. Ideally this ID should - /// be unique, but for practical reasons, it suffices if collisions are simply extremely improbable. - /// - /// Note that this ID, while it must be consistent for multiple streams according to the semantics - /// above, is not considered part of the data per se. So, to take the example of a data view specifically, - /// a single data view must render consistent IDs across all cursorings, but there is no suggestion at - /// all that if the "same" data were presented in a different data view (as by, say, being transformed, - /// cached, saved, or whatever), that the IDs between the two different data views would have any - /// discernable relationship. - ValueGetter GetIdGetter(); - } - - /// - /// Defines the possible states of a cursor. - /// - public enum CursorState - { - NotStarted, - Good, - Done - } - - /// - /// The basic cursor interface. is incremented by - /// and . When the cursor state is or - /// , is -1. Otherwise, - /// >= 0. - /// - public interface ICursor : ICounted, IDisposable - { - /// - /// Returns the state of the cursor. Before the first call to or - /// this should be . After - /// any call those move functions that returns true, this should return - /// , - /// - CursorState State { get; } - - /// - /// Advance to the next row. When the cursor is first created, this method should be called to - /// move to the first row. Returns false if there are no more rows. - /// - bool MoveNext(); - - /// - /// Logically equivalent to calling the given number of times. The - /// parameter must be positive. Note that cursor implementations may be - /// able to optimize this. - /// - bool MoveMany(long count); - - /// - /// Returns a cursor that can be used for invoking , , - /// , and , with results identical to calling those - /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the - /// root cursor will be faster. As an aside, note that this is not necessarily the case of - /// values from . - /// - ICursor GetRootCursor(); - } -} \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs index 7d5e0c1790..c67bba60aa 100644 --- a/src/Microsoft.ML.Core/Data/IDataView.cs +++ b/src/Microsoft.ML.Core/Data/IDataView.cs @@ -88,7 +88,7 @@ public interface IDataView /// a getter for an inactive columns will throw. The predicate must be /// non-null. To activate all columns, pass "col => true". /// - IRowCursor GetRowCursor(Func needCol, Random rand = null); + RowCursor GetRowCursor(Func needCol, Random rand = null); /// /// This constructs a set of parallel batch cursors. The value n is a recommended limit @@ -102,7 +102,7 @@ public interface IDataView /// should return the "same" row as would have been returned through the regular serial cursor, /// but all rows should be returned by exactly one of the cursors returned from this cursor. /// The cursors can have their values reconciled downstream through the use of the - /// property. + /// property. /// /// This is an object that can be used to reconcile the /// returned array of cursors. When the array of cursors is of length 1, it is legal, @@ -111,7 +111,7 @@ public interface IDataView /// The suggested degree of parallelism. /// An instance /// - IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null); /// @@ -129,7 +129,7 @@ public interface IRowCursorConsolidator /// /// Create a consolidated cursor from the given parallel cursor set. /// - IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs); + RowCursor CreateCursor(IChannelProvider provider, RowCursor[] inputs); } /// @@ -139,36 +139,119 @@ public interface IRowCursorConsolidator public delegate void ValueGetter(ref TValue value); /// - /// A logical row. May be a row of an IDataView or a stand-alone row. If/when its contents - /// change, its ICounted.Counter value is incremented. + /// A logical row. May be a row of an or a stand-alone row. If/when its contents + /// change, its value is changed. /// - public interface IRow : ICounted + public abstract class Row { + /// + /// This is incremented when the underlying contents changes, giving clients a way to detect change. + /// Generally it's -1 when the object is in an invalid state. In particular, for an , this is -1 + /// when the is or . + /// + /// Note that this position is not position within the underlying data, but position of this cursor only. + /// If one, for example, opened a set of parallel streaming cursors, or a shuffled cursor, each such cursor's + /// first valid entry would always have position 0. + /// + public abstract long Position { get; } + + /// + /// This provides a means for reconciling multiple streams of counted things. Generally, in each stream, + /// batch numbers should be non-decreasing. Furthermore, any given batch number should only appear in one + /// of the streams. Order is determined by batch number. The reconciler ensures that each stream (that is + /// still active) has at least one item available, then takes the item with the smallest batch number. + /// + /// Note that there is no suggestion that the batches for a particular entry will be consistent from + /// cursoring to cursoring, except for the consistency in resulting in the same overall ordering. The same + /// entry could have different batch numbers from one cursoring to another. There is also no requirement + /// that any given batch number must appear, at all. + /// + public abstract long Batch { get; } + + /// + /// A getter for a 128-bit ID value. It is common for objects to serve multiple + /// instances to iterate over what is supposed to be the same data, for example, in a + /// a cursor set will produce the same data as a serial cursor, just partitioned, and a shuffled cursor + /// will produce the same data as a serial cursor or any other shuffled cursor, only shuffled. The ID + /// exists for applications that need to reconcile which entry is actually which. Ideally this ID should + /// be unique, but for practical reasons, it suffices if collisions are simply extremely improbable. + /// + /// Note that this ID, while it must be consistent for multiple streams according to the semantics + /// above, is not considered part of the data per se. So, to take the example of a data view specifically, + /// a single data view must render consistent IDs across all cursorings, but there is no suggestion at + /// all that if the "same" data were presented in a different data view (as by, say, being transformed, + /// cached, saved, or whatever), that the IDs between the two different data views would have any + /// discernable relationship. + public abstract ValueGetter GetIdGetter(); + /// /// Returns whether the given column is active in this row. /// - bool IsColumnActive(int col); + public abstract bool IsColumnActive(int col); /// /// Returns a value getter delegate to fetch the given column value from the row. /// This throws if the column is not active in this row, or if the type /// differs from this column's type. /// - ValueGetter GetGetter(int col); + public abstract ValueGetter GetGetter(int col); /// /// Gets a , which provides name and type information for variables /// (i.e., columns in ML.NET's type system) stored in this row. /// - Schema Schema { get; } + public abstract Schema Schema { get; } + + } + /// + /// Defines the possible states of a cursor. + /// + public enum CursorState + { + NotStarted, + Good, + Done } /// - /// A cursor through rows of an . Note that this includes/is an - /// , as well as an . + /// The basic cursor base class to cursor through rows of an . Note that + /// this is also an . The is incremented by + /// and . When the cursor state is or + /// , is -1. Otherwise, + /// >= 0. /// - public interface IRowCursor : ICursor, IRow + public abstract class RowCursor : Row, IDisposable { + /// + /// Returns the state of the cursor. Before the first call to or + /// this should be . After + /// any call those move functions that returns , this should return + /// , + /// + public abstract CursorState State { get; } + + /// + /// Advance to the next row. When the cursor is first created, this method should be called to + /// move to the first row. Returns false if there are no more rows. + /// + public abstract bool MoveNext(); + + /// + /// Logically equivalent to calling the given number of times. The + /// parameter must be positive. Note that cursor implementations may be + /// able to optimize this. + /// + public abstract bool MoveMany(long count); + + /// + /// Returns a cursor that can be used for invoking , , + /// , and , with results identical to calling those + /// on this cursor. Generally, if the root cursor is not the same as this cursor, using the + /// root cursor will be faster. As an aside, note that this is not necessarily the case of + /// values from . + /// + public abstract RowCursor GetRootCursor(); + public abstract void Dispose(); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs index c4b3147d2c..f0e3f45adb 100644 --- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs +++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs @@ -66,7 +66,7 @@ public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper } /// - /// This interface maps an input to an output . Typically, the output contains + /// This interface maps an input to an output . Typically, the output contains /// both the input columns and new columns added by the implementing class, although some implementations may /// return a subset of the input columns. /// This interface is similar to , except it does not have any input role mappings, @@ -93,26 +93,26 @@ public interface IRowToRowMapper Func GetDependencies(Func predicate); /// - /// Get an with the indicated active columns, based on the input . + /// Get an with the indicated active columns, based on the input . /// The active columns are those for which returns true. Getting values on inactive /// columns of the returned row will throw. Null predicates are disallowed. /// - /// The of should be the same object as + /// The of should be the same object as /// . Implementors of this method should throw if that is not the case. Conversely, /// the returned value must have the same schema as . /// - /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the - /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. + /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the + /// getters of the input row and base the output values on the current values of the input . + /// The output values are re-computed when requested through the getters. /// /// The optional should be invoked by any user of this row mapping, once it no - /// longer needs the . If no action is needed when the cursor is Disposed, the implementation + /// longer needs the . If no action is needed when the cursor is Disposed, the implementation /// should set to null, otherwise it should be set to a delegate to be /// invoked by the code calling this object. (For example, a wrapping cursor's /// method. It's best for this action to be idempotent - calling it multiple times should be equivalent to /// calling it once. /// - IRow GetRow(IRow input, Func active, out Action disposer); + Row GetRow(Row input, Func active, out Action disposer); } } diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs index ecec0b1a0d..1f95d941cb 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs @@ -6,31 +6,30 @@ namespace Microsoft.ML.Runtime.Data { /// /// Base class for a cursor has an input cursor, but still needs to do work on - /// MoveNext/MoveMany. + /// / . /// [BestFriend] - internal abstract class LinkedRootCursorBase : RootCursorBase - where TInput : class, ICursor + internal abstract class LinkedRootCursorBase : RootCursorBase { - private readonly ICursor _root; /// Gets the input cursor. - protected TInput Input { get; } + protected RowCursor Input { get; } /// /// Returns the root cursor of the input. It should be used to perform MoveNext or MoveMany operations. - /// Note that GetRootCursor() returns "this", NOT Root. Root is used to advance our input, not for - /// clients of this cursor. That's why it is protected, not public. + /// Note that returns , not . + /// is used to advance our input, not for clients of this cursor. That is why it is + /// protected, not public. /// - protected ICursor Root { get { return _root; } } + protected RowCursor Root { get; } - protected LinkedRootCursorBase(IChannelProvider provider, TInput input) + protected LinkedRootCursorBase(IChannelProvider provider, RowCursor input) : base(provider) { Ch.AssertValue(input, nameof(input)); Input = input; - _root = Input.GetRootCursor(); + Root = Input.GetRootCursor(); } public override void Dispose() diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs index 0ed4dd19f9..78b4fd142b 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs @@ -14,7 +14,7 @@ internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase { public override long Batch => Input.Batch; - protected LinkedRowFilterCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active) + protected LinkedRowFilterCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active) : base(provider, input, schema, active) { } diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs index 188522ad3d..fb045ec6e4 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs @@ -7,20 +7,20 @@ namespace Microsoft.ML.Runtime.Data { /// - /// A base class for a that has an input cursor, but still needs - /// to do work on /. Note + /// A base class for a that has an input cursor, but still needs + /// to do work on /. Note /// that the default assumes /// that each input column is exposed as an output column with the same column index. /// [BestFriend] - internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase, IRowCursor + internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase { private readonly bool[] _active; /// Gets row's schema. - public Schema Schema { get; } + public sealed override Schema Schema { get; } - protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active) + protected LinkedRowRootCursorBase(IChannelProvider provider, RowCursor input, Schema schema, bool[] active) : base(provider, input) { Ch.CheckValue(schema, nameof(schema)); @@ -29,13 +29,13 @@ protected LinkedRowRootCursorBase(IChannelProvider provider, IRowCursor input, S Schema = schema; } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.Count); return _active == null || _active[col]; } - public virtual ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { return Input.GetGetter(col); } diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 6e2bd5230f..d5cbe6928b 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -495,7 +495,7 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, return cols; } - private sealed class MetadataRow : IRow + private sealed class MetadataRow : Row { private readonly Schema.Metadata _metadata; @@ -505,20 +505,20 @@ public MetadataRow(Schema.Metadata metadata) _metadata = metadata; } - public Schema Schema => _metadata.Schema; - public long Position => 0; - public long Batch => 0; - public ValueGetter GetGetter(int col) => _metadata.GetGetter(col); - public ValueGetter GetIdGetter() => (ref UInt128 dst) => dst = default; - public bool IsColumnActive(int col) => true; + public override Schema Schema => _metadata.Schema; + public override long Position => 0; + public override long Batch => 0; + public override ValueGetter GetGetter(int col) => _metadata.GetGetter(col); + public override ValueGetter GetIdGetter() => (ref UInt128 dst) => dst = default; + public override bool IsColumnActive(int col) => true; } /// - /// Presents a as a an . + /// Presents a as a an . /// /// The metadata to wrap. /// A row that wraps an input metadata. - public static IRow MetadataAsRow(Schema.Metadata metadata) + public static Row MetadataAsRow(Schema.Metadata metadata) { Contracts.CheckValue(metadata, nameof(metadata)); return new MetadataRow(metadata); diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs index 5b64a40d6a..75d58d98bb 100644 --- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Data; namespace Microsoft.ML.Runtime.Data { @@ -15,23 +16,21 @@ namespace Microsoft.ML.Runtime.Data /// This cursor base class returns "this" from . That is, all /// / calls will be seen by this cursor. For a cursor /// that has an input cursor and does NOT need notification on /, - /// use . + /// use . /// [BestFriend] - internal abstract class RootCursorBase : ICursor + internal abstract class RootCursorBase : RowCursor { protected readonly IChannel Ch; + private CursorState _state; + private long _position; /// /// Zero-based position of the cursor. /// - public long Position { get; private set; } + public sealed override long Position => _position; - public abstract long Batch { get; } - - public abstract ValueGetter GetIdGetter(); - - public CursorState State { get; private set; } + public sealed override CursorState State => _state; /// /// Convenience property for checking whether the current state of the cursor is . @@ -39,7 +38,7 @@ internal abstract class RootCursorBase : ICursor protected bool IsGood => State == CursorState.Good; /// - /// Creates an instance of the RootCursorBase class + /// Creates an instance of the class /// /// Channel provider protected RootCursorBase(IChannelProvider provider) @@ -47,21 +46,21 @@ protected RootCursorBase(IChannelProvider provider) Contracts.CheckValue(provider, nameof(provider)); Ch = provider.Start("Cursor"); - Position = -1; - State = CursorState.NotStarted; + _position = -1; + _state = CursorState.NotStarted; } - public virtual void Dispose() + public override void Dispose() { if (State != CursorState.Done) { Ch.Dispose(); - Position = -1; - State = CursorState.Done; + _position = -1; + _state = CursorState.Done; } } - public bool MoveNext() + public sealed override bool MoveNext() { if (State == CursorState.Done) return false; @@ -71,8 +70,8 @@ public bool MoveNext() { Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good); - Position++; - State = CursorState.Good; + _position++; + _state = CursorState.Good; return true; } @@ -80,7 +79,7 @@ public bool MoveNext() return false; } - public bool MoveMany(long count) + public sealed override bool MoveMany(long count) { // Note: If we decide to allow count == 0, then we need to special case // that MoveNext() has never been called. It's not entirely clear what the return @@ -95,8 +94,8 @@ public bool MoveMany(long count) { Ch.Assert(State == CursorState.NotStarted || State == CursorState.Good); - Position += count; - State = CursorState.Good; + _position += count; + _state = CursorState.Good; return true; } @@ -137,6 +136,6 @@ protected virtual bool MoveManyCore(long count) /// those on this cursor. Generally, if the root cursor is not the same as this cursor, using /// the root cursor will be faster. /// - public ICursor GetRootCursor() => this; + public override RowCursor GetRootCursor() => this; } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/Schema.cs b/src/Microsoft.ML.Core/Data/Schema.cs index ec45753331..90c87fc54a 100644 --- a/src/Microsoft.ML.Core/Data/Schema.cs +++ b/src/Microsoft.ML.Core/Data/Schema.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Data { /// - /// This class represents the of an object like, for interstance, an or an . + /// This class represents the of an object like, for interstance, an or an . /// On the high level, the schema is a collection of 'columns'. Each column has the following properties: /// - Column name. /// - Column type. diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs index da60c84ccf..b2641d985e 100644 --- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; + namespace Microsoft.ML.Runtime.Data { /// @@ -11,28 +13,27 @@ namespace Microsoft.ML.Runtime.Data /// Dispose is virtual with the default implementation delegating to the input cursor. /// [BestFriend] - internal abstract class SynchronizedCursorBase : ICursor - where TBase : class, ICursor + internal abstract class SynchronizedCursorBase : RowCursor { protected readonly IChannel Ch; - private readonly ICursor _root; + private readonly RowCursor _root; private bool _disposed; - protected TBase Input { get; } + protected RowCursor Input { get; } - public long Position => _root.Position; + public sealed override long Position => _root.Position; - public long Batch => _root.Batch; + public sealed override long Batch => _root.Batch; - public CursorState State => _root.State; + public sealed override CursorState State => _root.State; /// /// Convenience property for checking whether the current state is CursorState.Good. /// protected bool IsGood => _root.State == CursorState.Good; - protected SynchronizedCursorBase(IChannelProvider provider, TBase input) + protected SynchronizedCursorBase(IChannelProvider provider, RowCursor input) { Contracts.AssertValue(provider, "provider"); Ch = provider.Start("Cursor"); @@ -42,7 +43,7 @@ protected SynchronizedCursorBase(IChannelProvider provider, TBase input) _root = Input.GetRootCursor(); } - public virtual void Dispose() + public override void Dispose() { if (!_disposed) { @@ -52,24 +53,12 @@ public virtual void Dispose() } } - public bool MoveNext() - { - return _root.MoveNext(); - } + public sealed override bool MoveNext() => _root.MoveNext(); - public bool MoveMany(long count) - { - return _root.MoveMany(count); - } + public sealed override bool MoveMany(long count) => _root.MoveMany(count); - public ICursor GetRootCursor() - { - return _root; - } + public sealed override RowCursor GetRootCursor() => _root; - public ValueGetter GetIdGetter() - { - return Input.GetIdGetter(); - } + public sealed override ValueGetter GetIdGetter() => Input.GetIdGetter(); } } diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 4d8752a6fe..f81672008b 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -114,7 +114,7 @@ public static int GetThreadCount(IHost host, int num = 0, bool preferOne = false /// Try to create a cursor set from upstream and consolidate it here. The host determines /// the target cardinality of the cursor set. /// - public static bool TryCreateConsolidatingCursor(out IRowCursor curs, + public static bool TryCreateConsolidatingCursor(out RowCursor curs, IDataView view, Func predicate, IHost host, Random rand) { Contracts.CheckValue(host, nameof(host)); @@ -146,19 +146,19 @@ public static bool TryCreateConsolidatingCursor(out IRowCursor curs, /// cardinality. If not all the active columns are cachable, this will only /// produce the given input cursor. /// - public static IRowCursor[] CreateSplitCursors(out IRowCursorConsolidator consolidator, - IChannelProvider provider, IRowCursor input, int num) + public static RowCursor[] CreateSplitCursors(out IRowCursorConsolidator consolidator, + IChannelProvider provider, RowCursor input, int num) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(input, nameof(input)); consolidator = null; if (num <= 1) - return new IRowCursor[1] { input }; + return new RowCursor[1] { input }; // If any active columns are not cachable, we can't split. if (!AllCachable(input.Schema, input.IsColumnActive)) - return new IRowCursor[1] { input }; + return new RowCursor[1] { input }; // REVIEW: Should we limit the cardinality to some reasonable size? @@ -205,7 +205,7 @@ public static bool IsCachable(this ColumnType type) /// that is, they all are non-null, have the same schemas, and the same /// set of columns are active. /// - public static bool SameSchemaAndActivity(IRowCursor[] cursors) + public static bool SameSchemaAndActivity(RowCursor[] cursors) { // There must be something to actually consolidate. if (Utils.Size(cursors) == 0) @@ -239,7 +239,7 @@ public static bool SameSchemaAndActivity(IRowCursor[] cursors) /// Given a parallel cursor set, this consolidates them into a single cursor. The batchSize /// is a hint used for efficiency. /// - public static IRowCursor ConsolidateGeneric(IChannelProvider provider, IRowCursor[] inputs, int batchSize) + public static RowCursor ConsolidateGeneric(IChannelProvider provider, RowCursor[] inputs, int batchSize) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckNonEmpty(inputs, nameof(inputs)); @@ -309,12 +309,12 @@ public Consolidator(Splitter splitter) _splitter = splitter; } - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) + public RowCursor CreateCursor(IChannelProvider provider, RowCursor[] inputs) { return Consolidate(provider, inputs, 128, ref _splitter._consolidateCachePools); } - public static IRowCursor Consolidate(IChannelProvider provider, IRowCursor[] inputs, int batchSize, ref object[] ourPools) + public static RowCursor Consolidate(IChannelProvider provider, RowCursor[] inputs, int batchSize, ref object[] ourPools) { Contracts.AssertValue(provider); using (var ch = provider.Start("Consolidate")) @@ -323,14 +323,14 @@ public static IRowCursor Consolidate(IChannelProvider provider, IRowCursor[] inp } } - private static IRowCursor ConsolidateCore(IChannelProvider provider, IRowCursor[] inputs, ref object[] ourPools, IChannel ch) + private static RowCursor ConsolidateCore(IChannelProvider provider, RowCursor[] inputs, ref object[] ourPools, IChannel ch) { ch.CheckNonEmpty(inputs, nameof(inputs)); if (inputs.Length == 1) return inputs[0]; ch.CheckParam(SameSchemaAndActivity(inputs), nameof(inputs), "Inputs not compatible for consolidation"); - IRowCursor cursor = inputs[0]; + RowCursor cursor = inputs[0]; var schema = cursor.Schema; ch.CheckParam(AllCachable(schema, cursor.IsColumnActive), nameof(inputs), "Inputs had some uncachable input columns"); @@ -494,7 +494,7 @@ private static MadeObjectPool GetPoolCore(object[] pools, int poolIdx) } } - public static IRowCursor[] Split(out IRowCursorConsolidator consolidator, IChannelProvider provider, Schema schema, IRowCursor input, int cthd) + public static RowCursor[] Split(out IRowCursorConsolidator consolidator, IChannelProvider provider, Schema schema, RowCursor input, int cthd) { Contracts.AssertValue(provider, "provider"); @@ -506,7 +506,7 @@ public static IRowCursor[] Split(out IRowCursorConsolidator consolidator, IChann } } - private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannelProvider ch, IRowCursor input, int cthd) + private RowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannelProvider ch, RowCursor input, int cthd) { Contracts.AssertValue(ch); ch.AssertValue(input); @@ -524,7 +524,7 @@ private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannel int[] colToActive; Utils.BuildSubsetMaps(_schema.ColumnCount, input.IsColumnActive, out activeToCol, out colToActive); - Func createFunc = CreateInPipe; + Func createFunc = CreateInPipe; var inGenMethod = createFunc.GetMethodInfo().GetGenericMethodDefinition(); object[] arguments = new object[] { input, 0 }; // Only one set of in-pipes, one per column, as well as for extra side information. @@ -644,7 +644,7 @@ private IRowCursor[] SplitCore(out IRowCursorConsolidator consolidator, IChannel /// /// An in pipe creator intended to be used from the splitter only. /// - private InPipe CreateInPipe(IRow input, int col) + private InPipe CreateInPipe(Row input, int col) { Contracts.AssertValue(input); Contracts.Assert(0 <= col && col < _schema.ColumnCount); @@ -654,7 +654,7 @@ private InPipe CreateInPipe(IRow input, int col) /// /// An in pipe creator intended to be used from the splitter only. /// - private InPipe CreateIdInPipe(IRow input) + private InPipe CreateIdInPipe(Row input) { Contracts.AssertValue(input); return CreateInPipeCore(_schema.ColumnCount + (int)ExtraIndex.Id, input.GetIdGetter()); @@ -849,7 +849,7 @@ public void SetAll(OutPipe[] pipes) /// /// This helps a cursor present the results of a . Practically its role - /// really is to just provide a stable delegate for the . + /// really is to just provide a stable delegate for the . /// There is one of these created per column, per output cursor, i.e., in splitting /// there are n of these created per column, and when consolidating only one of these /// is created per column. @@ -999,7 +999,7 @@ protected override void Getter(ref T value) /// objects from the input blocking collection, and yields the /// values stored therein through the help of objects. /// - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly Schema _schema; private readonly int[] _activeToCol; @@ -1014,12 +1014,9 @@ private sealed class Cursor : RootCursorBase, IRowCursor private long _batch; private bool _disposed; - public Schema Schema => _schema; + public override Schema Schema => _schema; - public override long Batch - { - get { return _batch; } - } + public override long Batch => _batch; /// /// Constructs one of the split cursors. @@ -1114,13 +1111,13 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActive.Length, nameof(col)); return _colToActive[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); var getter = _getters[_colToActive[col]] as ValueGetter; @@ -1136,9 +1133,9 @@ public ValueGetter GetGetter(int col) /// at the cost of being totally synchronous, that is, there is no parallel benefit from /// having split the input cursors. /// - internal sealed class SynchronousConsolidatingCursor : RootCursorBase, IRowCursor + internal sealed class SynchronousConsolidatingCursor : RootCursorBase { - private readonly IRowCursor[] _cursors; + private readonly RowCursor[] _cursors; private readonly Delegate[] _getters; private readonly Schema _schema; @@ -1152,7 +1149,7 @@ internal sealed class SynchronousConsolidatingCursor : RootCursorBase, IRowCurso // Index into _cursors array pointing to the current cursor, or -1 if this cursor is not in Good state. private int _icursor; // If this cursor is in Good state then this should equal _cursors[_icursor], else null. - private IRowCursor _currentCursor; + private RowCursor _currentCursor; private bool _disposed; private readonly struct CursorStats @@ -1171,9 +1168,9 @@ public CursorStats(long batch, int idx) // input batch as our own batch. Should we suppress it? public override long Batch { get { return _batch; } } - public Schema Schema => _schema; + public override Schema Schema => _schema; - public SynchronousConsolidatingCursor(IChannelProvider provider, IRowCursor[] cursors) + public SynchronousConsolidatingCursor(IChannelProvider provider, RowCursor[] cursors) : base(provider) { Ch.CheckNonEmpty(cursors, nameof(cursors)); @@ -1199,7 +1196,7 @@ private void InitHeap() { for (int i = 0; i < _cursors.Length; ++i) { - IRowCursor cursor = _cursors[i]; + RowCursor cursor = _cursors[i]; Ch.Assert(cursor.State == CursorState.NotStarted); if (cursor.MoveNext()) _mins.Add(new CursorStats(cursor.Batch, i)); @@ -1291,13 +1288,13 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActive.Length, nameof(col)); return _colToActive[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); var getter = _getters[_colToActive[col]] as ValueGetter; @@ -1307,7 +1304,7 @@ public ValueGetter GetGetter(int col) } } - public static ValueGetter>[] PopulateGetterArray(IRowCursor cursor, List colIndices) + public static ValueGetter>[] PopulateGetterArray(RowCursor cursor, List colIndices) { var n = colIndices.Count; var getters = new ValueGetter>[n]; @@ -1335,7 +1332,7 @@ public static ValueGetter>[] PopulateGetterArray(IRowCursor return getters; } - public static ValueGetter> GetSingleValueGetter(IRow cursor, int i, ColumnType colType) + public static ValueGetter> GetSingleValueGetter(Row cursor, int i, ColumnType colType) { var floatGetter = cursor.GetGetter(i); T v = default(T); @@ -1365,7 +1362,7 @@ public static ValueGetter> GetSingleValueGetter(IRow cur return getter; } - public static ValueGetter> GetVectorFlatteningGetter(IRow cursor, int colIndex, ColumnType colType) + public static ValueGetter> GetVectorFlatteningGetter(Row cursor, int colIndex, ColumnType colType) { var vecGetter = cursor.GetGetter>(colIndex); var vbuf = default(VBuffer); diff --git a/src/Microsoft.ML.Data/Data/IRowSeekable.cs b/src/Microsoft.ML.Data/Data/IRowSeekable.cs index 9270ddc7f3..17514612af 100644 --- a/src/Microsoft.ML.Data/Data/IRowSeekable.cs +++ b/src/Microsoft.ML.Data/Data/IRowSeekable.cs @@ -14,18 +14,20 @@ namespace Microsoft.ML.Runtime.Data /// public interface IRowSeekable { - IRowSeeker GetSeeker(Func predicate); + RowSeeker GetSeeker(Func predicate); Schema Schema { get; } } /// /// Represents a row seeker with random access that can retrieve a specific row by the row index. - /// For IRowSeeker, when the state is valid (that is when MoveTo() returns true), it returns the - /// current row index. Otherwise it's -1. + /// For , when the state is valid (that is when + /// returns ), it returns the current row index. Otherwise it's -1. /// - public interface IRowSeeker : IRow, IDisposable + public abstract class RowSeeker : Row, IDisposable { + public abstract void Dispose(); + /// /// Moves the seeker to a row at a specific row index. /// If the row index specified is out of range (less than zero or not less than the @@ -33,6 +35,6 @@ public interface IRowSeeker : IRow, IDisposable /// /// The row index to move to. /// True if a row with specified index is found; false otherwise. - bool MoveTo(long rowIndex); + public abstract bool MoveTo(long rowIndex); } } diff --git a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs index 4c68dcbdd1..c090b185c6 100644 --- a/src/Microsoft.ML.Data/Data/ITransposeDataView.cs +++ b/src/Microsoft.ML.Data/Data/ITransposeDataView.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace Microsoft.ML.Runtime.Data { // REVIEW: There are a couple problems. Firstly, what to do about cases where @@ -14,8 +16,8 @@ namespace Microsoft.ML.Runtime.Data /// /// A view of data where columns can optionally be accessed slot by slot, as opposed to row /// by row in a typical dataview. A slot-accessible column can be accessed with a slot-by-slot - /// cursor via an (naturally, as opposed to row-by-row through an - /// ). This interface is intended to be implemented by classes that + /// cursor via an (naturally, as opposed to row-by-row through an + /// ). This interface is intended to be implemented by classes that /// want to provide an option for an alternate way of accessing the data stored in a /// . /// @@ -36,26 +38,7 @@ public interface ITransposeDataView : IDataView /// Presents a cursor over the slots of a transposable column, or throws if the column /// is not transposable. /// - ISlotCursor GetSlotCursor(int col); - } - - /// - /// A cursor that allows slot-by-slot access of data. - /// - public interface ISlotCursor : ICursor - { - /// - /// The slot type for this cursor. Note that this should equal the - /// for the column from which this slot cursor - /// was created. - /// - VectorType GetSlotType(); - - /// - /// A getter delegate for the slot values. The type must correspond - /// to the item type from . - /// - ValueGetter> GetGetter(); + SlotCursor GetSlotCursor(int col); } /// @@ -65,8 +48,8 @@ public interface ITransposeSchema : ISchema { /// /// Analogous to , except instead of returning the type of value - /// accessible through the , returns the item type of value accessible - /// through the . This will return null iff this particular + /// accessible through the , returns the item type of value accessible + /// through the . This will return null iff this particular /// column is not transposable, that is, it cannot be viewed in a slotwise fashion. Observe from /// the return type that this will always be a vector type. This vector type should be of fixed /// size and one dimension. diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index 57098d2895..5e79d05448 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -23,17 +23,17 @@ public static class RowCursorUtils /// The row to get the getter for /// The column index, which must be active on that row /// The getter as a delegate - public static Delegate GetGetterAsDelegate(IRow row, int col) + public static Delegate GetGetterAsDelegate(Row row, int col) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active"); - Func getGetter = GetGetterAsDelegateCore; + Func getGetter = GetGetterAsDelegateCore; return Utils.MarshalInvoke(getGetter, row.Schema.GetColumnType(col).RawType, row, col); } - private static Delegate GetGetterAsDelegateCore(IRow row, int col) + private static Delegate GetGetterAsDelegateCore(Row row, int col) { return row.GetGetter(col); } @@ -44,7 +44,7 @@ private static Delegate GetGetterAsDelegateCore(IRow row, int col) /// . /// /// - public static Delegate GetGetterAs(ColumnType typeDst, IRow row, int col) + public static Delegate GetGetterAs(ColumnType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.IsPrimitive, nameof(typeDst)); @@ -55,7 +55,7 @@ public static Delegate GetGetterAs(ColumnType typeDst, IRow row, int col) var typeSrc = row.Schema.GetColumnType(col); Contracts.Check(typeSrc.IsPrimitive, "Source column type must be primitive"); - Func> del = GetGetterAsCore; + Func> del = GetGetterAsCore; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, typeDst.RawType); return (Delegate)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, row, col }); } @@ -64,7 +64,7 @@ public static Delegate GetGetterAs(ColumnType typeDst, IRow row, int col) /// Given a destination type, IRow, and column index, return a ValueGetter{TDst} for the column /// with a conversion to typeDst, if needed. /// - public static ValueGetter GetGetterAs(ColumnType typeDst, IRow row, int col) + public static ValueGetter GetGetterAs(ColumnType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.IsPrimitive, nameof(typeDst)); @@ -76,12 +76,12 @@ public static ValueGetter GetGetterAs(ColumnType typeDst, IRow row, var typeSrc = row.Schema.GetColumnType(col); Contracts.Check(typeSrc.IsPrimitive, "Source column type must be primitive"); - Func> del = GetGetterAsCore; + Func> del = GetGetterAsCore; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, typeof(TDst)); return (ValueGetter)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, row, col }); } - private static ValueGetter GetGetterAsCore(ColumnType typeSrc, ColumnType typeDst, IRow row, int col) + private static ValueGetter GetGetterAsCore(ColumnType typeSrc, ColumnType typeDst, Row row, int col) { Contracts.Assert(typeof(TSrc) == typeSrc.RawType); Contracts.Assert(typeof(TDst) == typeDst.RawType); @@ -112,7 +112,7 @@ private static ValueGetter GetGetterAsCore(ColumnType typeSrc, /// into the required type. This method can be useful if you want to output a value /// as a string in a generic way, but don't really care how you do it. /// - public static ValueGetter GetGetterAsStringBuilder(IRow row, int col) + public static ValueGetter GetGetterAsStringBuilder(Row row, int col) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckParam(0 <= col && col < row.Schema.ColumnCount, nameof(col)); @@ -123,7 +123,7 @@ public static ValueGetter GetGetterAsStringBuilder(IRow row, int return Utils.MarshalInvoke(GetGetterAsStringBuilderCore, typeSrc.RawType, typeSrc, row, col); } - private static ValueGetter GetGetterAsStringBuilderCore(ColumnType typeSrc, IRow row, int col) + private static ValueGetter GetGetterAsStringBuilderCore(ColumnType typeSrc, Row row, int col) { Contracts.Assert(typeof(TSrc) == typeSrc.RawType); @@ -142,9 +142,9 @@ private static ValueGetter GetGetterAsStringBuilderCore(Col /// /// Given the item type, typeDst, a row, and column index, return a ValueGetter for the vector-valued /// column with a conversion to a vector of typeDst, if needed. This is the weakly typed version of - /// . + /// . /// - public static Delegate GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) + public static Delegate GetVecGetterAs(PrimitiveType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckValue(row, nameof(row)); @@ -163,7 +163,7 @@ public static Delegate GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) /// Given the item type, typeDst, a row, and column index, return a ValueGetter{VBuffer{TDst}} for the /// vector-valued column with a conversion to a vector of typeDst, if needed. /// - public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, IRow row, int col) + public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, Row row, int col) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.RawType == typeof(TDst), nameof(typeDst)); @@ -183,7 +183,7 @@ public static ValueGetter> GetVecGetterAs(PrimitiveType type /// Given the item type, typeDst, and a slot cursor, return a ValueGetter{VBuffer{TDst}} for the /// vector-valued column with a conversion to a vector of typeDst, if needed. /// - public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, ISlotCursor cursor) + public static ValueGetter> GetVecGetterAs(PrimitiveType typeDst, SlotCursor cursor) { Contracts.CheckValue(typeDst, nameof(typeDst)); Contracts.CheckParam(typeDst.RawType == typeof(TDst), nameof(typeDst)); @@ -200,12 +200,12 @@ public static ValueGetter> GetVecGetterAs(PrimitiveType type /// private abstract class GetterFactory { - public static GetterFactory Create(IRow row, int col) + public static GetterFactory Create(Row row, int col) { return new RowImpl(row, col); } - public static GetterFactory Create(ISlotCursor cursor) + public static GetterFactory Create(SlotCursor cursor) { return new SlotImpl(cursor); } @@ -214,10 +214,10 @@ public static GetterFactory Create(ISlotCursor cursor) private sealed class RowImpl : GetterFactory { - private readonly IRow _row; + private readonly Row _row; private readonly int _col; - public RowImpl(IRow row, int col) + public RowImpl(Row row, int col) { _row = row; _col = col; @@ -231,9 +231,9 @@ public override ValueGetter GetGetter() private sealed class SlotImpl : GetterFactory { - private readonly ISlotCursor _cursor; + private readonly SlotCursor _cursor; - public SlotImpl(ISlotCursor cursor) + public SlotImpl(SlotCursor cursor) { _cursor = cursor; } @@ -294,7 +294,7 @@ private static ValueGetter> GetVecGetterAsCore(VectorT /// is different than it was, in the last call. This is practically useful for determining /// group boundaries. Note that the delegate will return true on the first row. /// - public static Func GetIsNewGroupDelegate(IRow cursor, int col) + public static Func GetIsNewGroupDelegate(Row cursor, int col) { Contracts.CheckValue(cursor, nameof(cursor)); Contracts.Check(0 <= col && col < cursor.Schema.ColumnCount); @@ -303,7 +303,7 @@ public static Func GetIsNewGroupDelegate(IRow cursor, int col) return Utils.MarshalInvoke(GetIsNewGroupDelegateCore, type.RawType, cursor, col); } - private static Func GetIsNewGroupDelegateCore(IRow cursor, int col) + private static Func GetIsNewGroupDelegateCore(Row cursor, int col) { var getter = cursor.GetGetter(col); bool first = true; @@ -329,7 +329,7 @@ private static Func GetIsNewGroupDelegateCore(IRow cursor, int col) [Obsolete("The usages of this appear to be based on a total misunderstanding of what Batch actually is. It is a mechanism " + "to enable sharding and recovery of parallelized data, and has nothing to do with actual data.")] [BestFriend] - internal static Func GetIsNewBatchDelegate(IRow cursor, int batchSize) + internal static Func GetIsNewBatchDelegate(Row cursor, int batchSize) { Contracts.CheckParam(batchSize > 0, nameof(batchSize), "Batch size must be > 0"); long lastNewBatchPosition = -1; @@ -366,7 +366,7 @@ public static string TestGetLabelGetter(ColumnType type, bool allowKeys) return allowKeys ? "Expected R4, R8, Bool or Key type" : "Expected R4, R8 or Bool type"; } - public static ValueGetter GetLabelGetter(IRow cursor, int labelIndex) + public static ValueGetter GetLabelGetter(Row cursor, int labelIndex) { var type = cursor.Schema.GetColumnType(labelIndex); @@ -388,7 +388,7 @@ public static ValueGetter GetLabelGetter(IRow cursor, int labelIndex) return GetLabelGetterNotFloat(cursor, labelIndex); } - private static ValueGetter GetLabelGetterNotFloat(IRow cursor, int labelIndex) + private static ValueGetter GetLabelGetterNotFloat(Row cursor, int labelIndex) { var type = cursor.Schema.GetColumnType(labelIndex); @@ -425,7 +425,7 @@ private static ValueGetter GetLabelGetterNotFloat(IRow cursor, int label }; } - public static ValueGetter> GetLabelGetter(ISlotCursor cursor) + public static ValueGetter> GetLabelGetter(SlotCursor cursor) { var type = cursor.GetSlotType().ItemType; if (type == NumberType.R4) @@ -461,7 +461,7 @@ public static ValueGetter> GetLabelGetter(ISlotCursor cursor) /// Fetches the value of the column by name, in the given row. /// Used by the evaluators to retrieve the metrics from the results IDataView. /// - public static T Fetch(IExceptionContext ectx, IRow row, string name) + public static T Fetch(IExceptionContext ectx, Row row, string name) { if (!row.Schema.TryGetColumnIndex(name, out int col)) throw ectx.Except($"Could not find column '{name}'"); @@ -472,7 +472,7 @@ public static T Fetch(IExceptionContext ectx, IRow row, string name) /// /// Given a row, returns a one-row data view. This is useful for cases where you have a row, and you - /// wish to use some facility normally only exposed to dataviews. (For example, you have an + /// wish to use some facility normally only exposed to dataviews. (For example, you have an /// but want to save it somewhere using a .) /// Note that it is not possible for this method to ensure that the input does not /// change, so users of this convenience must take care of what they do with the input row or the data @@ -481,7 +481,7 @@ public static T Fetch(IExceptionContext ectx, IRow row, string name) /// An environment used to create the host for the resulting data view /// A row, whose columns must all be active /// A single-row data view incorporating that row - public static IDataView RowAsDataView(IHostEnvironment env, IRow row) + public static IDataView RowAsDataView(IHostEnvironment env, Row row) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(row, nameof(row)); @@ -491,13 +491,13 @@ public static IDataView RowAsDataView(IHostEnvironment env, IRow row) private sealed class OneRowDataView : IDataView { - private readonly IRow _row; + private readonly Row _row; private readonly IHost _host; // A channel provider is required for creating the cursor. public Schema Schema => _row.Schema; public bool CanShuffle => true; // The shuffling is even uniformly IID!! :) - public OneRowDataView(IHostEnvironment env, IRow row) + public OneRowDataView(IHostEnvironment env, Row row) { Contracts.AssertValue(env); _host = env.Register("OneRowDataView"); @@ -507,7 +507,7 @@ public OneRowDataView(IHostEnvironment env, IRow row) _row = row; } - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); @@ -515,12 +515,12 @@ public IRowCursor GetRowCursor(Func needCol, Random rand = null) return new Cursor(_host, this, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursor(needCol, rand) }; + return new RowCursor[] { GetRowCursor(needCol, rand) }; } public long? GetRowCount() @@ -528,12 +528,12 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun return 1; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly OneRowDataView _parent; private readonly bool[] _active; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { get { return 0; } } public Cursor(IHost host, OneRowDataView parent, bool[] active) @@ -551,7 +551,7 @@ protected override bool MoveNextCore() return State == CursorState.NotStarted; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); Ch.CheckParam(IsColumnActive(col), nameof(col), "Requested column is not active"); @@ -564,7 +564,7 @@ public ValueGetter GetGetter(int col) }; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); // We present the "illusion" that this column is not active, even though it must be diff --git a/src/Microsoft.ML.Data/Data/SlotCursor.cs b/src/Microsoft.ML.Data/Data/SlotCursor.cs new file mode 100644 index 0000000000..7e151254e0 --- /dev/null +++ b/src/Microsoft.ML.Data/Data/SlotCursor.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// A cursor that allows slot-by-slot access of data. This is to + /// what is to . + /// + public abstract class SlotCursor : IDisposable + { + [BestFriend] + private protected readonly IChannel Ch; + private CursorState _state; + + /// + /// Whether the cursor is in a state where it can serve up data, that is, + /// has been called and returned . + /// + [BestFriend] + private protected bool IsGood => _state == CursorState.Good; + + [BestFriend] + private protected SlotCursor(IChannelProvider provider) + { + Contracts.AssertValue(provider); + Ch = provider.Start("Slot Cursor"); + _state = CursorState.NotStarted; + } + + /// + /// The slot index. Incremented by one when is called and returns . + /// When initially created, or after returns , this will be -1. + /// + public abstract int SlotIndex { get; } + + /// + /// Advance to the next slot. When the cursor is first created, this method should be called to + /// move to the first slot. Returns if there are no more slots. + /// + public abstract bool MoveNext(); + + /// + /// The slot type for this cursor. Note that this should equal the + /// for the column from which this slot cursor + /// was created. + /// + public abstract VectorType GetSlotType(); + + /// + /// A getter delegate for the slot values. The type must correspond + /// to the item type from . + /// + public abstract ValueGetter> GetGetter(); + + public virtual void Dispose() + { + if (_state != CursorState.Done) + { + Ch.Dispose(); + _state = CursorState.Done; + } + } + + /// + /// For wrapping another slot cursor from which we get and , + /// but not the data or type accesors. Somewhat analogous to the + /// for s. + /// + [BestFriend] + internal abstract class SynchronizedSlotCursor : SlotCursor + { + private readonly SlotCursor _root; + + public SynchronizedSlotCursor(IChannelProvider provider, SlotCursor cursor) + : base(provider) + { + Contracts.AssertValue(cursor); + // If the input is itself a sync-base, we can walk up the chain to get its root, + // thereby making things more efficient. + _root = cursor is SynchronizedSlotCursor sync ? sync._root : cursor; + } + + public override bool MoveNext() + => _root.MoveNext(); + + public override int SlotIndex => _root.SlotIndex; + } + + /// + /// A useful base class for common implementations, somewhat + /// analogous to the for s. + /// + [BestFriend] + internal abstract class RootSlotCursor : SlotCursor + { + private int _slotIndex; + + public RootSlotCursor(IChannelProvider provider) + : base(provider) + { + _slotIndex = -1; + } + + public override int SlotIndex => _slotIndex; + + public override void Dispose() + { + base.Dispose(); + _slotIndex = -1; + } + + public override bool MoveNext() + { + if (_state == CursorState.Done) + return false; + + Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); + if (MoveNextCore()) + { + Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); + + _slotIndex++; + _state = CursorState.Good; + return true; + } + + Dispose(); + return false; + } + + /// + /// Core implementation of . Called only if this method + /// has not yet previously returned . + /// + protected abstract bool MoveNextCore(); + } + } +} diff --git a/src/Microsoft.ML.Data/DataDebuggerPreview.cs b/src/Microsoft.ML.Data/DataDebuggerPreview.cs index ba1049eb2b..296495b88c 100644 --- a/src/Microsoft.ML.Data/DataDebuggerPreview.cs +++ b/src/Microsoft.ML.Data/DataDebuggerPreview.cs @@ -63,7 +63,7 @@ internal DataDebuggerPreview(IDataView data, int maxRows = Defaults.MaxRows) public override string ToString() => $"{Schema.Count} columns, {RowView.Length} rows"; - private Action> MakeSetter(IRow row, int col) + private Action> MakeSetter(Row row, int col) { var getter = row.GetGetter(col); string name = row.Schema[col].Name; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 172a6d695e..aa33db97b8 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -1234,7 +1234,7 @@ private TableOfContentsEntry CreateRowIndexEntry(string rowIndexName) return entry; } - private IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + private RowCursor GetRowCursorCore(Func predicate, Random rand = null) { if (rand != null && _randomShufflePoolRows > 0) { @@ -1247,23 +1247,23 @@ private IRowCursor GetRowCursorCore(Func predicate, Random rand = nul return new Cursor(this, predicate, rand); } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); return GetRowCursorCore(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private const string _badCursorState = "cursor is either not started or is ended, and cannot get values"; @@ -1285,7 +1285,7 @@ private sealed class Cursor : RootCursorBase, IRowCursor private volatile bool _disposed; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { @@ -2009,7 +2009,7 @@ public override Delegate GetGetter() } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; @@ -2070,7 +2070,7 @@ protected override bool MoveNextCore() return more; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(_colToActivesIndex[col] >= 0, nameof(col), "requested column not active"); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs index 9a47f932de..d8c060c031 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinarySaver.cs @@ -88,7 +88,7 @@ protected WritePipe(BinarySaver parent) /// /// Returns an appropriate generic WritePipe{T} for the given column. /// - public static WritePipe Create(BinarySaver parent, IRowCursor cursor, ColumnCodec col) + public static WritePipe Create(BinarySaver parent, RowCursor cursor, ColumnCodec col) { Type writePipeType = typeof(WritePipe<>).MakeGenericType(col.Codec.Type.RawType); return (WritePipe)Activator.CreateInstance(writePipeType, parent, cursor, col); @@ -109,7 +109,7 @@ private sealed class WritePipe : WritePipe private MemoryStream _currentStream; private T _value; - public WritePipe(BinarySaver parent, IRowCursor cursor, ColumnCodec col) + public WritePipe(BinarySaver parent, RowCursor cursor, ColumnCodec col) : base(parent) { var codec = col.Codec as IValueCodec; @@ -581,7 +581,7 @@ private void FetchWorker(BlockingCollection toCompress, IDataView data, HashSet activeSet = new HashSet(activeColumns.Select(col => col.SourceIndex)); long blockIndex = 0; int remainingInBlock = rowsPerBlock; - using (IRowCursor cursor = data.GetRowCursor(activeSet.Contains)) + using (RowCursor cursor = data.GetRowCursor(activeSet.Contains)) { WritePipe[] pipes = new WritePipe[activeColumns.Length]; for (int c = 0; c < activeColumns.Length; ++c) @@ -746,7 +746,7 @@ private int RowsPerBlockHeuristic(IDataView data, ColumnCodec[] actives) EstimatorDelegate del = EstimatorCore; MethodInfo methInfo = del.GetMethodInfo().GetGenericMethodDefinition(); - using (IRowCursor cursor = data.GetRowCursor(active.Contains, rand)) + using (RowCursor cursor = data.GetRowCursor(active.Contains, rand)) { object[] args = new object[] { cursor, null, null, null }; var writers = new IValueWriter[actives.Length]; @@ -776,10 +776,10 @@ private int RowsPerBlockHeuristic(IDataView data, ColumnCodec[] actives) } } - private delegate void EstimatorDelegate(IRowCursor cursor, ColumnCodec col, + private delegate void EstimatorDelegate(RowCursor cursor, ColumnCodec col, out Func fetchWriteEstimator, out IValueWriter writer); - private void EstimatorCore(IRowCursor cursor, ColumnCodec col, + private void EstimatorCore(RowCursor cursor, ColumnCodec col, out Func fetchWriteEstimator, out IValueWriter writer) { ValueGetter getter = cursor.GetGetter(col.SourceIndex); diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index e281cd0876..2ed8081bad 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -568,14 +568,14 @@ private static string GenerateTag(int index) public ITransposeSchema TransposeSchema { get; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); return View.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); @@ -583,7 +583,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, return View.GetRowCursorSet(out consolidator, predicate, n, rand); } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); if (TransposeSchema?.GetSlotType(col) == null) diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index 4efb11f0fe..5ad393f2d2 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -293,16 +293,16 @@ public void Save(ModelSaveContext ctx) return null; } - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { return new Cursor(_host, this, _files, needCol, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) { consolidator = null; var cursor = new Cursor(_host, this, _files, needCol, rand); - return new IRowCursor[] { cursor }; + return new RowCursor[] { cursor }; } /// @@ -362,7 +362,7 @@ private IDataLoader CreateLoaderFromBytes(byte[] loaderBytes, IMultiStreamSource } } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private PartitionedFileLoader _parent; @@ -372,7 +372,7 @@ private sealed class Cursor : RootCursorBase, IRowCursor private Delegate[] _subGetters; // Cached getters of the sub-cursor. private ReadOnlyMemory[] _colValues; // Column values cached from the file path. - private IRowCursor _subCursor; // Sub cursor of the current file. + private RowCursor _subCursor; // Sub cursor of the current file. private IEnumerator _fileOrder; @@ -397,9 +397,9 @@ public Cursor(IChannelProvider provider, PartitionedFileLoader parent, IMultiStr public override long Batch => 0; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -423,7 +423,7 @@ public override ValueGetter GetIdGetter() }; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.Count); return _active[col]; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 225dfe1e98..e748967714 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1364,7 +1364,7 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files) public Schema Schema => _reader._bindings.AsSchema; - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -1372,7 +1372,7 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) return Cursor.Create(_reader, _files, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs index ac70417f6b..1a2f94da3b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Data { public sealed partial class TextLoader { - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { // Lines are divided into batches and processed a batch at a time. This enables // parallel parsing. @@ -133,7 +133,7 @@ private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader re } } - public static IRowCursor Create(TextLoader parent, IMultiStreamSource files, bool[] active) + public static RowCursor Create(TextLoader parent, IMultiStreamSource files, bool[] active) { // Note that files is allowed to be empty. Contracts.AssertValue(parent); @@ -150,7 +150,7 @@ public static IRowCursor Create(TextLoader parent, IMultiStreamSource files, boo return new Cursor(parent, stats, active, reader, srcNeeded, cthd); } - public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator, + public static RowCursor[] CreateSet(out IRowCursorConsolidator consolidator, TextLoader parent, IMultiStreamSource files, bool[] active, int n) { // Note that files is allowed to be empty. @@ -168,11 +168,11 @@ public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator, if (cthd <= 1) { consolidator = null; - return new IRowCursor[1] { new Cursor(parent, stats, active, reader, srcNeeded, 1) }; + return new RowCursor[1] { new Cursor(parent, stats, active, reader, srcNeeded, 1) }; } consolidator = new Consolidator(cthd); - var cursors = new IRowCursor[cthd]; + var cursors = new RowCursor[cthd]; try { for (int i = 0; i < cursors.Length; i++) @@ -273,7 +273,7 @@ public static string GetEmbeddedArgs(IMultiStreamSource files) return sb.ToString(); } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; public override void Dispose() { @@ -301,13 +301,13 @@ protected override bool MoveNextCore() return false; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.Infos.Length); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); var fn = _getters[col] as ValueGetter; @@ -834,7 +834,7 @@ public Consolidator(int cthd) _cthd = cthd; } - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) + public RowCursor CreateCursor(IChannelProvider provider, RowCursor[] inputs) { Contracts.AssertValue(provider); int cthd = Interlocked.Exchange(ref _cthd, 0); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 9474605079..e55db4df55 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -47,7 +47,7 @@ private abstract class ValueWriter { public readonly int Source; - public static ValueWriter Create(IRowCursor cursor, int col, char sep) + public static ValueWriter Create(RowCursor cursor, int col, char sep) { Contracts.AssertValue(cursor); @@ -148,7 +148,7 @@ private sealed class VecValueWriter : ValueWriterBase private readonly VBuffer> _slotNames; private readonly int _slotCount; - public VecValueWriter(IRowCursor cursor, VectorType type, int source, char sep) + public VecValueWriter(RowCursor cursor, VectorType type, int source, char sep) : base(type.ItemType, source, sep) { _getSrc = cursor.GetGetter>(source); @@ -213,7 +213,7 @@ private sealed class ValueWriter : ValueWriterBase private T _src; private string _columnName; - public ValueWriter(IRowCursor cursor, PrimitiveType type, int source, char sep) + public ValueWriter(RowCursor cursor, PrimitiveType type, int source, char sep) : base(type, source, sep) { _getSrc = cursor.GetGetter(source); @@ -573,7 +573,7 @@ public State(TextSaver parent, TextWriter writer, ValueWriter[] pipes, bool hasH _mpslotichLim = new int[128]; } - public void Run(IRowCursor cursor, ref long count, out int minLen, out int maxLen) + public void Run(RowCursor cursor, ref long count, out int minLen, out int maxLen) { minLen = int.MaxValue; maxLen = 0; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 0954bae382..f57e08b8c7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -668,7 +668,7 @@ public VectorType GetSlotType(int col) return _header.RowCount; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -677,16 +677,16 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) return new Cursor(this, predicate); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); if (HasRowData) return _schemaEntry.GetView().GetRowCursorSet(out consolidator, predicate, n, rand); consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < _header.ColumnCount, nameof(col)); var view = _entries[col].GetViewOrNull(); @@ -699,7 +699,7 @@ public ISlotCursor GetSlotCursor(int col) // We don't want the type error, if there is one, to be handled by the get-getter, because // at the point we've gotten the interior cursor, but not yet constructed the slot cursor. ColumnType cursorType = TransposeSchema.GetSlotType(col).ItemType; - IRowCursor inputCursor = view.GetRowCursor(c => true); + RowCursor inputCursor = view.GetRowCursor(c => true); try { return Utils.MarshalInvoke(GetSlotCursorCore, cursorType.RawType, inputCursor); @@ -714,41 +714,56 @@ public ISlotCursor GetSlotCursor(int col) } } - private ISlotCursor GetSlotCursorCore(IRowCursor inputCursor) + private SlotCursor GetSlotCursorCore(RowCursor inputCursor) { return new SlotCursor(this, inputCursor); } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor + private sealed class SlotCursor : SlotCursor { private readonly TransposeLoader _parent; private readonly ValueGetter> _getter; + private readonly RowCursor _rowCursor; - private IHost Host { get { return _parent._host; } } - - public SlotCursor(TransposeLoader parent, IRowCursor cursor) - : base(parent._host, cursor) + public SlotCursor(TransposeLoader parent, RowCursor cursor) + : base(parent._host) { _parent = parent; - Ch.Assert(cursor.Schema.ColumnCount == 1); - Ch.Assert(cursor.Schema.GetColumnType(0).RawType == typeof(VBuffer)); - _getter = Input.GetGetter>(0); - } + Ch.AssertValue(cursor); + Ch.Assert(cursor.Schema.Count == 1); + Ch.Assert(cursor.Schema[0].Type.RawType == typeof(VBuffer)); + Ch.Assert(cursor.Schema[0].Type is VectorType); + _rowCursor = cursor; - public VectorType GetSlotType() - { - var type = Input.Schema.GetColumnType(0).AsVector; - Ch.AssertValue(type); - return type; + _getter = _rowCursor.GetGetter>(0); } - public ValueGetter> GetGetter() + public override VectorType GetSlotType() + => (VectorType)_rowCursor.Schema[0].Type; + + public override ValueGetter> GetGetter() { ValueGetter> getter = _getter as ValueGetter>; if (getter == null) throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); return getter; } + + public override bool MoveNext() + { + return _rowCursor.MoveNext(); + } + + public override int SlotIndex + { + get + { + long pos = _rowCursor.Position; + Contracts.Assert(pos <= int.MaxValue); + return (int)pos; + } + } + } private Transposer EnsureAndGetTransposer(int col) @@ -777,16 +792,16 @@ private Transposer EnsureAndGetTransposer(int col) return _colTransposers[col]; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly TransposeLoader _parent; private readonly int[] _actives; private readonly int[] _colToActivesIndex; - private readonly ICursor[] _transCursors; + private readonly SlotCursor[] _transCursors; private readonly Delegate[] _getters; private bool _disposed; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; public override long Batch { get { return 0; } } @@ -802,7 +817,7 @@ public Cursor(TransposeLoader parent, Func pred) Ch.Assert(!_parent.HasRowData); Utils.BuildSubsetMaps(_parent._header.ColumnCount, pred, out _actives, out _colToActivesIndex); - _transCursors = new ICursor[_actives.Length]; + _transCursors = new SlotCursor[_actives.Length]; _getters = new Delegate[_actives.Length]; // The following will fill in both the _transCursors and _getters arrays. for (int i = 0; i < _actives.Length; ++i) @@ -841,7 +856,7 @@ private void InitOne(int col) var type = Schema.GetColumnType(col); Ch.Assert(typeof(T) == type.RawType); var trans = _parent.EnsureAndGetTransposer(col); - ISlotCursor cursor = trans.GetSlotCursor(0); + SlotCursor cursor = trans.GetSlotCursor(0); ValueGetter> getter = cursor.GetGetter(); VBuffer buff = default(VBuffer); ValueGetter oneGetter = @@ -862,7 +877,7 @@ private void InitVec(int col) Ch.Assert(type.IsVector); Ch.Assert(typeof(T) == type.ItemType.RawType); var trans = _parent.EnsureAndGetTransposer(col); - ISlotCursor cursor = trans.GetSlotCursor(0); + SlotCursor cursor = trans.GetSlotCursor(0); ValueGetter> getter = cursor.GetGetter(); int i = _colToActivesIndex[col]; _getters[i] = getter; @@ -892,25 +907,13 @@ protected override bool MoveNextCore() return more; } - protected override bool MoveManyCore(long count) - { - Ch.Assert(State != CursorState.Done); - bool more = Position < _parent._header.RowCount - count; - for (int i = 0; i < _transCursors.Length; ++i) - { - bool cMore = _transCursors[i].MoveMany(count); - Ch.Assert(cMore == more); - } - return more; - } - - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col <= _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col <= _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index d20d832e88..e9a100bbef 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -146,7 +146,7 @@ private void CheckSchemaConsistency() return sum; } - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); if (rand == null || !_canShuffle) @@ -154,20 +154,20 @@ public IRowCursor GetRowCursor(Func needCol, Random rand = null) return new RandCursor(this, needCol, rand, _counts); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private abstract class CursorBase : RootCursorBase, IRowCursor + private abstract class CursorBase : RootCursorBase { protected readonly IDataView[] Sources; protected readonly Delegate[] Getters; public override long Batch => 0; - public Schema Schema { get; } + public sealed override Schema Schema { get; } public CursorBase(AppendRowsDataView parent) : base(parent._host) @@ -189,7 +189,7 @@ protected Delegate CreateGetter(int col) protected abstract ValueGetter CreateTypedGetter(int col); - public ValueGetter GetGetter(int col) + public sealed override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate."); if (!(Getters[col] is ValueGetter)) @@ -197,7 +197,7 @@ public ValueGetter GetGetter(int col) return Getters[col] as ValueGetter; } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.ColumnCount, "Column index is out of range"); return Getters[col] != null; @@ -209,7 +209,7 @@ public bool IsColumnActive(int col) /// private sealed class Cursor : CursorBase { - private IRowCursor _currentCursor; + private RowCursor _currentCursor; private ValueGetter _currentIdGetter; private int _currentSourceIndex; @@ -299,7 +299,7 @@ public override void Dispose() /// private sealed class RandCursor : CursorBase { - private readonly IRowCursor[] _cursorSet; + private readonly RowCursor[] _cursorSet; private readonly MultinomialWithoutReplacementSampler _sampler; private readonly Random _rand; private int _currentSourceIndex; @@ -313,7 +313,7 @@ public RandCursor(AppendRowsDataView parent, Func needCol, Random ran _rand = rand; Ch.AssertValue(counts); Ch.Assert(Sources.Length == counts.Length); - _cursorSet = new IRowCursor[counts.Length]; + _cursorSet = new RowCursor[counts.Length]; for (int i = 0; i < counts.Length; i++) { Ch.Assert(counts[i] >= 0); @@ -374,7 +374,7 @@ public override void Dispose() if (State != CursorState.Done) { Ch.Dispose(); - foreach (IRowCursor c in _cursorSet) + foreach (RowCursor c in _cursorSet) c.Dispose(); base.Dispose(); } diff --git a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs index c11d49923c..38e0e1b5bf 100644 --- a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs +++ b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs @@ -226,29 +226,29 @@ public DataView(IHostEnvironment env, ArrayDataViewBuilder builder, int rowCount _rowCount = rowCount; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - return new RowCursor(_host, this, predicate, rand); + return new Cursor(_host, this, predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { new RowCursor(_host, this, predicate, rand) }; + return new RowCursor[] { new Cursor(_host, this, predicate, rand) }; } - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly DataView _view; private readonly BitArray _active; private readonly int[] _indices; - public Schema Schema => _view.Schema; + public override Schema Schema => _view.Schema; public override long Batch { @@ -256,7 +256,7 @@ public override long Batch get { return 0; } } - public RowCursor(IChannelProvider provider, DataView view, Func predicate, Random rand) + public Cursor(IChannelProvider provider, DataView view, Func predicate, Random rand) : base(provider) { Ch.AssertValue(view); @@ -298,13 +298,13 @@ public override ValueGetter GetIdGetter() } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col & col < Schema.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(0 <= col & col < Schema.ColumnCount); Ch.Check(_active[col], "column is not active"); diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs index 6d670db5ef..9f2e0ab447 100644 --- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs +++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs @@ -203,7 +203,7 @@ public int MapInputToCacheColumnIndex(int inputIndex) return _rowCount; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -235,7 +235,7 @@ private int[] GetPermutationOrNull(Random rand) return Utils.GetRandomPermutation(rand, (int)_rowCount); } - private IRowCursor GetRowCursorWaiterCore(TWaiter waiter, Func predicate, Random rand) + private RowCursor GetRowCursorWaiterCore(TWaiter waiter, Func predicate, Random rand) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); @@ -247,7 +247,7 @@ private IRowCursor GetRowCursorWaiterCore(TWaiter waiter, Func.Create(waiter, perm)); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); @@ -258,7 +258,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, if (n <= 1) { consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } consolidator = new Consolidator(); @@ -273,13 +273,13 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, /// private sealed class Consolidator : IRowCursorConsolidator { - public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs) + public RowCursor CreateCursor(IChannelProvider provider, RowCursor[] inputs) { return DataViewUtils.ConsolidateGeneric(provider, inputs, _batchSize); } } - private IRowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func predicate, int n, Random rand) + private RowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func predicate, int n, Random rand) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); @@ -287,7 +287,7 @@ private IRowCursor[] GetRowCursorSetWaiterCore(TWaiter waiter, Func(TWaiter waiter, Func(Func predicate, TIndex index) + private RowCursor CreateCursor(Func predicate, TIndex index) where TIndex : struct, IIndex { Contracts.AssertValue(predicate); return new RowCursor(this, predicate, index); } - public IRowSeeker GetSeeker(Func predicate) + public RowSeeker GetSeeker(Func predicate) { _host.CheckValue(predicate, nameof(predicate)); // The seeker needs to know the row count when it validates the row index to move to. @@ -320,11 +320,11 @@ public IRowSeeker GetSeeker(Func predicate) return GetSeeker(predicate, waiter); } - private IRowSeeker GetSeeker(Func predicate, TWaiter waiter) + private RowSeeker GetSeeker(Func predicate, TWaiter waiter) where TWaiter : struct, IWaiter { _host.AssertValue(predicate); - return new RowSeeker(this, predicate, waiter); + return new RowSeeker(new RowSeekerCore(this, predicate, waiter)); } /// @@ -339,7 +339,7 @@ private void KickoffFiller(int[] columns) _host.AssertValue(columns); HashSet taskColumns = null; - IRowCursor cursor; + RowCursor cursor; ColumnCache[] caches; OrderedWaiter waiter; lock (_cacheLock) @@ -390,7 +390,7 @@ private void KickoffFiller(int[] columns) /// The caches we must fill and, at the end of the cursor, freeze /// The waiter to increment as we cache each additional row /// - private void Filler(IRowCursor cursor, ColumnCache[] caches, OrderedWaiter waiter) + private void Filler(RowCursor cursor, ColumnCache[] caches, OrderedWaiter waiter) { _host.AssertValue(cursor); _host.AssertValue(caches); @@ -464,15 +464,15 @@ internal void Wait() } } - private sealed class RowCursor : RowCursorSeekerBase, IRowCursor + private sealed class RowCursor : RowCursorSeekerBase where TIndex : struct, IIndex { private CursorState _state; private readonly TIndex _index; - public CursorState State { get { return _state; } } + public override CursorState State => _state; - public long Batch { get { return _index.Batch; } } + public override long Batch => _index.Batch; public RowCursor(CacheDataView parent, Func predicate, TIndex index) : base(parent, predicate) @@ -481,17 +481,11 @@ public RowCursor(CacheDataView parent, Func predicate, TIndex index) _index = index; } - public ValueGetter GetIdGetter() - { - return _index.GetIdGetter(); - } + public override ValueGetter GetIdGetter() => _index.GetIdGetter(); - public ICursor GetRootCursor() - { - return this; - } + public override RowCursor GetRootCursor() => this; - public bool MoveNext() + public override bool MoveNext() { if (_state == CursorState.Done) { @@ -502,7 +496,7 @@ public bool MoveNext() Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); if (_index.MoveNext()) { - Position++; + PositionCore++; Ch.Assert(Position >= 0); _state = CursorState.Good; return true; @@ -513,7 +507,7 @@ public bool MoveNext() return false; } - public bool MoveMany(long count) + public override bool MoveMany(long count) { // Note: If we decide to allow count == 0, then we need to special case // that MoveNext() has never been called. It's not entirely clear what the return @@ -529,7 +523,7 @@ public bool MoveMany(long count) Ch.Assert(_state == CursorState.NotStarted || _state == CursorState.Good); if (_index.MoveMany(count)) { - Position += count; + PositionCore += count; _state = CursorState.Good; Ch.Assert(Position >= 0); return true; @@ -556,14 +550,41 @@ protected override ValueGetter CreateGetterDelegateCore(ColumnCa } } - private sealed class RowSeeker : RowCursorSeekerBase, IRowSeeker - where TWaiter : struct, IWaiter + private sealed class RowSeeker : RowSeeker + where TWaiter : struct, IWaiter + { + private readonly RowSeekerCore _internal; + + public RowSeeker(RowSeekerCore toWrap) + { + Contracts.AssertValue(toWrap); + _internal = toWrap; + } + + public override long Position => _internal.Position; + public override long Batch => _internal.Batch; + public override Schema Schema => _internal.Schema; + + public override void Dispose() + { + } + + public override ValueGetter GetGetter(int col) => _internal.GetGetter(col); + public override ValueGetter GetIdGetter() => _internal.GetIdGetter(); + public override bool IsColumnActive(int col) => _internal.IsColumnActive(col); + public override bool MoveTo(long rowIndex) => _internal.MoveTo(rowIndex); + } + + private sealed class RowSeekerCore : RowCursorSeekerBase + where TWaiter : struct, IWaiter { private readonly TWaiter _waiter; - public long Batch { get { return 0; } } + public override long Batch => 0; - public ValueGetter GetIdGetter() + public override CursorState State => throw new NotImplementedException(); + + public override ValueGetter GetIdGetter() { return (ref UInt128 val) => @@ -573,7 +594,7 @@ public ValueGetter GetIdGetter() }; } - public RowSeeker(CacheDataView parent, Func predicate, TWaiter waiter) + public RowSeekerCore(CacheDataView parent, Func predicate, TWaiter waiter) : base(parent, predicate) { _waiter = waiter; @@ -585,11 +606,11 @@ public bool MoveTo(long rowIndex) { // If requested row index is out of range, the row seeker // returns false and sets its position to -1. - Position = -1; + PositionCore = -1; return false; } - Position = rowIndex; + PositionCore = rowIndex; return true; } @@ -601,6 +622,10 @@ protected override ValueGetter CreateGetterDelegateCore(ColumnCa { return (ref TValue value) => cache.Fetch((int)Position, ref value); } + + public override bool MoveNext() => throw Ch.ExceptNotSupp(); + public override bool MoveMany(long count) => throw Ch.ExceptNotSupp(); + public override RowCursor GetRootCursor() => throw Ch.ExceptNotSupp(); } private interface IWaiter @@ -675,7 +700,7 @@ private sealed class WaiterWaiter : IWaiter /// /// If this is true, then a could be used instead. /// - public bool IsTrivial { get { return _waiters.Length == 0; } } + public bool IsTrivial => _waiters.Length == 0; private WaiterWaiter(CacheDataView parent, Func pred) { @@ -722,7 +747,7 @@ public static Wrapper Create(CacheDataView parent, Func pred) { private readonly WaiterWaiter _waiter; - public bool IsTrivial { get { return _waiter.IsTrivial; } } + public bool IsTrivial => _waiter.IsTrivial; public Wrapper(WaiterWaiter waiter) { @@ -730,7 +755,7 @@ public Wrapper(WaiterWaiter waiter) _waiter = waiter; } - public bool Wait(long pos) { return _waiter.Wait(pos); } + public bool Wait(long pos) => _waiter.Wait(pos); } } @@ -758,7 +783,7 @@ private interface IIndex /// /// Moves to the next index. Once this or has returned /// false, it should never be called again. (This in constrast to public - /// objects, whose move methods are robust to that usage.) + /// objects, whose move methods are robust to that usage.) /// /// Whether the next index is available. bool MoveNext(); @@ -842,11 +867,11 @@ public Wrapper(SequenceIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } @@ -933,11 +958,11 @@ public Wrapper(RandomIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } @@ -1103,11 +1128,11 @@ public Wrapper(BlockSequenceIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } @@ -1211,34 +1236,35 @@ public Wrapper(BlockRandomIndex index) _index = index; } - public long Batch { get { return _index.Batch; } } - public long GetIndex() { return _index.GetIndex(); } - public ValueGetter GetIdGetter() { return _index.GetIdGetter(); } - public bool MoveNext() { return _index.MoveNext(); } - public bool MoveMany(long count) { return _index.MoveMany(count); } + public long Batch => _index.Batch; + public long GetIndex() => _index.GetIndex(); + public ValueGetter GetIdGetter() => _index.GetIdGetter(); + public bool MoveNext() => _index.MoveNext(); + public bool MoveMany(long count) => _index.MoveMany(count); } } - private abstract class RowCursorSeekerBase : IDisposable + private abstract class RowCursorSeekerBase : RowCursor { protected readonly CacheDataView Parent; protected readonly IChannel Ch; + protected long PositionCore; private readonly int[] _colToActivesIndex; private readonly Delegate[] _getters; private bool _disposed; - public Schema Schema => Parent.Schema; + public sealed override Schema Schema => Parent.Schema; - public long Position { get; protected set; } + public sealed override long Position => PositionCore; protected RowCursorSeekerBase(CacheDataView parent, Func predicate) { Contracts.AssertValue(parent); Parent = parent; Ch = parent._host.Start("Cursor"); - Position = -1; + PositionCore = -1; // Set up the mapping from active columns. int colLim = Schema.ColumnCount; @@ -1259,24 +1285,24 @@ protected RowCursorSeekerBase(CacheDataView parent, Func predicate) } } - public bool IsColumnActive(int col) + public sealed override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; } - public void Dispose() + public sealed override void Dispose() { if (!_disposed) { DisposeCore(); - Position = -1; + PositionCore = -1; Ch.Dispose(); _disposed = true; } } - public ValueGetter GetGetter(int col) + public sealed override ValueGetter GetGetter(int col) { if (!IsColumnActive(col)) throw Ch.Except("Column #{0} is requested but not active in the cursor", col); @@ -1348,7 +1374,7 @@ protected ColumnCache(IExceptionContext ctx, OrderedWaiter waiter) /// The column of the cursor we are wrapping. /// The waiter for the filler associated with this column /// - public static ColumnCache Create(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public static ColumnCache Create(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) { Contracts.AssertValue(parent); var host = parent._host; @@ -1368,7 +1394,7 @@ public static ColumnCache Create(CacheDataView parent, IRowCursor input, int src if (_pipeConstructorTypes == null) { Interlocked.CompareExchange(ref _pipeConstructorTypes, - new Type[] { typeof(CacheDataView), typeof(IRowCursor), typeof(int), typeof(OrderedWaiter) }, null); + new Type[] { typeof(CacheDataView), typeof(RowCursor), typeof(int), typeof(OrderedWaiter) }, null); } var constructor = pipeType.GetConstructor(_pipeConstructorTypes); return (ColumnCache)constructor.Invoke(new object[] { parent, input, srcCol, waiter }); @@ -1416,7 +1442,7 @@ private sealed class ImplVec : ColumnCache> // Temporary working reusable storage for caching the source data. private VBuffer _temp; - public ImplVec(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ImplVec(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent, input, srcCol, waiter) { var type = input.Schema.GetColumnType(srcCol); @@ -1499,7 +1525,7 @@ private sealed class ImplOne : ColumnCache private T[] _values; private ValueGetter _getter; - public ImplOne(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ImplOne(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent, input, srcCol, waiter) { _getter = input.GetGetter(srcCol); @@ -1534,7 +1560,7 @@ public override void Freeze() private abstract class ColumnCache : ColumnCache { - public ColumnCache(CacheDataView parent, IRowCursor input, int srcCol, OrderedWaiter waiter) + public ColumnCache(CacheDataView parent, RowCursor input, int srcCol, OrderedWaiter waiter) : base(parent._host, waiter) { Contracts.AssertValue(input); diff --git a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs index eb319cb7bc..326a750a26 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs +++ b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs @@ -43,7 +43,7 @@ public Func GetDependencies(Func predicate) return toReturn; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); @@ -73,7 +73,7 @@ public IRow GetRow(IRow input, Func active, out Action disposer) for (int i = deps.Length - 1; i >= 1; --i) deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]); - IRow result = input; + Row result = input; for (int i = 0; i < InnerMappers.Length; ++i) { result = InnerMappers[i].GetRow(result, deps[i], out var localDisp); @@ -90,12 +90,12 @@ public IRow GetRow(IRow input, Func active, out Action disposer) return result; } - private sealed class SubsetActive : IRow + private sealed class SubsetActive : Row { - private readonly IRow _row; + private readonly Row _row; private Func _pred; - public SubsetActive(IRow row, Func pred) + public SubsetActive(Row row, Func pred) { Contracts.AssertValue(row); Contracts.AssertValue(pred); @@ -103,12 +103,12 @@ public SubsetActive(IRow row, Func pred) _pred = pred; } - public Schema Schema => _row.Schema; - public long Position => _row.Position; - public long Batch => _row.Batch; - public ValueGetter GetGetter(int col) => _row.GetGetter(col); - public ValueGetter GetIdGetter() => _row.GetIdGetter(); - public bool IsColumnActive(int col) => _pred(col); + public override Schema Schema => _row.Schema; + public override long Position => _row.Position; + public override long Batch => _row.Batch; + public override ValueGetter GetGetter(int col) => _row.GetGetter(col); + public override ValueGetter GetIdGetter() => _row.GetIdGetter(); + public override bool IsColumnActive(int col) => _pred(col); } } } diff --git a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs index 90aedca48a..466da6b098 100644 --- a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs +++ b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs @@ -28,14 +28,14 @@ public EmptyDataView(IHostEnvironment env, Schema schema) public long? GetRowCount() => 0; - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); return new Cursor(_host, Schema, needCol); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); @@ -43,11 +43,11 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun return new[] { new Cursor(_host, Schema, needCol) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly bool[] _active; - public Schema Schema { get; } + public override Schema Schema { get; } public override long Batch => 0; public Cursor(IChannelProvider provider, Schema schema, Func needCol) @@ -71,9 +71,9 @@ public override ValueGetter GetIdGetter() protected override bool MoveNextCore() => false; - public bool IsColumnActive(int col) => 0 <= col && col < _active.Length && _active[col]; + public override bool IsColumnActive(int col) => 0 <= col && col < _active.Length && _active[col]; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "Can't get getter for inactive column"); return diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs index 817ba98ed1..e5b5ae1654 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs @@ -150,7 +150,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _typeDst; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs index 0cec2a7a58..92e39caede 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs @@ -106,7 +106,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -114,10 +114,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random bool[] active; Func inputPred = GetActive(predicate, out active); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(this, input, active); + return new Cursor(this, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -129,9 +129,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(this, inputs[i], active); + cursors[i] = new Cursor(this, inputs[i], active); return cursors; } @@ -147,13 +147,13 @@ private Func GetActive(Func predicate, out bool[] active) } // REVIEW: Should this cache the source value like MissingValueFilter does? - private sealed class RowCursor : LinkedRowFilterCursorBase + private sealed class Cursor : LinkedRowFilterCursorBase { private readonly ValueGetter _getSrc; private readonly InPredicate _pred; private T1 _src; - public RowCursor(Impl parent, IRowCursor input, bool[] active) + public Cursor(Impl parent, RowCursor input, bool[] active) : base(parent.Host, input, parent.OutputSchema, active) { _getSrc = Input.GetGetter(parent._colSrc); diff --git a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs index 7d93533625..25613791d0 100644 --- a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs +++ b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs @@ -27,12 +27,12 @@ public OpaqueDataView(IDataView source) return _source.GetRowCount(); } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return _source.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { return _source.GetRowCursorSet(out consolidator, predicate, n, rand); diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index d5ccb2d063..a4839d710b 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -38,7 +38,7 @@ public interface IRowMapper : ICanSaveModel /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the /// i'th output column if activeOutput(i) is true, and null otherwise. /// - Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer); + Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer); /// /// Returns information about the output columns, including their name, type and any metadata information. @@ -178,14 +178,14 @@ private Func GetActiveOutputColumns(bool[] active) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Func predicateInput; var active = GetActive(predicate, out predicateInput); - return new RowCursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); + return new Cursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -200,9 +200,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, inputs[i], this, active); + cursors[i] = new Cursor(Host, inputs[i], this, active); return cursors; } @@ -235,7 +235,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => Source.Schema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); @@ -251,7 +251,7 @@ public IRow GetRow(IRow input, Func active, out Action disposer) var pred = GetActiveOutputColumns(activeArr); var getters = _mapper.CreateGetters(input, pred, out disp); disposer += disp; - return new Row(input, this, OutputSchema, getters); + return new RowImpl(input, this, OutputSchema, getters); } } @@ -285,20 +285,20 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) } } - private sealed class Row : IRow + private sealed class RowImpl : Row { - private readonly IRow _input; + private readonly Row _input; private readonly Delegate[] _getters; private readonly RowToRowMapperTransform _parent; - public long Batch { get { return _input.Batch; } } + public override long Batch => _input.Batch; - public long Position { get { return _input.Position; } } + public override long Position => _input.Position; - public Schema Schema { get; } + public override Schema Schema { get; } - public Row(IRow input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters) { _input = input; _parent = parent; @@ -306,7 +306,7 @@ public Row(IRow input, RowToRowMapperTransform parent, Schema schema, Delegate[] _getters = getters; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); @@ -320,9 +320,9 @@ public ValueGetter GetGetter(int col) return fn; } - public ValueGetter GetIdGetter() => _input.GetIdGetter(); + public override ValueGetter GetIdGetter() => _input.GetIdGetter(); - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); @@ -332,16 +332,16 @@ public bool IsColumnActive(int col) } } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Delegate[] _getters; private readonly bool[] _active; private readonly ColumnBindings _bindings; private readonly Action _disposer; - public Schema Schema => _bindings.Schema; + public override Schema Schema => _bindings.Schema; - public RowCursor(IChannelProvider provider, IRowCursor input, RowToRowMapperTransform parent, bool[] active) + public Cursor(IChannelProvider provider, RowCursor input, RowToRowMapperTransform parent, bool[] active) : base(provider, input) { var pred = parent.GetActiveOutputColumns(active); @@ -350,13 +350,13 @@ public RowCursor(IChannelProvider provider, IRowCursor input, RowToRowMapperTran _bindings = parent._bindings; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.Schema.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index 7cc8ec7059..55f67e6de2 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -11,41 +11,40 @@ namespace Microsoft.ML.Runtime.Data { /// - /// An implementation of that gets its , , - /// and from an input row. The constructor requires a schema and array of getter + /// An implementation of that gets its , , + /// and from an input row. The constructor requires a schema and array of getter /// delegates. A null delegate indicates an inactive column. The delegates are assumed to be of the appropriate type /// (this does not validate the type). /// REVIEW: Should this validate that the delegates are of the appropriate type? It wouldn't be difficult /// to do so. /// - public sealed class SimpleRow : IRow + public sealed class SimpleRow : Row { - private readonly Schema _schema; - private readonly IRow _input; + private readonly Row _input; private readonly Delegate[] _getters; - public Schema Schema { get { return _schema; } } + public override Schema Schema { get; } - public long Position { get { return _input.Position; } } + public override long Position => _input.Position; - public long Batch { get { return _input.Batch; } } + public override long Batch => _input.Batch; - public SimpleRow(Schema schema, IRow input, Delegate[] getters) + public SimpleRow(Schema schema, Row input, Delegate[] getters) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(input, nameof(input)); Contracts.Check(Utils.Size(getters) == schema.ColumnCount); - _schema = schema; + Schema = schema; _input = input; _getters = getters ?? new Delegate[0]; } - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return _input.GetIdGetter(); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); Contracts.Check(IsColumnActive(col)); @@ -55,7 +54,7 @@ public ValueGetter GetGetter(int col) return fn; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Contracts.Check(0 <= col && col < _getters.Length); return _getters[col] != null; diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 2befe3a726..e6ddcd57b5 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -229,7 +229,7 @@ private static int[] CheckIndices(IHost host, IDataView view, int[] columns) return columns; } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < _tschema.ColumnCount, nameof(col)); if (_inputToTransposed[col] == -1) @@ -249,7 +249,7 @@ public ISlotCursor GetSlotCursor(int col) return Utils.MarshalInvoke(GetSlotCursorCore, type, col); } - private ISlotCursor GetSlotCursorCore(int col) + private SlotCursor GetSlotCursorCore(int col) { if (_tschema.GetColumnType(col).IsVector) return new SlotCursorVec(this, col); @@ -265,12 +265,12 @@ private ISlotCursor GetSlotCursorCore(int col) public bool CanShuffle { get { return _view.CanShuffle; } } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return _view.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { return _view.GetRowCursorSet(out consolidator, predicate, n, rand); } @@ -357,14 +357,12 @@ public VectorType GetSlotType(int col) } } - private abstract class SlotCursor : RootCursorBase, ISlotCursor + private abstract class SlotCursor : SlotCursor.RootSlotCursor { private readonly Transposer _parent; private readonly int _col; private ValueGetter> _getter; - public override long Batch { get { return 0; } } - protected SlotCursor(Transposer parent, int col) : base(parent._host) { @@ -373,17 +371,7 @@ protected SlotCursor(Transposer parent, int col) _col = col; } - public override ValueGetter GetIdGetter() - { - return - (ref UInt128 val) => - { - Ch.Check(IsGood, "Cannot call ID getter in current state"); - val = new UInt128((ulong)Position, 0); - }; - } - - public ValueGetter> GetGetter() + public override ValueGetter> GetGetter() { if (_getter == null) _getter = GetGetterCore(); @@ -393,7 +381,7 @@ public ValueGetter> GetGetter() return getter; } - public VectorType GetSlotType() + public override VectorType GetSlotType() { return _parent.TransposeSchema.GetSlotType(_col); } @@ -406,6 +394,7 @@ private sealed class SlotCursorOne : SlotCursor private readonly IDataView _view; private readonly int _col; private readonly int _len; + private bool _moved; public SlotCursorOne(Transposer parent, int col) : base(parent, col) @@ -435,7 +424,7 @@ public SlotCursorOne(Transposer parent, int col) protected override bool MoveNextCore() { // We only can move next on one slot, since this is a scalar column. - return State == CursorState.NotStarted; + return _moved = !_moved; } protected override ValueGetter> GetGetterCore() @@ -577,7 +566,7 @@ public SlotCursorVec(Transposer parent, int col) /// private void EnsureValid() { - Ch.Check(State == CursorState.Good, "Cursor is not in good state, cannot get values"); + Ch.Check(IsGood, "Cursor is not in good state, cannot get values"); Ch.Assert(_slotCurr >= 0); if (_colStored == _colCurr) return; @@ -867,7 +856,7 @@ private void OutputColumnToSplitterIndices(int col, out int splitInd, out int sp splitCol = _colToSplitCol[col]; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); bool[] activeSplitters; @@ -875,7 +864,7 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) return new Cursor(_host, this, _input.GetRowCursor(srcPred, rand), predicate, activeSplitters); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -1000,7 +989,7 @@ public void GetMetadata(string kind, int col, ref TValue value) /// There is one instance of these per column, implementing the possible splitting /// of one column from a into multiple columns. The instance /// describes the resulting split columns through its implementation of - /// , and then can be bound to an to provide + /// , and then can be bound to an to provide /// that splitting functionality. /// private abstract class Splitter : NoMetadataSchema @@ -1059,10 +1048,10 @@ public static Splitter Create(IDataView view, int col) } /// - /// Given an input , create the containing the split + /// Given an input , create the containing the split /// version of the columns. /// - public abstract IRow Bind(IRow row, Func pred); + public abstract Row Bind(Row row, Func pred); private static Splitter CreateCore(IDataView view, int col) { @@ -1097,17 +1086,17 @@ public override string GetColumnName(int col) } #endregion - private abstract class RowBase : IRow + private abstract class RowBase : Row where TSplitter : Splitter { protected readonly TSplitter Parent; - protected readonly IRow Input; + protected readonly Row Input; - public Schema Schema => Parent.AsSchema; - public long Position => Input.Position; - public long Batch => Input.Batch; + public sealed override Schema Schema => Parent.AsSchema; + public sealed override long Position => Input.Position; + public sealed override long Batch => Input.Batch; - public RowBase(TSplitter parent, IRow input) + public RowBase(TSplitter parent, Row input) { Contracts.AssertValue(parent); Contracts.AssertValue(input); @@ -1116,14 +1105,10 @@ public RowBase(TSplitter parent, IRow input) Input = input; } - public ValueGetter GetIdGetter() + public sealed override ValueGetter GetIdGetter() { return Input.GetIdGetter(); } - - public abstract bool IsColumnActive(int col); - - public abstract ValueGetter GetGetter(int col); } /// @@ -1150,20 +1135,20 @@ public override ColumnType GetColumnType(int col) return _view.Schema.GetColumnType(SrcCol); } - public override IRow Bind(IRow row, Func pred) + public override Row Bind(Row row, Func pred) { Contracts.AssertValue(row); Contracts.Assert(row.Schema == _view.Schema); Contracts.AssertValue(pred); Contracts.Assert(row.IsColumnActive(SrcCol)); - return new Row(this, row, pred(0)); + return new RowImpl(this, row, pred(0)); } - private sealed class Row : RowBase> + private sealed class RowImpl : RowBase> { private readonly bool _isActive; - public Row(NoSplitter parent, IRow input, bool isActive) + public RowImpl(NoSplitter parent, Row input, bool isActive) : base(parent, input) { Contracts.Assert(Parent.ColumnCount == 1); @@ -1236,16 +1221,16 @@ public override ColumnType GetColumnType(int col) return _types[col]; } - public override IRow Bind(IRow row, Func pred) + public override Row Bind(Row row, Func pred) { Contracts.AssertValue(row); Contracts.Assert(row.Schema == _view.Schema); Contracts.AssertValue(pred); Contracts.Assert(row.IsColumnActive(SrcCol)); - return new Row(this, row, pred); + return new RowImpl(this, row, pred); } - private sealed class Row : RowBase> + private sealed class RowImpl : RowBase> { // Counter of the last valid input, updated by EnsureValid. private long _lastValid; @@ -1260,7 +1245,7 @@ private sealed class Row : RowBase> // Getters. private readonly ValueGetter>[] _getters; - public Row(ColumnSplitter parent, IRow input, Func pred) + public RowImpl(ColumnSplitter parent, Row input, Func pred) : base(parent, input) { _inputGetter = input.GetGetter>(Parent.SrcCol); @@ -1367,17 +1352,17 @@ private void EnsureValid() } /// - /// The cursor implementation creates the s using , + /// The cursor implementation creates the s using , /// then collates the results from those rows as effectively one big row. /// - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly DataViewSlicer _slicer; - private readonly IRow[] _sliceRows; + private readonly Row[] _sliceRows; - public Schema Schema => _slicer.Schema; + public override Schema Schema => _slicer.Schema; - public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input, Func pred, bool[] activeSplitters) + public Cursor(IChannelProvider provider, DataViewSlicer slicer, RowCursor input, Func pred, bool[] activeSplitters) : base(provider, input) { Ch.AssertValue(slicer); @@ -1385,7 +1370,7 @@ public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input Ch.Assert(Utils.Size(activeSplitters) == slicer._splitters.Length); _slicer = slicer; - _sliceRows = new IRow[_slicer._splitters.Length]; + _sliceRows = new Row[_slicer._splitters.Length]; var activeSrc = new bool[slicer._splitters.Length]; var activeSrcSet = new HashSet(); int offset = 0; @@ -1403,7 +1388,7 @@ public Cursor(IChannelProvider provider, DataViewSlicer slicer, IRowCursor input } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.ColumnCount, "col"); int splitInd; @@ -1412,7 +1397,7 @@ public bool IsColumnActive(int col) return _sliceRows[splitInd] != null && _sliceRows[splitInd].IsColumnActive(splitCol); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); int splitInd; @@ -1446,7 +1431,7 @@ public static void GetSingleSlotValue(this ITransposeDataView view, int col, } /// - /// The is parameterized by a type that becomes the + /// The is parameterized by a type that becomes the /// type parameter for a , and this is generally preferable and more /// sensible but for various reasons it's often a lot simpler to have a get-getter be over /// the actual type returned by the getter, that is, parameterize this by the actual @@ -1457,7 +1442,7 @@ public static void GetSingleSlotValue(this ITransposeDataView view, int col, /// The cursor to get the getter for /// The exception contxt /// The value getter - public static ValueGetter GetGetterWithVectorType(this ISlotCursor cursor, IExceptionContext ctx = null) + public static ValueGetter GetGetterWithVectorType(this SlotCursor cursor, IExceptionContext ctx = null) { Contracts.CheckValueOrNull(ctx); ctx.CheckValue(cursor, nameof(cursor)); @@ -1478,15 +1463,15 @@ public static ValueGetter GetGetterWithVectorType(this ISlotCurs /// /// Given a slot cursor, construct a single-column equivalent row cursor, with the single column /// active and having the same type. This is useful to exploit the many utility methods that exist - /// to handle and but that know nothing about - /// , without having to rewrite all of them. This is, however, rather + /// to handle and but that know nothing about + /// , without having to rewrite all of them. This is, however, rather /// something of a hack; whenever possible or reasonable the slot cursor should be used directly. /// The name of this column is always "Waffles". /// /// The channel provider used in creating the wrapping row cursor /// The slot cursor to wrap /// A row cursor with a single active column with the same type as the slot type - public static IRowCursor GetRowCursorShim(IChannelProvider provider, ISlotCursor cursor) + public static RowCursor GetRowCursorShim(IChannelProvider provider, SlotCursor cursor) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(cursor, nameof(cursor)); @@ -1494,7 +1479,7 @@ public static IRowCursor GetRowCursorShim(IChannelProvider provider, ISlotCursor return Utils.MarshalInvoke(GetRowCursorShimCore, cursor.GetSlotType().ItemType.RawType, provider, cursor); } - private static IRowCursor GetRowCursorShimCore(IChannelProvider provider, ISlotCursor cursor) + private static RowCursor GetRowCursorShimCore(IChannelProvider provider, SlotCursor cursor) { return new SlotRowCursorShim(provider, cursor); } @@ -1508,11 +1493,10 @@ public sealed class SlotDataView : IDataView private readonly ITransposeDataView _data; private readonly int _col; private readonly ColumnType _type; - private readonly SchemaImpl _schemaImpl; - public Schema Schema => _schemaImpl.AsSchema; + public Schema Schema { get; } - public bool CanShuffle { get { return false; } } + public bool CanShuffle => false; public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) { @@ -1525,7 +1509,10 @@ public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) _data = data; _col = col; - _schemaImpl = new SchemaImpl(this); + + var builder = new SchemaBuilder(); + builder.AddColumn(_data.Schema[_col].Name, _type, null); + Schema = builder.GetSchema(); } public long? GetRowCount() @@ -1536,113 +1523,50 @@ public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) return valueCount; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); return Utils.MarshalInvoke(GetRowCursor, _type.ItemType.RawType, predicate(0)); } - private IRowCursor GetRowCursor(bool active) + private RowCursor GetRowCursor(bool active) { return new Cursor(this, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private sealed class SchemaImpl : ISchema - { - private readonly SlotDataView _parent; - - private IHost Host { get { return _parent._host; } } - - public Schema AsSchema { get; } - - public int ColumnCount { get { return 1; } } - - public SchemaImpl(SlotDataView parent) - { - Contracts.AssertValue(parent); - _parent = parent; - AsSchema = Schema.Create(this); - } - - public ColumnType GetColumnType(int col) - { - Host.CheckParam(col == 0, nameof(col)); - return _parent._type; - } - - public string GetColumnName(int col) - { - Host.CheckParam(col == 0, nameof(col)); - // There is no real need for this to have the real name as the internal IDV - // substream does not have its name accessed, but we'll save it just the same. - // I am tempted though to just have this thing always claim its name is 'Pancakes'. - return _parent._data.Schema.GetColumnName(_parent._col); - } - - public bool TryGetColumnIndex(string name, out int col) - { - if (name == GetColumnName(0)) - { - col = 0; - return true; - } - col = -1; - return false; - } - - // No metadata. The top level IDV will hold the schema information, including metadata. - // This per-column dataview schema information is just minimally functional. - - public IEnumerable> GetMetadataTypes(int col) - { - Host.CheckParam(col == 0, nameof(col)); - return Enumerable.Empty>(); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Host.CheckNonEmpty(kind, nameof(kind)); - Host.CheckParam(col == 0, nameof(col)); - return null; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Host.CheckNonEmpty(kind, nameof(kind)); - Host.CheckParam(col == 0, nameof(col)); - throw MetadataUtils.ExceptGetMetadata(); - } - } - - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly SlotDataView _parent; + private readonly SlotCursor _slotCursor; private readonly Delegate _getter; - public Schema Schema => _parent.Schema; + public override Schema Schema => _parent.Schema; + + public override long Batch => 0; public Cursor(SlotDataView parent, bool active) - : base(parent._host, parent._data.GetSlotCursor(parent._col)) + : base(parent._host) { _parent = parent; + _slotCursor = _parent._data.GetSlotCursor(parent._col); if (active) - _getter = Input.GetGetter(); + _getter = _slotCursor.GetGetter(); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(col == 0, nameof(col)); return _getter != null; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(col == 0, nameof(col)); Ch.CheckParam(_getter != null, nameof(col), "requested column not active"); @@ -1652,98 +1576,61 @@ public ValueGetter GetGetter(int col) throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); return getter; } - } - } - - // REVIEW: This shim class is very similar to the above shim class, except at the - // cursor level, not the cursorable level. Is there some non-horrifying way to unify both, somehow? - private sealed class SlotRowCursorShim : SynchronizedCursorBase, IRowCursor - { - private readonly SchemaImpl _schema; - - public Schema Schema => _schema.AsSchema; - - private sealed class SchemaImpl : ISchema - { - private readonly SlotRowCursorShim _parent; - private readonly VectorType _type; - private IChannel Ch { get { return _parent.Ch; } } + public override ValueGetter GetIdGetter() => GetId; - public Schema AsSchema { get; } - - public int ColumnCount { get { return 1; } } - - public SchemaImpl(SlotRowCursorShim parent, VectorType slotType) + private void GetId(ref UInt128 id) { - Contracts.AssertValue(parent); - _parent = parent; - Ch.AssertValue(slotType); - _type = slotType; - AsSchema = Schema.Create(this); + Ch.Check(_slotCursor.SlotIndex >= 0, "Cannot get ID with cursor in current state."); + id = new UInt128((ulong)_slotCursor.SlotIndex, 0); } - public ColumnType GetColumnType(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return _type; - } + protected override bool MoveNextCore() => _slotCursor.MoveNext(); + } + } - public string GetColumnName(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return "Waffles"; - } + // REVIEW: This shim class is very similar to the above shim class, except at the + // cursor level, not the cursorable level. Is there some non-horrifying way to unify both, somehow? + private sealed class SlotRowCursorShim : RootCursorBase + { + private readonly SlotCursor _slotCursor; - public bool TryGetColumnIndex(string name, out int col) - { - if (name == GetColumnName(0)) - { - col = 0; - return true; - } - col = -1; - return false; - } + public override Schema Schema { get; } - public IEnumerable> GetMetadataTypes(int col) - { - Ch.CheckParam(col == 0, nameof(col)); - return Enumerable.Empty>(); - } + public override long Batch => 0; - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Ch.CheckNonEmpty(kind, nameof(kind)); - Ch.CheckParam(col == 0, nameof(col)); - return null; - } + public SlotRowCursorShim(IChannelProvider provider, SlotCursor cursor) + : base(provider) + { + Contracts.AssertValue(cursor); - public void GetMetadata(string kind, int col, ref TValue value) - { - Ch.CheckNonEmpty(kind, nameof(kind)); - Ch.CheckParam(col == 0, nameof(col)); - throw MetadataUtils.ExceptGetMetadata(); - } + _slotCursor = cursor; + var builder = new SchemaBuilder(); + builder.AddColumn("Waffles", cursor.GetSlotType(), null); + Schema = builder.GetSchema(); } - public SlotRowCursorShim(IChannelProvider provider, ISlotCursor cursor) - : base(provider, cursor) + public override bool IsColumnActive(int col) { - _schema = new SchemaImpl(this, Input.GetSlotType()); + Ch.CheckParam(col == 0, nameof(col)); + return true; } - public bool IsColumnActive(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(col == 0, nameof(col)); - return true; + return _slotCursor.GetGetterWithVectorType(Ch); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetIdGetter() => GetId; + + private void GetId(ref UInt128 id) { - Ch.CheckParam(col == 0, nameof(col)); - return Input.GetGetterWithVectorType(Ch); + Ch.Check(_slotCursor.SlotIndex >= 0, "Cannot get ID with cursor in current state."); + id = new UInt128((ulong)_slotCursor.SlotIndex, 0); } + + protected override bool MoveNextCore() => _slotCursor.MoveNext(); } /// diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs index 6bd6356b2a..7b344a02be 100644 --- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs +++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs @@ -71,7 +71,7 @@ private ZipDataView(IHost host, IDataView[] sources) return min; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); @@ -89,31 +89,31 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) } /// - /// Create an with no requested columns on a data view. + /// Create an with no requested columns on a data view. /// Potentially, this can be optimized by calling GetRowCount(lazy:true) first, and if the count is not known, /// wrapping around GetCursor(). /// - private IRowCursor GetMinimumCursor(IDataView dv) + private RowCursor GetMinimumCursor(IDataView dv) { _host.AssertValue(dv); return dv.GetRowCursor(x => false); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { - private readonly IRowCursor[] _cursors; + private readonly RowCursor[] _cursors; private readonly CompositeSchema _compositeSchema; private readonly bool[] _isColumnActive; public override long Batch { get { return 0; } } - public Cursor(ZipDataView parent, IRowCursor[] srcCursors, Func predicate) + public Cursor(ZipDataView parent, RowCursor[] srcCursors, Func predicate) : base(parent._host) { Ch.AssertNonEmpty(srcCursors); @@ -167,15 +167,15 @@ protected override bool MoveManyCore(long count) return true; } - public Schema Schema => _compositeSchema.AsSchema; + public override Schema Schema => _compositeSchema.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { _compositeSchema.CheckColumnInRange(col); return _isColumnActive[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int dv; int srcCol; diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index 14a23ab580..7a5b98da5b 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -254,7 +254,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -262,10 +262,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public sealed override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -277,18 +277,18 @@ public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator c Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); + cursors[i] = new Cursor(Host, _bindings, inputs[i], active); return cursors; } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool[] active) : base(provider, input) { Ch.AssertValue(bindings); @@ -298,15 +298,15 @@ public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, _active = active; } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index 153f442d32..b04101f5b7 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -128,9 +128,9 @@ public interface ICanGetSummaryInKeyValuePairs public interface ICanGetSummaryAsIRow { - IRow GetSummaryIRowOrNull(RoleMappedSchema schema); + Row GetSummaryIRowOrNull(RoleMappedSchema schema); - IRow GetStatsIRowOrNull(RoleMappedSchema schema); + Row GetStatsIRowOrNull(RoleMappedSchema schema); } public interface ICanGetSummaryAsIDataView diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index 8eae667c53..994414d8f4 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -236,7 +236,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => _rootSchema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { _ectx.Assert(IsCompositeRowToRowMapper(_chain)); _ectx.AssertValue(input); diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 7b8bd9651e..361f580c71 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -497,7 +497,7 @@ private void FinishOtherMetrics() } } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(!_streaming && PassNum < 2 || PassNum < 1); Host.AssertValue(schema.Label); diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 528f24efce..4dfe37f9f6 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -609,7 +609,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, bool } } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.AssertValue(schema.Label); Host.Assert(PassNum < 1); @@ -981,7 +981,7 @@ public override Func GetDependencies(Func activeOutput) return col => activeOutput(AssignedCol) && col == ScoreIndex; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 7d015da046..797854dcbf 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -484,7 +484,7 @@ private void ProcessRowSecondPass() WeightedCounters.UpdateSecondPass(in _features, _indicesArr); } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { AssertValid(assertGetters: false); @@ -646,7 +646,7 @@ public override Func GetDependencies(Func activeOutput) (activeOutput(ClusterIdCol) || activeOutput(SortedClusterCol) || activeOutput(SortedClusterScoreCol)); } - public override Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { disposer = null; diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 14c8d58649..829dd0043a 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -253,7 +253,7 @@ public bool Start() /// /// This method should get the getters of the new IRow that are needed for the next pass. /// - public abstract void InitializeNextPass(IRow row, RoleMappedSchema schema); + public abstract void InitializeNextPass(Row row, RoleMappedSchema schema); /// /// Call the getters once, and process the input as necessary. @@ -324,7 +324,7 @@ protected virtual List GetWarningsCore() // When a new value is encountered, it uses a callback for creating a new aggregator. protected abstract class AggregatorDictionaryBase { - protected IRow Row; + protected Row Row; protected readonly Func CreateAgg; protected readonly RoleMappedSchema Schema; @@ -346,7 +346,7 @@ protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Fun /// /// Gets the stratification column getter for the new IRow. /// - public abstract void Reset(IRow row); + public abstract void Reset(Row row); public static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType, Func createAgg) @@ -397,7 +397,7 @@ public GenericAggregatorDictionary(RoleMappedSchema schema, string stratCol, Col _dict = new Dictionary(); } - public override void Reset(IRow row) + public override void Reset(Row row) { Row = row; int col; @@ -505,6 +505,6 @@ public virtual void Save(ModelSaveContext ctx) public abstract Schema.DetachedColumn[] GetOutputColumns(); - public abstract Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer); + public abstract Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer); } } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 8889ccd6cf..205c24a8cf 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -978,7 +978,7 @@ private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataVie (in VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); } - private static List GetMetricNames(IChannel ch, Schema schema, IRow row, Func ignoreCol, + private static List GetMetricNames(IChannel ch, Schema schema, Row row, Func ignoreCol, ValueGetter[] getters, ValueGetter>[] vBufferGetters) { ch.AssertValue(schema); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs index 102e43a894..59a0c2c700 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -74,7 +74,7 @@ public class BinaryClassificationMetrics /// public double Auprc { get; } - protected private static T Fetch(IExceptionContext ectx, IRow row, string name) + protected private static T Fetch(IExceptionContext ectx, Row row, string name) { if (!row.Schema.TryGetColumnIndex(name, out int col)) throw ectx.Except($"Could not find column '{name}'"); @@ -83,7 +83,7 @@ protected private static T Fetch(IExceptionContext ectx, IRow row, string nam return val; } - internal BinaryClassificationMetrics(IExceptionContext ectx, IRow overallResult) + internal BinaryClassificationMetrics(IExceptionContext ectx, Row overallResult) { double Fetch(string name) => Fetch(ectx, overallResult, name); Auc = Fetch(BinaryClassifierEvaluator.Auc); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs index c4f0861224..7a8787eeff 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs @@ -42,7 +42,7 @@ public sealed class CalibratedBinaryClassificationMetrics : BinaryClassification /// public double Entropy { get; } - internal CalibratedBinaryClassificationMetrics(IExceptionContext ectx, IRow overallResult) + internal CalibratedBinaryClassificationMetrics(IExceptionContext ectx, Row overallResult) : base(ectx, overallResult) { double Fetch(string name) => Fetch(ectx, overallResult, name); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs index fc7e87e150..62b3352b63 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ClusteringMetrics.cs @@ -35,7 +35,7 @@ public sealed class ClusteringMetrics /// public double Dbi { get; } - internal ClusteringMetrics(IExceptionContext ectx, IRow overallResult, bool calculateDbi) + internal ClusteringMetrics(IExceptionContext ectx, Row overallResult, bool calculateDbi) { double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs index 4eff184abc..7acdce7025 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MultiClassClassifierMetrics.cs @@ -81,7 +81,7 @@ public sealed class MultiClassClassifierMetrics /// public double[] PerClassLogLoss { get; } - internal MultiClassClassifierMetrics(IExceptionContext ectx, IRow overallResult, int topK) + internal MultiClassClassifierMetrics(IExceptionContext ectx, Row overallResult, int topK) { double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); AccuracyMicro = FetchDouble(MultiClassClassifierEvaluator.AccuracyMicro); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs index 975d4a494c..a3749f0ecc 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs @@ -24,7 +24,7 @@ public sealed class RankerMetrics /// public double[] Dcg { get; } - private static T Fetch(IExceptionContext ectx, IRow row, string name) + private static T Fetch(IExceptionContext ectx, Row row, string name) { if (!row.Schema.TryGetColumnIndex(name, out int col)) throw ectx.Except($"Could not find column '{name}'"); @@ -33,7 +33,7 @@ private static T Fetch(IExceptionContext ectx, IRow row, string name) return val; } - internal RankerMetrics(IExceptionContext ectx, IRow overallResult) + internal RankerMetrics(IExceptionContext ectx, Row overallResult) { VBuffer Fetch(string name) => Fetch>(ectx, overallResult, name); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs index 8a9fd96b31..8bda753e21 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/RegressionMetrics.cs @@ -53,7 +53,7 @@ public sealed class RegressionMetrics /// public double RSquared { get; } - internal RegressionMetrics(IExceptionContext ectx, IRow overallResult) + internal RegressionMetrics(IExceptionContext ectx, Row overallResult) { double Fetch(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); L1 = Fetch(RegressionEvaluator.L1); diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 273599f7f5..cb12bf92e4 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -387,7 +387,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s ClassNames = classNames; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(PassNum < 1); Host.AssertValue(schema.Label); @@ -662,7 +662,7 @@ public override Func GetDependencies(Func activeOutput) activeOutput(SortedClassesCol) || activeOutput(LogLossCol)); } - public override Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { disposer = null; diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 1b1f8f6f05..fb88a5bb60 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -299,7 +299,7 @@ public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, int size, WeightedCounters = Weighted ? new Counters(lossFunction, _size) : null; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); @@ -457,7 +457,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index ab05c1ee0c..36eabc8a71 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -380,7 +380,7 @@ private ValueGetter>> CreateSlotNamesGetter(string }; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index fe750fce59..a2fea8415d 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -440,7 +440,7 @@ public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel GroupId = new List>(); } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); @@ -610,12 +610,12 @@ public void Save(ModelSaveContext ctx) return _transform.GetRowCount(); } - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { return _transform.GetRowCursor(needCol, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) { return _transform.GetRowCursorSet(out consolidator, needCol, n, rand); } @@ -775,7 +775,7 @@ private void Copy(Double[] src, ref VBuffer dst) dst = editor.Commit(); } - protected override ValueGetter GetLabelGetter(IRow row) + protected override ValueGetter GetLabelGetter(Row row) { var lb = RowCursorUtils.GetLabelGetter(row, _bindings.LabelIndex); return @@ -787,12 +787,12 @@ protected override ValueGetter GetLabelGetter(IRow row) }; } - protected override ValueGetter GetScoreGetter(IRow row) + protected override ValueGetter GetScoreGetter(Row row) { return row.GetGetter(_bindings.ScoreIndex); } - protected override RowCursorState InitializeState(IRow input) + protected override RowCursorState InitializeState(Row input) { return new RowCursorState(_truncationLevel); } diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index 729e0bffe0..b4fdf358a0 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -259,7 +259,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(IRow input, Func activeCols, out Action disposer) + public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs index dcd2ce618e..27856f0bcb 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs @@ -191,7 +191,7 @@ protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFun Weighted = weighted; } - public override void InitializeNextPass(IRow row, RoleMappedSchema schema) + public override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index a44cf64753..90670d518f 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -559,7 +559,7 @@ public Func GetDependencies(Func predicate) return _predictor.GetInputColumnRoles(); } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { Func predictorPredicate = col => false; for (int i = 0; i < OutputSchema.ColumnCount; i++) @@ -584,12 +584,12 @@ public IRow GetRow(IRow input, Func predicate, out Action disposer) return new SimpleRow(OutputSchema, predictorRow, getters); } - private Delegate GetPredictorGetter(IRow input, int col) + private Delegate GetPredictorGetter(Row input, int col) { return input.GetGetter(col); } - private Delegate GetProbGetter(IRow input) + private Delegate GetProbGetter(Row input) { var scoreGetter = RowCursorUtils.GetGetterAs(NumberType.R4, input, _scoreCol); ValueGetter probGetter = diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 68fdb15ab7..84acb75898 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -220,7 +220,7 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new BinaryClassifierScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Host.AssertValue(output); Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema); diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs index 1035735bc3..5b5796ca59 100644 --- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs @@ -91,7 +91,7 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new ClusteringScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Contracts.AssertValue(output); Contracts.Assert(output.Schema == Bindings.RowMapper.OutputSchema); diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs index 5062068a40..3a738887ec 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs @@ -214,24 +214,24 @@ public void Save(ModelSaveContext ctx) ctx.Writer.WriteBoolByte(Stringify); } - public Delegate GetTextContributionGetter(IRow input, int colSrc, VBuffer> slotNames) + public Delegate GetTextContributionGetter(Row input, int colSrc, VBuffer> slotNames) { Contracts.CheckValue(input, nameof(input)); Contracts.Check(0 <= colSrc && colSrc < input.Schema.ColumnCount); var typeSrc = input.Schema.GetColumnType(colSrc); - Func>, ValueGetter>> del = GetTextValueGetter; + Func>, ValueGetter>> del = GetTextValueGetter; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType); return (Delegate)meth.Invoke(this, new object[] { input, colSrc, slotNames }); } - public Delegate GetContributionGetter(IRow input, int colSrc) + public Delegate GetContributionGetter(Row input, int colSrc) { Contracts.CheckValue(input, nameof(input)); Contracts.Check(0 <= colSrc && colSrc < input.Schema.ColumnCount); var typeSrc = input.Schema.GetColumnType(colSrc); - Func>> del = GetValueGetter; + Func>> del = GetValueGetter; // REVIEW: Assuming Feature contributions will be VBuffer. // For multiclass LR it needs to be(VBuffer[]. @@ -249,7 +249,7 @@ private ReadOnlyMemory GetSlotName(int index, VBuffer : slotName; } - private ValueGetter> GetTextValueGetter(IRow input, int colSrc, VBuffer> slotNames) + private ValueGetter> GetTextValueGetter(Row input, int colSrc, VBuffer> slotNames) { Contracts.AssertValue(input); Contracts.AssertValue(Predictor); @@ -292,7 +292,7 @@ private ValueGetter> GetTextValueGetter(IRow input, i }; } - private ValueGetter> GetValueGetter(IRow input, int colSrc) + private ValueGetter> GetValueGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.AssertValue(Predictor); @@ -385,7 +385,7 @@ public Func GetDependencies(Func predicate) return col => false; } - public IRow GetOutputRow(IRow input, Func predicate, out Action disposer) + public Row GetOutputRow(Row input, Func predicate, out Action disposer) { Contracts.AssertValue(input); Contracts.AssertValue(predicate); @@ -419,7 +419,7 @@ public Func GetGenericPredicate(Func predicate) yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { return GetOutputRow(input, active, out disposer); } diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 99097dc12b..f301792603 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -261,7 +261,7 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new GenericScorer(env, this, newSource); } - protected override Delegate[] GetGetters(IRow output, Func predicate) + protected override Delegate[] GetGetters(Row output, Func predicate) { Host.Assert(_bindings.DerivedColumnCount == 0); Host.AssertValue(output); diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 3ee4377443..0324d950ce 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -307,7 +307,7 @@ public Func GetDependencies(Func predicate) return _mapper.GetInputColumnRoles(); } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { var innerRow = _mapper.GetRow(input, predicate, out disposer); return new RowImpl(innerRow, OutputSchema); @@ -386,17 +386,17 @@ public void GetMetadata(string kind, int col, ref TValue value) } } - private sealed class RowImpl : IRow + private sealed class RowImpl : Row { - private readonly IRow _row; + private readonly Row _row; private readonly Schema _schema; - public long Batch { get { return _row.Batch; } } - public long Position { get { return _row.Position; } } + public override long Batch => _row.Batch; + public override long Position => _row.Position; // The schema is of course the only difference from _row. - public Schema Schema => _schema; + public override Schema Schema => _schema; - public RowImpl(IRow row, Schema schema) + public RowImpl(Row row, Schema schema) { Contracts.AssertValue(row); Contracts.AssertValue(schema); @@ -405,17 +405,17 @@ public RowImpl(IRow row, Schema schema) _schema = schema; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { return _row.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { return _row.GetGetter(col); } - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return _row.GetIdGetter(); } @@ -551,7 +551,7 @@ public override IDataTransform ApplyToData(IHostEnvironment env, IDataView newSo return new MultiClassClassifierScorer(env, this, newSource); } - protected override Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter) + protected override Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter) { Host.AssertValue(output); Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 8aff6ff69a..488f0101fe 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -400,7 +400,7 @@ protected override bool WantParallelCursors(Func predicate) return Bindings.AnyNewColumnsActive(predicate); } - protected override Delegate[] GetGetters(IRow output, Func predicate) + protected override Delegate[] GetGetters(Row output, Func predicate) { Host.Assert(Bindings.DerivedColumnCount == 1); Host.AssertValue(output); @@ -432,10 +432,10 @@ protected override Delegate[] GetGetters(IRow output, Func predicate) return getters; } - protected abstract Delegate GetPredictedLabelGetter(IRow output, out Delegate scoreGetter); + protected abstract Delegate GetPredictedLabelGetter(Row output, out Delegate scoreGetter); protected void EnsureCachedPosition(ref long cachedPosition, ref TScore score, - IRow boundRow, ValueGetter scoreGetter) + Row boundRow, ValueGetter scoreGetter) { if (cachedPosition != boundRow.Position) { diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 8cb39d2f82..34913c5d1b 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -114,7 +114,7 @@ private static bool[] GetActive(BindingsBase bindings, Func predicate /// protected abstract bool WantParallelCursors(Func predicate); - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Contracts.AssertValue(predicate); Contracts.AssertValueOrNull(rand); @@ -124,10 +124,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random Func predicateMapper; var active = GetActive(bindings, predicate, out predicateInput, out predicateMapper); var input = Source.GetRowCursor(predicateInput, rand); - return new RowCursor(Host, this, input, active, predicateMapper); + return new Cursor(Host, this, input, active, predicateMapper); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -143,13 +143,13 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Contracts.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, this, inputs[i], active, predicateMapper); + cursors[i] = new Cursor(Host, this, inputs[i], active, predicateMapper); return cursors; } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disp) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disp) { var bindings = GetBindings(); Func predicateInput; @@ -173,9 +173,9 @@ protected override Func GetDependenciesCore(Func predicate /// Create and fill an array of getters of size InfoCount. The indices of the non-null entries in the /// result should be exactly those for which predicate(iinfo) is true. /// - protected abstract Delegate[] GetGetters(IRow output, Func predicate); + protected abstract Delegate[] GetGetters(Row output, Func predicate); - protected static Delegate[] GetGettersFromRow(IRow row, Func predicate) + protected static Delegate[] GetGettersFromRow(Row row, Func predicate) { Contracts.AssertValue(row); Contracts.AssertValue(predicate); @@ -189,19 +189,19 @@ protected static Delegate[] GetGettersFromRow(IRow row, Func predicat return getters; } - protected static Delegate GetGetterFromRow(IRow row, int col) + protected static Delegate GetGetterFromRow(Row row, int col) { Contracts.AssertValue(row); Contracts.Assert(0 <= col && col < row.Schema.ColumnCount); Contracts.Assert(row.IsColumnActive(col)); var type = row.Schema.GetColumnType(col); - Func> del = GetGetterFromRow; + Func> del = GetGetterFromRow; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); return (Delegate)meth.Invoke(null, new object[] { row, col }); } - protected static ValueGetter GetGetterFromRow(IRow output, int col) + protected static ValueGetter GetGetterFromRow(Row output, int col) { Contracts.AssertValue(output); Contracts.Assert(0 <= col && col < output.Schema.ColumnCount); @@ -215,16 +215,16 @@ protected override int MapColumnIndex(out bool isSrc, int col) return bindings.MapColumnIndex(out isSrc, col); } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly BindingsBase _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; private readonly Action _disposer; - public Schema Schema { get; } + public override Schema Schema { get; } - public RowCursor(IChannelProvider provider, RowToRowScorerBase parent, IRowCursor input, bool[] active, Func predicateMapper) + public Cursor(IChannelProvider provider, RowToRowScorerBase parent, RowCursor input, bool[] active, Func predicateMapper) : base(provider, input) { Ch.AssertValue(parent); @@ -255,13 +255,13 @@ public override void Dispose() base.Dispose(); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 4d44d07543..c5f6301154 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -144,18 +144,18 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema); - protected virtual Delegate GetPredictionGetter(IRow input, int colSrc) + protected virtual Delegate GetPredictionGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.Assert(0 <= colSrc && colSrc < input.Schema.ColumnCount); var typeSrc = input.Schema.GetColumnType(colSrc); - Func> del = GetValueGetter; + Func> del = GetValueGetter; var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, ScoreType.RawType); return (Delegate)meth.Invoke(this, new object[] { input, colSrc }); } - private ValueGetter GetValueGetter(IRow input, int colSrc) + private ValueGetter GetValueGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.Assert(ValueMapper != null); @@ -223,7 +223,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => InputRoleMappedSchema.Schema; - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { Contracts.AssertValue(input); Contracts.AssertValue(predicate); @@ -510,7 +510,7 @@ public Func GetDependencies(Func predicate) yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name); } - private Delegate[] CreateGetters(IRow input, bool[] active) + private Delegate[] CreateGetters(Row input, bool[] active) { Contracts.Assert(Utils.Size(active) == 2); Contracts.Assert(_parent._distMapper != null); @@ -553,7 +553,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active) private static void EnsureCachedResultValueMapper(ValueMapper, Float, Float> mapper, ref long cachedPosition, ValueGetter> featureGetter, ref VBuffer features, - ref Float score, ref Float prob, IRow input) + ref Float score, ref Float prob, Row input) { Contracts.AssertValue(mapper); if (cachedPosition != input.Position) @@ -566,7 +566,7 @@ private static void EnsureCachedResultValueMapper(ValueMapper, Fl } } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { Contracts.AssertValue(input); var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate); @@ -657,7 +657,7 @@ protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema sch return new SingleValueRowMapper(schema, this, Schema.Create(new SchemaImpl(ScoreType, _quantiles))); } - protected override Delegate GetPredictionGetter(IRow input, int colSrc) + protected override Delegate GetPredictionGetter(Row input, int colSrc) { Contracts.AssertValue(input); Contracts.Assert(0 <= colSrc && colSrc < input.Schema.ColumnCount); diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 41e984f522..b641744421 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -237,14 +237,14 @@ private static Func CreatePredicate(RoleMappedData data, CursOpt opt, /// Create a row cursor for the RoleMappedData with the indicated standard columns active. /// This does not verify that the columns exist, but merely activates the ones that do exist. /// - public static IRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, Random rand, IEnumerable extraCols = null) + public static RowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, Random rand, IEnumerable extraCols = null) => data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand); /// /// Create a row cursor set for the RoleMappedData with the indicated standard columns active. /// This does not verify that the columns exist, but merely activates the ones that do exist. /// - public static IRowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator, + public static RowCursor[] CreateRowCursorSet(this RoleMappedData data, out IRowCursorConsolidator consolidator, CursOpt opt, int n, Random rand, IEnumerable extraCols = null) => data.Data.GetRowCursorSet(out consolidator, CreatePredicate(data, opt, extraCols), n, rand); @@ -258,7 +258,7 @@ private static void AddOpt(HashSet cols, ColumnInfo info) /// /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter> GetFeatureFloatVectorGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -271,7 +271,7 @@ public static ValueGetter> GetFeatureFloatVectorGetter(this IRow /// /// Get the getter for the feature column, assuming it is a vector of float. /// - public static ValueGetter> GetFeatureFloatVectorGetter(this IRow row, RoleMappedData data) + public static ValueGetter> GetFeatureFloatVectorGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetFeatureFloatVectorGetter(row, data.Schema); @@ -281,7 +281,7 @@ public static ValueGetter> GetFeatureFloatVectorGetter(this IRow /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetLabelFloatGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -295,7 +295,7 @@ public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedSc /// Get a getter for the label as a float. This assumes that the label column type /// has already been validated as appropriate for the kind of training being done. /// - public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetLabelFloatGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetLabelFloatGetter(row, data.Schema); @@ -304,7 +304,7 @@ public static ValueGetter GetLabelFloatGetter(this IRow row, RoleMappedDa /// /// Get the getter for the weight column, or null if there is no weight column. /// - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -317,7 +317,7 @@ public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMapp return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Index); } - public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetOptWeightFloatGetter(row, data.Schema); @@ -326,7 +326,7 @@ public static ValueGetter GetOptWeightFloatGetter(this IRow row, RoleMapp /// /// Get the getter for the group column, or null if there is no group column. /// - public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedSchema schema) + public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedSchema schema) { Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); @@ -339,7 +339,7 @@ public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedSche return RowCursorUtils.GetGetterAs(NumberType.U8, row, col.Index); } - public static ValueGetter GetOptGroupGetter(this IRow row, RoleMappedData data) + public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); return GetOptGroupGetter(row, data.Schema); @@ -393,7 +393,7 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, b /// /// This is the base class for a data cursor. Data cursors are specially typed - /// "convenience" cursor-like objects, less general than a but + /// "convenience" cursor-like objects, less general than a but /// more convenient for common access patterns that occur in machine learning. For /// example, the common idiom of iterating over features/labels/weights while skipping /// "bad" features, labels, and weights. There will be two typical access patterns for @@ -404,9 +404,9 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, b /// public abstract class TrainingCursorBase : IDisposable { - public IRow Row { get { return _cursor; } } + public Row Row { get { return _cursor; } } - private readonly IRowCursor _cursor; + private readonly RowCursor _cursor; private readonly Action _signal; private long _skipCount; @@ -420,7 +420,7 @@ public abstract class TrainingCursorBase : IDisposable /// /// /// This method is called - protected TrainingCursorBase(IRowCursor input, Action signal) + protected TrainingCursorBase(RowCursor input, Action signal) { Contracts.AssertValue(input); Contracts.AssertValueOrNull(signal); @@ -428,7 +428,7 @@ protected TrainingCursorBase(IRowCursor input, Action signal) _signal = signal; } - protected static IRowCursor CreateCursor(RoleMappedData data, CursOpt opt, Random rand, params int[] extraCols) + protected static RowCursor CreateCursor(RoleMappedData data, CursOpt opt, Random rand, params int[] extraCols) { Contracts.AssertValue(data); Contracts.AssertValueOrNull(rand); @@ -596,7 +596,7 @@ public TCurs[] CreateSet(int n, Random rand = null, params int[] extraCols) /// , whose return value is used to call /// this action. /// - protected abstract TCurs CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal); + protected abstract TCurs CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal); /// /// Accumulates signals from cursors, anding them together. Once it has @@ -658,7 +658,7 @@ public StandardScalarCursor(RoleMappedData data, CursOpt opt, Random rand = null { } - protected StandardScalarCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected StandardScalarCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, signal) { Contracts.AssertValue(data); @@ -723,7 +723,7 @@ public Factory(RoleMappedData data, CursOpt opt) { } - protected override StandardScalarCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override StandardScalarCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) => new StandardScalarCursor(input, data, opt, signal); } } @@ -748,7 +748,7 @@ public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Featu { } - protected FeatureFloatVectorCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected FeatureFloatVectorCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { if ((opt & CursOpt.Features) != 0 && data.Schema.Feature != null) @@ -789,7 +789,7 @@ public Factory(RoleMappedData data, CursOpt opt = CursOpt.Features) { } - protected override FeatureFloatVectorCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override FeatureFloatVectorCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new FeatureFloatVectorCursor(input, data, opt, signal); } @@ -816,7 +816,7 @@ public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label, { } - protected FloatLabelCursor(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected FloatLabelCursor(RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null) @@ -856,7 +856,7 @@ public Factory(RoleMappedData data, CursOpt opt = CursOpt.Label) { } - protected override FloatLabelCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override FloatLabelCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new FloatLabelCursor(input, data, opt, signal); } @@ -885,7 +885,7 @@ public MultiClassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = { } - protected MultiClassLabelCursor(int classCount, IRowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) + protected MultiClassLabelCursor(int classCount, RowCursor input, RoleMappedData data, CursOpt opt, Action signal = null) : base(input, data, opt, signal) { Contracts.Assert(classCount >= 0); @@ -934,7 +934,7 @@ public Factory(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label) _classCount = classCount; } - protected override MultiClassLabelCursor CreateCursorCore(IRowCursor input, RoleMappedData data, CursOpt opt, Action signal) + protected override MultiClassLabelCursor CreateCursorCore(RowCursor input, RoleMappedData data, CursOpt opt, Action signal) { return new MultiClassLabelCursor(_classCount, input, data, opt, signal); } diff --git a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs index fecac3d005..114bbc4fdd 100644 --- a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs +++ b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs @@ -13,11 +13,11 @@ namespace Microsoft.ML.Runtime.Data /// inconvenient or inefficient to handle the "no output selected" case in their /// own implementation. /// - internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBase, IRowCursor + internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBase { private readonly ColumnBindingsBase _bindings; - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; /// /// Creates a wrapped version of the cursor @@ -25,7 +25,7 @@ internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBaseChannel provider /// The input cursor /// The bindings object, - public BindingsWrappedRowCursor(IChannelProvider provider, IRowCursor input, ColumnBindingsBase bindings) + public BindingsWrappedRowCursor(IChannelProvider provider, RowCursor input, ColumnBindingsBase bindings) : base(provider, input) { Ch.CheckValue(input, nameof(input)); @@ -34,7 +34,7 @@ public BindingsWrappedRowCursor(IChannelProvider provider, IRowCursor input, Col _bindings = bindings; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col & col < _bindings.ColumnCount, "col"); bool isSrc; @@ -42,7 +42,7 @@ public bool IsColumnActive(int col) return isSrc && Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "col"); bool isSrc; diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index c33c2e55ff..327c029767 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -650,7 +650,7 @@ private void GetSlotNames(ref VBuffer> dst) bldr.GetResult(ref dst); } - public Delegate MakeGetter(IRow input) + public Delegate MakeGetter(Row input) { if (_isIdentity) return Utils.MarshalInvoke(MakeIdentityGetter, OutputType.RawType, input); @@ -658,13 +658,13 @@ public Delegate MakeGetter(IRow input) return Utils.MarshalInvoke(MakeGetter, OutputType.ItemType.RawType, input); } - private Delegate MakeIdentityGetter(IRow input) + private Delegate MakeIdentityGetter(Row input) { Contracts.Assert(SrcIndices.Length == 1); return input.GetGetter(SrcIndices[0]); } - private Delegate MakeGetter(IRow input) + private Delegate MakeGetter(Row input) { var srcGetterOnes = new ValueGetter[SrcIndices.Length]; var srcGetterVecs = new ValueGetter>[SrcIndices.Length]; @@ -847,7 +847,7 @@ public override Func GetDependencies(Func activeOutput) public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; return _columns[iinfo].MakeGetter(input); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 6ab8fe9b7b..4479856cc0 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -175,13 +175,13 @@ internal Mapper(ColumnCopyingTransformer parent, Schema inputSchema, (string Sou _columns = columns; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); disposer = null; - Delegate MakeGetter(IRow row, int index) + Delegate MakeGetter(Row row, int index) => input.GetGetter(index); input.Schema.TryGetColumnIndex(_columns[iinfo].Source, out int colIndex); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index afbbbe59b8..33d61e4560 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -579,32 +579,32 @@ private static Schema GenerateOutputSchema(IEnumerable map, } } - private sealed class Row : IRow + private sealed class RowImpl : Row { private readonly Mapper _mapper; - private readonly IRow _input; - public Row(IRow input, Mapper mapper) + private readonly Row _input; + public RowImpl(Row input, Mapper mapper) { _mapper = mapper; _input = input; } - public long Position => _input.Position; + public override long Position => _input.Position; - public long Batch => _input.Batch; + public override long Batch => _input.Batch; - Schema IRow.Schema => _mapper.OutputSchema; + public override Schema Schema => _mapper.OutputSchema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int index = _mapper.GetInputIndex(col); return _input.GetGetter(index); } - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() => _input.GetIdGetter(); - public bool IsColumnActive(int col) => true; + public override bool IsColumnActive(int col) => true; } private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate @@ -633,7 +633,7 @@ public SelectColumnsDataTransform(IHostEnvironment env, ColumnSelectingTransform public long? GetRowCount() => Source.GetRowCount(); - public IRowCursor GetRowCursor(Func needCol, Random rand = null) + public RowCursor GetRowCursor(Func needCol, Random rand = null) { _host.AssertValue(needCol, nameof(needCol)); _host.AssertValueOrNull(rand); @@ -644,10 +644,10 @@ public IRowCursor GetRowCursor(Func needCol, Random rand = null) // Build the active state for the output var active = Utils.BuildArray(_mapper.OutputSchema.ColumnCount, needCol); - return new RowCursor(_host, _mapper, inputRowCursor, active); + return new Cursor(_host, _mapper, inputRowCursor, active); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func needCol, int n, Random rand = null) { _host.CheckValue(needCol, nameof(needCol)); _host.CheckValueOrNull(rand); @@ -661,10 +661,10 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun _host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) { - cursors[i] = new RowCursor(_host, _mapper, inputs[i], active); + cursors[i] = new Cursor(_host, _mapper, inputs[i], active); } return cursors; } @@ -684,22 +684,22 @@ public Func GetDependencies(Func activeOutput) return col => active[col]; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { disposer = null; - return new Row(input, _mapper); + return new RowImpl(input, _mapper); } public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) => new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource); } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Mapper _mapper; - private readonly IRowCursor _inputCursor; + private readonly RowCursor _inputCursor; private readonly bool[] _active; - public RowCursor(IChannelProvider provider, Mapper mapper, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Mapper mapper, RowCursor input, bool[] active) : base(provider, input) { _mapper = mapper; @@ -707,15 +707,15 @@ public RowCursor(IChannelProvider provider, Mapper mapper, IRowCursor input, boo _active = active; } - public Schema Schema => _mapper.OutputSchema; + public override Schema Schema => _mapper.OutputSchema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { int index = _mapper.GetInputIndex(col); return _inputCursor.GetGetter(index); } - public bool IsColumnActive(int col) => _active[col]; + public override bool IsColumnActive(int col) => _active[col]; } } } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index d1f6c70bee..e8a75d4edd 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -692,7 +692,7 @@ private void CombineRanges( newRangeMax = maxRange2; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -711,7 +711,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return MakeVecGetter(input, iinfo); } - private Delegate MakeOneTrivialGetter(IRow input, int iinfo) + private Delegate MakeOneTrivialGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -734,7 +734,7 @@ private void OneTrivialGetter(ref TDst value) value = default(TDst); } - private Delegate MakeVecTrivialGetter(IRow input, int iinfo) + private Delegate MakeVecTrivialGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -757,19 +757,19 @@ private void VecTrivialGetter(ref VBuffer value) VBufferUtils.Resize(ref value, 1, 0); } - private Delegate MakeVecGetter(IRow input, int iinfo) + private Delegate MakeVecGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); Host.Assert(_srcTypes[iinfo].IsVector); Host.Assert(!_suppressed[iinfo]); - Func>> del = MakeVecGetter; + Func>> del = MakeVecGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_srcTypes[iinfo].ItemType.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { input, iinfo }); } - private ValueGetter> MakeVecGetter(IRow input, int iinfo) + private ValueGetter> MakeVecGetter(Row input, int iinfo) { var srcGetter = GetSrcGetter>(input, iinfo); var typeDst = _dstTypes[iinfo]; @@ -786,7 +786,7 @@ private ValueGetter> MakeVecGetter(IRow input, int iinfo) }; } - private ValueGetter GetSrcGetter(IRow input, int iinfo) + private ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -795,12 +795,12 @@ private ValueGetter GetSrcGetter(IRow input, int iinfo) return input.GetGetter(src); } - private Delegate GetSrcGetter(ColumnType typeDst, IRow row, int iinfo) + private Delegate GetSrcGetter(ColumnType typeDst, Row row, int iinfo) { Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; + Func> del = GetSrcGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); } diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index 72231d5eb9..0bf2687c39 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -332,7 +332,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -340,10 +340,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -351,7 +351,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); - IRowCursor input; + RowCursor input; if (n > 1 && ShouldUseParallelCursors(predicate) != false) { @@ -360,9 +360,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid if (inputs.Length != 1) { - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); + cursors[i] = new Cursor(Host, _bindings, inputs[i], active); return cursors; } input = inputs[0]; @@ -371,10 +371,10 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid input = Source.GetRowCursor(inputPred); consolidator = null; - return new IRowCursor[] { new RowCursor(Host, _bindings, input, active) }; + return new RowCursor[] { new Cursor(Host, _bindings, input, active) }; } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; @@ -383,7 +383,7 @@ private sealed class RowCursor : SynchronizedCursorBase, IRowCursor private readonly TauswortheHybrid[] _rngs; private readonly long[] _lastCounters; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool[] active) : base(provider, input) { Ch.CheckValue(bindings, nameof(bindings)); @@ -408,15 +408,15 @@ public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, } } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 60e799ddb6..1e9151f44a 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -273,7 +273,7 @@ internal HashingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] } if (Utils.Size(sourceColumnsForInvertHash) > 0) { - using (IRowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) + using (RowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) { using (var ch = Host.Start("Invert hash building")) { @@ -307,7 +307,7 @@ internal HashingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] } } - private Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) + private Delegate GetGetterCore(Row input, int iinfo, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); @@ -394,7 +394,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } #region Getters - private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter ComposeGetterOne(Row input, int iinfo, int srcCol, ColumnType srcType) { Host.Assert(HashingEstimator.IsColumnTypeValid(srcType)); @@ -452,7 +452,7 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, Co } } - private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter> ComposeGetterVec(Row input, int iinfo, int srcCol, ColumnType srcType) { Host.Assert(srcType.IsVector); Host.Assert(HashingEstimator.IsColumnTypeValid(srcType.ItemType)); @@ -505,7 +505,7 @@ private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int s } } - private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, int srcCol, ColumnType srcType) + private ValueGetter> ComposeGetterVecCore(Row input, int iinfo, int srcCol, ColumnType srcType) where THash : struct, IHasher { Host.Assert(srcType.IsVector); @@ -709,7 +709,7 @@ public uint HashCore(uint seed, uint mask, in long value) } } - private static ValueGetter MakeScalarHashGetter(IRow input, int srcCol, uint seed, uint mask) + private static ValueGetter MakeScalarHashGetter(Row input, int srcCol, uint seed, uint mask) where THash : struct, IHasher { Contracts.Assert(Utils.IsPowerOfTwo(mask + 1)); @@ -914,18 +914,18 @@ private void AddMetaKeyValues(int i, MetadataBuilder builder) builder.AddKeyValues(_parent._kvTypes[i].VectorSize, _parent._kvTypes[i].ItemType.AsPrimitive, getter); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); } private abstract class InvertHashHelper { - protected readonly IRow Row; + protected readonly Row Row; private readonly bool _includeSlot; private readonly ColumnInfo _ex; private readonly ColumnType _srcType; private readonly int _srcCol; - private InvertHashHelper(IRow row, ColumnInfo ex) + private InvertHashHelper(Row row, ColumnInfo ex) { Contracts.AssertValue(row); Row = row; @@ -946,13 +946,13 @@ private InvertHashHelper(IRow row, ColumnInfo ex) /// The extra column info /// The number of input hashed valuPres to accumulate per output hash value /// A hash getter, built on top of . - public static InvertHashHelper Create(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public static InvertHashHelper Create(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) { row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); ColumnType typeSrc = row.Schema.GetColumnType(srcCol); Type t = typeSrc.IsVector ? (ex.Ordered ? typeof(ImplVecOrdered<>) : typeof(ImplVec<>)) : typeof(ImplOne<>); t = t.MakeGenericType(typeSrc.ItemType.RawType); - var consTypes = new Type[] { typeof(IRow), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; + var consTypes = new Type[] { typeof(Row), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; var constructorInfo = t.GetConstructor(consTypes); return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, ex, invertHashMaxCount, dstGetter }); } @@ -1029,7 +1029,7 @@ private abstract class Impl : InvertHashHelper { protected readonly InvertHashCollector Collector; - protected Impl(IRow row, ColumnInfo ex, int invertHashMaxCount) + protected Impl(Row row, ColumnInfo ex, int invertHashMaxCount) : base(row, ex) { Contracts.AssertValue(row); @@ -1062,7 +1062,7 @@ private sealed class ImplOne : Impl private T _value; private uint _hash; - public ImplOne(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplOne(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter(_srcCol); @@ -1096,7 +1096,7 @@ private sealed class ImplVec : Impl private VBuffer _value; private VBuffer _hash; - public ImplVec(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplVec(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter>(_srcCol); @@ -1130,7 +1130,7 @@ private sealed class ImplVecOrdered : Impl> private VBuffer _value; private VBuffer _hash; - public ImplVecOrdered(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + public ImplVecOrdered(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) : base(row, ex, invertHashMaxCount) { _srcGetter = Row.GetGetter>(_srcCol); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index fd6af73412..0fe412ee84 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -211,7 +211,7 @@ public void SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _types.Length); @@ -294,7 +294,7 @@ protected KeyToValueMap(Mapper mapper, PrimitiveType typeVal, int iinfo) InfoIndex = iinfo; } - public abstract Delegate GetMappingGetter(IRow input); + public abstract Delegate GetMappingGetter(Row input); public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken); } @@ -346,7 +346,7 @@ private void MapKey(in TKey src, ReadOnlySpan values, ref TValue dst) dst = _na; } - public override Delegate GetMappingGetter(IRow input) + public override Delegate GetMappingGetter(Row input) { // When constructing the getter, there are a few cases we have to consider: // If scalar then it's just a straightforward mapping. diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index 363bc4c83f..98193ad24e 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -438,7 +438,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) dst = new VBuffer(ranges.Length, ranges); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); @@ -456,7 +456,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac /// This is for the singleton case. This should be equivalent to both Bag and Ord over /// a vector of size one. /// - private ValueGetter> MakeGetterOne(IRow input, int iinfo) + private ValueGetter> MakeGetterOne(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsKey); @@ -489,7 +489,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) /// /// This is for the bagging case - vector input and outputs should be added. /// - private ValueGetter> MakeGetterBag(IRow input, int iinfo) + private ValueGetter> MakeGetterBag(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); @@ -533,7 +533,7 @@ private ValueGetter> MakeGetterBag(IRow input, int iinfo) /// /// This is for the indicator (non-bagging) case - vector input and outputs should be concatenated. /// - private ValueGetter> MakeGetterInd(IRow input, int iinfo) + private ValueGetter> MakeGetterInd(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index f63b97333b..3832ab6f8a 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -165,7 +165,7 @@ private bool PassThrough(string kind, int iinfo) return kind != MetadataUtils.Kinds.KeyValues; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Contracts.AssertValueOrNull(ch); Contracts.AssertValue(input); @@ -190,34 +190,34 @@ protected override VectorType GetSlotTypeCore(int iinfo) return _slotType; } - protected override ISlotCursor GetSlotCursorCore(int iinfo) + protected override SlotCursor GetSlotCursorCore(int iinfo) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.AssertValue(Infos[iinfo].SlotTypeSrc); - ISlotCursor cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); - return new SlotCursor(Host, cursor, GetSlotTypeCore(iinfo)); + var cursor = InputTranspose.GetSlotCursor(Infos[iinfo].Source); + return new SlotCursorImpl(Host, cursor, GetSlotTypeCore(iinfo)); } - private sealed class SlotCursor : SynchronizedCursorBase, ISlotCursor + private sealed class SlotCursorImpl : SlotCursor.SynchronizedSlotCursor { private readonly Delegate _getter; private readonly VectorType _type; - public SlotCursor(IChannelProvider provider, ISlotCursor cursor, VectorType typeDst) + public SlotCursorImpl(IChannelProvider provider, SlotCursor cursor, VectorType typeDst) : base(provider, cursor) { Ch.AssertValue(typeDst); - _getter = RowCursorUtils.GetLabelGetter(Input); + _getter = RowCursorUtils.GetLabelGetter(cursor); _type = typeDst; } - public VectorType GetSlotType() + public override VectorType GetSlotType() { return _type; } - public ValueGetter> GetGetter() + public override ValueGetter> GetGetter() { ValueGetter> getter = _getter as ValueGetter>; if (getter == null) diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 3734245258..2c5acafd1d 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -163,7 +163,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return BoolType.Instance; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValue(ch); @@ -175,7 +175,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, return GetGetter(ch, input, iinfo); } - private ValueGetter GetGetter(IChannel ch, IRow input, int iinfo) + private ValueGetter GetGetter(IChannel ch, Row input, int iinfo) { Host.AssertValue(ch); ch.AssertValue(input); diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index 36c9f43368..1f7a192627 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -204,7 +204,7 @@ private static bool TestType(ColumnType type) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -212,10 +212,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random bool[] active; Func inputPred = GetActive(predicate, out active); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(this, input, active); + return new Cursor(this, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -227,9 +227,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(this, inputs[i], active); + cursors[i] = new Cursor(this, inputs[i], active); return cursors; } @@ -245,13 +245,13 @@ private Func GetActive(Func predicate, out bool[] active) return col => activeInput[col]; } - private sealed class RowCursor : LinkedRowFilterCursorBase + private sealed class Cursor : LinkedRowFilterCursorBase { private abstract class Value { - protected readonly RowCursor Cursor; + protected readonly Cursor Cursor; - protected Value(RowCursor cursor) + protected Value(Cursor cursor) { Contracts.AssertValue(cursor); Cursor = cursor; @@ -261,7 +261,7 @@ protected Value(RowCursor cursor) public abstract Delegate GetGetter(); - public static Value Create(RowCursor cursor, ColInfo info) + public static Value Create(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -269,18 +269,18 @@ public static Value Create(RowCursor cursor, ColInfo info) MethodInfo meth; if (info.Type is VectorType vecType) { - Func> d = CreateVec; + Func> d = CreateVec; meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vecType.ItemType.RawType); } else { - Func> d = CreateOne; + Func> d = CreateOne; meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType); } return (Value)meth.Invoke(null, new object[] { cursor, info }); } - private static ValueOne CreateOne(RowCursor cursor, ColInfo info) + private static ValueOne CreateOne(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -292,7 +292,7 @@ private static ValueOne CreateOne(RowCursor cursor, ColInfo info) return new ValueOne(cursor, getSrc, hasBad); } - private static ValueVec CreateVec(RowCursor cursor, ColInfo info) + private static ValueVec CreateVec(Cursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); @@ -310,7 +310,7 @@ private abstract class TypedValue : Value private readonly InPredicate _hasBad; public T Src; - protected TypedValue(RowCursor cursor, ValueGetter getSrc, InPredicate hasBad) + protected TypedValue(Cursor cursor, ValueGetter getSrc, InPredicate hasBad) : base(cursor) { Contracts.AssertValue(getSrc); @@ -330,7 +330,7 @@ private sealed class ValueOne : TypedValue { private readonly ValueGetter _getter; - public ValueOne(RowCursor cursor, ValueGetter getSrc, InPredicate hasBad) + public ValueOne(Cursor cursor, ValueGetter getSrc, InPredicate hasBad) : base(cursor, getSrc, hasBad) { _getter = GetValue; @@ -352,7 +352,7 @@ private sealed class ValueVec : TypedValue> { private readonly ValueGetter> _getter; - public ValueVec(RowCursor cursor, ValueGetter> getSrc, InPredicate> hasBad) + public ValueVec(Cursor cursor, ValueGetter> getSrc, InPredicate> hasBad) : base(cursor, getSrc, hasBad) { _getter = GetValue; @@ -374,7 +374,7 @@ public override Delegate GetGetter() private readonly NAFilter _parent; private readonly Value[] _values; - public RowCursor(NAFilter parent, IRowCursor input, bool[] active) + public Cursor(NAFilter parent, RowCursor input, bool[] active) : base(parent.Host, input, parent.OutputSchema, active) { _parent = parent; diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 0d0a619c1c..3d4219a60d 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -118,12 +118,12 @@ public bool CanShuffle return Source.GetRowCount(); } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { return Source.GetRowCursor(predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { return Source.GetRowCursorSet(out consolidator, predicate, n, rand); } @@ -133,7 +133,7 @@ public Func GetDependencies(Func predicate) return predicate; } - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 7c3bdd363c..a57f6270a1 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -361,7 +361,7 @@ private AffineColumnFunction(IHost host) public bool CanSaveOnnx(OnnxContext ctx) => true; public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); @@ -480,7 +480,7 @@ private CdfColumnFunction(IHost host) public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); public abstract NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams(); @@ -609,7 +609,7 @@ protected BinColumnFunction(IHost host) public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); - public abstract Delegate GetGetter(IRow input, int icol); + public abstract Delegate GetGetter(Row input, int icol); public void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc) { @@ -732,7 +732,7 @@ private abstract class SupervisedBinFunctionBuilderBase : IColumnFunctionBuilder protected readonly int LabelCardinality; private readonly ValueGetter _labelGetterSrc; - protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, IRow dataRow) + protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, Row dataRow) { Contracts.CheckValue(host, nameof(host)); Host = host; @@ -742,7 +742,7 @@ protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, _labelGetterSrc = GetLabelGetter(dataRow, labelColId, out LabelCardinality); } - private ValueGetter GetLabelGetter(IRow row, int col, out int labelCardinality) + private ValueGetter GetLabelGetter(Row row, int col, out int labelCardinality) { // The label column type is checked as part of args validation. var type = row.Schema.GetColumnType(col); @@ -816,7 +816,7 @@ private abstract class OneColumnSupervisedBinFunctionBuilderBase : Super protected readonly List ColValues; protected OneColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, - IRow dataRow) + Row dataRow) : base(host, lim, labelColId, dataRow) { _colGetterSrc = dataRow.GetGetter(valueColId); @@ -844,7 +844,7 @@ private abstract class VecColumnSupervisedBinFunctionBuilderBase : Super protected readonly List[] ColValues; protected readonly int ColumnSlotCount; - protected VecColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, IRow dataRow) + protected VecColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, Row dataRow) : base(host, lim, labelColId, dataRow) { _colValueGetter = dataRow.GetGetter>(valueColId); @@ -899,7 +899,7 @@ protected override bool AcceptColumnValue() internal static partial class MinMaxUtils { public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -912,7 +912,7 @@ public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost h } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MinMaxColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { if (srcType.IsNumber) { @@ -935,7 +935,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MinMaxCo internal static partial class MeanVarUtils { public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -949,7 +949,7 @@ public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MeanVarColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); @@ -975,7 +975,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MeanVarC internal static partial class LogMeanVarUtils { public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -988,7 +988,7 @@ public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHo } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanVarColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(column); @@ -1014,7 +1014,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanV internal static partial class BinUtils { public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); @@ -1028,7 +1028,7 @@ public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host } public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningColumn column, IHost host, - int srcIndex, ColumnType srcType, IRowCursor cursor) + int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); @@ -1053,7 +1053,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningC internal static class SupervisedBinUtils { public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, IHost host, - int icol, int srcIndex, ColumnType srcType, IRowCursor cursor) + int icol, int srcIndex, ColumnType srcType, RowCursor cursor) { Contracts.AssertValue(host); host.AssertValue(args); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 5792486012..42a03e3eb6 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -583,7 +583,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int fe return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -659,7 +659,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -901,7 +901,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, new[] { Mean }, new[] { Stddev }); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { if (Stddev <= TFloat.Epsilon) { @@ -956,7 +956,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, Mean, Stddev); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -1085,7 +1085,7 @@ public override void Save(ModelSaveContext ctx) c => BinNormSerializationUtils.SaveModel(c, new[] { _binUpperBounds }, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -1170,7 +1170,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveSubModel("BinNormalizer", c => BinNormSerializationUtils.SaveModel(c, _binUpperBounds, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R8Adder.Instance); @@ -1842,7 +1842,7 @@ public sealed class SupervisedBinOneColumnFunctionBuilder : OneColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1862,7 +1862,7 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, Row dataRow) { var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); @@ -1880,7 +1880,7 @@ public sealed class SupervisedBinVecColumnFunctionBuilder : VecColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1902,7 +1902,7 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, Row dataRow) { var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 54e6693f76..d310dbb2e4 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -585,7 +585,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -660,7 +660,7 @@ public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount) return true; } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -905,7 +905,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, new[] { Mean }, new[] { Stddev }); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { if (Stddev <= TFloat.Epsilon) { @@ -960,7 +960,7 @@ public override void Save(ModelSaveContext ctx) CdfNormSerializationUtils.SaveModel(ctx, UseLog, Mean, Stddev); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -1090,7 +1090,7 @@ public override void Save(ModelSaveContext ctx) c => BinNormSerializationUtils.SaveModel(c, new[] { _binUpperBounds }, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter(icol); ValueGetter del = @@ -1175,7 +1175,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveSubModel("BinNormalizer", c => BinNormSerializationUtils.SaveModel(c, _binUpperBounds, saveText: true)); } - public override Delegate GetGetter(IRow input, int icol) + public override Delegate GetGetter(Row input, int icol) { var getSrc = input.GetGetter>(icol); var bldr = new BufferBuilder(R4Adder.Instance); @@ -1850,7 +1850,7 @@ public sealed class SupervisedBinOneColumnFunctionBuilder : OneColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1870,7 +1870,7 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, Row dataRow) { var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); @@ -1888,7 +1888,7 @@ public sealed class SupervisedBinVecColumnFunctionBuilder : VecColumnSupervisedB private readonly int _numBins; private readonly int _minBinSize; - private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, IRow dataRow) + private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, Row dataRow) : base(host, lim, valueColumnId, labelColumnId, dataRow) { _fix = fix; @@ -1910,7 +1910,7 @@ public override IColumnFunction CreateColumnFunction() return BinColumnFunction.Create(Host, binUpperBounds, _fix); } - public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, IRow dataRow) + public static IColumnFunctionBuilder Create(SupervisedBinArguments args, IHost host, int argsColumnIndex, int valueColumnId, int labelColumnId, Row dataRow) { var lim = args.Column[argsColumnIndex].MaxTrainingExamples ?? args.MaxTrainingExamples; host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1"); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs index cc3f4e352f..e84d4eb8a4 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs @@ -55,7 +55,7 @@ public interface IColumnAggregator internal interface IColumnFunction : ICanSaveModel { - Delegate GetGetter(IRow input, int icol); + Delegate GetGetter(Row input, int icol); void AttachMetadata(MetadataDispatcher.Builder bldr, ColumnType typeSrc); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 1434b680a7..222fea019c 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -74,7 +74,7 @@ private protected ColumnBase(string input, string output, long maxTrainingExampl MaxTrainingExamples = maxTrainingExamples; } - internal abstract IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor); + internal abstract IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor); internal static ColumnBase Create(string input, string output, NormalizerMode mode) { @@ -112,7 +112,7 @@ public MinMaxColumn(string input, string output = null, long maxTrainingExamples { } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.MinMaxUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -127,7 +127,7 @@ public MeanVarColumn(string input, string output = null, UseCdf = useCdf; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.MeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -142,7 +142,7 @@ public LogMeanVarColumn(string input, string output = null, UseCdf = useCdf; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.LogMeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -157,7 +157,7 @@ public BinningColumn(string input, string output = null, NumBins = numBins; } - internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, IRowCursor cursor) + internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor) => NormalizeTransform.BinUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } @@ -516,7 +516,7 @@ private void IsNormalizedGetter(ref bool dst) dst = true; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; return _parent.Columns[iinfo].ColumnFunction.GetGetter(input, ColMapNewToOld[iinfo]); diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs index 38eea76459..625bd22e9e 100644 --- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs @@ -152,15 +152,15 @@ public virtual void Save(ModelSaveContext ctx) return Source.GetRowCount(); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -178,14 +178,14 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) return GetRowCursorCore(predicate); } - private IRowCursor GetRowCursorCore(Func predicate) + private RowCursor GetRowCursorCore(Func predicate) { var bindings = GetBindings(); var active = bindings.GetActive(predicate); Contracts.Assert(active.Length == bindings.ColumnCount); var predInput = bindings.GetDependencies(predicate); - return new RowCursor(this, Source.GetRowCursor(predInput, null), Source.GetRowCursor(predInput, null), active); + return new Cursor(this, Source.GetRowCursor(predInput, null), Source.GetRowCursor(predInput, null), active); } /// @@ -199,17 +199,17 @@ private IRowCursor GetRowCursorCore(Func predicate) /// /// Get the getter for the first input column. /// - protected abstract ValueGetter GetLabelGetter(IRow row); + protected abstract ValueGetter GetLabelGetter(Row row); /// /// Get the getter for the second input column. /// - protected abstract ValueGetter GetScoreGetter(IRow row); + protected abstract ValueGetter GetScoreGetter(Row row); /// /// Return a new state object. /// - protected abstract TState InitializeState(IRow input); + protected abstract TState InitializeState(Row input); /// /// Update the state object with one example. @@ -222,11 +222,11 @@ private IRowCursor GetRowCursorCore(Func predicate) /// protected abstract void UpdateState(TState state); - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly PerGroupTransformBase _parent; - private readonly IRowCursor _groupCursor; - private readonly IRowCursor _input; + private readonly RowCursor _groupCursor; + private readonly RowCursor _input; private readonly bool[] _active; private readonly Delegate[] _getters; @@ -237,11 +237,11 @@ private sealed class RowCursor : RootCursorBase, IRowCursor private readonly ValueGetter _labelGetter; private readonly ValueGetter _scoreGetter; - public Schema Schema => _parent.OutputSchema; + public override Schema Schema => _parent.OutputSchema; - public override long Batch { get { return 0; } } + public override long Batch => 0; - public RowCursor(PerGroupTransformBase parent, IRowCursor input, IRowCursor groupCursor, bool[] active) + public Cursor(PerGroupTransformBase parent, RowCursor input, RowCursor groupCursor, bool[] active) : base(parent.Host) { Ch.AssertValue(parent); @@ -267,13 +267,13 @@ public RowCursor(PerGroupTransformBase parent, IRowCurso _scoreGetter = _parent.GetScoreGetter(_groupCursor); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _parent.GetBindings().ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(IsColumnActive(col), nameof(col), "requested column is not active"); @@ -324,8 +324,8 @@ protected override bool MoveNextCore() // Read the whole group from the auxiliary cursor. while (_groupCursor.State != CursorState.Done && !_newGroupInGroupCursorDel()) { - TLabel label = default(TLabel); - TScore score = default(TScore); + TLabel label = default; + TScore score = default; _labelGetter(ref label); _scoreGetter(ref score); _parent.ProcessExample(_state, label, score); diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index f02a05dfd9..88413f9ef2 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -204,7 +204,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -215,7 +215,7 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random return CreateCursorCore(input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -227,13 +227,13 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid Host.AssertNonEmpty(inputs); // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) cursors[i] = CreateCursorCore(inputs[i], active); return cursors; } - private IRowCursor CreateCursorCore(IRowCursor input, bool[] active) + private RowCursor CreateCursorCore(RowCursor input, bool[] active) { if (_type == NumberType.R4) return new SingleRowCursor(this, input, active); @@ -268,7 +268,7 @@ private abstract class RowCursorBase : LinkedRowFilterCursorBase private readonly Double _min; private readonly Double _max; - protected RowCursorBase(RangeFilter parent, IRowCursor input, bool[] active) + protected RowCursorBase(RangeFilter parent, RowCursor input, bool[] active) : base(parent.Host, input, parent.OutputSchema, active) { Parent = parent; @@ -319,15 +319,15 @@ public override ValueGetter GetGetter(int col) return fn; } - public static IRowCursor CreateKeyRowCursor(RangeFilter filter, IRowCursor input, bool[] active) + public static RowCursor CreateKeyRowCursor(RangeFilter filter, RowCursor input, bool[] active) { Contracts.Assert(filter._type.IsKey); - Func del = CreateKeyRowCursor; + Func del = CreateKeyRowCursor; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(filter._type.RawType); - return (IRowCursor)methodInfo.Invoke(null, new object[] { filter, input, active }); + return (RowCursor)methodInfo.Invoke(null, new object[] { filter, input, active }); } - private static IRowCursor CreateKeyRowCursor(RangeFilter filter, IRowCursor input, bool[] active) + private static RowCursor CreateKeyRowCursor(RangeFilter filter, RowCursor input, bool[] active) { Contracts.Assert(filter._type.IsKey); return new KeyRowCursor(filter, input, active); @@ -340,7 +340,7 @@ private sealed class SingleRowCursor : RowCursorBase private readonly ValueGetter _getter; private Single _value; - public SingleRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public SingleRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type == NumberType.R4); @@ -373,7 +373,7 @@ private sealed class DoubleRowCursor : RowCursorBase private readonly ValueGetter _getter; private Double _value; - public DoubleRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public DoubleRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type == NumberType.R8); @@ -408,7 +408,7 @@ private sealed class KeyRowCursor : RowCursorBase private readonly ValueMapper _conv; private readonly int _count; - public KeyRowCursor(RangeFilter parent, IRowCursor input, bool[] active) + public KeyRowCursor(RangeFilter parent, RowCursor input, bool[] active) : base(parent, input, active) { Ch.Assert(Parent._type.KeyCount > 0); diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 121d9c43fb..76da06d846 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -223,7 +223,7 @@ internal static bool CanShuffleAll(ISchema schema) /// /// Utility to take a cursor, and get a shuffled version of this cursor. /// - public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRows, IRowCursor cursor, Random rand) + public static RowCursor GetShuffledCursor(IChannelProvider provider, int poolRows, RowCursor cursor, Random rand) { Contracts.CheckValue(provider, nameof(provider)); @@ -236,7 +236,7 @@ public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRo if (poolRows == 1) return cursor; - return new RowCursor(provider, poolRows, cursor, rand); + return new Cursor(provider, poolRows, cursor, rand); } public override bool CanShuffle { get { return true; } } @@ -249,7 +249,7 @@ public static IRowCursor GetShuffledCursor(IChannelProvider provider, int poolRo return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -286,16 +286,16 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random // source cursor. if (rand == null || _poolRows == 1) return input; - return new RowCursor(Host, _poolRows, input, rand); + return new Cursor(Host, _poolRows, input, rand); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } /// @@ -344,7 +344,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid /// The result is something functionally equivalent to but but considerably faster than the /// simple implementation described in the first paragraph. /// - private sealed class RowCursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { /// /// Pipes, in addition to column values, will also communicate extra information @@ -465,7 +465,7 @@ public void Fetch(int idx, ref T value) private const int _bufferDepth = 3; private readonly int _poolRows; - private readonly IRowCursor _input; + private readonly RowCursor _input; private readonly Random _rand; // This acts as mapping from the "circular" index to the actual index within the pipe. @@ -496,15 +496,12 @@ public void Fetch(int idx, ref T value) private readonly int[] _colToActivesIndex; - public Schema Schema { get { return _input.Schema; } } + public override Schema Schema => _input.Schema; - public override long Batch - { - // REVIEW: Implement cursor set support. - get { return 0; } - } + // REVIEW: Implement cursor set support. + public override long Batch => 0; - public RowCursor(IChannelProvider provider, int poolRows, IRowCursor input, Random rand) + public Cursor(IChannelProvider provider, int poolRows, RowCursor input, Random rand) : base(provider) { Ch.AssertValue(input); @@ -669,7 +666,7 @@ protected override bool MoveNextCore() return true; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.Assert((_colToActivesIndex[col] >= 0) == _input.IsColumnActive(col)); @@ -706,7 +703,7 @@ private ValueGetter CreateGetterDelegate(ShufflePipe pipe) return getter; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); Ch.CheckParam(_colToActivesIndex[col] >= 0, nameof(col), "requested column not active"); diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index cce93e439b..356a0cc546 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -69,7 +69,7 @@ protected MapperBase(IHost host, Schema inputSchema) public Schema.DetachedColumn[] GetOutputColumns() => _outputColumns.Value; - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { // REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema // of the input row. @@ -98,7 +98,7 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac return result; } - protected abstract Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer); + protected abstract Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer); public abstract Func GetDependencies(Func activeOutput); diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 7f0399fc6d..12a8cf2138 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -188,26 +188,26 @@ public override void Save(ModelSaveContext ctx) return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate); Host.AssertValueOrNull(rand); var input = Source.GetRowCursor(predicate); var activeColumns = Utils.BuildArray(OutputSchema.ColumnCount, predicate); - return new RowCursor(Host, input, OutputSchema, activeColumns, _skip, _take); + return new Cursor(Host, input, OutputSchema, activeColumns, _skip, _take); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate) }; + return new RowCursor[] { GetRowCursorCore(predicate) }; } - private sealed class RowCursor : LinkedRowRootCursorBase + private sealed class Cursor : LinkedRowRootCursorBase { private readonly long _skip; private readonly long _take; @@ -219,7 +219,7 @@ public override long Batch { get { return 0; } } - public RowCursor(IChannelProvider provider, IRowCursor input, Schema schema, bool[] active, long skip, long take) + public Cursor(IChannelProvider provider, RowCursor input, Schema schema, bool[] active, long skip, long take) : base(provider, input, schema, active) { Ch.Assert(skip >= 0); diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index cbf60b6967..4618b756f7 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -60,7 +60,7 @@ protected TransformBase(IHost host, IDataView input) public abstract Schema OutputSchema { get; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -72,7 +72,7 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) // When the input wants to be split, this puts the consolidation after this transform // instead of before. This is likely to produce better performance, for example, when // this is RangeFilter. - IRowCursor curs; + RowCursor curs; if (useParallel != false && DataViewUtils.TryCreateConsolidatingCursor(out curs, this, predicate, Host, rng)) { @@ -93,9 +93,9 @@ public IRowCursor GetRowCursor(Func predicate, Random rand = null) /// /// Create a single (non-parallel) row cursor. /// - protected abstract IRowCursor GetRowCursorCore(Func predicate, Random rand = null); + protected abstract RowCursor GetRowCursorCore(Func predicate, Random rand = null); - public abstract IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public abstract RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null); } @@ -168,7 +168,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => Source.Schema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); @@ -180,29 +180,29 @@ public IRow GetRow(IRow input, Func active, out Action disposer) Action disp; var getters = CreateGetters(input, active, out disp); disposer += disp; - return new Row(input, this, OutputSchema, getters); + return new RowImpl(input, this, OutputSchema, getters); } } - protected abstract Delegate[] CreateGetters(IRow input, Func active, out Action disp); + protected abstract Delegate[] CreateGetters(Row input, Func active, out Action disp); protected abstract int MapColumnIndex(out bool isSrc, int col); - private sealed class Row : IRow + private sealed class RowImpl : Row { private readonly Schema _schema; - private readonly IRow _input; + private readonly Row _input; private readonly Delegate[] _getters; private readonly RowToRowMapperTransformBase _parent; - public long Batch { get { return _input.Batch; } } + public override long Batch => _input.Batch; - public long Position { get { return _input.Position; } } + public override long Position => _input.Position; - public Schema Schema { get { return _schema; } } + public override Schema Schema => _schema; - public Row(IRow input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters) { _input = input; _parent = parent; @@ -210,7 +210,7 @@ public Row(IRow input, RowToRowMapperTransformBase parent, Schema schema, Delega _getters = getters; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); @@ -224,12 +224,12 @@ public ValueGetter GetGetter(int col) return fn; } - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return _input.GetIdGetter(); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); @@ -683,9 +683,9 @@ protected virtual void ActivateSourceColumns(int iinfo, bool[] active) /// otherwise it should be set to a delegate to be invoked by the cursor's Dispose method. It's best /// for this action to be idempotent - calling it multiple times should be equivalent to calling it once. /// - protected abstract Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer); + protected abstract Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer); - protected ValueGetter GetSrcGetter(IRow input, int iinfo) + protected ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -694,12 +694,12 @@ protected ValueGetter GetSrcGetter(IRow input, int iinfo) return input.GetGetter(src); } - protected Delegate GetSrcGetter(ColumnType typeDst, IRow row, int iinfo) + protected Delegate GetSrcGetter(ColumnType typeDst, Row row, int iinfo) { Host.CheckValue(typeDst, nameof(typeDst)); Host.CheckValue(row, nameof(row)); - Func> del = GetSrcGetter; + Func> del = GetSrcGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeDst.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { row, iinfo }); } @@ -727,7 +727,7 @@ protected virtual bool WantParallelCursors(Func predicate) return _bindings.AnyNewColumnsActive(predicate); } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -735,10 +735,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, this, input, active); + return new Cursor(Host, this, input, active); } - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public sealed override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -753,9 +753,9 @@ public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator c inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, this, inputs[i], active); + cursors[i] = new Cursor(Host, this, inputs[i], active); return cursors; } @@ -770,7 +770,7 @@ protected Exception ExceptGetSlotCursor(int col) OutputSchema[col].Name); } - public ISlotCursor GetSlotCursor(int col) + public SlotCursor GetSlotCursor(int col) { Host.CheckParam(0 <= col && col < _bindings.ColumnCount, nameof(col)); @@ -795,7 +795,7 @@ public ISlotCursor GetSlotCursor(int col) /// null for all new columns, and so reaching this is only possible if there is a /// bug. /// - protected virtual ISlotCursor GetSlotCursorCore(int iinfo) + protected virtual SlotCursor GetSlotCursorCore(int iinfo) { Host.Assert(false); throw Host.ExceptNotImpl("Data view indicated it could transpose a column, but apparently it could not"); @@ -811,7 +811,7 @@ protected override Func GetDependenciesCore(Func predicate return _bindings.GetDependencies(predicate); } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disposer) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disposer) { Func activeInfos = iinfo => @@ -836,7 +836,7 @@ protected override Delegate[] CreateGetters(IRow input, Func active, } } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; @@ -844,7 +844,7 @@ private sealed class RowCursor : SynchronizedCursorBase, IRowCursor private readonly Delegate[] _getters; private readonly Action[] _disposers; - public RowCursor(IChannelProvider provider, OneToOneTransformBase parent, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, OneToOneTransformBase parent, RowCursor input, bool[] active) : base(provider, input) { Ch.AssertValue(parent); @@ -880,9 +880,9 @@ public override void Dispose() base.Dispose(); } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -898,7 +898,7 @@ public ValueGetter GetGetter(int col) return fn; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 331e5fff71..7c2d42168b 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -468,7 +468,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 15b10c3e74..81d8e927b4 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -768,7 +768,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -777,7 +777,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return Utils.MarshalInvoke(MakeGetter, type.RawType, input, iinfo); } - private Delegate MakeGetter(IRow row, int src) => _termMap[src].GetMappingGetter(row); + private Delegate MakeGetter(Row row, int src) => _termMap[src].GetMappingGetter(row); private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index c136bda9c6..87c112b87d 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -280,7 +280,7 @@ private Trainer(Builder bldr, int max) /// the input type to the desired type /// The builder we add items to /// An associated training pipe - public static Trainer Create(IRow row, int col, bool autoConvert, int count, Builder bldr) + public static Trainer Create(Row row, int col, bool autoConvert, int count, Builder bldr) { Contracts.AssertValue(row); var schema = row.Schema; @@ -297,7 +297,7 @@ public static Trainer Create(IRow row, int col, bool autoConvert, int count, Bui return Utils.MarshalInvoke(CreateOne, bldr.ItemType.RawType, row, col, autoConvert, count, bldr); } - private static Trainer CreateOne(IRow row, int col, bool autoConvert, int count, Builder bldr) + private static Trainer CreateOne(Row row, int col, bool autoConvert, int count, Builder bldr) { Contracts.AssertValue(row); Contracts.AssertValue(bldr); @@ -313,7 +313,7 @@ private static Trainer CreateOne(IRow row, int col, bool autoConvert, int cou return new ImplOne(inputGetter, count, bldrT); } - private static Trainer CreateVec(IRow row, int col, int count, Builder bldr) + private static Trainer CreateVec(Row row, int col, int count, Builder bldr) { Contracts.AssertValue(row); Contracts.AssertValue(bldr); @@ -849,7 +849,7 @@ public static BoundTermMap CreateCore(IHostEnvironment env, Schema schema, Te return new Impl(env, schema, mapT, infos, textMetadata, iinfo); } - public abstract Delegate GetMappingGetter(IRow row); + public abstract Delegate GetMappingGetter(Row row); /// /// Allows us to optionally register metadata. It is also perfectly legal for @@ -890,7 +890,7 @@ private static uint MapDefault(ValueMapper map) return dst; } - public override Delegate GetMappingGetter(IRow input) + public override Delegate GetMappingGetter(Row input) { // When constructing the getter, there are a few cases we have to consider: // If scalar then it's just a straightforward mapping. diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 067a4c37bb..dcce4f0018 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -103,12 +103,12 @@ public Func GetDependencies(Func predicate) yield break; } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { return new SimpleRow(OutputSchema, input, new[] { CreateScoreGetter(input, predicate, out disposer) }); } - public abstract Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer); + public abstract Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer); } // A generic base class for pipeline ensembles. This class contains the combiner. @@ -124,7 +124,7 @@ public Bound(SchemaBindablePipelineEnsemble parent, RoleMappedSchema schema) _combiner = parent.Combiner; } - public override Delegate CreateScoreGetter(IRow input, Func mapperPredicate, out Action disposer) + public override Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer) { disposer = null; @@ -158,7 +158,7 @@ public override Delegate CreateScoreGetter(IRow input, Func mapperPre return scoreGetter; } - public ValueGetter GetLabelGetter(IRow input, int i, out Action disposer) + public ValueGetter GetLabelGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label != null, "Mapper was not trained using a label column"); @@ -168,7 +168,7 @@ public ValueGetter GetLabelGetter(IRow input, int i, out Action disposer return RowCursorUtils.GetLabelGetter(pipelineRow, Mappers[i].InputRoleMappedSchema.Label.Index); } - public ValueGetter GetWeightGetter(IRow input, int i, out Action disposer) + public ValueGetter GetWeightGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 73d032dce5..997d14b281 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1441,7 +1441,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB pch.SetHeader(new ProgressHeader("features"), e => e.SetProgress(0, iFeature, features.Length)); while (cursor.MoveNext()) { - iFeature = checked((int)cursor.Position); + iFeature = cursor.SlotIndex; if (!localConstructBinFeatures[iFeature]) continue; @@ -1670,19 +1670,19 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB return result; } - private void GetFeatureValues(ISlotCursor cursor, int iFeature, ValueGetter> getter, + private void GetFeatureValues(SlotCursor cursor, int iFeature, ValueGetter> getter, ref VBuffer temp, ref VBuffer doubleTemp, ValueMapper, VBuffer> copier) { while (cursor.MoveNext()) { - Contracts.Assert(iFeature >= checked((int)cursor.Position)); + Contracts.Assert(iFeature >= cursor.SlotIndex); - if (iFeature == checked((int)cursor.Position)) + if (iFeature == cursor.SlotIndex) break; } - Contracts.Assert(cursor.Position == iFeature); + Contracts.Assert(cursor.SlotIndex == iFeature); getter(ref temp); copier(in temp, ref doubleTemp); @@ -1700,13 +1700,13 @@ private static ValueGetter> SubsetGetter(ValueGetter> g /// Returns a slot dropper object that has ranges of slots to be dropped, /// based on an examination of the feature values. /// - private static SlotDropper ConstructDropSlotRanges(ISlotCursor cursor, + private static SlotDropper ConstructDropSlotRanges(SlotCursor cursor, ValueGetter> getter, ref VBuffer temp) { // The iteration here is slightly differently from a usual cursor iteration. Here, temp // already holds the value of the cursor's current position, and we don't really want // to re-fetch it, and the cursor is necessarily advanced. - Contracts.Assert(cursor.State == CursorState.Good); + Contracts.Assert(cursor.SlotIndex >= 0); BitArray rowHasMissing = new BitArray(temp.Length); for (; ; ) { @@ -3302,7 +3302,7 @@ public int GetLeaf(int treeId, in VBuffer features, ref List path) return TrainedEnsemble.GetTreeAt(treeId).GetLeaf(in features, ref path); } - public IRow GetSummaryIRowOrNull(RoleMappedSchema schema) + public Row GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names); @@ -3317,7 +3317,7 @@ public IRow GetSummaryIRowOrNull(RoleMappedSchema schema) return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public IRow GetStatsIRowOrNull(RoleMappedSchema schema) + public Row GetStatsIRowOrNull(RoleMappedSchema schema) { return null; } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 5f1e46b6dc..1e9cda953f 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -208,7 +208,7 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { _ectx.CheckValue(input, nameof(input)); _ectx.CheckValue(predicate, nameof(predicate)); @@ -216,7 +216,7 @@ public IRow GetRow(IRow input, Func predicate, out Action disposer) return new SimpleRow(OutputSchema, input, CreateGetters(input, predicate)); } - private Delegate[] CreateGetters(IRow input, Func predicate) + private Delegate[] CreateGetters(Row input, Func predicate) { _ectx.AssertValue(input); _ectx.AssertValue(predicate); @@ -259,7 +259,7 @@ private Delegate[] CreateGetters(IRow input, Func predicate) private sealed class State { private readonly IExceptionContext _ectx; - private readonly IRow _input; + private readonly Row _input; private readonly FastTreePredictionWrapper _ensemble; private readonly int _numTrees; private readonly int _numLeaves; @@ -276,7 +276,7 @@ private sealed class State private long _cachedLeafBuilderPosition; private long _cachedPathBuilderPosition; - public State(IExceptionContext ectx, IRow input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex) + public State(IExceptionContext ectx, Row input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex) { Contracts.AssertValue(ectx); _ectx = ectx; diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 9c9cd820b4..0964154eb6 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -689,7 +689,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -714,7 +714,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return del; } - private ValueGetter GetSrcGetter(IRow input, int iinfo) + private ValueGetter GetSrcGetter(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 7b89bdb173..eb2f444a92 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -173,7 +173,7 @@ public Mapper(ImageGrayscaleTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent.ColumnPairs.Select((x, idx) => new Schema.DetachedColumn(x.output, InputSchema[ColMapNewToOld[idx]].Type, null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 81b3d339e5..f996560011 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -161,7 +161,7 @@ public Mapper(ImageLoaderTransform parent, Schema inputSchema) _parent = parent; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index ceecacb0f9..8b14382bcb 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -429,7 +429,7 @@ public Mapper(ImagePixelExtractorTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent._columns.Select((x, idx) => new Schema.DetachedColumn(x.Output, _types[idx], null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); @@ -440,7 +440,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac } //REVIEW Rewrite it to where TValue : IConvertible - private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) + private ValueGetter> GetGetterCore(Row input, int iinfo, out Action disposer) where TValue : struct { var type = _types[iinfo]; diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index e5054b515f..8241b2ff12 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -306,7 +306,7 @@ public Mapper(ImageResizerTransform parent, Schema inputSchema) protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _parent._columns.Select(x => new Schema.DetachedColumn(x.Output, x.Type, null)).ToArray(); - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index 54b9148e86..ffaa91aa51 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -330,7 +330,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _types[iinfo]; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); @@ -351,7 +351,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou } - private ValueGetter GetterFromType(IRow input, int iinfo, ColInfoEx ex, bool needScale) where TValue : IConvertible + private ValueGetter GetterFromType(Row input, int iinfo, ColInfoEx ex, bool needScale) where TValue : IConvertible { var getSrc = GetSrcGetter>(input, iinfo); var src = default(VBuffer); diff --git a/src/Microsoft.ML.Legacy/Models/ConfusionMatrix.cs b/src/Microsoft.ML.Legacy/Models/ConfusionMatrix.cs index 2fee41ee3a..99df7bd672 100644 --- a/src/Microsoft.ML.Legacy/Models/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Legacy/Models/ConfusionMatrix.cs @@ -52,7 +52,7 @@ internal static List Create(IHostEnvironment env, IDataView con throw env.Except($"ConfusionMatrix data view did not contain a {nameof(MetricKinds.ColumnNames.Count)} column."); } - IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn); + RowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn); var slots = default(VBuffer>); confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots); var slotsValues = slots.GetValues(); diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index e414261336..4390ad31ab 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -353,7 +353,7 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed } } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; Host.AssertValue(input); @@ -367,7 +367,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache); } - private Delegate MakeGetter(IRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache) + private Delegate MakeGetter(Row input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache) { Host.AssertValue(input); ValueGetter> valuegetter = (ref VBuffer dst) => @@ -384,7 +384,7 @@ private Delegate MakeGetter(IRow input, int iinfo, INamedOnnxValueGetter[] sr return valuegetter; } - private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(IRow input, + private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(Row input, string[] inputColNames, int[] inputColIndices, bool[] isInputVector, @@ -400,14 +400,14 @@ private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(IRow input, return srcNamedOnnxValueGetters; } - private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(IRow input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape) + private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(Row input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape) { var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType; Contracts.AssertValue(type); return Utils.MarshalInvoke(CreateNameOnnxValueGetter, type, input, isVector, colName, colIndex, onnxShape); } - private static INamedOnnxValueGetter CreateNameOnnxValueGetter(IRow input, bool isVector, string colName, int colIndex, OnnxShape onnxShape) + private static INamedOnnxValueGetter CreateNameOnnxValueGetter(Row input, bool isVector, string colName, int colIndex, OnnxShape onnxShape) { if (isVector) return new NamedOnnxValueGetterVec(input, colName, colIndex, onnxShape); @@ -419,7 +419,7 @@ private class NameOnnxValueGetter : INamedOnnxValueGetter private readonly ValueGetter _srcgetter; private readonly string _colName; - public NameOnnxValueGetter(IRow input, string colName, int colIndex) + public NameOnnxValueGetter(Row input, string colName, int colIndex) { _colName = colName; _srcgetter = input.GetGetter(colIndex); @@ -439,7 +439,7 @@ private class NamedOnnxValueGetterVec : INamedOnnxValueGetter private readonly string _colName; private VBuffer _vBuffer; private VBuffer _vBufferDense; - public NamedOnnxValueGetterVec(IRow input, string colName, int colIndex, OnnxShape tensorShape) + public NamedOnnxValueGetterVec(Row input, string colName, int colIndex, OnnxShape tensorShape) { _srcgetter = input.GetGetter>(colIndex); _tensorShape = tensorShape; diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs index 1a6a5c85f2..385b5d57a8 100644 --- a/src/Microsoft.ML.PCA/PcaTransform.cs +++ b/src/Microsoft.ML.PCA/PcaTransform.cs @@ -609,7 +609,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _numColumns); diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 5627f9c0a1..e5fb2609d5 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -391,19 +391,19 @@ private static Stream OpenStream(string filename) return _rowCount; } - public IRowCursor GetRowCursor(Func predicate, Random rand = null) + public RowCursor GetRowCursor(Func predicate, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); return new Cursor(this, predicate, rand); } - public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursor(predicate, rand) }; + return new RowCursor[] { GetRowCursor(predicate, rand) }; } public void Save(ModelSaveContext ctx) @@ -433,7 +433,7 @@ public void Save(ModelSaveContext ctx) } } - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { private readonly ParquetLoader _loader; private readonly Stream _fileStream; @@ -586,11 +586,11 @@ protected override bool MoveNextCore() return false; } - public ML.Data.Schema Schema => _loader.Schema; + public override ML.Data.Schema Schema => _loader.Schema; public override long Batch => 0; - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(IsColumnActive(col), nameof(col), "requested column not active"); @@ -612,7 +612,7 @@ public override ValueGetter GetIdGetter() }; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index 465989a265..1993752f02 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -339,7 +339,7 @@ private void CheckInputSchema(ISchema schema, int matrixColumnIndexCol, int matr _env.CheckParam(type.Equals(_parent.MatrixRowIndexType), nameof(schema), msg); } - private Delegate[] CreateGetter(IRow input, bool[] active) + private Delegate[] CreateGetter(Row input, bool[] active) { _env.CheckValue(input, nameof(input)); _env.Assert(Utils.Size(active) == OutputSchema.ColumnCount); @@ -358,7 +358,7 @@ private Delegate[] CreateGetter(IRow input, bool[] active) return getters; } - public IRow GetRow(IRow input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate, out Action disposer) { var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate); var getters = CreateGetter(input, active); diff --git a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs index 33bb90ae0b..b1305a5a90 100644 --- a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs +++ b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs @@ -249,7 +249,7 @@ private unsafe void Dispose(bool disposing) } } - private MFNode[] ConstructLabeledNodesFrom(IChannel ch, ICursor cursor, ValueGetter labGetter, + private MFNode[] ConstructLabeledNodesFrom(IChannel ch, RowCursor cursor, ValueGetter labGetter, ValueGetter rowGetter, ValueGetter colGetter, int rowCount, int colCount) { @@ -303,7 +303,7 @@ private MFNode[] ConstructLabeledNodesFrom(IChannel ch, ICursor cursor, ValueGet } public unsafe void Train(IChannel ch, int rowCount, int colCount, - ICursor cursor, ValueGetter labGetter, + RowCursor cursor, ValueGetter labGetter, ValueGetter rowGetter, ValueGetter colGetter) { if (_pMFModel != null) @@ -333,9 +333,9 @@ public unsafe void Train(IChannel ch, int rowCount, int colCount, } public unsafe void TrainWithValidation(IChannel ch, int rowCount, int colCount, - ICursor cursor, ValueGetter labGetter, + RowCursor cursor, ValueGetter labGetter, ValueGetter rowGetter, ValueGetter colGetter, - ICursor validCursor, ValueGetter validLabGetter, + RowCursor validCursor, ValueGetter validLabGetter, ValueGetter validRowGetter, ValueGetter validColGetter) { if (_pMFModel != null) diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index 11ce748cb2..df486453b8 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -96,7 +96,7 @@ public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleM } } - public IRow GetRow(IRow input, Func predicate, out Action action) + public Row GetRow(Row input, Func predicate, out Action action) { var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16); var featureBuffer = new VBuffer(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 72692c4016..ad4fca42df 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -346,7 +346,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) public abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); - public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema) + public virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); @@ -359,7 +359,7 @@ public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema) return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public virtual IRow GetStatsIRowOrNull(RoleMappedSchema schema) => null; + public virtual Row GetStatsIRowOrNull(RoleMappedSchema schema) => null; public abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); @@ -502,7 +502,7 @@ public IList> GetSummaryInKeyValuePairs(RoleMappedS return results; } - public override IRow GetStatsIRowOrNull(RoleMappedSchema schema) + public override Row GetStatsIRowOrNull(RoleMappedSchema schema) { if (_stats == null) return null; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 6d8a411e7c..e2b280ca1e 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -982,12 +982,12 @@ public IDataView GetSummaryDataView(RoleMappedSchema schema) return bldr.GetDataView(); } - public IRow GetSummaryIRowOrNull(RoleMappedSchema schema) + public Row GetSummaryIRowOrNull(RoleMappedSchema schema) { return null; } - public IRow GetStatsIRowOrNull(RoleMappedSchema schema) + public Row GetStatsIRowOrNull(RoleMappedSchema schema) { if (_stats == null) return null; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index b4ced664c1..e29435cc63 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -540,21 +540,21 @@ private void UpdateModelOnDisk(string modelDir, Arguments args) } } - private static ITensorValueGetter CreateTensorValueGetter(IRow input, bool isVector, int colIndex, TFShape tfShape) + private static ITensorValueGetter CreateTensorValueGetter(Row input, bool isVector, int colIndex, TFShape tfShape) { if (isVector) return new TensorValueGetterVec(input, colIndex, tfShape); return new TensorValueGetter(input, colIndex, tfShape); } - private static ITensorValueGetter CreateTensorValueGetter(IRow input, TFDataType tfType, bool isVector, int colIndex, TFShape tfShape) + private static ITensorValueGetter CreateTensorValueGetter(Row input, TFDataType tfType, bool isVector, int colIndex, TFShape tfShape) { var type = TFTensor.TypeFromTensorType(tfType); Contracts.AssertValue(type); return Utils.MarshalInvoke(CreateTensorValueGetter, type, input, isVector, colIndex, tfShape); } - private static ITensorValueGetter[] GetTensorValueGetters(IRow input, + private static ITensorValueGetter[] GetTensorValueGetters(Row input, int[] inputColIndices, bool[] isInputVector, TFDataType[] tfInputTypes, @@ -861,7 +861,7 @@ public OutputCache() } } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { disposer = null; Host.AssertValue(input); @@ -875,7 +875,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache); } - private Delegate MakeGetter(IRow input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) + private Delegate MakeGetter(Row input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) { Host.AssertValue(input); ValueGetter> valuegetter = (ref VBuffer dst) => @@ -964,7 +964,7 @@ private class TensorValueGetter : ITensorValueGetter private readonly TFShape _tfShape; private int _position; - public TensorValueGetter(IRow input, int colIndex, TFShape tfShape) + public TensorValueGetter(Row input, int colIndex, TFShape tfShape) { _srcgetter = input.GetGetter(colIndex); _tfShape = tfShape; @@ -1010,7 +1010,7 @@ private class TensorValueGetterVec : ITensorValueGetter private readonly T[] _bufferedData; private int _position; - public TensorValueGetterVec(IRow input, int colIndex, TFShape tfShape) + public TensorValueGetterVec(Row input, int colIndex, TFShape tfShape) { _srcgetter = input.GetGetter>(colIndex); _tfShape = tfShape; diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs index 792b06d6b3..8dc4cc27b3 100644 --- a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs +++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs @@ -25,16 +25,16 @@ internal interface IStatefulTransformer : ITransformer IStatefulTransformer Clone(); } - internal interface IStatefulRow : IRow + internal abstract class StatefulRow : Row { - Action GetPinger(); + public abstract Action GetPinger(); } internal interface IStatefulRowMapper : IRowMapper { void CloneState(); - Action CreatePinger(IRow input, Func activeOutput, out Action disposer); + Action CreatePinger(Row input, Func activeOutput, out Action disposer); } /// @@ -98,8 +98,8 @@ public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transform { } - internal IRow GetStatefulRows(IRow input, IRowToRowMapper mapper, Func active, - List rows, out Action disposer) + internal Row GetStatefulRows(Row input, IRowToRowMapper mapper, Func active, + List rows, out Action disposer) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); @@ -123,8 +123,8 @@ internal IRow GetStatefulRows(IRow input, IRowToRowMapper mapper, Func= 1; --i) deps[i - 1] = innerMappers[i].GetDependencies(deps[i]); - IRow result = input; + Row result = input; for (int i = 0; i < innerMappers.Length; ++i) { Action localDisp; result = GetStatefulRows(result, innerMappers[i], deps[i], rows, out localDisp); - if (result is IStatefulRow) - rows.Add((IStatefulRow)result); + if (result is StatefulRow statefulResult) + rows.Add(statefulResult); if (localDisp != null) { @@ -158,25 +158,21 @@ internal IRow GetStatefulRows(IRow input, IRowToRowMapper mapper, Func CreatePinger(List rows) + private Action CreatePinger(List rows) { - Action[] pingers = new Action[rows.Count]; - int index = 0; + if (rows.Count == 0) + return position => { }; + Action pinger = null; foreach (var row in rows) - pingers[index++] = row.GetPinger(); - - return (long position) => - { - foreach (var ping in pingers) - ping(position); - }; + pinger += row.GetPinger(); + return pinger; } internal override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { - List rows = new List(); - IRow outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows, out disposer); + List rows = new List(); + Row outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows, out disposer); var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition); _pinger = CreatePinger(rows); outputRow = cursorable.GetRow(outputRowLocal); diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 1c867a5f19..8b41ee8edf 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -625,7 +625,7 @@ public Func GetDependencies(Func activeOutput) public void Save(ModelSaveContext ctx) => _parent.Save(ctx); - public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { disposer = null; var getters = new Delegate[1]; @@ -637,7 +637,7 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac private delegate void ProcessData(ref TInput src, ref VBuffer dst); - private Delegate MakeGetter(IRow input, TState state) + private Delegate MakeGetter(Row input, TState state) { _host.AssertValue(input); var srcGetter = input.GetGetter(_inputColumnIndex); @@ -653,7 +653,7 @@ private Delegate MakeGetter(IRow input, TState state) return valueGetter; } - public Action CreatePinger(IRow input, Func activeOutput, out Action disposer) + public Action CreatePinger(Row input, Func activeOutput, out Action disposer) { disposer = null; Action pinger = null; @@ -663,7 +663,7 @@ public Action CreatePinger(IRow input, Func activeOutput, out A return pinger; } - private Action MakePinger(IRow input, TState state) + private Action MakePinger(Row input, TState state) { _host.AssertValue(input); var srcGetter = input.GetGetter(_inputColumnIndex); diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index c6f38380f4..97bb3eb53a 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -352,7 +352,7 @@ private void InitFunction(TState state) return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { var srcCursor = _transform.GetRowCursor(predicate, rand); return new Cursor(this, srcCursor); @@ -365,35 +365,35 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random return _transform.GetRowCount(); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } /// /// A wrapper around the cursor which replaces the schema. /// - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly SequentialTransformBase _parent; - public Cursor(SequentialTransformBase parent, IRowCursor input) + public Cursor(SequentialTransformBase parent, RowCursor input) : base(parent.Host, input) { Ch.Assert(input.Schema.ColumnCount == parent.OutputSchema.ColumnCount); _parent = parent; } - public Schema Schema { get { return _parent.OutputSchema; } } + public override Schema Schema { get { return _parent.OutputSchema; } } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.ColumnCount, "col"); return Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "col"); return Input.GetGetter(col); diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 8b6d74426c..7af0eb1cf1 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -429,7 +429,7 @@ private void InitFunction(TState state) public override bool CanShuffle { get { return false; } } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { var srcCursor = _transform.GetRowCursor(predicate, rand); var clone = (SequentialDataTransform)MemberwiseClone(); @@ -448,10 +448,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random return _transform.GetRowCount(); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + return new RowCursor[] { GetRowCursorCore(predicate, rand) }; } public override void Save(ModelSaveContext ctx) @@ -478,26 +478,26 @@ public Func GetDependencies(Func predicate) return col => false; } - public IRow GetRow(IRow input, Func active, out Action disposer) => - new Row(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer), + public Row GetRow(Row input, Func active, out Action disposer) => + new RowImpl(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer), _mapper.CreatePinger(input, active, out disposer)); } - private sealed class Row : IStatefulRow + private sealed class RowImpl : StatefulRow { private readonly Schema _schema; - private readonly IRow _input; + private readonly Row _input; private readonly Delegate[] _getters; private readonly Action _pinger; - public Schema Schema { get { return _schema; } } + public override Schema Schema => _schema; - public long Position { get { return _input.Position; } } + public override long Position => _input.Position; - public long Batch { get { return _input.Batch; } } + public override long Batch => _input.Batch; - public Row(Schema schema, IRow input, Delegate[] getters, Action pinger) + public RowImpl(Schema schema, Row input, Delegate[] getters, Action pinger) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(input, nameof(input)); @@ -508,12 +508,12 @@ public Row(Schema schema, IRow input, Delegate[] getters, Action pinger) _pinger = pinger; } - public ValueGetter GetIdGetter() + public override ValueGetter GetIdGetter() { return _input.GetIdGetter(); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); Contracts.Check(IsColumnActive(col)); @@ -523,10 +523,10 @@ public ValueGetter GetGetter(int col) return fn; } - public Action GetPinger() => + public override Action GetPinger() => _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long)); - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Contracts.Check(0 <= col && col < _getters.Length); return _getters[col] != null; @@ -536,26 +536,26 @@ public bool IsColumnActive(int col) /// /// A wrapper around the cursor which replaces the schema. /// - private sealed class Cursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly SequentialDataTransform _parent; - public Cursor(IHost host, SequentialDataTransform parent, IRowCursor input) + public Cursor(IHost host, SequentialDataTransform parent, RowCursor input) : base(host, input) { Ch.Assert(input.Schema.ColumnCount == parent.OutputSchema.ColumnCount); _parent = parent; } - public Schema Schema { get { return _parent.OutputSchema; } } + public override Schema Schema => _parent.OutputSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < Schema.ColumnCount, "col"); return Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "col"); return Input.GetGetter(col); @@ -688,14 +688,14 @@ private Func GetActiveOutputColumns(bool[] active) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Func predicateInput; var active = GetActive(predicate, out predicateInput); - return new RowCursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); + return new Cursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -710,9 +710,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, inputs[i], this, active); + cursors[i] = new Cursor(Host, inputs[i], this, active); return cursors; } @@ -745,7 +745,7 @@ public Func GetDependencies(Func predicate) Schema IRowToRowMapper.InputSchema => Source.Schema; - public IRow GetRow(IRow input, Func active, out Action disposer) + public Row GetRow(Row input, Func active, out Action disposer) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); @@ -766,21 +766,21 @@ public IRow GetRow(IRow input, Func active, out Action disposer) } } - private sealed class StatefulRow : IStatefulRow + private sealed class StatefulRow : TimeSeries.StatefulRow { - private readonly IRow _input; + private readonly Row _input; private readonly Delegate[] _getters; private readonly Action _pinger; private readonly TimeSeriesRowToRowMapperTransform _parent; - public long Batch { get { return _input.Batch; } } + public override long Batch => _input.Batch; - public long Position { get { return _input.Position; } } + public override long Position => _input.Position; - public Schema Schema { get; } + public override Schema Schema { get; } - public StatefulRow(IRow input, TimeSeriesRowToRowMapperTransform parent, + public StatefulRow(Row input, TimeSeriesRowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action pinger) { _input = input; @@ -790,7 +790,7 @@ public StatefulRow(IRow input, TimeSeriesRowToRowMapperTransform parent, _pinger = pinger; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); @@ -804,12 +804,12 @@ public ValueGetter GetGetter(int col) return fn; } - public Action GetPinger() => + public override Action GetPinger() => _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long)); - public ValueGetter GetIdGetter() => _input.GetIdGetter(); + public override ValueGetter GetIdGetter() => _input.GetIdGetter(); - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); @@ -819,16 +819,16 @@ public bool IsColumnActive(int col) } } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Delegate[] _getters; private readonly bool[] _active; private readonly ColumnBindings _bindings; private readonly Action _disposer; - public Schema Schema => _bindings.Schema; + public override Schema Schema => _bindings.Schema; - public RowCursor(IChannelProvider provider, IRowCursor input, TimeSeriesRowToRowMapperTransform parent, bool[] active) + public Cursor(IChannelProvider provider, RowCursor input, TimeSeriesRowToRowMapperTransform parent, bool[] active) : base(provider, input) { var pred = parent.GetActiveOutputColumns(active); @@ -837,13 +837,13 @@ public RowCursor(IChannelProvider provider, IRowCursor input, TimeSeriesRowToRow _bindings = parent._bindings; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.Schema.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Transforms/BootstrapSamplingTransformer.cs b/src/Microsoft.ML.Transforms/BootstrapSamplingTransformer.cs index 7476920482..3418dcf64c 100644 --- a/src/Microsoft.ML.Transforms/BootstrapSamplingTransformer.cs +++ b/src/Microsoft.ML.Transforms/BootstrapSamplingTransformer.cs @@ -164,35 +164,35 @@ public static BootstrapSamplingTransformer Create(IHostEnvironment env, ModelLoa return false; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { // We do not use the input random because this cursor does not support shuffling. var rgen = new TauswortheHybrid(_state); var input = Source.GetRowCursor(predicate, _shuffleInput ? new TauswortheHybrid(rgen) : null); - IRowCursor cursor = new RowCursor(this, input, rgen); + RowCursor cursor = new Cursor(this, input, rgen); if (_poolSize > 1) cursor = RowShufflingTransformer.GetShuffledCursor(Host, _poolSize, cursor, new TauswortheHybrid(rgen)); return cursor; } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { var cursor = GetRowCursorCore(predicate, rand); consolidator = null; - return new IRowCursor[] { cursor }; + return new RowCursor[] { cursor }; } - private sealed class RowCursor : LinkedRootCursorBase, IRowCursor + private sealed class Cursor : LinkedRootCursorBase { private int _remaining; private readonly BootstrapSamplingTransformer _parent; private readonly Random _rgen; - public override long Batch { get { return 0; } } + public override long Batch => 0; - public Schema Schema { get { return Input.Schema; } } + public override Schema Schema => Input.Schema; - public RowCursor(BootstrapSamplingTransformer parent, IRowCursor input, Random rgen) + public Cursor(BootstrapSamplingTransformer parent, RowCursor input, Random rgen) : base(parent.Host, input) { Ch.AssertValue(rgen); @@ -211,12 +211,12 @@ public override ValueGetter GetIdGetter() }; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { return Input.GetGetter(col); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { return Input.IsColumnActive(col); } diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index d597422854..d330a06f37 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -285,26 +285,26 @@ public static long[][] Train(IHostEnvironment env, IDataView input, string[] col public static bool IsValidColumnType(ColumnType type) => type == NumberType.R4 || type == NumberType.R8 || type.IsText; - private static CountAggregator GetOneAggregator(IRow row, ColumnType colType, int colSrc) + private static CountAggregator GetOneAggregator(Row row, ColumnType colType, int colSrc) { - Func del = GetOneAggregator; + Func del = GetOneAggregator; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.RawType); return (CountAggregator)methodInfo.Invoke(null, new object[] { row, colType, colSrc }); } - private static CountAggregator GetOneAggregator(IRow row, ColumnType colType, int colSrc) + private static CountAggregator GetOneAggregator(Row row, ColumnType colType, int colSrc) { return new CountAggregator(colType, row.GetGetter(colSrc)); } - private static CountAggregator GetVecAggregator(IRow row, ColumnType colType, int colSrc) + private static CountAggregator GetVecAggregator(Row row, ColumnType colType, int colSrc) { - Func del = GetVecAggregator; + Func del = GetVecAggregator; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType); return (CountAggregator)methodInfo.Invoke(null, new object[] { row, colType, colSrc }); } - private static CountAggregator GetVecAggregator(IRow row, ColumnType colType, int colSrc) + private static CountAggregator GetVecAggregator(Row row, ColumnType colType, int colSrc) { return new CountAggregator(colType, row.GetGetter>(colSrc)); } diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index e35a41d5e3..e8d13b4944 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -460,7 +460,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 1de6dea27b..7020eb6ede 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -156,7 +156,7 @@ public override void Save(ModelSaveContext ctx) public override Schema OutputSchema => _groupSchema.AsSchema; - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); @@ -173,12 +173,12 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random public override bool CanShuffle { get { return false; } } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); consolidator = null; - return new IRowCursor[] { GetRowCursorCore(predicate) }; + return new RowCursor[] { GetRowCursorCore(predicate) }; } /// @@ -430,7 +430,7 @@ public void GetMetadata(string kind, int col, ref TValue value) /// - The group column getters are taken directly from the trailing cursor. /// - The keep column getters are provided by the aggregators. /// - private sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase { /// /// This class keeps track of the previous group key and tests the current group key against the previous one. @@ -439,7 +439,7 @@ private sealed class GroupKeyColumnChecker { public readonly Func IsSameKey; - private static Func MakeSameChecker(IRow row, int col) + private static Func MakeSameChecker(Row row, int col) { T oldValue = default(T); T newValue = default(T); @@ -465,12 +465,12 @@ private static Func MakeSameChecker(IRow row, int col) }; } - public GroupKeyColumnChecker(IRow row, int col) + public GroupKeyColumnChecker(Row row, int col) { Contracts.AssertValue(row); var type = row.Schema.GetColumnType(col); - Func> del = MakeSameChecker; + Func> del = MakeSameChecker; var mi = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); IsSameKey = (Func)mi.Invoke(null, new object[] { row, col }); } @@ -480,7 +480,7 @@ public GroupKeyColumnChecker(IRow row, int col) // REVIEW: Currently, it always produces dense buffers. The anticipated use cases don't include many // default values at the moment. /// - /// This class handles the aggregation of one 'keep' column into a vector. It wraps around an 's + /// This class handles the aggregation of one 'keep' column into a vector. It wraps around an 's /// column, reads the data and aggregates. /// private abstract class KeepColumnAggregator @@ -489,7 +489,7 @@ private abstract class KeepColumnAggregator public abstract void SetSize(int size); public abstract void ReadValue(int position); - public static KeepColumnAggregator Create(IRow row, int col) + public static KeepColumnAggregator Create(Row row, int col) { Contracts.AssertValue(row); var colType = row.Schema.GetColumnType(col); @@ -497,7 +497,7 @@ public static KeepColumnAggregator Create(IRow row, int col) var type = typeof(ListAggregator<>); - var cons = type.MakeGenericType(colType.RawType).GetConstructor(new[] { typeof(IRow), typeof(int) }); + var cons = type.MakeGenericType(colType.RawType).GetConstructor(new[] { typeof(Row), typeof(int) }); return cons.Invoke(new object[] { row, col }) as KeepColumnAggregator; } @@ -508,7 +508,7 @@ private sealed class ListAggregator : KeepColumnAggregator private TValue[] _buffer; private int _size; - public ListAggregator(IRow row, int col) + public ListAggregator(Row row, int col) { Contracts.AssertValue(row); _srcGetter = row.GetGetter(col); @@ -546,15 +546,15 @@ public override void ReadValue(int position) private readonly bool[] _active; private readonly int _groupCount; - private readonly IRowCursor _leadingCursor; - private readonly IRowCursor _trailingCursor; + private readonly RowCursor _leadingCursor; + private readonly RowCursor _trailingCursor; private readonly GroupKeyColumnChecker[] _groupCheckers; private readonly KeepColumnAggregator[] _aggregators; public override long Batch { get { return 0; } } - public Schema Schema => _parent.OutputSchema; + public override Schema Schema => _parent.OutputSchema; public Cursor(GroupTransform parent, Func predicate) : base(parent.Host) @@ -601,7 +601,7 @@ public override ValueGetter GetIdGetter() return _trailingCursor.GetIdGetter(); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { _parent._groupSchema.CheckColumnInRange(col); return _active[col]; @@ -669,7 +669,7 @@ public override void Dispose() base.Dispose(); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { _parent._groupSchema.CheckColumnInRange(col); if (!_active[col]) diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index bf674b8ed5..e8ce9c1c15 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -456,7 +456,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) private static MethodInfo _methGetterVecToVec; private static MethodInfo _methGetterVecToOne; - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); @@ -466,17 +466,17 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou // Construct MethodInfos templates that we need for the generic methods. if (_methGetterOneToOne == null) { - Func> del = ComposeGetterOneToOne; + Func> del = ComposeGetterOneToOne; Interlocked.CompareExchange(ref _methGetterOneToOne, del.GetMethodInfo().GetGenericMethodDefinition(), null); } if (_methGetterVecToVec == null) { - Func>> del = ComposeGetterVecToVec; + Func>> del = ComposeGetterVecToVec; Interlocked.CompareExchange(ref _methGetterVecToVec, del.GetMethodInfo().GetGenericMethodDefinition(), null); } if (_methGetterVecToOne == null) { - Func> del = ComposeGetterVecToOne; + Func> del = ComposeGetterVecToOne; Interlocked.CompareExchange(ref _methGetterVecToOne, del.GetMethodInfo().GetGenericMethodDefinition(), null); } @@ -502,7 +502,7 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou /// Input type. Must be a non-vector /// Row inout /// Index of the getter - private ValueGetter ComposeGetterOneToOne(IRow input, int iinfo) + private ValueGetter ComposeGetterOneToOne(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(!Infos[iinfo].TypeSrc.IsVector); @@ -526,7 +526,7 @@ private ValueGetter ComposeGetterOneToOne(IRow input, int iinfo) /// Input type. Must be a vector /// Row input /// Index of the getter - private ValueGetter> ComposeGetterVecToVec(IRow input, int iinfo) + private ValueGetter> ComposeGetterVecToVec(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(Infos[iinfo].TypeSrc.IsVector); @@ -590,7 +590,7 @@ private ValueGetter> ComposeGetterVecToVec(IRow input, int i /// Input type. Must be a vector /// Row input /// Index of the getter - private ValueGetter ComposeGetterVecToOne(IRow input, int iinfo) + private ValueGetter ComposeGetterVecToOne(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(Infos[iinfo].TypeSrc.IsVector); diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index 818806083c..f883f29c45 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -352,7 +352,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) dst = editor.Commit(); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); @@ -367,7 +367,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac /// /// This is for the scalar case. /// - private ValueGetter> MakeGetterOne(IRow input, int iinfo) + private ValueGetter> MakeGetterOne(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsKey); @@ -397,7 +397,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) /// /// This is for the indicator case - vector input and outputs should be concatenated. /// - private ValueGetter> MakeGetterInd(IRow input, int iinfo) + private ValueGetter> MakeGetterInd(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(_infos[iinfo].TypeSrc.IsVector); diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index 456ba0dc58..90191211dc 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -191,18 +191,18 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() } return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - Func>> del = MakeVecGetter; + Func>> del = MakeVecGetter; var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_srcTypes[iinfo].ItemType.RawType); return (Delegate)methodInfo.Invoke(this, new object[] { input, iinfo }); } - private ValueGetter> MakeVecGetter(IRow input, int iinfo) + private ValueGetter> MakeVecGetter(Row input, int iinfo) { var srcGetter = input.GetGetter>(_srcCols[iinfo]); var buffer = default(VBuffer); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs index 72a52e5cdc..a58603c16e 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs @@ -237,7 +237,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) dst = editor.Commit(); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 049911f48e..c54f6693bd 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -224,7 +224,7 @@ private static Delegate GetIsNADelegate(ColumnType type) return Runtime.Data.Conversion.Conversions.Instance.GetIsNAPredicate(type.ItemType); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); @@ -238,10 +238,10 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac /// /// Getter generator for single valued inputs. /// - private ValueGetter ComposeGetterOne(IRow input, int iinfo) + private ValueGetter ComposeGetterOne(Row input, int iinfo) => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].InputType.RawType, input, iinfo); - private ValueGetter ComposeGetterOne(IRow input, int iinfo) + private ValueGetter ComposeGetterOne(Row input, int iinfo) { var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); var src = default(T); @@ -260,10 +260,10 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo) /// /// Getter generator for vector valued inputs. /// - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + private ValueGetter> ComposeGetterVec(Row input, int iinfo) => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].InputType.ItemType.RawType, input, iinfo); - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + private ValueGetter> ComposeGetterVec(Row input, int iinfo) { var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); var isNA = (InPredicate)_infos[iinfo].InputIsNA; diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 41e7b5f857..b1b48f7864 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -645,7 +645,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); @@ -659,13 +659,13 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac /// /// Getter generator for single valued inputs. /// - private Delegate ComposeGetterOne(IRow input, int iinfo) + private Delegate ComposeGetterOne(Row input, int iinfo) => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); /// /// Replaces NA values for scalars. /// - private Delegate ComposeGetterOne(IRow input, int iinfo) + private Delegate ComposeGetterOne(Row input, int iinfo) { var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); var src = default(T); @@ -685,13 +685,13 @@ private Delegate ComposeGetterOne(IRow input, int iinfo) /// /// Getter generator for vector valued inputs. /// - private Delegate ComposeGetterVec(IRow input, int iinfo) + private Delegate ComposeGetterVec(Row input, int iinfo) => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); /// /// Replaces NA values for vectors. /// - private Delegate ComposeGetterVec(IRow input, int iinfo) + private Delegate ComposeGetterVec(Row input, int iinfo) { var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); var isNA = (InPredicate)_isNAs[iinfo]; diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs b/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs index 8466d1b5ef..f9cdde6ad6 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Transforms public sealed partial class MissingValueReplacingTransformer { - private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, ReplacementKind? kind, bool bySlot, IRowCursor cursor, int col) + private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, ReplacementKind? kind, bool bySlot, RowCursor cursor, int col) { ch.Assert(type.ItemType.IsNumber); if (!type.IsVector) @@ -150,7 +150,7 @@ private abstract class StatAggregator : StatAggregator /// public long RowCount { get { return _rowCount; } } - protected StatAggregator(IChannel ch, IRowCursor cursor, int col) + protected StatAggregator(IChannel ch, RowCursor cursor, int col) : base(ch) { Ch.AssertValue(cursor); @@ -178,7 +178,7 @@ private abstract class StatAggregatorAcrossSlots : StatAggregator< /// public UInt128 ValueCount { get { return _valueCount; } } - protected StatAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) + protected StatAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col) : base(ch, cursor, col) { } @@ -199,7 +199,7 @@ protected sealed override void ProcessRow(in VBuffer src) private abstract class StatAggregatorBySlot : StatAggregator, TStatItem[]> { - protected StatAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) + protected StatAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col) : base(ch, cursor, col) { Ch.AssertValue(type); @@ -235,7 +235,7 @@ private abstract class MinMaxAggregatorOne : StatAggregator : StatAggregato /// public long ValuesProcessed { get { return _valuesProcessed; } } - protected MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) + protected MinMaxAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col, bool returnMax) : base(ch, cursor, col) { ReturnMax = returnMax; @@ -300,7 +300,7 @@ private abstract class MinMaxAggregatorBySlot : StatAggregator // The count of the number of times ProcessValue has been called on a specific slot (used for tracking sparsity). private readonly long[] _valuesProcessed; - protected MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) + protected MinMaxAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col) { Ch.AssertValue(type); @@ -540,7 +540,7 @@ private static class R4 // mean of a set of Single values. Conversion to Single happens in GetStat. public sealed class MeanAggregatorOne : StatAggregator { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) + public MeanAggregatorOne(IChannel ch, RowCursor cursor, int col) : base(ch, cursor, col) { } @@ -560,7 +560,7 @@ public override object GetStat() public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) + public MeanAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col) : base(ch, cursor, col) { } @@ -580,7 +580,7 @@ public override object GetStat() public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) + public MeanAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col) : base(ch, type, cursor, col) { } @@ -606,7 +606,7 @@ public override object GetStat() public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorOne(IChannel ch, RowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { Stat = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; @@ -627,7 +627,7 @@ protected override void ProcessValueMax(in Single val) public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { Stat = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; @@ -659,7 +659,7 @@ public override object GetStat() public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) { Single bound = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; @@ -701,7 +701,7 @@ private static class R8 { public sealed class MeanAggregatorOne : StatAggregator { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) + public MeanAggregatorOne(IChannel ch, RowCursor cursor, int col) : base(ch, cursor, col) { } @@ -719,7 +719,7 @@ public override object GetStat() public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) + public MeanAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col) : base(ch, cursor, col) { } @@ -737,7 +737,7 @@ public override object GetStat() public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) + public MeanAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col) : base(ch, type, cursor, col) { } @@ -759,7 +759,7 @@ public override object GetStat() public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorOne(IChannel ch, RowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { Stat = ReturnMax ? Double.NegativeInfinity : Double.PositiveInfinity; @@ -780,7 +780,7 @@ protected override void ProcessValueMax(in Double val) public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorAcrossSlots(IChannel ch, RowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { Stat = ReturnMax ? Double.NegativeInfinity : Double.PositiveInfinity; @@ -812,7 +812,7 @@ public override object GetStat() public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) + public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, RowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) { Double bound = ReturnMax ? Double.MinValue : Double.MaxValue; diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index d2b3fe483e..0f24deb240 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -297,7 +297,7 @@ public override void Save(ModelSaveContext ctx) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -305,10 +305,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -316,7 +316,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); - IRowCursor input; + RowCursor input; if (n > 1 && ShouldUseParallelCursors(predicate) != false) { @@ -325,9 +325,9 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid if (inputs.Length != 1) { - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); + cursors[i] = new Cursor(Host, _bindings, inputs[i], active); return cursors; } input = inputs[0]; @@ -336,7 +336,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid input = Source.GetRowCursor(inputPred); consolidator = null; - return new IRowCursor[] { new RowCursor(Host, _bindings, input, active) }; + return new RowCursor[] { new Cursor(Host, _bindings, input, active) }; } protected override Func GetDependenciesCore(Func predicate) @@ -349,7 +349,7 @@ protected override int MapColumnIndex(out bool isSrc, int col) return _bindings.MapColumnIndex(out isSrc, col); } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disposer) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disposer) { Func activeInfos = iinfo => @@ -370,7 +370,7 @@ protected override Delegate[] CreateGetters(IRow input, Func active, getters[iinfo] = MakeGetter(iinfo); else { - Func> srcDel = GetSrcGetter; + Func> srcDel = GetSrcGetter; var meth = srcDel.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_bindings.ColumnTypes[iinfo].ItemType.RawType); getters[iinfo] = (Delegate)meth.Invoke(this, new object[] { input, iinfo }); } @@ -379,7 +379,7 @@ protected override Delegate[] CreateGetters(IRow input, Func active, } } - private ValueGetter GetSrcGetter(IRow input, int iinfo) + private ValueGetter GetSrcGetter(Row input, int iinfo) { return input.GetGetter(_bindings.SrcCols[iinfo]); } @@ -403,13 +403,13 @@ private Delegate MakeGetterVec(int length) VBufferUtils.Resize(ref value, length, 0)); } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool[] active) : base(provider, input) { Ch.CheckValue(bindings, nameof(bindings)); @@ -427,15 +427,15 @@ public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, } } - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index e2fc0359ad..c9f0f46d9b 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -136,7 +136,7 @@ public override void Save(ModelSaveContext ctx) _bindings.Save(ctx); } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -145,19 +145,19 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var input = Source.GetRowCursor(inputPred, rand); bool active = predicate(_bindings.MapIinfoToCol(0)); - return new RowCursor(Host, _bindings, input, active); + return new Cursor(Host, _bindings, input, active); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); Host.CheckValueOrNull(rand); var inputPred = _bindings.GetDependencies(predicate); - IRowCursor[] cursors = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); + RowCursor[] cursors = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); bool active = predicate(_bindings.MapIinfoToCol(0)); for (int c = 0; c < cursors.Length; ++c) - cursors[c] = new RowCursor(Host, _bindings, cursors[c], active); + cursors[c] = new Cursor(Host, _bindings, cursors[c], active); return cursors; } @@ -167,14 +167,14 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid return null; } - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool _active; - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool active) + public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, bool active) : base(provider, input) { Ch.CheckValue(bindings, nameof(bindings)); @@ -182,7 +182,7 @@ public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, _active = active; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.CheckParam(0 <= col && col < _bindings.ColumnCount, nameof(col)); bool isSrc; @@ -193,7 +193,7 @@ public bool IsColumnActive(int col) return _active; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _bindings.ColumnCount, nameof(col)); Ch.CheckParam(IsColumnActive(col), nameof(col)); diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 0984f40f97..89a61a5ec7 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -540,7 +540,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -550,7 +550,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return GetterFromFloatType(input, iinfo); } - private ValueGetter> GetterFromVectorType(IRow input, int iinfo) + private ValueGetter> GetterFromVectorType(Row input, int iinfo) { var getSrc = input.GetGetter>(_srcCols[iinfo]); var src = default(VBuffer); @@ -567,7 +567,7 @@ private ValueGetter> GetterFromVectorType(IRow input, int iinfo) } - private ValueGetter> GetterFromFloatType(IRow input, int iinfo) + private ValueGetter> GetterFromFloatType(Row input, int iinfo) { var getSrc = input.GetGetter(_srcCols[iinfo]); var src = default(float); diff --git a/src/Microsoft.ML.Transforms/TermLookupTransformer.cs b/src/Microsoft.ML.Transforms/TermLookupTransformer.cs index 70c90a64d1..83a36c55af 100644 --- a/src/Microsoft.ML.Transforms/TermLookupTransformer.cs +++ b/src/Microsoft.ML.Transforms/TermLookupTransformer.cs @@ -114,7 +114,7 @@ public static VecValueMap CreateVector(VectorType type) return new VecValueMap(type); } - public abstract void Train(IExceptionContext ectx, IRowCursor cursor, int colTerm, int colValue); + public abstract void Train(IExceptionContext ectx, RowCursor cursor, int colTerm, int colValue); public abstract Delegate GetGetter(ValueGetter> getSrc); } @@ -136,7 +136,7 @@ protected ValueMap(ColumnType type) /// /// Bind this value map to the given cursor for "training". /// - public override void Train(IExceptionContext ectx, IRowCursor cursor, int colTerm, int colValue) + public override void Train(IExceptionContext ectx, RowCursor cursor, int colTerm, int colValue) { Contracts.AssertValue(ectx); ectx.Assert(_terms == null); @@ -694,7 +694,7 @@ private void SetMetadata() md.Seal(); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 8978401c68..f05ab6161d 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -720,7 +720,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -729,7 +729,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return GetTopic(input, iinfo); } - private ValueGetter> GetTopic(IRow input, int iinfo) + private ValueGetter> GetTopic(Row input, int iinfo) { var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); var src = default(VBuffer); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index a84c7102d0..c57b8ffcb4 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -378,8 +378,8 @@ public NgramHashingTransformer(IHostEnvironment env, Arguments args, IDataView i string[][] friendlyNames = args.Column.Select(c => c.FriendlyNames).ToArray(); var helper = new InvertHashHelper(this, friendlyNames, inputPred, invertHashMaxCounts); - using (IRowCursor srcCursor = input.GetRowCursor(inputPred)) - using (var dstCursor = new RowCursor(this, srcCursor, active, helper.Decorate)) + using (RowCursor srcCursor = input.GetRowCursor(inputPred)) + using (var dstCursor = new Cursor(this, srcCursor, active, helper.Decorate)) { var allGetters = InvertHashHelper.CallAllGetters(dstCursor); while (dstCursor.MoveNext()) @@ -628,7 +628,7 @@ private void AssertValid(uint[] ngram, int ngramLength, int lim, int icol) return null; } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); @@ -636,10 +636,10 @@ protected override IRowCursor GetRowCursorCore(Func predicate, Random var inputPred = _bindings.GetDependencies(predicate); var active = _bindings.GetActive(predicate); var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(this, input, active); + return new Cursor(this, input, active); } - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, + public sealed override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { Host.CheckValue(predicate, nameof(predicate)); @@ -654,9 +654,9 @@ public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator c inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Host.AssertNonEmpty(inputs); - var cursors = new IRowCursor[inputs.Length]; + var cursors = new RowCursor[inputs.Length]; for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(this, inputs[i], active); + cursors[i] = new Cursor(this, inputs[i], active); return cursors; } @@ -665,7 +665,7 @@ protected override Func GetDependenciesCore(Func predicate return _bindings.GetDependencies(predicate); } - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disp) + protected override Delegate[] CreateGetters(Row input, Func active, out Action disp) { Func activeInfos = iinfo => @@ -693,7 +693,7 @@ protected override int MapColumnIndex(out bool isSrc, int col) return _bindings.MapColumnIndex(out isSrc, col); } - private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator decorator = null) + private Delegate MakeGetter(IChannel ch, Row input, int iinfo, FinderDecorator decorator = null) { ch.Assert(_bindings.Infos[iinfo].SrcTypes.All(t => t.IsVector && t.ItemType.IsKey)); @@ -728,15 +728,15 @@ private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator private delegate NgramIdFinder FinderDecorator(int iinfo, NgramIdFinder finder); - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class Cursor : SynchronizedCursorBase { private readonly Bindings _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; - public Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.AsSchema; - public RowCursor(NgramHashingTransformer parent, IRowCursor input, bool[] active, FinderDecorator decorator = null) + public Cursor(NgramHashingTransformer parent, RowCursor input, bool[] active, FinderDecorator decorator = null) : base(parent.Host, input) { Ch.AssertValue(parent); @@ -760,13 +760,13 @@ private bool IsIndexActive(int iinfo) return _active == null || _active[_bindings.MapIinfoToCol(iinfo)]; } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _bindings.ColumnCount); return _active == null || _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); @@ -827,7 +827,7 @@ public InvertHashHelper(NgramHashingTransformer parent, string[][] friendlyNames /// Construct an action that calls all the getters for a row, so as to easily force computation /// of lazily computed values. This will have the side effect of calling the decorator. /// - public static Action CallAllGetters(IRow row) + public static Action CallAllGetters(Row row) { var colCount = row.Schema.ColumnCount; List getters = new List(); @@ -845,14 +845,14 @@ public static Action CallAllGetters(IRow row) }; } - private static Action GetNoOpGetter(IRow row, int col) + private static Action GetNoOpGetter(Row row, int col) { - Func func = GetNoOpGetter; + Func func = GetNoOpGetter; var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(row.Schema.GetColumnType(col).RawType); return (Action)meth.Invoke(null, new object[] { row, col }); } - private static Action GetNoOpGetter(IRow row, int col) + private static Action GetNoOpGetter(Row row, int col) { T value = default(T); var getter = row.GetGetter(col); diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 5f548c4632..ca9794cef6 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -665,7 +665,7 @@ private NgramIdFinder GetNgramIdFinder(int iinfo) }; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 52ba66f0c6..eeffa7b8f7 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -425,7 +425,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -1001,7 +1001,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 51bc1555bf..eb84af33a3 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -280,7 +280,7 @@ private static Dictionary CombinedDiacriticsMap } } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -299,7 +299,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return MakeGetterOne(input, iinfo); } - private ValueGetter> MakeGetterOne(IRow input, int iinfo) + private ValueGetter> MakeGetterOne(Row input, int iinfo) { var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); Host.AssertValue(getSrc); @@ -313,7 +313,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) }; } - private ValueGetter>> MakeGetterVec(IRow input, int iinfo) + private ValueGetter>> MakeGetterVec(Row input, int iinfo) { var getSrc = input.GetGetter>>(ColMapNewToOld[iinfo]); Host.AssertValue(getSrc); diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 590b98af5e..8d4a46652d 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -398,7 +398,7 @@ private void AppendCharRepr(char c, StringBuilder bldr) } } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -409,7 +409,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return MakeGetterVec(input, iinfo); } - private ValueGetter> MakeGetterOne(IRow input, int iinfo) + private ValueGetter> MakeGetterOne(Row input, int iinfo) { Host.AssertValue(input); var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); @@ -438,7 +438,7 @@ private ValueGetter> MakeGetterOne(IRow input, int iinfo) }; } - private ValueGetter> MakeGetterVec(IRow input, int iinfo) + private ValueGetter> MakeGetterVec(Row input, int iinfo) { Host.AssertValue(input); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index a2056e8af5..585b435c70 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -559,7 +559,7 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV nodeP.AddAttribute("axis", 1); } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); @@ -567,7 +567,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return GetGetterVec(input, iinfo); } - private ValueGetter> GetGetterVec(IRow input, int iinfo) + private ValueGetter> GetGetterVec(Row input, int iinfo) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 4419e08888..307821d1cb 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -257,7 +257,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return result; } - protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _parent._columns.Length); @@ -272,7 +272,7 @@ protected override Delegate MakeGetter(IRow input, int iinfo, Func ac return MakeGetterVec(input, iinfo); } - private ValueGetter>> MakeGetterOne(IRow input, int iinfo) + private ValueGetter>> MakeGetterOne(Row input, int iinfo) { Host.AssertValue(input); var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); @@ -298,7 +298,7 @@ private ValueGetter>> MakeGetterOne(IRow input, int }; } - private ValueGetter>> MakeGetterVec(IRow input, int iinfo) + private ValueGetter>> MakeGetterVec(Row input, int iinfo) { Host.AssertValue(input); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 6c58376f84..0467a5763d 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -179,19 +179,19 @@ public override bool CanShuffle get { return false; } } - protected override IRowCursor GetRowCursorCore(Func predicate, Random rand = null) + protected override RowCursor GetRowCursorCore(Func predicate, Random rand = null) { var activeInput = _schemaImpl.GetActiveInput(predicate); var inputCursor = Source.GetRowCursor(col => activeInput[col], null); return new Cursor(Host, inputCursor, _schemaImpl, predicate); } - public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, + public override RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, Random rand = null) { var activeInput = _schemaImpl.GetActiveInput(predicate); var inputCursors = Source.GetRowCursorSet(out consolidator, col => activeInput[col], n, null); - return Utils.BuildArray(inputCursors.Length, + return Utils.BuildArray(inputCursors.Length, x => new Cursor(Host, inputCursors[x], _schemaImpl, predicate)); } @@ -441,7 +441,7 @@ public void GetMetadata(string kind, int col, ref TValue value) } } - private sealed class Cursor : LinkedRootCursorBase, IRowCursor + private sealed class Cursor : LinkedRootCursorBase { private readonly SchemaImpl _schemaImpl; @@ -467,7 +467,7 @@ private sealed class Cursor : LinkedRootCursorBase, IRowCursor // Parallel to columns. private int[] _colSizes; - public Cursor(IChannelProvider provider, IRowCursor input, SchemaImpl schema, Func predicate) + public Cursor(IChannelProvider provider, RowCursor input, SchemaImpl schema, Func predicate) : base(provider, input) { _schemaImpl = schema; @@ -582,15 +582,15 @@ private Func MakeSizeGetter(int col) }; } - public Schema Schema => _schemaImpl.AsSchema; + public override Schema Schema => _schemaImpl.AsSchema; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { Ch.Check(0 <= col && col < _schemaImpl.ColumnCount); return _active[col]; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { Ch.CheckParam(0 <= col && col < _schemaImpl.ColumnCount, nameof(col)); diff --git a/test/Microsoft.ML.Benchmarks/CacheDataViewBench.cs b/test/Microsoft.ML.Benchmarks/CacheDataViewBench.cs new file mode 100644 index 0000000000..d01491f36f --- /dev/null +++ b/test/Microsoft.ML.Benchmarks/CacheDataViewBench.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using BenchmarkDotNet.Attributes; +using Microsoft.ML.Runtime.Data; +using System; + +namespace Microsoft.ML.Benchmarks +{ + public class CacheDataViewBench + { + private const int Length = 100000; + + // Global. + private IDataView _cacheDataView; + // Per iteration. + private RowCursor _cursor; + private ValueGetter _getter; + + private RowSeeker _seeker; + private long[] _positions; + + [GlobalSetup(Targets = new[] { nameof(CacheWithCursor), nameof(CacheWithSeeker) })] + public void Setup() + { + var ctx = new MLContext(); + var builder = new ArrayDataViewBuilder(ctx); + int[] values = new int[Length]; + for (int i = 0; i < values.Length; ++i) + values[i] = i; + builder.AddColumn("A", NumberType.I4, values); + var dv = builder.GetDataView(); + var cacheDv = ctx.Data.Cache(dv); + + var col = cacheDv.Schema.GetColumnOrNull("A").Value; + // First do one pass through. + using (var cursor = cacheDv.GetRowCursor(colIndex => colIndex == col.Index)) + { + var getter = cursor.GetGetter(col.Index); + int val = 0; + int count = 0; + while (cursor.MoveNext()) + { + getter(ref val); + if (val != cursor.Position) + throw new Exception($"Unexpected value {val} at {cursor.Position}"); + count++; + } + if (count != Length) + throw new Exception($"Expected {Length} values in cache but only saw {count}"); + } + _cacheDataView = cacheDv; + + // Only needed for seeker, but may as well set it. + _positions = new long[Length]; + var rand = new Random(0); + for (int i = 0; i < _positions.Length; ++i) + _positions[i] = rand.Next(Length); + } + + [IterationSetup(Target = nameof(CacheWithCursor))] + public void CacheWithCursorSetup() + { + var col = _cacheDataView.Schema.GetColumnOrNull("A").Value; + _cursor = _cacheDataView.GetRowCursor(colIndex => colIndex == col.Index); + _getter = _cursor.GetGetter(col.Index); + } + + [Benchmark] + public void CacheWithCursor() + { + int val = 0; + while (_cursor.MoveNext()) + _getter(ref val); + } + + [IterationSetup(Target = nameof(CacheWithSeeker))] + public void CacheWithSeekerSetup() + { + var col = _cacheDataView.Schema.GetColumnOrNull("A").Value; + _seeker = ((IRowSeekable)_cacheDataView).GetSeeker(colIndex => colIndex == col.Index); + _getter = _seeker.GetGetter(col.Index); + } + + [Benchmark] + public void CacheWithSeeker() + { + int val = 0; + foreach (long pos in _positions) + { + _seeker.MoveTo(pos); + _getter(ref val); + } + } + } +} diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index 6db0768287..4a518be564 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -17,26 +17,26 @@ namespace Microsoft.ML.Benchmarks { public class HashBench { - private sealed class Row : IRow + private sealed class RowImpl : Row { - public Schema Schema { get; } + public long PositionValue; - public long Position { get; set; } - - public long Batch => 0; - public ValueGetter GetIdGetter() + public override Schema Schema { get; } + public override long Position => PositionValue; + public override long Batch => 0; + public override ValueGetter GetIdGetter() => (ref UInt128 val) => val = new UInt128((ulong)Position, 0); private readonly Delegate _getter; - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { if (col != 0) throw new Exception(); return true; } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { if (col != 0) throw new Exception(); @@ -45,14 +45,14 @@ public ValueGetter GetGetter(int col) throw new Exception(); } - public static Row Create(ColumnType type, ValueGetter getter) + public static RowImpl Create(ColumnType type, ValueGetter getter) { if (type.RawType != typeof(T)) throw new Exception(); - return new Row(type, getter); + return new RowImpl(type, getter); } - private Row(ColumnType type, Delegate getter) + private RowImpl(ColumnType type, Delegate getter) { var builder = new SchemaBuilder(); builder.AddColumn("Foo", type, null); @@ -65,7 +65,7 @@ private Row(ColumnType type, Delegate getter) private readonly IHostEnvironment _env = new MLContext(); - private Row _inRow; + private RowImpl _inRow; private ValueGetter _getter; private ValueGetter> _vecGetter; @@ -73,7 +73,7 @@ private void InitMap(T val, ColumnType type, int hashBits = 20, ValueGetter dst = val; - _inRow = Row.Create(type, getter); + _inRow = RowImpl.Create(type, getter); // One million features is a nice, typical number. var info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: hashBits); var xf = new HashingTransformer(_env, new[] { info }); @@ -95,7 +95,7 @@ private void RunScalar() for (int i = 0; i < Count; ++i) { _getter(ref val); - ++_inRow.Position; + ++_inRow.PositionValue; } } @@ -114,7 +114,7 @@ private void RunVector() for (int i = 0; i < Count; ++i) { _vecGetter(ref val); - ++_inRow.Position; + ++_inRow.PositionValue; } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs index 1d102ca37d..08e8039b51 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs @@ -34,7 +34,7 @@ protected bool EqualTypes(ColumnType type1, ColumnType type2, bool exactTypes) return !exactTypes && type1 is VectorType vt1 && type2 is VectorType vt2 && vt1.ItemType.Equals(vt2.ItemType) && vt1.Size == vt2.Size; } - protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter idGetter) + protected Func GetIdComparer(Row r1, Row r2, out ValueGetter idGetter) { var g1 = r1.GetIdGetter(); idGetter = g1; @@ -50,7 +50,7 @@ protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter id }; } - protected Func GetComparerOne(IRow r1, IRow r2, int col, Func fn) + protected Func GetComparerOne(Row r1, Row r2, int col, Func fn) { var g1 = r1.GetGetter(col); var g2 = r2.GetGetter(col); @@ -73,7 +73,7 @@ private static bool EqualWithEps(Double x, Double y) // bitwise comparison is needed because Abs(Inf-Inf) and Abs(NaN-NaN) are not 0s. return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) < DoubleEps; } - protected Func GetComparerVec(IRow r1, IRow r2, int col, int size, Func fn) + protected Func GetComparerVec(Row r1, Row r2, int col, int size, Func fn) { var g1 = r1.GetGetter>(col); var g2 = r2.GetGetter>(col); @@ -153,7 +153,7 @@ protected bool CompareVec(in VBuffer v1, in VBuffer v2, int size, Func< return false; } } - protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType type, bool exactDoubles) + protected Func GetColumnComparer(Row r1, Row r2, int col, ColumnType type, bool exactDoubles) { if (type is VectorType vecType) { @@ -312,7 +312,7 @@ protected bool CheckSameValues(IDataView view1, IDataView view2, bool exactTypes return all; } - protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTypes, bool exactDoubles, bool checkId, bool checkIdCollisions = true) + protected bool CheckSameValues(RowCursor curs1, RowCursor curs2, bool exactTypes, bool exactDoubles, bool checkId, bool checkIdCollisions = true) { Contracts.Assert(curs1.Schema.ColumnCount == curs2.Schema.ColumnCount); @@ -392,13 +392,13 @@ protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTyp } } - protected bool CheckSameValues(IRowCursor curs1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true) + protected bool CheckSameValues(RowCursor curs1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true) { Contracts.Assert(curs1.Schema.ColumnCount == view2.Schema.ColumnCount); // Get a cursor for each column. int colLim = curs1.Schema.ColumnCount; - var cursors = new IRowCursor[colLim]; + var cursors = new RowCursor[colLim]; try { for (int col = 0; col < colLim; col++) diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index a08795ca4c..f96f6f205a 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -672,7 +672,7 @@ private void CombineAndTestTreeEnsembles(IDataView idv, IPredictorModel[] fastTr Assert.True(scoredArray[i].Schema.TryGetColumnIndex("PredictedLabel", out predColArray[i])); } - var cursors = new IRowCursor[predCount]; + var cursors = new RowCursor[predCount]; for (int i = 0; i < predCount; i++) cursors[i] = scoredArray[i].GetRowCursor(c => c == scoreColArray[i] || c == probColArray[i] || c == predColArray[i]); @@ -850,7 +850,7 @@ private void CombineAndTestEnsembles(IDataView idv, string name, string options, } } - var cursors = new IRowCursor[predCount]; + var cursors = new RowCursor[predCount]; for (int i = 0; i < predCount; i++) cursors[i] = scoredArray[i].GetRowCursor(c => c == scoreColArray[i] || c == probColArray[i] || c == predColArray[i]); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index e2bc679830..bde25f5228 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -865,7 +865,7 @@ protected bool CheckSameValues(IDataView view1, IDataView view2, bool exactTypes return all; } - protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTypes, bool exactDoubles, bool checkId, bool checkIdCollisions = true) + protected bool CheckSameValues(RowCursor curs1, RowCursor curs2, bool exactTypes, bool exactDoubles, bool checkId, bool checkIdCollisions = true) { Contracts.Assert(curs1.Schema.ColumnCount == curs2.Schema.ColumnCount); @@ -946,13 +946,13 @@ protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTyp } } - protected bool CheckSameValues(IRowCursor curs1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true) + protected bool CheckSameValues(RowCursor curs1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true) { Contracts.Assert(curs1.Schema.ColumnCount == view2.Schema.ColumnCount); // Get a cursor for each column. int colLim = curs1.Schema.ColumnCount; - var cursors = new IRowCursor[colLim]; + var cursors = new RowCursor[colLim]; try { for (int col = 0; col < colLim; col++) @@ -1029,7 +1029,7 @@ protected bool CheckSameValues(IRowCursor curs1, IDataView view2, bool exactType } } - protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter idGetter) + protected Func GetIdComparer(Row r1, Row r2, out ValueGetter idGetter) { var g1 = r1.GetIdGetter(); idGetter = g1; @@ -1045,7 +1045,7 @@ protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter id }; } - protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType type, bool exactDoubles) + protected Func GetColumnComparer(Row r1, Row r2, int col, ColumnType type, bool exactDoubles) { if (!type.IsVector) { @@ -1174,7 +1174,7 @@ private static bool EqualWithEpsSingle(float x, float y) return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) < SingleEps; } - protected Func GetComparerOne(IRow r1, IRow r2, int col, Func fn) + protected Func GetComparerOne(Row r1, Row r2, int col, Func fn) { var g1 = r1.GetGetter(col); var g2 = r2.GetGetter(col); @@ -1191,7 +1191,7 @@ protected Func GetComparerOne(IRow r1, IRow r2, int col, Func GetComparerVec(IRow r1, IRow r2, int col, int size, Func fn) + protected Func GetComparerVec(Row r1, Row r2, int col, int size, Func fn) { var g1 = r1.GetGetter>(col); var g2 = r2.GetGetter>(col);