Skip to content

Commit

Permalink
addressing comments from round 6
Browse files Browse the repository at this point in the history
  • Loading branch information
sfilipi committed Feb 13, 2019
1 parent b4c897c commit 8da2253
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 71 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Expand Up @@ -78,7 +78,7 @@ public Row GetRow(Row input, Func<int, bool> active)
{
var outputColumns = InnerMappers[i].OutputSchema.Where(c => deps[i](c.Index));
var cols = InnerMappers[i].GetDependencies(outputColumns).ToArray();
deps[i - 1] = c => cols.Count() > 0 ? cols.Any(col => col.Index == c) : false;
deps[i - 1] = c => cols.Length > 0 ? cols.Any(col => col.Index == c) : false;
}

Row result = input;
Expand Down
132 changes: 66 additions & 66 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Expand Up @@ -28,7 +28,7 @@ namespace Microsoft.ML.Data
/// </summary>
[BestFriend]
internal interface IRowMapper : ICanSaveModel
{
{
/// <summary>
/// Returns the input columns needed for the requested output columns.
/// </summary>
Expand All @@ -55,7 +55,7 @@ internal interface IRowMapper : ICanSaveModel
/// Returns parent transfomer which uses this mapper.
/// </summary>
ITransformer GetTransformer();
}
}

public delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema);

Expand All @@ -66,7 +66,7 @@ internal interface IRowMapper : ICanSaveModel
/// </summary>
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
ITransformCanSaveOnnx, ITransformCanSavePfa, ITransformTemplate
{
{
private readonly IRowMapper _mapper;
private readonly ColumnBindings _bindings;

Expand All @@ -76,15 +76,15 @@ public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMa
public const string RegistrationName = "RowToRowMapperTransform";
public const string LoaderSignature = "RowToRowMapper";
private static VersionInfo GetVersionInfo()
{
{
return new VersionInfo(
modelSignature: "ROW MPPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName);
}
}

public override Schema OutputSchema => _bindings.Schema;

Expand All @@ -95,44 +95,44 @@ private static VersionInfo GetVersionInfo()
[BestFriend]
internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<Schema, IRowMapper> mapperFactory)
: base(env, RegistrationName, input)
{
{
Contracts.CheckValue(mapper, nameof(mapper));
Contracts.CheckValueOrNull(mapperFactory);
_mapper = mapper;
_mapperFactory = mapperFactory;
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
}
}

[BestFriend]
internal static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
{
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValue(mapper, nameof(mapper));
return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema;
}
}

private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, input)
{
{
// *** Binary format ***
// _mapper

ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
}
}

public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
{
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
h.CheckValue(input, nameof(input));
return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
}
}

private protected override void SaveModel(ModelSaveContext ctx)
{
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
Expand All @@ -141,14 +141,14 @@ private protected override void SaveModel(ModelSaveContext ctx)
// _mapper

ctx.SaveModel(_mapper, "Mapper");
}
}

/// <summary>
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
/// and the needed active input columns, given a predicate for the needed active output columns.
/// </summary>
private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Column> inputColumns)
{
{
int n = _bindings.Schema.Count;
var active = Utils.BuildArray(n, predicate);
Contracts.Assert(active.Length == n);
Expand All @@ -163,13 +163,13 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
var predicateIn = _mapper.GetDependencies(predicateOut);

// Combine the two sets of input columns.
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index]|| predicateIn(col.Index));
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));

return active;
}
}

private Func<int, bool> GetActiveOutputColumns(bool[] active)
{
{
Contracts.AssertValue(active);
Contracts.Assert(active.Length == _bindings.Schema.Count);

Expand All @@ -179,26 +179,26 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
};
}
}

protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
{
{
Host.AssertValue(predicate, "predicate");
if (_bindings.AddedColumnIndices.Any(predicate))
return true;
return null;
}
}

protected override RowCursor GetRowCursorCore(IEnumerable<Schema.Column> columnsNeeded, Random rand = null)
{
{
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
var active = GetActive(predicate, out IEnumerable<Schema.Column> inputCols);

return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
}
}

public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNeeded, int n, Random rand = null)
{
{
Host.CheckValueOrNull(rand);

var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
Expand All @@ -215,89 +215,89 @@ public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNe
for (int i = 0; i < inputs.Length; i++)
cursors[i] = new Cursor(Host, inputs[i], this, active);
return cursors;
}
}

