Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 105 additions & 118 deletions src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using System;
using System.Linq;

[assembly: LoadableClass(typeof(ChooseColumnsByIndexTransform), typeof(ChooseColumnsByIndexTransform.Arguments), typeof(SignatureDataTransform),
"", "ChooseColumnsByIndexTransform", "ChooseColumnsByIndex")]
Expand All @@ -31,158 +30,146 @@ public sealed class Arguments
public bool Drop;
}

private sealed class Bindings : ISchema
private sealed class Bindings
{
public readonly int[] Sources;

private readonly Schema _input;
private readonly Dictionary<string, int> _nameToIndex;

// The following argument is used only to inform serialization.
private readonly int[] _dropped;

public Schema AsSchema { get; }

public Bindings(Arguments args, Schema schemaInput)
/// <summary>
/// A collection of source column indexes after removing those we want to drop. Specifically, j=_sources[i] means
/// that the i-th output column in the output schema is the j-th column in the input schema.
/// </summary>
private readonly int[] _sources;

/// <summary>
/// Input schema of this transform. It's useful when determining column dependencies and other
/// relations between input and output schemas.
/// </summary>
private readonly Schema _sourceSchema;

/// <summary>
/// Some column indexes in the input schema. <see cref="_sources"/> is computed from <see cref="_selectedColumnIndexes"/>
/// and <see cref="_drop"/>.
/// </summary>
private readonly int[] _selectedColumnIndexes;

/// <summary>
/// True, if this transform drops selected columns indexed by <see cref="_selectedColumnIndexes"/>.
/// </summary>
private readonly bool _drop;

// This transform's output schema.
internal Schema OutputSchema { get; }

internal Bindings(Arguments args, Schema sourceSchema)
{
Contracts.AssertValue(args);
Contracts.AssertValue(schemaInput);
Contracts.AssertValue(sourceSchema);

_sourceSchema = sourceSchema;

_input = schemaInput;
// Store user-specified arguments as the major state of this transform. Only the major states will
// be saved and all other attributes can be reconstructed from them.
_drop = args.Drop;
_selectedColumnIndexes = args.Index;

int[] indexCopy = args.Index == null ? new int[0] : args.Index.ToArray();
BuildNameDict(indexCopy, args.Drop, out Sources, out _dropped, out _nameToIndex, user: true);
// Compute actually used attributes in runtime from those major states.
ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources);

AsSchema = Schema.Create(this);
// All necessary fields in this class are set, so we can compute output schema now.
OutputSchema = ComputeOutputSchema();
}

private void BuildNameDict(int[] indexCopy, bool drop, out int[] sources, out int[] dropped, out Dictionary<string, int> nameToCol, bool user)
/// <summary>
/// Common method of computing <see cref="_sources"/> from necessary parameters. This function is used in constructors.
/// </summary>
private static void ComputeSources(bool drop, int[] selectedColumnIndexes, Schema sourceSchema, out int[] sources)
{
Contracts.AssertValue(indexCopy);
foreach (int col in indexCopy)
{
if (col < 0 || _input.ColumnCount <= col)
{
const string fmt = "Column index {0} invalid for input with {1} columns";
if (user)
throw Contracts.ExceptUserArg(nameof(Arguments.Index), fmt, col, _input.ColumnCount);
else
throw Contracts.ExceptDecode(fmt, col, _input.ColumnCount);
}
}
// Compute the mapping, <see cref="_sources"/>, from output column index to input column index.
if (drop)
{
sources = Enumerable.Range(0, _input.ColumnCount).Except(indexCopy).ToArray();
dropped = indexCopy;
}
else
{
sources = indexCopy;
dropped = null;
}
if (user)
Contracts.CheckUserArg(sources.Length > 0, nameof(Arguments.Index), "Choose columns by index has no output columns");
// Drop columns indexed by args.Index
sources = Enumerable.Range(0, sourceSchema.ColumnCount).Except(selectedColumnIndexes).ToArray();
else
Contracts.CheckDecode(sources.Length > 0, "Choose columns by index has no output columns");
nameToCol = new Dictionary<string, int>();
for (int c = 0; c < sources.Length; ++c)
nameToCol[_input.GetColumnName(sources[c])] = c;
}

public Bindings(ModelLoadContext ctx, Schema schemaInput)
{
Contracts.AssertValue(ctx);
Contracts.AssertValue(schemaInput);

_input = schemaInput;

// *** Binary format ***
// bool(as byte): whether the indicated source columns are columns to keep, or drop
// int: number of source column indices
// int[]: source column indices
// Keep columns indexed by args.Index
sources = selectedColumnIndexes;

bool isDrop = ctx.Reader.ReadBoolByte();
BuildNameDict(ctx.Reader.ReadIntArray() ?? new int[0], isDrop, out Sources, out _dropped, out _nameToIndex, user: false);
AsSchema = Schema.Create(this);
// Make sure the output of this transform is meaningful.
Contracts.Check(sources.Length > 0, "Choose columns by index has no output column.");
}

