diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index b0fa32250b..7abfa534e4 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -2,15 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using System; +using System.Linq; [assembly: LoadableClass(typeof(ChooseColumnsByIndexTransform), typeof(ChooseColumnsByIndexTransform.Arguments), typeof(SignatureDataTransform), "", "ChooseColumnsByIndexTransform", "ChooseColumnsByIndex")] @@ -31,158 +30,146 @@ public sealed class Arguments public bool Drop; } - private sealed class Bindings : ISchema + private sealed class Bindings { - public readonly int[] Sources; - - private readonly Schema _input; - private readonly Dictionary _nameToIndex; - - // The following argument is used only to inform serialization. - private readonly int[] _dropped; - - public Schema AsSchema { get; } - - public Bindings(Arguments args, Schema schemaInput) + /// + /// A collection of source column indexes after removing those we want to drop. Specifically, j=_sources[i] means + /// that the i-th output column in the output schema is the j-th column in the input schema. + /// + private readonly int[] _sources; + + /// + /// Input schema of this transform. It's useful when determining column dependencies and other + /// relations between input and output schemas. + /// + private readonly Schema _sourceSchema; + + /// + /// Some column indexes in the input schema. is computed from + /// and . + /// + private readonly int[] _selectedColumnIndexes; + + /// + /// True, if this transform drops selected columns indexed by . + /// + private readonly bool _drop; + + // This transform's output schema. + internal Schema OutputSchema { get; } + + internal Bindings(Arguments args, Schema sourceSchema) { Contracts.AssertValue(args); - Contracts.AssertValue(schemaInput); + Contracts.AssertValue(sourceSchema); + + _sourceSchema = sourceSchema; - _input = schemaInput; + // Store user-specified arguments as the major state of this transform. Only the major states will + // be saved and all other attributes can be reconstructed from them. + _drop = args.Drop; + _selectedColumnIndexes = args.Index; - int[] indexCopy = args.Index == null ? new int[0] : args.Index.ToArray(); - BuildNameDict(indexCopy, args.Drop, out Sources, out _dropped, out _nameToIndex, user: true); + // Compute actually used attributes in runtime from those major states. + ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources); - AsSchema = Schema.Create(this); + // All necessary fields in this class are set, so we can compute output schema now. + OutputSchema = ComputeOutputSchema(); } - private void BuildNameDict(int[] indexCopy, bool drop, out int[] sources, out int[] dropped, out Dictionary nameToCol, bool user) + /// + /// Common method of computing from necessary parameters. This function is used in constructors. + /// + private static void ComputeSources(bool drop, int[] selectedColumnIndexes, Schema sourceSchema, out int[] sources) { - Contracts.AssertValue(indexCopy); - foreach (int col in indexCopy) - { - if (col < 0 || _input.ColumnCount <= col) - { - const string fmt = "Column index {0} invalid for input with {1} columns"; - if (user) - throw Contracts.ExceptUserArg(nameof(Arguments.Index), fmt, col, _input.ColumnCount); - else - throw Contracts.ExceptDecode(fmt, col, _input.ColumnCount); - } - } + // Compute the mapping, , from output column index to input column index. if (drop) - { - sources = Enumerable.Range(0, _input.ColumnCount).Except(indexCopy).ToArray(); - dropped = indexCopy; - } - else - { - sources = indexCopy; - dropped = null; - } - if (user) - Contracts.CheckUserArg(sources.Length > 0, nameof(Arguments.Index), "Choose columns by index has no output columns"); + // Drop columns indexed by args.Index + sources = Enumerable.Range(0, sourceSchema.ColumnCount).Except(selectedColumnIndexes).ToArray(); else - Contracts.CheckDecode(sources.Length > 0, "Choose columns by index has no output columns"); - nameToCol = new Dictionary(); - for (int c = 0; c < sources.Length; ++c) - nameToCol[_input.GetColumnName(sources[c])] = c; - } - - public Bindings(ModelLoadContext ctx, Schema schemaInput) - { - Contracts.AssertValue(ctx); - Contracts.AssertValue(schemaInput); - - _input = schemaInput; - - // *** Binary format *** - // bool(as byte): whether the indicated source columns are columns to keep, or drop - // int: number of source column indices - // int[]: source column indices + // Keep columns indexed by args.Index + sources = selectedColumnIndexes; - bool isDrop = ctx.Reader.ReadBoolByte(); - BuildNameDict(ctx.Reader.ReadIntArray() ?? new int[0], isDrop, out Sources, out _dropped, out _nameToIndex, user: false); - AsSchema = Schema.Create(this); + // Make sure the output of this transform is meaningful. + Contracts.Check(sources.Length > 0, "Choose columns by index has no output column."); } - public void Save(ModelSaveContext ctx) + /// + /// After and are set, pick up selected columns from to create + /// Note that tells us what columns in are put into . + /// + private Schema ComputeOutputSchema() { - Contracts.AssertValue(ctx); + var schemaBuilder = new SchemaBuilder(); + for (int i = 0; i < _sources.Length; ++i) + { + // selectedIndex is an column index of input schema. Note that the input column indexed by _sources[i] in _sourceSchema is sent + // to the i-th column in the output schema. + var selectedIndex = _sources[i]; - // *** Binary format *** - // bool(as byte): whether the indicated columns are columns to keep, or drop - // int: number of source column indices - // int[]: source column indices + // The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided. + string fmt = string.Format("Column index {0} invalid for input with {1} columns", selectedIndex, _sourceSchema.ColumnCount); + Contracts.Check(selectedIndex < _sourceSchema.ColumnCount, fmt); - ctx.Writer.WriteBoolByte(_dropped != null); - ctx.Writer.WriteIntArray(_dropped ?? Sources); + // Copy the selected column into output schema. + var selectedColumn = _sourceSchema[selectedIndex]; + schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata); + } + return schemaBuilder.GetSchema(); } - public int ColumnCount + internal Bindings(ModelLoadContext ctx, Schema sourceSchema) { - get { return Sources.Length; } - } + Contracts.AssertValue(ctx); + Contracts.AssertValue(sourceSchema); - public bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - if (name == null) - { - col = default(int); - return false; - } - return _nameToIndex.TryGetValue(name, out col); - } + _sourceSchema = sourceSchema; - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetColumnName(Sources[col]); - } + // *** Binary format *** + // bool (as byte): operation mode + // int[]: selected source column indices + _drop = ctx.Reader.ReadBoolByte(); + _selectedColumnIndexes = ctx.Reader.ReadIntArray(); - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetColumnType(Sources[col]); - } + // Compute actually used attributes in runtime from those major states. + ComputeSources(_drop, _selectedColumnIndexes, _sourceSchema, out _sources); - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetMetadataTypes(Sources[col]); + _sourceSchema = sourceSchema; + OutputSchema = ComputeOutputSchema(); } - public ColumnType GetMetadataTypeOrNull(string kind, int col) + internal void Save(ModelSaveContext ctx) { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _input.GetMetadataTypeOrNull(kind, Sources[col]); - } + Contracts.AssertValue(ctx); - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - _input.GetMetadata(kind, Sources[col], ref value); + // *** Binary format *** + // bool (as byte): operation mode + // int[]: selected source column indices + ctx.Writer.WriteBoolByte(_drop); + ctx.Writer.WriteIntArray(_selectedColumnIndexes); } internal bool[] GetActive(Func predicate) { - return Utils.BuildArray(ColumnCount, predicate); + return Utils.BuildArray(OutputSchema.ColumnCount, predicate); } internal Func GetDependencies(Func predicate) { Contracts.AssertValue(predicate); - var active = new bool[_input.ColumnCount]; - for (int i = 0; i < Sources.Length; i++) + var active = new bool[_sourceSchema.ColumnCount]; + for (int i = 0; i < _sources.Length; i++) { if (predicate(i)) - active[Sources[i]] = true; + active[_sources[i]] = true; } return col => 0 <= col && col < active.Length && active[col]; } + + /// + /// Given the column index in the output schema, this function returns its source column's index in the input schema. + /// + internal int GetSourceColumnIndex(int outputColumnIndex) => _sources[outputColumnIndex]; } public const string LoaderSignature = "ChooseColumnsIdxTrans"; @@ -245,7 +232,7 @@ public override void Save(ModelSaveContext ctx) _bindings.Save(ctx); } - public override Schema OutputSchema => _bindings.AsSchema; + public override Schema OutputSchema => _bindings.OutputSchema; protected override bool? ShouldUseParallelCursors(Func predicate) { @@ -292,17 +279,17 @@ public Cursor(IChannelProvider provider, Bindings bindings, RowCursor input, boo : base(provider, input) { Ch.AssertValue(bindings); - Ch.Assert(active == null || active.Length == bindings.ColumnCount); + Ch.Assert(active == null || active.Length == bindings.OutputSchema.ColumnCount); _bindings = bindings; _active = active; } - public override Schema Schema => _bindings.AsSchema; + public override Schema Schema => _bindings.OutputSchema; public override bool IsColumnActive(int col) { - Ch.Check(0 <= col && col < _bindings.ColumnCount); + Ch.Check(0 <= col && col < _bindings.OutputSchema.ColumnCount); return _active == null || _active[col]; } @@ -310,7 +297,7 @@ public override ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col)); - var src = _bindings.Sources[col]; + var src = _bindings.GetSourceColumnIndex(col); return Input.GetGetter(src); } } diff --git a/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-1-out.txt b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-1-out.txt new file mode 100644 index 0000000000..6229f5397f --- /dev/null +++ b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-1-out.txt @@ -0,0 +1,28 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Name:TX:0 +#@ col=Label:R4:1 +#@ } +Name Label +25 0 +38 0 +28 1 +44 1 +18 0 +34 0 +29 0 +63 1 +24 0 +55 0 +65 1 +36 0 +26 0 +58 0 +48 1 +43 1 +20 0 +43 0 +37 0 +40 1 +Wrote 20 rows of length 2 diff --git a/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-out.txt b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-out.txt new file mode 100644 index 0000000000..6229f5397f --- /dev/null +++ b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndex-out.txt @@ -0,0 +1,28 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Name:TX:0 +#@ col=Label:R4:1 +#@ } +Name Label +25 0 +38 0 +28 1 +44 1 +18 0 +34 0 +29 0 +63 1 +24 0 +55 0 +65 1 +36 0 +26 0 +58 0 +48 1 +43 1 +20 0 +43 0 +37 0 +40 1 +Wrote 20 rows of length 2 diff --git a/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-1-out.txt b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-1-out.txt new file mode 100644 index 0000000000..4f9cd70a5c --- /dev/null +++ b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-1-out.txt @@ -0,0 +1,28 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Cat:TX:0-7 +#@ col=Num:R4:8-13 +#@ } +Workclass education marital-status occupation relationship ethnicity sex native-country-region age fnlwgt education-num capital-gain capital-loss hours-per-week +Private 11th Never-married Machine-op-inspct Own-child Black Male United-States 25 226802 7 0 0 40 +Private HS-grad Married-civ-spouse Farming-fishing Husband White Male United-States 38 89814 9 0 0 50 +Local-gov Assoc-acdm Married-civ-spouse Protective-serv Husband White Male United-States 28 336951 12 0 0 40 +Private Some-college Married-civ-spouse Machine-op-inspct Husband Black Male United-States 44 160323 10 7688 0 40 +? Some-college Never-married ? Own-child White Female United-States 18 103497 10 0 0 30 +Private 10th Never-married Other-service Not-in-family White Male United-States 34 198693 6 0 0 30 +? HS-grad Never-married ? Unmarried Black Male United-States 29 227026 9 0 0 40 +Self-emp-not-inc Prof-school Married-civ-spouse Prof-specialty Husband White Male United-States 63 104626 15 3103 0 32 +Private Some-college Never-married Other-service Unmarried White Female United-States 24 369667 10 0 0 40 +Private 7th-8th Married-civ-spouse Craft-repair Husband White Male United-States 55 104996 4 0 0 10 +Private HS-grad Married-civ-spouse Machine-op-inspct Husband White Male United-States 65 184454 9 6418 0 40 +Federal-gov Bachelors Married-civ-spouse Adm-clerical Husband White Male United-States 36 212465 13 0 0 40 +Private HS-grad Never-married Adm-clerical Not-in-family White Female United-States 26 82091 9 0 0 39 +? HS-grad Married-civ-spouse ? Husband White Male United-States 58 299831 9 0 0 35 +Private HS-grad Married-civ-spouse Machine-op-inspct Husband White Male United-States 48 279724 9 3103 0 48 +Private Masters Married-civ-spouse Exec-managerial Husband White Male United-States 43 346189 14 0 0 50 +State-gov Some-college Never-married Other-service Own-child White Male United-States 20 444554 10 0 0 25 +Private HS-grad Married-civ-spouse Adm-clerical Wife White Female United-States 43 128354 9 0 0 30 +Private HS-grad Widowed Machine-op-inspct Unmarried White Female United-States 37 60548 9 0 0 20 +Private Doctorate Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male ? 40 85019 16 0 0 45 +Wrote 20 rows of length 14 diff --git a/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-out.txt b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-out.txt new file mode 100644 index 0000000000..4f9cd70a5c --- /dev/null +++ b/test/BaselineOutput/Common/Command/SavePipeChooseColumnsByIndexDrop-out.txt @@ -0,0 +1,28 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=Cat:TX:0-7 +#@ col=Num:R4:8-13 +#@ } +Workclass education marital-status occupation relationship ethnicity sex native-country-region age fnlwgt education-num capital-gain capital-loss hours-per-week +Private 11th Never-married Machine-op-inspct Own-child Black Male United-States 25 226802 7 0 0 40 +Private HS-grad Married-civ-spouse Farming-fishing Husband White Male United-States 38 89814 9 0 0 50 +Local-gov Assoc-acdm Married-civ-spouse Protective-serv Husband White Male United-States 28 336951 12 0 0 40 +Private Some-college Married-civ-spouse Machine-op-inspct Husband Black Male United-States 44 160323 10 7688 0 40 +? Some-college Never-married ? Own-child White Female United-States 18 103497 10 0 0 30 +Private 10th Never-married Other-service Not-in-family White Male United-States 34 198693 6 0 0 30 +? HS-grad Never-married ? Unmarried Black Male United-States 29 227026 9 0 0 40 +Self-emp-not-inc Prof-school Married-civ-spouse Prof-specialty Husband White Male United-States 63 104626 15 3103 0 32 +Private Some-college Never-married Other-service Unmarried White Female United-States 24 369667 10 0 0 40 +Private 7th-8th Married-civ-spouse Craft-repair Husband White Male United-States 55 104996 4 0 0 10 +Private HS-grad Married-civ-spouse Machine-op-inspct Husband White Male United-States 65 184454 9 6418 0 40 +Federal-gov Bachelors Married-civ-spouse Adm-clerical Husband White Male United-States 36 212465 13 0 0 40 +Private HS-grad Never-married Adm-clerical Not-in-family White Female United-States 26 82091 9 0 0 39 +? HS-grad Married-civ-spouse ? Husband White Male United-States 58 299831 9 0 0 35 +Private HS-grad Married-civ-spouse Machine-op-inspct Husband White Male United-States 48 279724 9 3103 0 48 +Private Masters Married-civ-spouse Exec-managerial Husband White Male United-States 43 346189 14 0 0 50 +State-gov Some-college Never-married Other-service Own-child White Male United-States 20 444554 10 0 0 25 +Private HS-grad Married-civ-spouse Adm-clerical Wife White Female United-States 43 128354 9 0 0 30 +Private HS-grad Widowed Machine-op-inspct Unmarried White Female United-States 37 60548 9 0 0 20 +Private Doctorate Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male ? 40 85019 16 0 0 45 +Wrote 20 rows of length 14 diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index 378f98ccb0..d7545055b6 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -2093,5 +2093,40 @@ public void Datatypes() TestCore("savedata", intermediateData.Path, "loader=binary", "saver=text", textOutputPath.Arg("dout")); Done(); } + + [TestCategory("DataPipeSerialization")] + [Fact()] + public void SavePipeChooseColumnsByIndex() + { + string dataPath = GetDataPath("adult.tiny.with-schema.txt"); + const string loaderArgs = "loader=text{header+ col=Label:0 col=Cat:TX:1-8 col=Num:9-14 col=Name:TX:9}"; + + OutputPath modelPath = ModelPath(); + string extraArgs = "xf=ChooseColumnsByIndex{ind=3 ind=0}"; + TestCore("showdata", dataPath, loaderArgs, extraArgs); + + _step++; + + TestCore("showdata", dataPath, string.Format("in={{{0}}}", modelPath.Path), ""); + Done(); + } + + [TestCategory("DataPipeSerialization")] + [Fact()] + public void SavePipeChooseColumnsByIndexDrop() + { + string dataPath = GetDataPath("adult.tiny.with-schema.txt"); + const string loaderArgs = "loader=text{header+ col=Label:0 col=Cat:TX:1-8 col=Num:9-14 col=Name:TX:9}"; + + OutputPath modelPath = ModelPath(); + + string extraArgs = "xf=ChooseColumnsByIndex{ind=3 ind=0 drop+}"; + TestCore("showdata", dataPath, loaderArgs, extraArgs); + + _step++; + + TestCore("showdata", dataPath, string.Format("in={{{0}}}", modelPath.Path), ""); + Done(); + } } }