Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Api/CustomMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public Mapper(CustomMappingTransformer<TSrc, TDst> parent, Schema inputSchema)
_typedSrc = TypedCursorable<TSrc>.Create(_host, emptyDataView, false, _parent.InputSchemaDefinition);
}

public Delegate[] CreateGetters(Row input, Func<int, bool> activeOutput, out Action disposer)
Delegate[] IRowMapper.CreateGetters(Row input, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
// If no outputs are active, we short-circuit to empty array of getters.
Expand Down Expand Up @@ -158,7 +158,7 @@ private Delegate GetDstGetter<T>(Row input, int colIndex, Action refreshAction)
return combinedGetter;
}

public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput)
{
if (Enumerable.Range(0, _parent.AddedSchema.Columns.Length).Any(activeOutput))
{
Expand All @@ -169,7 +169,7 @@ public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
return col => false;
}

public Schema.DetachedColumn[] GetOutputColumns()
Schema.DetachedColumn[] IRowMapper.GetOutputColumns()
{
var dstRow = new DataViewConstructionUtils.InputRow<TDst>(_host, _parent.AddedSchema);
// All the output columns of dstRow are our outputs.
Expand Down
25 changes: 15 additions & 10 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public static IDataView LoadPipeWithPredictor(IHostEnvironment env, Stream model
return pipe;
}

public sealed class InputRow<TRow> : InputRowBase<TRow>, IRowBackedBy<TRow>
public sealed class InputRow<TRow> : InputRowBase<TRow>
where TRow : class
{
private TRow _value;
Expand Down Expand Up @@ -416,7 +416,12 @@ public sealed class WrappedCursor : RowCursor
public override long Batch => _toWrap.Batch;
public override Schema Schema => _toWrap.Schema;

public override void Dispose() => _toWrap.Dispose();
protected override void Dispose(bool disposing)
{
if (disposing)
_toWrap.Dispose();
}

public override ValueGetter<TValue> GetGetter<TValue>(int col)
=> _toWrap.GetGetter<TValue>(col);
public override ValueGetter<UInt128> GetIdGetter() => _toWrap.GetIdGetter();
Expand All @@ -434,8 +439,8 @@ public abstract class DataViewCursorBase : InputRowBase<TRow>

protected readonly DataViewBase<TRow> DataView;
protected readonly IChannel Ch;

private long _position;

/// <summary>
/// Zero-based position of the cursor.
/// </summary>
Expand All @@ -462,14 +467,14 @@ protected DataViewCursorBase(IHostEnvironment env, DataViewBase<TRow> dataView,
/// </summary>
protected bool IsGood => State == CursorState.Good;

public virtual void Dispose()
protected sealed override void Dispose(bool disposing)
{
if (State != CursorState.Done)
{
Ch.Dispose();
_position = -1;
State = CursorState.Done;
}
if (State == CursorState.Done)
return;
Ch.Dispose();
_position = -1;
base.Dispose(disposing);
State = CursorState.Done;
}

public bool MoveNext()
Expand Down
25 changes: 20 additions & 5 deletions src/Microsoft.ML.Api/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,15 @@ public override void Predict(TSrc example, ref TDst prediction)
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public abstract class PredictionEngineBase<TSrc, TDst>
public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
private bool _disposed;

[BestFriend]
private protected ITransformer Transformer { get; }

Expand Down Expand Up @@ -193,12 +195,14 @@ private protected PredictionEngineBase(IHostEnvironment env, ITransformer transf
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
}

internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
[BestFriend]
private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs<TDst> outputRow)
{
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition);
var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer);
var outputRowLocal = mapper.GetRow(inputRow, col => true);
outputRow = cursorable.GetRow(outputRowLocal);
disposer = inputRow.Dispose;
}

protected virtual Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
Expand All @@ -208,9 +212,20 @@ protected virtual Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionCon
return transformer.GetRowToRowMapper;
}

~PredictionEngineBase()
public void Dispose()
{
Disposing(true);
GC.SuppressFinalize(this);
}

