diff --git a/src/Microsoft.ML.Api/CustomMappingTransformer.cs b/src/Microsoft.ML.Api/CustomMappingTransformer.cs index d3081fc096..7c81412d78 100644 --- a/src/Microsoft.ML.Api/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Api/CustomMappingTransformer.cs @@ -111,7 +111,7 @@ public Mapper(CustomMappingTransformer parent, Schema inputSchema) _typedSrc = TypedCursorable.Create(_host, emptyDataView, false, _parent.InputSchemaDefinition); } - public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) + Delegate[] IRowMapper.CreateGetters(Row input, Func activeOutput, out Action disposer) { disposer = null; // If no outputs are active, we short-circuit to empty array of getters. @@ -158,7 +158,7 @@ private Delegate GetDstGetter(Row input, int colIndex, Action refreshAction) return combinedGetter; } - public Func GetDependencies(Func activeOutput) + Func IRowMapper.GetDependencies(Func activeOutput) { if (Enumerable.Range(0, _parent.AddedSchema.Columns.Length).Any(activeOutput)) { @@ -169,7 +169,7 @@ public Func GetDependencies(Func activeOutput) return col => false; } - public Schema.DetachedColumn[] GetOutputColumns() + Schema.DetachedColumn[] IRowMapper.GetOutputColumns() { var dstRow = new DataViewConstructionUtils.InputRow(_host, _parent.AddedSchema); // All the output columns of dstRow are our outputs. diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index d68882f805..9c75f7189f 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -76,7 +76,7 @@ public static IDataView LoadPipeWithPredictor(IHostEnvironment env, Stream model return pipe; } - public sealed class InputRow : InputRowBase, IRowBackedBy + public sealed class InputRow : InputRowBase where TRow : class { private TRow _value; @@ -416,7 +416,12 @@ public sealed class WrappedCursor : RowCursor public override long Batch => _toWrap.Batch; public override Schema Schema => _toWrap.Schema; - public override void Dispose() => _toWrap.Dispose(); + protected override void Dispose(bool disposing) + { + if (disposing) + _toWrap.Dispose(); + } + public override ValueGetter GetGetter(int col) => _toWrap.GetGetter(col); public override ValueGetter GetIdGetter() => _toWrap.GetIdGetter(); @@ -434,8 +439,8 @@ public abstract class DataViewCursorBase : InputRowBase protected readonly DataViewBase DataView; protected readonly IChannel Ch; - private long _position; + /// /// Zero-based position of the cursor. /// @@ -462,14 +467,14 @@ protected DataViewCursorBase(IHostEnvironment env, DataViewBase dataView, /// protected bool IsGood => State == CursorState.Good; - public virtual void Dispose() + protected sealed override void Dispose(bool disposing) { - if (State != CursorState.Done) - { - Ch.Dispose(); - _position = -1; - State = CursorState.Done; - } + if (State == CursorState.Done) + return; + Ch.Dispose(); + _position = -1; + base.Dispose(disposing); + State = CursorState.Done; } public bool MoveNext() diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs index 78d4810568..b6080b4fc3 100644 --- a/src/Microsoft.ML.Api/PredictionEngine.cs +++ b/src/Microsoft.ML.Api/PredictionEngine.cs @@ -157,13 +157,15 @@ public override void Predict(TSrc example, ref TDst prediction) /// /// The user-defined type that holds the example. /// The user-defined type that holds the prediction. - public abstract class PredictionEngineBase + public abstract class PredictionEngineBase : IDisposable where TSrc : class where TDst : class, new() { private readonly DataViewConstructionUtils.InputRow _inputRow; private readonly IRowReadableAs _outputRow; private readonly Action _disposer; + private bool _disposed; + [BestFriend] private protected ITransformer Transformer { get; } @@ -193,12 +195,14 @@ private protected PredictionEngineBase(IHostEnvironment env, ITransformer transf PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow); } - internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, + [BestFriend] + private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition); - var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer); + var outputRowLocal = mapper.GetRow(inputRow, col => true); outputRow = cursorable.GetRow(outputRowLocal); + disposer = inputRow.Dispose; } protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) @@ -208,9 +212,20 @@ protected virtual Func TransformerChecker(IExceptionCon return transformer.GetRowToRowMapper; } - ~PredictionEngineBase() + public void Dispose() + { + Disposing(true); + GC.SuppressFinalize(this); + } + + [BestFriend] + private protected void Disposing(bool disposing) { - _disposer?.Invoke(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; } /// diff --git a/src/Microsoft.ML.Api/StatefulFilterTransform.cs b/src/Microsoft.ML.Api/StatefulFilterTransform.cs index f07ef7daad..6b63c37349 100644 --- a/src/Microsoft.ML.Api/StatefulFilterTransform.cs +++ b/src/Microsoft.ML.Api/StatefulFilterTransform.cs @@ -96,7 +96,7 @@ private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform false; Schema IDataView.Schema => OutputSchema; @@ -132,10 +132,7 @@ public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func return new[] { GetRowCursor(predicate, rand) }; } - public IDataView Source - { - get { return _source; } - } + public IDataView Source => _source; public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) { @@ -158,10 +155,7 @@ private sealed class Cursor : RootCursorBase private bool _disposed; - public override long Batch - { - get { return _input.Batch; } - } + public override long Batch => _input.Batch; public Cursor(StatefulFilterTransform parent, RowCursor input, Func predicate) : base(parent.Host) @@ -196,24 +190,22 @@ public Cursor(StatefulFilterTransform parent, RowCursor GetIdGetter() diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 4607697662..e73dc88b10 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Api /// /// The user-defined type that is being populated while cursoring. [BestFriend] - internal interface IRowReadableAs + internal interface IRowReadableAs : IDisposable where TRow : class { /// @@ -28,22 +28,6 @@ internal interface IRowReadableAs void FillValues(TRow row); } - /// - /// This interface is an with 'strongly typed' binding. - /// It can accept values of type and present the value as a row. - /// - /// The user-defined type that provides the values while cursoring. - internal interface IRowBackedBy - where TRow : class - { - /// - /// Accepts the fields of the user-supplied object and publishes the instance as a row. - /// If the row is accessed prior to any object being set, then the data accessors on the row should throw. - /// - /// The row object. Cannot be null. - void ExtractValues(TRow row); - } - /// /// This interface provides cursoring through a via a 'strongly typed' binding. /// It can populate the user-supplied object's fields with the values of the current row. @@ -253,36 +237,34 @@ public static TypedCursorable Create(IHostEnvironment env, IDataView data, return new TypedCursorable(env, data, ignoreMissingColumns, outSchema); } - private abstract class TypedRowBase + private abstract class TypedRowBase : WrappingRow { protected readonly IChannel Ch; - private readonly Row _input; private readonly Action[] _setters; - public long Batch => _input.Batch; - - public long Position => _input.Position; - - public Schema Schema => _input.Schema; + public override Schema Schema => base.Input.Schema; public TypedRowBase(TypedCursorable parent, Row input, string channelMessage) + : base(input) { Contracts.AssertValue(parent); Contracts.AssertValue(parent._host); Ch = parent._host.Start(channelMessage); Ch.AssertValue(input); - _input = input; - int n = parent._pokes.Length; Ch.Assert(n == parent._columns.Length); Ch.Assert(n == parent._columnIndices.Length); _setters = new Action[n]; for (int i = 0; i < n; i++) - _setters[i] = GenerateSetter(_input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]); + _setters[i] = GenerateSetter(Input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]); } - public ValueGetter GetIdGetter() => _input.GetIdGetter(); + protected override void DisposeCore(bool disposing) + { + if (disposing) + Ch.Dispose(); + } private Action GenerateSetter(Row input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) { @@ -292,7 +274,7 @@ private Action GenerateSetter(Row input, int index, InternalSchemaDefiniti Func> del; if (fieldType.IsArray) { - Ch.Assert(colType.IsVector); + Ch.Assert(colType is VectorType); // VBuffer> -> String[] if (fieldType.GetElementType() == typeof(string)) { @@ -459,14 +441,14 @@ public virtual void FillValues(TRow row) setter(row); } - public bool IsColumnActive(int col) + public override bool IsColumnActive(int col) { - return _input.IsColumnActive(col); + return Input.IsColumnActive(col); } - public ValueGetter GetGetter(int col) + public override ValueGetter GetGetter(int col) { - return _input.GetGetter(col); + return Input.GetGetter(col); } } @@ -481,6 +463,15 @@ public TypedRow(TypedCursorable parent, Row input) private sealed class RowImplementation : IRowReadableAs { private readonly TypedRow _row; + private bool _disposed; + + public void Dispose() + { + if (_disposed) + return; + _row.Dispose(); + _disposed = true; + } public RowImplementation(TypedRow row) => _row = row; @@ -496,6 +487,7 @@ private sealed class RowImplementation : IRowReadableAs private sealed class RowCursorImplementation : RowCursor { private readonly TypedCursor _cursor; + private bool _disposed; public RowCursorImplementation(TypedCursor cursor) => _cursor = cursor; @@ -504,7 +496,15 @@ private sealed class RowCursorImplementation : RowCursor public override long Batch => _cursor.Batch; public override Schema Schema => _cursor.Schema; - public override void Dispose() { } + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + _cursor.Dispose(); + _disposed = true; + base.Dispose(disposing); + } public override void FillValues(TRow row) => _cursor.FillValues(row); public override ValueGetter GetGetter(int col) => _cursor.GetGetter(col); @@ -518,7 +518,6 @@ public override void Dispose() { } private sealed class TypedCursor : TypedRowBase { private readonly RowCursor _input; - private bool _disposed; public TypedCursor(TypedCursorable parent, RowCursor input) : base(parent, input, "Cursor") @@ -534,16 +533,6 @@ public override void FillValues(TRow row) public CursorState State => _input.State; - public void Dispose() - { - if (!_disposed) - { - _input.Dispose(); - Ch.Dispose(); - _disposed = true; - } - } - public bool MoveNext() => _input.MoveNext(); public bool MoveMany(long count) => _input.MoveMany(count); public RowCursor GetRootCursor() => _input.GetRootCursor(); diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs index c67bba60aa..64da333c3a 100644 --- a/src/Microsoft.ML.Core/Data/IDataView.cs +++ b/src/Microsoft.ML.Core/Data/IDataView.cs @@ -142,7 +142,7 @@ public interface IRowCursorConsolidator /// A logical row. May be a row of an or a stand-alone row. If/when its contents /// change, its value is changed. /// - public abstract class Row + public abstract class Row : IDisposable { /// /// This is incremented when the underlying contents changes, giving clients a way to detect change. @@ -202,6 +202,25 @@ public abstract class Row /// public abstract Schema Schema { get; } + /// + /// Implementation of dispose. Calls with . + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// The disposable method for the disposable pattern. This default implementation does nothing. + /// + /// Whether this was called from . + /// Subclasses that implement should call this method with + /// , but I hasten to add that implementing finalizers should be + /// avoided if at all possible.. + protected virtual void Dispose(bool disposing) + { + } } /// @@ -221,7 +240,7 @@ public enum CursorState /// , is -1. Otherwise, /// >= 0. /// - public abstract class RowCursor : Row, IDisposable + public abstract class RowCursor : Row { /// /// Returns the state of the cursor. Before the first call to or @@ -252,6 +271,5 @@ public abstract class RowCursor : Row, IDisposable /// values from . /// public abstract RowCursor GetRootCursor(); - public abstract void Dispose(); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs index f0e3f45adb..ab6a0223a1 100644 --- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs +++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs @@ -104,15 +104,9 @@ public interface IRowToRowMapper /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. - /// - /// The optional should be invoked by any user of this row mapping, once it no - /// longer needs the . If no action is needed when the cursor is Disposed, the implementation - /// should set to null, otherwise it should be set to a delegate to be - /// invoked by the code calling this object. (For example, a wrapping cursor's - /// method. It's best for this action to be idempotent - calling it multiple times should be equivalent to - /// calling it once. + /// The output values are re-computed when requested through the getters. Also, the returned + /// will dispose when it is disposed. /// - Row GetRow(Row input, Func active, out Action disposer); + Row GetRow(Row input, Func active); } } diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs index 1f95d941cb..9025e27af4 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs @@ -32,12 +32,15 @@ protected LinkedRootCursorBase(IChannelProvider provider, RowCursor input) Root = Input.GetRootCursor(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Input.Dispose(); - base.Dispose(); + // The base class should set the state to done under these circumstances. + base.Dispose(true); } } } diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs index 75d58d98bb..15c0b501ff 100644 --- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs @@ -50,14 +50,14 @@ protected RootCursorBase(IChannelProvider provider) _state = CursorState.NotStarted; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) - { + if (State == CursorState.Done) + return; + if (disposing) Ch.Dispose(); - _position = -1; - _state = CursorState.Done; - } + _position = -1; + _state = CursorState.Done; } public sealed override bool MoveNext() diff --git a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs index 28017f1e7d..711360ac79 100644 --- a/src/Microsoft.ML.Core/Data/SchemaBuilder.cs +++ b/src/Microsoft.ML.Core/Data/SchemaBuilder.cs @@ -31,7 +31,7 @@ public SchemaBuilder() /// The column name. /// The column type. /// The column metadata. - public void AddColumn(string name, ColumnType type, Schema.Metadata metadata) + public void AddColumn(string name, ColumnType type, Schema.Metadata metadata = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValue(type, nameof(type)); diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs index b2641d985e..83bd53a3e3 100644 --- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs @@ -43,14 +43,17 @@ protected SynchronizedCursorBase(IChannelProvider provider, RowCursor input) _root = Input.GetRootCursor(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { Input.Dispose(); Ch.Dispose(); - _disposed = true; } + base.Dispose(disposing); + _disposed = true; } public sealed override bool MoveNext() => _root.MoveNext(); diff --git a/src/Microsoft.ML.Core/Data/WrappingRow.cs b/src/Microsoft.ML.Core/Data/WrappingRow.cs new file mode 100644 index 0000000000..6f855ad225 --- /dev/null +++ b/src/Microsoft.ML.Core/Data/WrappingRow.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// Convenient base class for implementors that wrap a single + /// as their input. The , , and + /// are taken from this . + /// + [BestFriend] + internal abstract class WrappingRow : Row + { + private bool _disposed; + + /// + /// The wrapped input row. + /// + protected Row Input { get; } + + public sealed override long Batch => Input.Batch; + public sealed override long Position => Input.Position; + public override ValueGetter GetIdGetter() => Input.GetIdGetter(); + + [BestFriend] + private protected WrappingRow(Row input) + { + Contracts.AssertValue(input); + Input = input; + } + + /// + /// This override of the dispose method by default only calls 's + /// method, but subclasses can enable additional functionality + /// via the functionality. + /// + /// + protected sealed override void Dispose(bool disposing) + { + if (_disposed) + return; + // Since the input was created first, and this instance may depend on it, we should + // dispose local resources first before potentially disposing the input row resources. + DisposeCore(disposing); + if (disposing) + Input.Dispose(); + _disposed = true; + } + + /// + /// Called from with in the case where + /// that method has never been called before, and right after has been + /// disposed. The default implementation does nothing. + /// + /// Whether this was called through the dispose path, as opposed + /// to the finalizer path. + protected virtual void DisposeCore(bool disposing) + { + } + } +} diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index f81672008b..0b0cef5343 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -55,7 +55,7 @@ public static string[] GetTempColumnNames(this ISchema schema, int n, string tag int j = 0; for (int i = 0; i < n; i++) { - for (;;) + for (; ; ) { string name = string.IsNullOrWhiteSpace(tag) ? string.Format("temp_{0:000}", j) : @@ -1056,23 +1056,21 @@ public Cursor(IChannelProvider provider, Schema schema, int[] activeToCol, int[] _quitAction = quitAction; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { foreach (var pipe in _pipes) pipe.Unset(); - _disposed = true; - if (_quitAction != null) - _quitAction(); + _quitAction?.Invoke(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } - public override ValueGetter GetIdGetter() - { - return _idGetter; - } + public override ValueGetter GetIdGetter() => _idGetter; protected override bool MoveNextCore() { @@ -1203,11 +1201,12 @@ private void InitHeap() } } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { - _disposed = true; _batch = -1; _icursor = -1; _currentCursor = null; @@ -1215,7 +1214,8 @@ public override void Dispose() foreach (var cursor in _cursors) cursor.Dispose(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } public override ValueGetter GetIdGetter() diff --git a/src/Microsoft.ML.Data/Data/IRowSeekable.cs b/src/Microsoft.ML.Data/Data/IRowSeekable.cs index 17514612af..d1ae1ebb7e 100644 --- a/src/Microsoft.ML.Data/Data/IRowSeekable.cs +++ b/src/Microsoft.ML.Data/Data/IRowSeekable.cs @@ -24,10 +24,8 @@ public interface IRowSeekable /// For , when the state is valid (that is when /// returns ), it returns the current row index. Otherwise it's -1. /// - public abstract class RowSeeker : Row, IDisposable + public abstract class RowSeeker : Row { - public abstract void Dispose(); - /// /// Moves the seeker to a row at a specific row index. /// If the row index specified is out of range (less than zero or not less than the diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index aa33db97b8..abbcc163a7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -1363,60 +1363,65 @@ public Cursor(BinaryLoader parent, Func predicate, Random rand) _pipeTask = SetupDecompressTask(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed && _readerThread != null) + if (_disposed) + return; + if (disposing) { - // We should reach this block only in the event of a dispose - // before all rows have been iterated upon. + if (_readerThread != null) + { + // We should reach this block only in the event of a dispose + // before all rows have been iterated upon. - // First set the flag on the cursor. The stream-reader and the - // pipe-decompressor workers will detect this, stop their work, - // and do whatever "cleanup" is natural for them to perform. - _disposed = true; + // First set the flag on the cursor. The stream-reader and the + // pipe-decompressor workers will detect this, stop their work, + // and do whatever "cleanup" is natural for them to perform. + _disposed = true; - // In the disk read -> decompress -> codec read pipeline, we - // clean up in reverse order. - // 1. First we clear out any pending codec readers, for each pipe. - // 2. Then we join the pipe worker threads, which in turn should - // have cleared out all of the pending blocks to decompress. - // 3. Then finally we join against the reader thread. + // In the disk read -> decompress -> codec read pipeline, we + // clean up in reverse order. + // 1. First we clear out any pending codec readers, for each pipe. + // 2. Then we join the pipe worker threads, which in turn should + // have cleared out all of the pending blocks to decompress. + // 3. Then finally we join against the reader thread. - // This code is analogous to the stuff in MoveNextCore, except - // nothing is actually done with the resulting blocks. + // This code is analogous to the stuff in MoveNextCore, except + // nothing is actually done with the resulting blocks. - try - { - for (; ; ) + try { - // This cross-block-index access pattern is deliberate, as - // by having a consistent access pattern everywhere we can - // have much greater confidence this will never deadlock. - bool anyTrue = false; - for (int c = 0; c < _pipes.Length; ++c) - anyTrue |= _pipes[c].MoveNextCleanup(); - if (!anyTrue) - break; + for (; ; ) + { + // This cross-block-index access pattern is deliberate, as + // by having a consistent access pattern everywhere we can + // have much greater confidence this will never deadlock. + bool anyTrue = false; + for (int c = 0; c < _pipes.Length; ++c) + anyTrue |= _pipes[c].MoveNextCleanup(); + if (!anyTrue) + break; + } + } + catch (OperationCanceledException ex) + { + // REVIEW: Encountering this here means that we did not encounter + // the exception during normal cursoring, but at some later point. I feel + // we should not be tolerant of this, and should throw, though it might be + // an ambiguous point. + Contracts.Assert(ex.CancellationToken == _exMarshaller.Token); + _exMarshaller.ThrowIfSet(Ch); + Contracts.Assert(false); + } + finally + { + _pipeTask.Wait(); + _readerThread.Join(); } - } - catch (OperationCanceledException ex) - { - // REVIEW: Encountering this here means that we did not encounter - // the exception during normal cursoring, but at some later point. I feel - // we should not be tolerant of this, and should throw, though it might be - // an ambiguous point. - Contracts.Assert(ex.CancellationToken == _exMarshaller.Token); - _exMarshaller.ThrowIfSet(Ch); - Contracts.Assert(false); - } - finally - { - _pipeTask.Wait(); - _readerThread.Join(); } } - - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } private Task SetupDecompressTask() diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs index 1a2f94da3b..dbeb1c891d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs @@ -275,16 +275,18 @@ public static string GetEmbeddedArgs(IMultiStreamSource files) public override Schema Schema => _bindings.AsSchema; - public override void Dispose() + protected override void Dispose(bool disposing) { if (_disposed) return; - + if (disposing) + { + _ator.Dispose(); + _reader.Release(); + _stats.Release(); + } _disposed = true; - _ator.Dispose(); - _reader.Release(); - _stats.Release(); - base.Dispose(); + base.Dispose(disposing); } protected override bool MoveNextCore() diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index f57e08b8c7..0b42c26ce0 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -824,15 +824,17 @@ public Cursor(TransposeLoader parent, Func pred) Init(_actives[i]); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { - _disposed = true; for (int i = 0; i < _transCursors.Length; ++i) _transCursors[i].Dispose(); - base.Dispose(); } + _disposed = true; + base.Dispose(disposing); } /// diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index e9a100bbef..c0d15d87f2 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -280,15 +280,16 @@ protected override bool MoveNextCore() return true; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Ch.Dispose(); - if (_currentCursor != null) - _currentCursor.Dispose(); - base.Dispose(); + _currentCursor?.Dispose(); } + base.Dispose(disposing); } } @@ -369,15 +370,17 @@ protected override bool MoveNextCore() return true; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (State != CursorState.Done) + if (State == CursorState.Done) + return; + if (disposing) { Ch.Dispose(); foreach (RowCursor c in _cursorSet) c.Dispose(); - base.Dispose(); } + base.Dispose(disposing); } } diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs index 9f2e0ab447..0e9308dc22 100644 --- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs +++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs @@ -565,10 +565,6 @@ public RowSeeker(RowSeekerCore toWrap) public override long Batch => _internal.Batch; public override Schema Schema => _internal.Schema; - public override void Dispose() - { - } - public override ValueGetter GetGetter(int col) => _internal.GetGetter(col); public override ValueGetter GetIdGetter() => _internal.GetIdGetter(); public override bool IsColumnActive(int col) => _internal.IsColumnActive(col); @@ -1291,15 +1287,18 @@ public sealed override bool IsColumnActive(int col) return _colToActivesIndex[col] >= 0; } - public sealed override void Dispose() + protected sealed override void Dispose(bool disposing) { - if (!_disposed) + if (_disposed) + return; + if (disposing) { DisposeCore(); PositionCore = -1; Ch.Dispose(); - _disposed = true; } + base.Dispose(disposing); + _disposed = true; } public sealed override ValueGetter GetGetter(int col) diff --git a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs index 326a750a26..aaf2e9f49e 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs +++ b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs @@ -43,13 +43,12 @@ public Func GetDependencies(Func predicate) return toReturn; } - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema"); - disposer = null; if (InnerMappers.Length == 0) { bool differentActive = false; @@ -75,17 +74,7 @@ public Row GetRow(Row input, Func active, out Action disposer) Row result = input; for (int i = 0; i < InnerMappers.Length; ++i) - { - result = InnerMappers[i].GetRow(result, deps[i], out var localDisp); - if (localDisp != null) - { - if (disposer == null) - disposer = localDisp; - else - disposer = localDisp + disposer; - // We want the last disposer to be called first, so the order of the addition here is important. - } - } + result = InnerMappers[i].GetRow(result, deps[i]); return result; } diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index a4839d710b..57382f890b 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -26,7 +26,8 @@ namespace Microsoft.ML.Runtime.Data /// ctor or Create method with , along with a corresponding /// . /// - public interface IRowMapper : ICanSaveModel + [BestFriend] + internal interface IRowMapper : ICanSaveModel { /// /// Returns the input columns needed for the requested output columns. @@ -36,7 +37,9 @@ public interface IRowMapper : ICanSaveModel /// /// Returns the getters for the output columns given an active set of output columns. The length of the getters /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the - /// i'th output column if activeOutput(i) is true, and null otherwise. + /// i'th output column if activeOutput(i) is true, and null otherwise. If creating a or + /// out of this, the delegate (if non-null) should be called + /// from the dispose of either of those instances. /// Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer); @@ -81,7 +84,8 @@ private static VersionInfo GetVersionInfo() bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; - public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func mapperFactory) + [BestFriend] + internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func mapperFactory) : base(env, RegistrationName, input) { Contracts.CheckValue(mapper, nameof(mapper)); @@ -91,7 +95,8 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper _bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns()); } - public static Schema GetOutputSchema(ISchema inputSchema, IRowMapper mapper) + [BestFriend] + internal static Schema GetOutputSchema(ISchema inputSchema, IRowMapper mapper) { Contracts.CheckValue(inputSchema, nameof(inputSchema)); Contracts.CheckValue(mapper, nameof(mapper)); @@ -235,23 +240,20 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => Source.Schema; - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; using (var ch = Host.Start("GetEntireRow")) { - Action disp; var activeArr = new bool[OutputSchema.ColumnCount]; for (int i = 0; i < OutputSchema.ColumnCount; i++) activeArr[i] = active(i); var pred = GetActiveOutputColumns(activeArr); - var getters = _mapper.CreateGetters(input, pred, out disp); - disposer += disp; - return new RowImpl(input, this, OutputSchema, getters); + var getters = _mapper.CreateGetters(input, pred, out Action disp); + return new RowImpl(input, this, OutputSchema, getters, disp); } } @@ -285,25 +287,27 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) } } - private sealed class RowImpl : Row + private sealed class RowImpl : WrappingRow { - private readonly Row _input; private readonly Delegate[] _getters; - private readonly RowToRowMapperTransform _parent; - - public override long Batch => _input.Batch; - - public override long Position => _input.Position; + private readonly Action _disposer; public override Schema Schema { get; } - public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action disposer) + : base(input) { - _input = input; _parent = parent; Schema = schema; _getters = getters; + _disposer = disposer; + } + + protected override void DisposeCore(bool disposing) + { + if (disposing) + _disposer?.Invoke(); } public override ValueGetter GetGetter(int col) @@ -311,7 +315,7 @@ public override ValueGetter GetGetter(int col) bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.GetGetter(index); + return Input.GetGetter(index); Contracts.Assert(_getters[index] != null); var fn = _getters[index] as ValueGetter; @@ -320,14 +324,12 @@ public override ValueGetter GetGetter(int col) return fn; } - public override ValueGetter GetIdGetter() => _input.GetIdGetter(); - public override bool IsColumnActive(int col) { bool isSrc; int index = _parent._bindings.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.IsColumnActive((index)); + return Input.IsColumnActive((index)); return _getters[index] != null; } } @@ -338,6 +340,7 @@ private sealed class Cursor : SynchronizedCursorBase private readonly bool[] _active; private readonly ColumnBindings _bindings; private readonly Action _disposer; + private bool _disposed; public override Schema Schema => _bindings.Schema; @@ -374,10 +377,14 @@ public override ValueGetter GetGetter(int col) return fn; } - public override void Dispose() + protected override void Dispose(bool disposing) { - _disposer?.Invoke(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; + base.Dispose(disposing); } } } diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index 55f67e6de2..01a2b719b6 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -13,45 +13,54 @@ namespace Microsoft.ML.Runtime.Data /// /// An implementation of that gets its , , /// and from an input row. The constructor requires a schema and array of getter - /// delegates. A null delegate indicates an inactive column. The delegates are assumed to be of the appropriate type - /// (this does not validate the type). + /// delegates. A delegate indicates an inactive column. The delegates are assumed to be + /// of the appropriate type (this does not validate the type). /// REVIEW: Should this validate that the delegates are of the appropriate type? It wouldn't be difficult /// to do so. /// - public sealed class SimpleRow : Row + [BestFriend] + internal sealed class SimpleRow : WrappingRow { - private readonly Row _input; private readonly Delegate[] _getters; + private readonly Action _disposer; public override Schema Schema { get; } - public override long Position => _input.Position; - - public override long Batch => _input.Batch; - - public SimpleRow(Schema schema, Row input, Delegate[] getters) + /// + /// Constructor. + /// + /// The schema for the row. + /// The row that is being wrapped by this row, where our , + /// , . + /// The collection of getter delegates, whose types should map those in a schema. + /// If one of these is , the corresponding column is considered inactive. + /// A method that, if non-null, will be called exactly once during + /// , prior to disposing . + public SimpleRow(Schema schema, Row input, Delegate[] getters, Action disposer = null) + : base(input) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(input, nameof(input)); - Contracts.Check(Utils.Size(getters) == schema.ColumnCount); + Contracts.Check(Utils.Size(getters) == schema.Count); + Contracts.CheckValueOrNull(disposer); Schema = schema; - _input = input; _getters = getters ?? new Delegate[0]; + _disposer = disposer; } - public override ValueGetter GetIdGetter() + protected override void DisposeCore(bool disposing) { - return _input.GetIdGetter(); + if (disposing) + _disposer?.Invoke(); } public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); Contracts.Check(IsColumnActive(col)); - var fn = _getters[col] as ValueGetter; - if (fn == null) - throw Contracts.Except("Unexpected TValue in GetGetter"); - return fn; + if (_getters[col] is ValueGetter fn) + return fn; + throw Contracts.Except("Unexpected TValue in GetGetter"); } public override bool IsColumnActive(int col) @@ -135,68 +144,6 @@ public void GetMetadata(string kind, int col, ref TValue value) protected abstract void GetMetadataCore(string kind, int col, ref TValue value); } - /// - /// An that takes all column names and types as constructor parameters. - /// The columns can optionally have text metadata. - /// - public sealed class SimpleSchema : SimpleSchemaBase - { - private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters; - - public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) - : base(ectx, columns) - { - _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount]; - } - - public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, - Dictionary>>> keyValues) - : this(ectx, columns) - { - foreach (var kvp in keyValues) - { - var name = kvp.Key; - var getter = kvp.Value; - if (!ColumnNameMap.TryGetValue(name, out int col)) - throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); - if (!Types[col].ItemType.IsKey) - throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); - _keyValueGetters[col] = getter; - } - } - - protected override IEnumerable> GetMetadataTypesCore(int col) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (_keyValueGetters[col] != null) - { - Ectx.Assert(Types[col].ItemType.IsKey); - yield return new KeyValuePair(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, Types[col].ItemType.KeyCount)); - } - } - - protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) - { - Ectx.Assert(Types[col].ItemType.IsKey); - return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount); - } - return null; - } - - protected override void GetMetadataCore(string kind, int col, ref TValue value) - { - Ectx.Assert(0 <= col && col < ColumnCount); - if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) - _keyValueGetters[col].Marshal(col, ref value); - else - throw Ectx.ExceptGetMetadata(); - } - } - public static class SimpleSchemaUtils { public static Schema Create(IExceptionContext ectx, params KeyValuePair[] columns) @@ -209,5 +156,4 @@ public static Schema Create(IExceptionContext ectx, params KeyValuePair : Row + private abstract class RowBase : WrappingRow where TSplitter : Splitter { protected readonly TSplitter Parent; - protected readonly Row Input; public sealed override Schema Schema => Parent.AsSchema; - public sealed override long Position => Input.Position; - public sealed override long Batch => Input.Batch; public RowBase(TSplitter parent, Row input) + : base(input) { Contracts.AssertValue(parent); Contracts.AssertValue(input); Contracts.Assert(input.IsColumnActive(parent.SrcCol)); Parent = parent; - Input = input; - } - - public sealed override ValueGetter GetIdGetter() - { - return Input.GetIdGetter(); } } @@ -1511,7 +1503,7 @@ public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) _col = col; var builder = new SchemaBuilder(); - builder.AddColumn(_data.Schema[_col].Name, _type, null); + builder.AddColumn(_data.Schema[_col].Name, _type); Schema = builder.GetSchema(); } @@ -1606,7 +1598,7 @@ public SlotRowCursorShim(IChannelProvider provider, SlotCursor cursor) _slotCursor = cursor; var builder = new SchemaBuilder(); - builder.AddColumn("Waffles", cursor.GetSlotType(), null); + builder.AddColumn("Waffles", cursor.GetSlotType()); Schema = builder.GetSchema(); } diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs index 7b344a02be..9de68b9d55 100644 --- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs +++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs @@ -110,6 +110,7 @@ private sealed class Cursor : RootCursorBase private readonly RowCursor[] _cursors; private readonly CompositeSchema _compositeSchema; private readonly bool[] _isColumnActive; + private bool _disposed; public override long Batch { get { return 0; } } @@ -124,11 +125,17 @@ public Cursor(ZipDataView parent, RowCursor[] srcCursors, Func predic _isColumnActive = Utils.BuildArray(_compositeSchema.ColumnCount, predicate); } - public override void Dispose() + protected override void Dispose(bool disposing) { - for (int i = _cursors.Length - 1; i >= 0; i--) - _cursors[i].Dispose(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + { + for (int i = _cursors.Length - 1; i >= 0; i--) + _cursors[i].Dispose(); + } + _disposed = true; + base.Dispose(disposing); } public override ValueGetter GetIdGetter() diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index 994414d8f4..5af6213603 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -236,7 +236,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => _rootSchema; - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { _ectx.Assert(IsCompositeRowToRowMapper(_chain)); _ectx.AssertValue(input); @@ -244,7 +244,6 @@ public Row GetRow(Row input, Func active, out Action disposer) _ectx.Check(input.Schema == InputSchema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; var mappers = new List(); var actives = new List>(); var transform = _chain as IDataTransform; @@ -262,11 +261,7 @@ public Row GetRow(Row input, Func active, out Action disposer) actives.Reverse(); var row = input; for (int i = 0; i < mappers.Count; i++) - { - Action disp; - row = mappers[i].GetRow(row, actives[i], out disp); - disposer += disp; - } + row = mappers[i].GetRow(row, actives[i]); return row; } diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 4dfe37f9f6..3dd2ca8f5c 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -188,7 +188,7 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) return names; } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); @@ -968,7 +968,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.WriteBoolByte(_useRaw); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { if (_probIndex >= 0) { @@ -981,7 +981,7 @@ public override Func GetDependencies(Func activeOutput) return col => activeOutput(AssignedCol) && col == ScoreIndex; } - public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); @@ -1079,7 +1079,7 @@ private bool GetPredictedLabel(Single val) return Single.IsNaN(val) ? false : val > _threshold; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { if (_probIndex >= 0) { diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 797854dcbf..ca2d36c3c4 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -142,7 +142,7 @@ protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string return new Aggregator(Host, schema.Feature, numClusters, _calculateDbi, schema.Weight != null, stratName); } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int numClusters = scoreInfo.Type.VectorSize; @@ -638,7 +638,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(_numClusters); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => @@ -646,13 +646,13 @@ public override Func GetDependencies(Func activeOutput) (activeOutput(ClusterIdCol) || activeOutput(SortedClusterCol) || activeOutput(SortedClusterScoreCol)); } - public override Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { disposer = null; var getters = new Delegate[3]; - if (!activeOutput(ClusterIdCol) && !activeOutput(SortedClusterCol) && !activeOutput(SortedClusterScoreCol)) + if (!activeCols(ClusterIdCol) && !activeCols(SortedClusterCol) && !activeCols(SortedClusterScoreCol)) return getters; long cachedPosition = -1; @@ -675,7 +675,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput } }; - if (activeOutput(ClusterIdCol)) + if (activeCols(ClusterIdCol)) { ValueGetter assignedFn = (ref uint dst) => @@ -686,7 +686,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput getters[ClusterIdCol] = assignedFn; } - if (activeOutput(SortedClusterScoreCol)) + if (activeCols(SortedClusterScoreCol)) { ValueGetter> topKScoresFn = (ref VBuffer dst) => @@ -700,7 +700,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput getters[SortedClusterScoreCol] = topKScoresFn; } - if (activeOutput(SortedClusterCol)) + if (activeCols(SortedClusterCol)) { ValueGetter> topKClassesFn = (ref VBuffer dst) => @@ -716,7 +716,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput return getters; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[3]; infos[ClusterIdCol] = new Schema.DetachedColumn(ClusterId, _types[ClusterIdCol], null); diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 829dd0043a..906bc020b8 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -446,7 +446,8 @@ public override IDataTransform GetPerInstanceMetrics(RoleMappedData data) return new RowToRowMapperTransform(Host, data.Data, mapper, null); } - protected abstract IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema); + [BestFriend] + private protected abstract IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema); } /// @@ -501,10 +502,22 @@ public virtual void Save(ModelSaveContext ctx) ctx.SaveStringOrNull(LabelCol); } - public abstract Func GetDependencies(Func activeOutput); + Func IRowMapper.GetDependencies(Func activeOutput) + => GetDependenciesCore(activeOutput); - public abstract Schema.DetachedColumn[] GetOutputColumns(); + [BestFriend] + private protected abstract Func GetDependenciesCore(Func activeOutput); - public abstract Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer); + Schema.DetachedColumn[] IRowMapper.GetOutputColumns() + => GetOutputColumnsCore(); + + [BestFriend] + private protected abstract Schema.DetachedColumn[] GetOutputColumnsCore(); + + Delegate[] IRowMapper.CreateGetters(Row input, Func activeCols, out Action disposer) + => CreateGettersCore(input, activeCols, out disposer); + + [BestFriend] + private protected abstract Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer); } } diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index cb12bf92e4..04316e3ed6 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -117,7 +117,7 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) return names; } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); @@ -648,7 +648,7 @@ public override void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(_classNames[i].ToString()); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { Host.Assert(ScoreIndex >= 0); Host.Assert(LabelIndex >= 0); @@ -662,13 +662,13 @@ public override Func GetDependencies(Func activeOutput) activeOutput(SortedClassesCol) || activeOutput(LogLossCol)); } - public override Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { disposer = null; var getters = new Delegate[4]; - if (!activeOutput(AssignedCol) && !activeOutput(SortedClassesCol) && !activeOutput(SortedScoresCol) && !activeOutput(LogLossCol)) + if (!activeCols(AssignedCol) && !activeCols(SortedClassesCol) && !activeCols(SortedScoresCol) && !activeCols(LogLossCol)) return getters; long cachedPosition = -1; @@ -677,7 +677,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput var scoresArr = new float[_numClasses]; int[] sortedIndices = new int[_numClasses]; - var labelGetter = activeOutput(LogLossCol) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : + var labelGetter = activeCols(LogLossCol) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : (ref float dst) => dst = float.NaN; var scoreGetter = input.GetGetter>(ScoreIndex); Action updateCacheIfNeeded = @@ -695,7 +695,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput } }; - if (activeOutput(AssignedCol)) + if (activeCols(AssignedCol)) { ValueGetter assignedFn = (ref uint dst) => @@ -706,7 +706,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput getters[AssignedCol] = assignedFn; } - if (activeOutput(SortedScoresCol)) + if (activeCols(SortedScoresCol)) { ValueGetter> topKScoresFn = (ref VBuffer dst) => @@ -720,7 +720,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput getters[SortedScoresCol] = topKScoresFn; } - if (activeOutput(SortedClassesCol)) + if (activeCols(SortedClassesCol)) { ValueGetter> topKClassesFn = (ref VBuffer dst) => @@ -734,7 +734,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput getters[SortedClassesCol] = topKClassesFn; } - if (activeOutput(LogLossCol)) + if (activeCols(LogLossCol)) { ValueGetter logLossFn = (ref double dst) => @@ -761,7 +761,7 @@ public override Delegate[] CreateGetters(Row input, Func activeOutput return getters; } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[4]; diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index fb88a5bb60..daa71a0cb1 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -47,7 +47,7 @@ public MultiOutputRegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Host.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); @@ -436,7 +436,7 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => @@ -446,7 +446,7 @@ public override Func GetDependencies(Func activeOutput) (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[5]; infos[LabelOutput] = new Schema.DetachedColumn(LabelCol, _labelType, _labelMetadata); @@ -457,7 +457,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index 36eabc8a71..eb2f9f677d 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -39,7 +39,7 @@ public QuantileRegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); @@ -345,13 +345,13 @@ public override void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(quantiles[i].ToString()); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[2]; @@ -380,7 +380,7 @@ private ValueGetter>> CreateSlotNamesGetter(string }; } - public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index b4fdf358a0..92a0765c88 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -69,7 +69,7 @@ protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string return new Aggregator(Host, LossFunction, schema.Weight != null, stratName); } - protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) + private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); @@ -245,13 +245,13 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex); } - public override Schema.DetachedColumn[] GetOutputColumns() + private protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var infos = new Schema.DetachedColumn[2]; infos[L1Col] = new Schema.DetachedColumn(L1, NumberType.R8, null); @@ -259,7 +259,7 @@ public override Schema.DetachedColumn[] GetOutputColumns() return infos; } - public override Delegate[] CreateGetters(Row input, Func activeCols, out Action disposer) + private protected override Delegate[] CreateGettersCore(Row input, Func activeCols, out Action disposer) { Host.Assert(LabelIndex >= 0); Host.Assert(ScoreIndex >= 0); diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 45860d2e39..17be128709 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -559,7 +559,7 @@ public Func GetDependencies(Func predicate) return _predictor.GetInputColumnRoles(); } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Func predictorPredicate = col => false; for (int i = 0; i < OutputSchema.ColumnCount; i++) @@ -570,7 +570,7 @@ public Row GetRow(Row input, Func predicate, out Action disposer) break; } } - var predictorRow = _predictor.GetRow(input, predictorPredicate, out disposer); + var predictorRow = _predictor.GetRow(input, predictorPredicate); var getters = new Delegate[OutputSchema.ColumnCount]; for (int i = 0; i < OutputSchema.ColumnCount - 1; i++) { diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs index 3a738887ec..5eb1b91e49 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs @@ -353,8 +353,9 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s if (parent.Stringify) { - _outputSchema = new SimpleSchema(_env, - new KeyValuePair(DefaultColumnNames.FeatureContributions, TextType.Instance)); + var builder = new SchemaBuilder(); + builder.AddColumn(DefaultColumnNames.FeatureContributions, TextType.Instance, null); + _outputSchema = builder.GetSchema(); if (InputSchema.HasSlotNames(InputRoleMappedSchema.Feature.Index, InputRoleMappedSchema.Feature.Type.VectorSize)) InputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, InputRoleMappedSchema.Feature.Index, ref _slotNames); @@ -385,28 +386,28 @@ public Func GetDependencies(Func predicate) return col => false; } - public Row GetOutputRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.AssertValue(input); - Contracts.AssertValue(predicate); + Contracts.AssertValue(active); var totalColumnsCount = 1 + _outputGenericSchema.ColumnCount; var getters = new Delegate[totalColumnsCount]; - if (predicate(totalColumnsCount - 1)) + if (active(totalColumnsCount - 1)) { getters[totalColumnsCount - 1] = _parent.Stringify ? _parent.GetTextContributionGetter(input, InputRoleMappedSchema.Feature.Index, _slotNames) : _parent.GetContributionGetter(input, InputRoleMappedSchema.Feature.Index); } - var genericRow = _genericRowMapper.GetRow(input, GetGenericPredicate(predicate), out disposer); + var genericRow = _genericRowMapper.GetRow(input, GetGenericPredicate(active)); for (var i = 0; i < _outputGenericSchema.ColumnCount; i++) { if (genericRow.IsColumnActive(i)) getters[i] = RowCursorUtils.GetGetterAsDelegate(genericRow, i); } - return new SimpleRow(OutputSchema, input, getters); + return new SimpleRow(OutputSchema, genericRow, getters); } public Func GetGenericPredicate(Func predicate) @@ -418,11 +419,6 @@ public Func GetGenericPredicate(Func predicate) { yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); } - - public Row GetRow(Row input, Func active, out Action disposer) - { - return GetOutputRow(input, active, out disposer); - } } private sealed class FeatureContributionSchema : ISchema diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 0324d950ce..41dff06100 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -307,9 +307,9 @@ public Func GetDependencies(Func predicate) return _mapper.GetInputColumnRoles(); } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { - var innerRow = _mapper.GetRow(input, predicate, out disposer); + var innerRow = _mapper.GetRow(input, predicate); return new RowImpl(innerRow, OutputSchema); } @@ -386,38 +386,30 @@ public void GetMetadata(string kind, int col, ref TValue value) } } - private sealed class RowImpl : Row + private sealed class RowImpl : WrappingRow { - private readonly Row _row; private readonly Schema _schema; - public override long Batch => _row.Batch; - public override long Position => _row.Position; // The schema is of course the only difference from _row. public override Schema Schema => _schema; public RowImpl(Row row, Schema schema) + : base(row) { Contracts.AssertValue(row); Contracts.AssertValue(schema); - _row = row; _schema = schema; } public override bool IsColumnActive(int col) { - return _row.IsColumnActive(col); + return Input.IsColumnActive(col); } public override ValueGetter GetGetter(int col) { - return _row.GetGetter(col); - } - - public override ValueGetter GetIdGetter() - { - return _row.GetIdGetter(); + return Input.GetGetter(col); } } } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 34913c5d1b..4a490f04f2 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -155,8 +155,9 @@ protected override Delegate[] CreateGetters(Row input, Func active, o Func predicateInput; Func predicateMapper; GetActive(bindings, active, out predicateInput, out predicateMapper); - var output = bindings.RowMapper.GetRow(input, predicateMapper, out disp); + var output = bindings.RowMapper.GetRow(input, predicateMapper); Func activeInfos = iinfo => active(bindings.MapIinfoToCol(iinfo)); + disp = output.Dispose; return GetGetters(output, activeInfos); } @@ -220,7 +221,8 @@ private sealed class Cursor : SynchronizedCursorBase private readonly BindingsBase _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; - private readonly Action _disposer; + private readonly Row _output; + private bool _disposed; public override Schema Schema { get; } @@ -236,23 +238,27 @@ public Cursor(IChannelProvider provider, RowToRowScorerBase parent, RowCursor in Ch.Assert(active.Length == _bindings.ColumnCount); _active = active; - var output = _bindings.RowMapper.GetRow(input, predicateMapper, out _disposer); + _output = _bindings.RowMapper.GetRow(input, predicateMapper); try { - Ch.Assert(output.Schema == _bindings.RowMapper.OutputSchema); - _getters = parent.GetGetters(output, iinfo => active[_bindings.MapIinfoToCol(iinfo)]); + Ch.Assert(_output.Schema == _bindings.RowMapper.OutputSchema); + _getters = parent.GetGetters(_output, iinfo => active[_bindings.MapIinfoToCol(iinfo)]); } catch (Exception) { - _disposer?.Invoke(); + _output.Dispose(); throw; } } - public override void Dispose() + protected override void Dispose(bool disposing) { - _disposer?.Invoke(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + _output.Dispose(); + _disposed = true; + base.Dispose(disposing); } public override bool IsColumnActive(int col) diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 279d2a4ad0..e0244fb3a7 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -223,7 +223,7 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => InputRoleMappedSchema.Schema; - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Contracts.AssertValue(input); Contracts.AssertValue(predicate); @@ -231,7 +231,6 @@ public Row GetRow(Row input, Func predicate, out Action disposer) var getters = new Delegate[1]; if (predicate(0)) getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Index); - disposer = null; return new SimpleRow(OutputSchema, input, getters); } } @@ -566,12 +565,11 @@ private static void EnsureCachedResultValueMapper(ValueMapper, Fl } } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { Contracts.AssertValue(input); var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate); var getters = CreateGetters(input, active); - disposer = null; return new SimpleRow(OutputSchema, input, getters); } } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index 327c029767..e3c47fa401 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -393,18 +393,18 @@ public static IDataTransform Create(IHostEnvironment env, TaggedArguments args, return transformer.MakeDataTransform(input); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); /// /// Factory method for SignatureLoadDataTransform. /// - public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => new ColumnConcatenatingTransformer(env, ctx).MakeDataTransform(input); /// /// Factory method for SignatureLoadRowMapper. /// - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => new ColumnConcatenatingTransformer(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa @@ -829,7 +829,7 @@ public KeyValuePair SavePfaInfo(BoundPfaContext ctx) } } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { var active = new bool[InputSchema.ColumnCount]; for (int i = 0; i < _columns.Length; i++) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 4c35b845a4..9cfd65bb38 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -158,7 +158,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema, ColumnPairs); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index cd602bad1a..1193f391e9 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -579,31 +579,23 @@ private static Schema GenerateOutputSchema(IEnumerable map, } } - private sealed class RowImpl : Row + private sealed class RowImpl : WrappingRow { private readonly Mapper _mapper; - private readonly Row _input; public RowImpl(Row input, Mapper mapper) + : base(input) { _mapper = mapper; - _input = input; } - public override long Position => _input.Position; - - public override long Batch => _input.Batch; - public override Schema Schema => _mapper.OutputSchema; public override ValueGetter GetGetter(int col) { int index = _mapper.GetInputIndex(col); - return _input.GetGetter(index); + return Input.GetGetter(index); } - public override ValueGetter GetIdGetter() - => _input.GetIdGetter(); - public override bool IsColumnActive(int col) => true; } @@ -684,9 +676,8 @@ public Func GetDependencies(Func activeOutput) return col => active[col]; } - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { - disposer = null; return new RowImpl(input, _mapper); } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index e8a75d4edd..80e7ac13fd 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -434,7 +434,7 @@ private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax) return true; } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 960c033dad..71c4fa3147 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -319,7 +319,7 @@ private Delegate GetGetterCore(Row input, int iinfo, out Action disposer) return ComposeGetterVec(input, iinfo, srcCol, srcType); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); // Factory method for SignatureLoadModel. private static HashingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index 6553b27239..60b692ad50 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -153,7 +153,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa { diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index da67c1a3b5..cd4142bbad 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -229,7 +229,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa { diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 3d4219a60d..c57a60d387 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -133,13 +133,11 @@ public Func GetDependencies(Func predicate) return predicate; } - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); Contracts.CheckParam(input.Schema == Source.Schema, nameof(input), "Schema of input row must be the same as the schema the mapper is bound to"); - - disposer = null; return input; } diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 0e20466a2e..afc597cf50 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -478,7 +478,7 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo public new IDataTransform MakeDataTransform(IDataView input) => base.MakeDataTransform(input); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa { diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index 0a74c8a39c..47b5ef4670 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -101,7 +101,7 @@ protected OneToOneMapperBase(IHost host, OneToOneTransformerBase parent, Schema } } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { var active = new bool[InputSchema.ColumnCount]; foreach (var pair in ColMapNewToOld) diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 76da06d846..56adc40906 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -495,6 +495,7 @@ public void Fetch(int idx, ref T value) private Exception _producerTaskException; private readonly int[] _colToActivesIndex; + private bool _disposed; public override Schema Schema => _input.Schema; @@ -554,14 +555,17 @@ public Cursor(IChannelProvider provider, int poolRows, RowCursor input, Random r _producerTask = LoopProducerWorker(); } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (_producerTask.Status == TaskStatus.Running) + if (_disposed) + return; + if (disposing && _producerTask.Status == TaskStatus.Running) { _toProduce.Post(0); _producerTask.Wait(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } public static void PostAssert(ITargetBlock target, T item) diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index 356a0cc546..6796736da0 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -33,7 +33,8 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper); } - protected abstract IRowMapper MakeRowMapper(Schema schema); + [BestFriend] + private protected abstract IRowMapper MakeRowMapper(Schema schema); public Schema GetOutputSchema(Schema inputSchema) { @@ -67,9 +68,9 @@ protected MapperBase(IHost host, Schema inputSchema) protected abstract Schema.DetachedColumn[] GetOutputColumnsCore(); - public Schema.DetachedColumn[] GetOutputColumns() => _outputColumns.Value; + Schema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value; - public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) + Delegate[] IRowMapper.CreateGetters(Row input, Func activeOutput, out Action disposer) { // REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema // of the input row. @@ -100,7 +101,11 @@ public Delegate[] CreateGetters(Row input, Func activeOutput, out Act protected abstract Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer); - public abstract Func GetDependencies(Func activeOutput); + Func IRowMapper.GetDependencies(Func activeOutput) + => GetDependenciesCore(activeOutput); + + [BestFriend] + private protected abstract Func GetDependenciesCore(Func activeOutput); public abstract void Save(ModelSaveContext ctx); } diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 4618b756f7..41ddf322e5 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -168,19 +168,16 @@ public Func GetDependencies(Func predicate) public Schema InputSchema => Source.Schema; - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; using (var ch = Host.Start("GetEntireRow")) { - Action disp; - var getters = CreateGetters(input, active, out disp); - disposer += disp; - return new RowImpl(input, this, OutputSchema, getters); + var getters = CreateGetters(input, active, out Action disp); + return new RowImpl(input, this, OutputSchema, getters, disp); } } @@ -188,26 +185,29 @@ public Row GetRow(Row input, Func active, out Action disposer) protected abstract int MapColumnIndex(out bool isSrc, int col); - private sealed class RowImpl : Row + private sealed class RowImpl : WrappingRow { private readonly Schema _schema; - private readonly Row _input; private readonly Delegate[] _getters; + private readonly Action _disposer; private readonly RowToRowMapperTransformBase _parent; - public override long Batch => _input.Batch; - - public override long Position => _input.Position; - public override Schema Schema => _schema; - public RowImpl(Row input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters) + public RowImpl(Row input, RowToRowMapperTransformBase parent, Schema schema, Delegate[] getters, Action disposer) + : base(input) { - _input = input; _parent = parent; _schema = schema; _getters = getters; + _disposer = disposer; + } + + protected override void DisposeCore(bool disposing) + { + if (disposing) + _disposer?.Invoke(); } public override ValueGetter GetGetter(int col) @@ -215,7 +215,7 @@ public override ValueGetter GetGetter(int col) bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.GetGetter(index); + return Input.GetGetter(index); Contracts.Assert(_getters[index] != null); var fn = _getters[index] as ValueGetter; @@ -224,17 +224,12 @@ public override ValueGetter GetGetter(int col) return fn; } - public override ValueGetter GetIdGetter() - { - return _input.GetIdGetter(); - } - public override bool IsColumnActive(int col) { bool isSrc; int index = _parent.MapColumnIndex(out isSrc, col); if (isSrc) - return _input.IsColumnActive((index)); + return Input.IsColumnActive((index)); return _getters[index] != null; } } @@ -842,7 +837,8 @@ private sealed class Cursor : SynchronizedCursorBase private readonly bool[] _active; private readonly Delegate[] _getters; - private readonly Action[] _disposers; + private readonly Action _disposer; + private bool _disposed; public Cursor(IChannelProvider provider, OneToOneTransformBase parent, RowCursor input, bool[] active) : base(provider, input) @@ -854,30 +850,29 @@ public Cursor(IChannelProvider provider, OneToOneTransformBase parent, RowCursor _active = active; _getters = new Delegate[parent.Infos.Length]; - // Build the delegates. - List disposers = null; + // Build the disposing delegate. + Action masterDisposer = null; for (int iinfo = 0; iinfo < _getters.Length; iinfo++) { if (!IsColumnActive(parent._bindings.MapIinfoToCol(iinfo))) continue; - Action disposer; - _getters[iinfo] = parent.GetGetterCore(Ch, Input, iinfo, out disposer); + _getters[iinfo] = parent.GetGetterCore(Ch, Input, iinfo, out Action disposer); if (disposer != null) - Utils.Add(ref disposers, disposer); + masterDisposer += disposer; } - - if (Utils.Size(disposers) > 0) - _disposers = disposers.ToArray(); + _disposer = masterDisposer; } - public override void Dispose() + protected override void Dispose(bool disposing) { - if (_disposers != null) + if (_disposed) + return; + if (disposing) { - foreach (var act in _disposers) - act(); + _disposer?.Invoke(); } - base.Dispose(); + _disposed = true; + base.Dispose(disposing); } public override Schema Schema => _bindings.AsSchema; diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 41be7c861f..3388f15c35 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -360,7 +360,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, DataKind kind, KeyRange range, out PrimitiveType itemType) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 81d8e927b4..7a5bf18995 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -713,7 +713,7 @@ public TermMap GetTermMap(int iinfo) return _unboundMaps[iinfo]; } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index dcce4f0018..3fd5ef95e8 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -103,12 +103,13 @@ public Func GetDependencies(Func predicate) yield break; } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { - return new SimpleRow(OutputSchema, input, new[] { CreateScoreGetter(input, predicate, out disposer) }); + var scoreGetter = CreateScoreGetter(input, predicate, out Action disposer); + return new SimpleRow(OutputSchema, input, new[] { scoreGetter }, disposer); } - public abstract Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer); + internal abstract Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer); } // A generic base class for pipeline ensembles. This class contains the combiner. @@ -124,7 +125,7 @@ public Bound(SchemaBindablePipelineEnsemble parent, RoleMappedSchema schema) _combiner = parent.Combiner; } - public override Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer) + internal override Delegate CreateScoreGetter(Row input, Func mapperPredicate, out Action disposer) { disposer = null; @@ -137,13 +138,12 @@ public override Delegate CreateScoreGetter(Row input, Func mapperPred // First get the output row from the pipelines. The input predicate of the predictor // is the output predicate of the pipeline. var inputPredicate = Mappers[i].GetDependencies(mapperPredicate); - var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out Action disp); - disposer += disp; + var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate); // Next we get the output row from the predictors. We activate the score column as output predicate. - var predictorRow = Mappers[i].GetRow(pipelineRow, col => col == ScoreCols[i], out disp); - disposer += disp; + var predictorRow = Mappers[i].GetRow(pipelineRow, col => col == ScoreCols[i]); getters[i] = predictorRow.GetGetter(ScoreCols[i]); + disposer += predictorRow.Dispose; } var comb = _combiner.GetCombiner(); @@ -164,7 +164,8 @@ public ValueGetter GetLabelGetter(Row input, int i, out Action disposer) Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label != null, "Mapper was not trained using a label column"); // The label should be in the output row of the i'th pipeline - var pipelineRow = BoundPipelines[i].GetRow(input, col => col == Mappers[i].InputRoleMappedSchema.Label.Index, out disposer); + var pipelineRow = BoundPipelines[i].GetRow(input, col => col == Mappers[i].InputRoleMappedSchema.Label.Index); + disposer = pipelineRow.Dispose; return RowCursorUtils.GetLabelGetter(pipelineRow, Mappers[i].InputRoleMappedSchema.Label.Index); } @@ -174,14 +175,16 @@ public ValueGetter GetWeightGetter(Row input, int i, out Action disposer if (Mappers[i].InputRoleMappedSchema.Weight == null) { - ValueGetter weight = (ref Single dst) => dst = 1; + ValueGetter weight = (ref float dst) => dst = 1; disposer = null; return weight; } // The weight should be in the output row of the i'th pipeline if it exists. var inputPredicate = Mappers[i].GetDependencies(col => col == Mappers[i].InputRoleMappedSchema.Weight.Index); - var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate, out disposer); - return pipelineRow.GetGetter(Mappers[i].InputRoleMappedSchema.Weight.Index); + var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate); + disposer = pipelineRow.Dispose; + return pipelineRow.GetGetter(Mappers[i].InputRoleMappedSchema.Weight.Index); + } } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 1e9cda953f..a6b5e4ebb7 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -208,11 +208,10 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func predicate) { _ectx.CheckValue(input, nameof(input)); _ectx.CheckValue(predicate, nameof(predicate)); - disposer = null; return new SimpleRow(OutputSchema, input, CreateGetters(input, predicate)); } diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index c18ef3f36a..da47e4e599 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -639,7 +639,7 @@ public static extern int Svd(Layout layout, SvdJob jobu, SvdJob jobvt, int m, int n, float[] a, int lda, float[] s, float[] u, int ldu, float[] vt, int ldvt, float[] superb); } - protected override IRowMapper MakeRowMapper(Schema schema) + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 2a117dfeda..d7b8a96c54 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -152,7 +152,7 @@ public override void Save(ModelSaveContext ctx) new float[] {0, 0, 0, 0, 1} }); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index fd7707ecf7..9165a7e06d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -147,7 +147,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(ImageLoaderTransform).Assembly.FullName); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index ef1a71c47f..6b631c8065 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -400,7 +400,7 @@ public override void Save(ModelSaveContext ctx) info.Save(ctx); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index ef4318c547..8a1ba5829d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -285,7 +285,7 @@ public override void Save(ModelSaveContext ctx) } } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { diff --git a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs index 5b56bc28a0..a627657f62 100644 --- a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs +++ b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs @@ -147,8 +147,9 @@ private IDataView ExecutePipeline() catch (Exception e) { _pipelineExecutionException = e; - var fakeColumn = new KeyValuePair("Blank", TextType.Instance); - _preview = new EmptyDataView(_environment, Schema.Create(new SimpleSchema(_environment, fakeColumn))); + var builder = new SchemaBuilder(); + builder.AddColumn("Blank", TextType.Instance); + _preview = new EmptyDataView(_environment, builder.GetSchema()); } } } diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index dcffd150d5..1e20ebe350 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -221,7 +221,7 @@ public override void Save(ModelSaveContext ctx) foreach (var colName in Outputs) ctx.SaveNonEmptyString(colName); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); private static int[] AdjustDimensions(OnnxShape shape) { @@ -309,7 +309,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() return info; } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); } diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs index 27ab5113d2..e3fee2639f 100644 --- a/src/Microsoft.ML.PCA/PcaTransform.cs +++ b/src/Microsoft.ML.PCA/PcaTransform.cs @@ -540,7 +540,7 @@ private float[][] PostProcess(float[][] y, float[] sigma, float[] z, int d, int return y; } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index 55e50c5e36..4af405b5d5 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -358,15 +358,14 @@ private Delegate[] CreateGetter(Row input, bool[] active) return getters; } - public Row GetRow(Row input, Func predicate, out Action disposer) + public Row GetRow(Row input, Func active) { - var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate); - var getters = CreateGetter(input, active); - disposer = null; + var activeArray = Utils.BuildArray(OutputSchema.ColumnCount, active); + var getters = CreateGetter(input, activeArray); return new SimpleRow(OutputSchema, input, getters); } - public ISchemaBindableMapper Bindable { get { return _parent; } } + public ISchemaBindableMapper Bindable => _parent; } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index df486453b8..1fa2a323b5 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -96,7 +96,7 @@ public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleM } } - public Row GetRow(Row input, Func predicate, out Action action) + public Row GetRow(Row input, Func predicate) { var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16); var featureBuffer = new VBuffer(); @@ -111,7 +111,6 @@ public Row GetRow(Row input, Func predicate, out Action action) inputGetters[f] = input.GetGetter>(_inputColumnIndexes[f]); } - action = null; var getters = new Delegate[2]; if (predicate(0)) { diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 016082d44d..05234bffe0 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -679,7 +679,7 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput return (tfOutputTypes, outputTypes); } - protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); public override void Save(ModelSaveContext ctx) { @@ -913,7 +913,7 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe } } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); } diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs index 8dc4cc27b3..11195e8874 100644 --- a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs +++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs @@ -98,16 +98,14 @@ public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transform { } - internal Row GetStatefulRows(Row input, IRowToRowMapper mapper, Func active, - List rows, out Action disposer) + internal Row GetStatefulRows(Row input, IRowToRowMapper mapper, Func active, List rows) { Contracts.CheckValue(input, nameof(input)); Contracts.CheckValue(active, nameof(active)); - disposer = null; IRowToRowMapper[] innerMappers = new IRowToRowMapper[0]; - if (mapper is CompositeRowToRowMapper) - innerMappers = ((CompositeRowToRowMapper)mapper).InnerMappers; + if (mapper is CompositeRowToRowMapper compositeMapper) + innerMappers = compositeMapper.InnerMappers; if (innerMappers.Length == 0) { @@ -122,10 +120,9 @@ internal Row GetStatefulRows(Row input, IRowToRowMapper mapper, Func throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema.GetColumnName(c)}' active but it was not."); } - var row = mapper.GetRow(input, active, out disposer); + var row = mapper.GetRow(input, active); if (row is StatefulRow statefulRow) rows.Add(statefulRow); - return row; } @@ -140,21 +137,10 @@ internal Row GetStatefulRows(Row input, IRowToRowMapper mapper, Func Row result = input; for (int i = 0; i < innerMappers.Length; ++i) { - Action localDisp; - result = GetStatefulRows(result, innerMappers[i], deps[i], rows, out localDisp); + result = GetStatefulRows(result, innerMappers[i], deps[i], rows); if (result is StatefulRow statefulResult) rows.Add(statefulResult); - - if (localDisp != null) - { - if (disposer == null) - disposer = localDisp; - else - disposer = localDisp + disposer; - // We want the last disposer to be called first, so the order of the addition here is important. - } } - return result; } @@ -168,13 +154,14 @@ private Action CreatePinger(List rows) return pinger; } - internal override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, + private protected override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { List rows = new List(); - Row outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows, out disposer); + Row outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows); var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition); _pinger = CreatePinger(rows); + disposer = outputRowLocal.Dispose; outputRow = cursorable.GetRow(outputRowLocal); } diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 7af0eb1cf1..bef7a167e7 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -259,7 +259,7 @@ private protected virtual void CloneCore(StateBase state) public bool IsRowToRowMapper => false; - public TState StateRef{ get; set; } + public TState StateRef { get; set; } public int StateRefCount; @@ -478,10 +478,12 @@ public Func GetDependencies(Func predicate) return col => false; } - public Row GetRow(Row input, Func active, out Action disposer) => - new RowImpl(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer), - _mapper.CreatePinger(input, active, out disposer)); - + public Row GetRow(Row input, Func active) + { + var getters = _mapper.CreateGetters(input, active, out Action disposer); + var pingers = _mapper.CreatePinger(input, active, out Action pingerDisposer); + return new RowImpl(_bindings.Schema, input, getters, pingers, disposer + pingerDisposer); + } } private sealed class RowImpl : StatefulRow @@ -490,6 +492,8 @@ private sealed class RowImpl : StatefulRow private readonly Row _input; private readonly Delegate[] _getters; private readonly Action _pinger; + private readonly Action _disposer; + private bool _disposed; public override Schema Schema => _schema; @@ -497,7 +501,7 @@ private sealed class RowImpl : StatefulRow public override long Batch => _input.Batch; - public RowImpl(Schema schema, Row input, Delegate[] getters, Action pinger) + public RowImpl(Schema schema, Row input, Delegate[] getters, Action pinger, Action disposer) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(input, nameof(input)); @@ -506,13 +510,22 @@ public RowImpl(Schema schema, Row input, Delegate[] getters, Action pinger _input = input; _getters = getters ?? new Delegate[0]; _pinger = pinger; + _disposer = disposer; } - public override ValueGetter GetIdGetter() + protected override void Dispose(bool disposing) { - return _input.GetIdGetter(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; + base.Dispose(disposing); } + public override ValueGetter GetIdGetter() + => _input.GetIdGetter(); + public override ValueGetter GetGetter(int col) { Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); @@ -745,32 +758,30 @@ public Func GetDependencies(Func predicate) Schema IRowToRowMapper.InputSchema => Source.Schema; - public Row GetRow(Row input, Func active, out Action disposer) + public Row GetRow(Row input, Func active) { Host.CheckValue(input, nameof(input)); Host.CheckValue(active, nameof(active)); Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); - disposer = null; using (var ch = Host.Start("GetEntireRow")) { - Action disp; var activeArr = new bool[OutputSchema.ColumnCount]; for (int i = 0; i < OutputSchema.ColumnCount; i++) activeArr[i] = active(i); var pred = GetActiveOutputColumns(activeArr); - var getters = _mapper.CreateGetters(input, pred, out disp); - disposer += disp; - return new StatefulRow(input, this, OutputSchema, getters, - _mapper.CreatePinger(input, pred, out disp)); + var getters = _mapper.CreateGetters(input, pred, out Action disp); + var pingers = _mapper.CreatePinger(input, pred, out Action pingerDisp); + return new StatefulRowImpl(input, this, OutputSchema, getters, pingers, disp + pingerDisp); } } - private sealed class StatefulRow : TimeSeries.StatefulRow + private sealed class StatefulRowImpl : StatefulRow { private readonly Row _input; private readonly Delegate[] _getters; private readonly Action _pinger; + private readonly Action _disposer; private readonly TimeSeriesRowToRowMapperTransform _parent; @@ -780,14 +791,21 @@ private sealed class StatefulRow : TimeSeries.StatefulRow public override Schema Schema { get; } - public StatefulRow(Row input, TimeSeriesRowToRowMapperTransform parent, - Schema schema, Delegate[] getters, Action pinger) + public StatefulRowImpl(Row input, TimeSeriesRowToRowMapperTransform parent, + Schema schema, Delegate[] getters, Action pinger, Action disposer) { _input = input; _parent = parent; Schema = schema; _getters = getters; _pinger = pinger; + _disposer = disposer; + } + + protected override void Dispose(bool disposing) + { + if (disposing) + _disposer?.Invoke(); } public override ValueGetter GetGetter(int col) @@ -825,6 +843,7 @@ private sealed class Cursor : SynchronizedCursorBase private readonly bool[] _active; private readonly ColumnBindings _bindings; private readonly Action _disposer; + private bool _disposed; public override Schema Schema => _bindings.Schema; @@ -854,19 +873,21 @@ public override ValueGetter GetGetter(int col) Ch.AssertValue(_getters); var getter = _getters[index]; - Ch.Assert(getter != null); - var fn = getter as ValueGetter; - if (fn == null) - throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); - return fn; + Ch.AssertValue(getter); + if (getter is ValueGetter fn) + return fn; + throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); } - public override void Dispose() + protected override void Dispose(bool disposing) { - _disposer?.Invoke(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; + base.Dispose(disposing); } } } - } diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index 5112452c13..21ad23b9b5 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -421,7 +421,7 @@ public override void Save(ModelSaveContext ctx) col.Save(ctx); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 7020eb6ede..f700b35590 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -552,7 +552,7 @@ public override void ReadValue(int position) private readonly GroupKeyColumnChecker[] _groupCheckers; private readonly KeepColumnAggregator[] _aggregators; - public override long Batch { get { return 0; } } + public override long Batch => 0; public override Schema Schema => _parent.OutputSchema; @@ -662,11 +662,19 @@ private bool IsSameGroup() return result; } - public override void Dispose() + private bool _disposed; + + protected override void Dispose(bool disposing) { - _leadingCursor.Dispose(); - _trailingCursor.Dispose(); - base.Dispose(); + if (_disposed) + return; + if (disposing) + { + _leadingCursor.Dispose(); + _trailingCursor.Dispose(); + } + _disposed = true; + base.Dispose(disposing); } public override ValueGetter GetGetter(int col) diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index 6159637da5..27456da815 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -162,7 +162,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index c509881d31..64dae82ef7 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -140,7 +140,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 971aae846c..8273467e7e 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -139,7 +139,7 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 7659d53c32..2f6ebef311 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -493,7 +493,7 @@ public static IDataTransform Create(IHostEnvironment env, IDataView input, param } // Factory method for SignatureLoadModel. - public static MissingValueReplacingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + private static MissingValueReplacingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(LoadName); @@ -505,11 +505,11 @@ public static MissingValueReplacingTransformer Create(IHostEnvironment env, Mode } // Factory method for SignatureLoadDataTransform. - public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); // Factory method for SignatureLoadRowMapper. - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); private VBuffer CreateVBuffer(T[] array) @@ -558,7 +558,7 @@ public override void Save(ModelSaveContext ctx) } } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 8f4e2b6221..52d1f9d735 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -498,7 +498,7 @@ public override void Save(ModelSaveContext ctx) _transformInfos[i].Save(ctx, string.Format("MatrixGenerator{0}", i)); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index dd0738e50e..08ff810c2b 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -1068,10 +1068,8 @@ private static List>> Train(IHostEnvironment env, I return columnMappings; } - protected override IRowMapper MakeRowMapper(Schema schema) - { - return new Mapper(this, schema); - } + private protected override IRowMapper MakeRowMapper(Schema schema) + => new Mapper(this, schema); } /// diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 5351e3f7f4..170141c823 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -540,7 +540,7 @@ public override void Save(ModelSaveContext ctx) } } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index a4e0cee7e6..a9459df027 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -307,7 +307,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); private void CheckResources() { @@ -491,7 +491,7 @@ private bool ResourceExists(StopWordsRemovingEstimator.Language lang) (_resourcesExist[langVal] ?? (_resourcesExist[langVal] = GetResourceFileStreamOrNull(lang) != null).Value); } - public override Func GetDependencies(Func activeOutput) + private protected override Func GetDependenciesCore(Func activeOutput) { var active = new bool[InputSchema.ColumnCount]; foreach (var pair in _colMapNewToOld) @@ -966,7 +966,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, Schema.Create(schema)); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 2a7fa68e4a..52ab910488 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -193,7 +193,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 949c4474bb..d267c6706a 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -183,7 +183,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase { diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index b1c6c73fe5..f83c0c26d4 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -318,7 +318,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write((uint)_modelKind); } - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 1805e9c317..18f652ad6d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -221,7 +221,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); - protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); + private protected override IRowMapper MakeRowMapper(Schema schema) => new Mapper(this, schema); private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa { diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index 4a518be564..fe823a220c 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -79,7 +79,7 @@ private void InitMap(T val, ColumnType type, int hashBits = 20, ValueGetter c == outCol, out var _); + var outRow = mapper.GetRow(_inRow, c => c == outCol); if (type is VectorType) _vecGetter = outRow.GetGetter>(outCol); else diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index e5845db090..b328cb60fc 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -147,7 +147,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe var xf = new HashingTransformer(Env, new[] { info }); var mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out int outCol); - var outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + var outRow = mapper.GetRow(inRow, c => c == outCol); var getter = outRow.GetGetter(outCol); uint result = 0; @@ -159,7 +159,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); - outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + outRow = mapper.GetRow(inRow, c => c == outCol); getter = outRow.GetGetter(outCol); getter(ref result); @@ -177,7 +177,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); - outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + outRow = mapper.GetRow(inRow, c => c == outCol); var vecGetter = outRow.GetGetter>(outCol); VBuffer vecResult = default; @@ -192,7 +192,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); - outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol); vecGetter(ref vecResult); @@ -211,7 +211,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); - outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol); vecGetter(ref vecResult); @@ -224,7 +224,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); - outRow = mapper.GetRow(inRow, c => c == outCol, out var _); + outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol); vecGetter(ref vecResult);