public void Save(ModelSaveContext ctx)
/// <summary>
/// After <see cref="_sourceSchema"/> and <see cref="_sources"/> are set, pick up selected columns from <see cref="_sourceSchema"/> to create <see cref="OutputSchema"/>
/// Note that <see cref="_sources"/> tells us what columns in <see cref="_sourceSchema"/> are put into <see cref="OutputSchema"/>.
/// </summary>
private Schema ComputeOutputSchema()
{
Contracts.AssertValue(ctx);
var schemaBuilder = new SchemaBuilder();
for (int i = 0; i < _sources.Length; ++i)
{
// selectedIndex is an column index of input schema. Note that the input column indexed by _sources[i] in _sourceSchema is sent
// to the i-th column in the output schema.
var selectedIndex = _sources[i];

// *** Binary format ***
// bool(as byte): whether the indicated columns are columns to keep, or drop
// int: number of source column indices
// int[]: source column indices
// The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided.
string fmt = string.Format("Column index {0} invalid for input with {1} columns", selectedIndex, _sourceSchema.ColumnCount);
Contracts.Check(selectedIndex < _sourceSchema.ColumnCount, fmt);

ctx.Writer.WriteBoolByte(_dropped != null);
ctx.Writer.WriteIntArray(_dropped ?? Sources);
// Copy the selected column into output schema.
var selectedColumn = _sourceSchema[selectedIndex];
schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata);
}
return schemaBuilder.GetSchema();
}

public int ColumnCount
internal Bindings(ModelLoadContext ctx, Schema sourceSchema)
{
get { return Sources.Length; }
}
Contracts.AssertValue(ctx);
Contracts.AssertValue(sourceSchema);

public bool TryGetColumnIndex(string name, out int col)
{
Contracts.CheckValueOrNull(name);
if (name == null)
{
col = default(int);
return false;
}
return _nameToIndex.TryGetValue(name, out col);
}
_sourceSchema = sourceSchema;

public string GetColumnName(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _input.GetColumnName(Sources[col]);
}
// *** Binary format ***
// bool (as byte): operation mode
// int[]: selected source column indices
_drop = ctx.Reader.ReadBoolByte();
_selectedColumnIndexes = ctx.Reader.ReadIntArray();

public ColumnType GetColumnType(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _input.GetColumnType(Sources[col]);
}
// Compute actually used attributes in runtime from those major states.
ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources);

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _input.GetMetadataTypes(Sources[col]);
_sourceSchema = sourceSchema;
OutputSchema = ComputeOutputSchema();
}

public ColumnType GetMetadataTypeOrNull(string kind, int col)
internal void Save(ModelSaveContext ctx)
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _input.GetMetadataTypeOrNull(kind, Sources[col]);
}
Contracts.AssertValue(ctx);

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
_input.GetMetadata(kind, Sources[col], ref value);
// *** Binary format ***
// bool (as byte): operation mode
// int[]: selected source column indices
ctx.Writer.WriteBoolByte(_drop);
ctx.Writer.WriteIntArray(_selectedColumnIndexes);
}

internal bool[] GetActive(Func<int, bool> predicate)
{
return Utils.BuildArray(ColumnCount, predicate);
return Utils.BuildArray(OutputSchema.ColumnCount, predicate);
}

internal Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
Contracts.AssertValue(predicate);
var active = new bool[_input.ColumnCount];
for (int i = 0; i < Sources.Length; i++)
var active = new bool[_sourceSchema.ColumnCount];
for (int i = 0; i < _sources.Length; i++)
{
if (predicate(i))
active[Sources[i]] = true;
active[_sources[i]] = true;
}
return col => 0 <= col && col < active.Length && active[col];
}

/// <summary>
/// Given the column index in the output schema, this function returns its source column's index in the input schema.
/// </summary>
internal int GetSourceColumnIndex(int outputColumnIndex) => _sources[outputColumnIndex];
}

public const string LoaderSignature = "ChooseColumnsIdxTrans";
Expand Down Expand Up @@ -245,7 +232,7 @@ public override void Save(ModelSaveContext ctx)
_bindings.Save(ctx);
}

public override Schema OutputSchema => _bindings.AsSchema;
public override Schema OutputSchema => _bindings.OutputSchema;

protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
{
Expand Down Expand Up @@ -292,25 +279,25 @@ public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, boo
: base(provider, input)
{
Ch.AssertValue(bindings);
Ch.Assert(active == null || active.Length == bindings.ColumnCount);
Ch.Assert(active == null || active.Length == bindings.OutputSchema.ColumnCount);

_bindings = bindings;
_active = active;
}

public override Schema Schema => _bindings.AsSchema;
public override Schema Schema => _bindings.OutputSchema;

public override bool IsColumnActive(int col)
{
Ch.Check(0 <= col && col < _bindings.ColumnCount);
Ch.Check(0 <= col && col < _bindings.OutputSchema.ColumnCount);
return _active == null || _active[col];
}

public override ValueGetter<TValue> GetGetter<TValue>(int col)
{
Ch.Check(IsColumnActive(col));

var src = _bindings.Sources[col];
var src = _bindings.GetSourceColumnIndex(col);
return Input.GetGetter<TValue>(src);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=Name:TX:0
#@ col=Label:R4:1
#@ }
Name Label
25 0
38 0
28 1
44 1
18 0
34 0
29 0
63 1
24 0
55 0
65 1
36 0
26 0
58 0
48 1
43 1
20 0
43 0
37 0
40 1
Wrote 20 rows of length 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=Name:TX:0
#@ col=Label:R4:1
#@ }
Name Label
25 0
38 0
28 1
44 1
18 0
34 0
29 0
63 1
24 0
55 0
65 1
36 0
26 0
58 0
48 1
43 1
20 0
43 0
37 0
40 1
Wrote 20 rows of length 2
Loading