[BestFriend]
private protected void Disposing(bool disposing)
{
_disposer?.Invoke();
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();
_disposed = true;
}

/// <summary>
Expand Down
32 changes: 12 additions & 20 deletions src/Microsoft.ML.Api/StatefulFilterTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform<TS
_bindings = new ColumnBindings(Schema.Create(newSource.Schema), DataViewConstructionUtils.GetSchemaColumns(_addedSchema));
}

public bool CanShuffle { get { return false; } }
public bool CanShuffle => false;

Schema IDataView.Schema => OutputSchema;

Expand Down Expand Up @@ -132,10 +132,7 @@ public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func
return new[] { GetRowCursor(predicate, rand) };
}

public IDataView Source
{
get { return _source; }
}
public IDataView Source => _source;

public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
{
Expand All @@ -158,10 +155,7 @@ private sealed class Cursor : RootCursorBase

private bool _disposed;

public override long Batch
{
get { return _input.Batch; }
}
public override long Batch => _input.Batch;

public Cursor(StatefulFilterTransform<TSrc, TDst, TState> parent, RowCursor<TSrc> input, Func<int, bool> predicate)
: base(parent.Host)
Expand Down Expand Up @@ -196,24 +190,22 @@ public Cursor(StatefulFilterTransform<TSrc, TDst, TState> parent, RowCursor<TSrc
_appendedRow = appendedDataView.GetRowCursor(appendedPredicate);
}

public override void Dispose()
protected override void Dispose(bool disposing)
{
if (!_disposed)
if (_disposed)
return;
if (disposing)
{
var disposableState = _state as IDisposable;
var disposableSrc = _src as IDisposable;
var disposableDst = _dst as IDisposable;
if (disposableState != null)
if (_state is IDisposable disposableState)
disposableState.Dispose();
if (disposableSrc != null)
if (_src is IDisposable disposableSrc)
disposableSrc.Dispose();
if (disposableDst != null)
if (_dst is IDisposable disposableDst)
disposableDst.Dispose();

_input.Dispose();
base.Dispose();
_disposed = true;
}
_disposed = true;
base.Dispose(disposing);
}

public override ValueGetter<UInt128> GetIdGetter()
Expand Down
79 changes: 34 additions & 45 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Api
/// </summary>
/// <typeparam name="TRow">The user-defined type that is being populated while cursoring.</typeparam>
[BestFriend]
internal interface IRowReadableAs<TRow>
internal interface IRowReadableAs<TRow> : IDisposable
where TRow : class
{
/// <summary>
Expand All @@ -28,22 +28,6 @@ internal interface IRowReadableAs<TRow>
void FillValues(TRow row);
}

/// <summary>
/// This interface is an <see cref="Row"/> with 'strongly typed' binding.
/// It can accept values of type <typeparamref name="TRow"/> and present the value as a row.
/// </summary>
/// <typeparam name="TRow">The user-defined type that provides the values while cursoring.</typeparam>
internal interface IRowBackedBy<TRow>
where TRow : class
{
/// <summary>
/// Accepts the fields of the user-supplied <paramref name="row"/> object and publishes the instance as a row.
/// If the row is accessed prior to any object being set, then the data accessors on the row should throw.
/// </summary>
/// <param name="row">The row object. Cannot be <c>null</c>.</param>
void ExtractValues(TRow row);
}

/// <summary>
/// This interface provides cursoring through a <see cref="IDataView"/> via a 'strongly typed' binding.
/// It can populate the user-supplied object's fields with the values of the current row.
Expand Down Expand Up @@ -253,36 +237,34 @@ public static TypedCursorable<TRow> Create(IHostEnvironment env, IDataView data,
return new TypedCursorable<TRow>(env, data, ignoreMissingColumns, outSchema);
}

private abstract class TypedRowBase
private abstract class TypedRowBase : WrappingRow
{
protected readonly IChannel Ch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected readonly IChannel Ch; [](start = 12, length = 31)

Not related to PR, but why we have channel here?
All it does is just work as IExceptionContext. Why it has to be IChannel?

private readonly Row _input;
private readonly Action<TRow>[] _setters;

public long Batch => _input.Batch;

public long Position => _input.Position;

public Schema Schema => _input.Schema;
public override Schema Schema => base.Input.Schema;

public TypedRowBase(TypedCursorable<TRow> parent, Row input, string channelMessage)
: base(input)
{
Contracts.AssertValue(parent);
Contracts.AssertValue(parent._host);
Ch = parent._host.Start(channelMessage);
Ch.AssertValue(input);

_input = input;

int n = parent._pokes.Length;
Ch.Assert(n == parent._columns.Length);
Ch.Assert(n == parent._columnIndices.Length);
_setters = new Action<TRow>[n];
for (int i = 0; i < n; i++)
_setters[i] = GenerateSetter(_input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]);
_setters[i] = GenerateSetter(Input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]);
}