void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
{
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsOnnx onnx)
{
{
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
onnx.SaveAsOnnx(ctx);
}
}
}

void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
{
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsPfa pfa)
{
{
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
pfa.SaveAsPfa(ctx);
}
}
}

/// <summary>
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<Schema.Column> IRowToRowMapper.GetDependencies(IEnumerable<Schema.Column> dependingColumns)
{
{
var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
GetActive(predicate, out IEnumerable<Schema.Column> inputColumns);
GetActive(predicate, out var inputColumns);
return inputColumns;
}
}

public Schema InputSchema => Source.Schema;

public Row GetRow(Row input, Func<int, bool> active)
{
{
Host.CheckValue(input, nameof(input));
Host.CheckValue(active, nameof(active));
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");

using (var ch = Host.Start("GetEntireRow"))
{
{
var activeArr = new bool[OutputSchema.Count];
for (int i = 0; i < OutputSchema.Count; i++)
activeArr[i] = active(i);
var pred = GetActiveOutputColumns(activeArr);
var getters = _mapper.CreateGetters(input, pred, out Action disp);
return new RowImpl(input, this, OutputSchema, getters, disp);
}
}
}

IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
{
{
Contracts.CheckValue(env, nameof(env));

Contracts.CheckValue(newSource, nameof(newSource));
if (_mapperFactory != null)
{
{
var newMapper = _mapperFactory(newSource.Schema);
return new RowToRowMapperTransform(env.Register(nameof(RowToRowMapperTransform)), newSource, newMapper, _mapperFactory);
}
}
// Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
using (var stream = new MemoryStream())
{
using (var rep = RepositoryWriter.CreateNew(stream, env))
{
using (var rep = RepositoryWriter.CreateNew(stream, env))
{
ModelSaveContext.SaveModel(rep, this, "model");
rep.Commit();
}
}

stream.Position = 0;
using (var rep = RepositoryReader.Open(stream, env))
{
{
IDataTransform newData;
ModelLoadContext.LoadModel<IDataTransform, SignatureLoadDataTransform>(env,
out newData, rep, "model", newSource);
return newData;
}
}
}
}

private sealed class RowImpl : WrappingRow
{
{
private readonly Delegate[] _getters;
private readonly RowToRowMapperTransform _parent;
private readonly Action _disposer;
Expand All @@ -306,21 +306,21 @@ private sealed class RowImpl : WrappingRow

public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action disposer)
: base(input)
{
{
_parent = parent;
Schema = schema;
_getters = getters;
_disposer = disposer;
}
}

protected override void DisposeCore(bool disposing)
{
{
if (disposing)
_disposer?.Invoke();
}
}

public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
Expand All @@ -331,20 +331,20 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
if (fn == null)
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}
}

public override bool IsColumnActive(int col)
{
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
if (isSrc)
return Input.IsColumnActive((index));
return _getters[index] != null;
}
}
}

private sealed class Cursor : SynchronizedCursorBase
{
{
private readonly Delegate[] _getters;
private readonly bool[] _active;
private readonly ColumnBindings _bindings;
Expand All @@ -355,21 +355,21 @@ private sealed class Cursor : SynchronizedCursorBase

public Cursor(IChannelProvider provider, RowCursor input, RowToRowMapperTransform parent, bool[] active)
: base(provider, input)
{
{
var pred = parent.GetActiveOutputColumns(active);
_getters = parent._mapper.CreateGetters(input, pred, out _disposer);
_active = active;
_bindings = parent._bindings;
}
}

public override bool IsColumnActive(int col)
{
{
Ch.Check(0 <= col && col < _bindings.Schema.Count);
return _active[col];
}
}

public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
{
Ch.Check(IsColumnActive(col));

bool isSrc;
Expand All @@ -384,22 +384,22 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
if (fn == null)
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}
}

protected override void Dispose(bool disposing)
{
{
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();
_disposed = true;
base.Dispose(disposing);
}
}
}

internal ITransformer GetTransformer()
{
{
return _mapper.GetTransformer();
}
}
}
}

0 comments on commit 8da2253

Please sign in to comment.