public ValueGetter<UInt128> GetIdGetter() => _input.GetIdGetter();
protected override void DisposeCore(bool disposing)
{
if (disposing)
Ch.Dispose();
}

private Action<TRow> GenerateSetter(Row input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
{
Expand All @@ -292,7 +274,7 @@ private Action<TRow> GenerateSetter(Row input, int index, InternalSchemaDefiniti
Func<Row, int, Delegate, Delegate, Action<TRow>> del;
if (fieldType.IsArray)
{
Ch.Assert(colType.IsVector);
Ch.Assert(colType is VectorType);
// VBuffer<ReadOnlyMemory<char>> -> String[]
if (fieldType.GetElementType() == typeof(string))
{
Expand Down Expand Up @@ -459,14 +441,14 @@ public virtual void FillValues(TRow row)
setter(row);
}

public bool IsColumnActive(int col)
public override bool IsColumnActive(int col)
{
return _input.IsColumnActive(col);
return Input.IsColumnActive(col);
}

public ValueGetter<TValue> GetGetter<TValue>(int col)
public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
return _input.GetGetter<TValue>(col);
return Input.GetGetter<TValue>(col);
}
}

Expand All @@ -481,6 +463,15 @@ public TypedRow(TypedCursorable<TRow> parent, Row input)
private sealed class RowImplementation : IRowReadableAs<TRow>
{
private readonly TypedRow _row;
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;
_row.Dispose();
_disposed = true;
}

public RowImplementation(TypedRow row) => _row = row;

Expand All @@ -496,6 +487,7 @@ private sealed class RowImplementation : IRowReadableAs<TRow>
private sealed class RowCursorImplementation : RowCursor<TRow>
{
private readonly TypedCursor _cursor;
private bool _disposed;

public RowCursorImplementation(TypedCursor cursor) => _cursor = cursor;

Expand All @@ -504,7 +496,15 @@ private sealed class RowCursorImplementation : RowCursor<TRow>
public override long Batch => _cursor.Batch;
public override Schema Schema => _cursor.Schema;

public override void Dispose() { }
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
_cursor.Dispose();
_disposed = true;
base.Dispose(disposing);
}

public override void FillValues(TRow row) => _cursor.FillValues(row);
public override ValueGetter<TValue> GetGetter<TValue>(int col) => _cursor.GetGetter<TValue>(col);
Expand All @@ -518,7 +518,6 @@ public override void Dispose() { }
private sealed class TypedCursor : TypedRowBase
{
private readonly RowCursor _input;
private bool _disposed;

public TypedCursor(TypedCursorable<TRow> parent, RowCursor input)
: base(parent, input, "Cursor")
Expand All @@ -534,16 +533,6 @@ public override void FillValues(TRow row)

public CursorState State => _input.State;

public void Dispose()
{
if (!_disposed)
{
_input.Dispose();
Ch.Dispose();
_disposed = true;
}
}

public bool MoveNext() => _input.MoveNext();
public bool MoveMany(long count) => _input.MoveMany(count);
public RowCursor GetRootCursor() => _input.GetRootCursor();
Expand Down
Loading