diff --git a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs index 9c106c2996..882b7e33cb 100644 --- a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs +++ b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs @@ -75,12 +75,18 @@ public void AddColumn(string name, PrimitiveType type, params T[] values) /// Constructs a new key column from an array where values are copied to output simply /// by being assigned. /// - public void AddColumn(string name, ValueGetter>> getKeyValues, ulong keyMin, int keyCount, params uint[] values) + /// The name of the column. + /// The delegate that does a reverse lookup based upon the given key. This is for metadata creation + /// The minimum to use. + /// The count of unique keys specified in values + /// The values to add to the column. Note that since this is creating a column, the values will be offset by 1. + public void AddColumn(string name, ValueGetter>> getKeyValues, ulong keyMin, int keyCount, params T1[] values) { _host.CheckValue(getKeyValues, nameof(getKeyValues)); _host.CheckParam(keyCount > 0, nameof(keyCount)); CheckLength(name, values); - _columns.Add(new AssignmentColumn(new KeyType(DataKind.U4, keyMin, keyCount), values)); + values.GetType().GetElementType().TryGetDataKind(out DataKind kind); + _columns.Add(new AssignmentColumn(new KeyType(kind, keyMin, keyCount), values)); _getKeyValues.Add(name, getKeyValues); _names.Add(name); } diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index c6b4ab5800..db6dcd3d14 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -4,7 +4,9 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Conversions; +using System.Collections.Generic; namespace Microsoft.ML { @@ -125,5 +127,23 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co string termsColumn = null, IComponentFactory loaderFactory = null) => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, file, termsColumn, loaderFactory); + + /// + /// Maps specified keys to specified values + /// + /// The key type. + /// The value type. + /// The categorical transform's catalog + /// The list of keys to use for the mapping. The mapping is 1-1 with values. This list must be the same length as values and + /// cannot contain duplicate keys. + /// The list of values to pair with the keys for the mapping. This list must be equal to the same length as keys. + /// The columns to apply this transform on. + /// + public static ValueMappingEstimator ValueMap( + this TransformsCatalog.ConversionTransforms catalog, + IEnumerable keys, + IEnumerable values, + params (string source, string name)[] columns) + => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, columns); } } diff --git a/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs new file mode 100644 index 0000000000..15a0ccccfc --- /dev/null +++ b/src/Microsoft.ML.Data/Transforms/ValueMappingTransformer.cs @@ -0,0 +1,975 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Transforms.Conversions; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(IDataTransform), typeof(ValueMappingTransformer), + typeof(ValueMappingTransformer.Arguments), typeof(SignatureDataTransform), + ValueMappingTransformer.UserName, "ValueMapping", "ValueMappingTransformer", ValueMappingTransformer.ShortName, + "TermLookup", "Lookup", "LookupTransform", DocName = "transform/ValueMappingTransformer.md")] + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(IDataTransform), typeof(ValueMappingTransformer), null, typeof(SignatureLoadDataTransform), + "Value Mapping Transform", ValueMappingTransformer.LoaderSignature, ValueMappingTransformer.TermLookupLoaderSignature)] + +[assembly: LoadableClass(ValueMappingTransformer.Summary, typeof(ValueMappingTransformer), null, typeof(SignatureLoadModel), + "Value Mapping Transform", ValueMappingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ValueMappingTransformer), null, typeof(SignatureLoadRowMapper), + ValueMappingTransformer.UserName, ValueMappingTransformer.LoaderSignature)] + +namespace Microsoft.ML.Transforms.Conversions +{ + /// + /// The ValueMappingEstimator is a 1-1 mapping from a key to value. The key type and value type are specified + /// through TKey and TValue. Arrays are supported for vector types which can be used as either a key or a value + /// or both. The mapping is specified, not trained by providiing a list of keys and a list of values. + /// + /// Specifies the key type. + /// Specifies the value type. + public sealed class ValueMappingEstimator : TrivialEstimator> + { + private (string input, string output)[] _columns; + + /// + /// Constructs the ValueMappingEstimator, key type -> value type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue. + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, false, columns)) + { + _columns = columns; + } + + /// + /// Constructs the ValueMappingEstimator, key type -> value type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue. + /// Specifies to treat the values as a . + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyType, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, treatValuesAsKeyType, columns)) + { + _columns = columns; + } + + /// + /// Constructs the ValueMappingEstimator, key type -> value array type mapping + /// + /// The environment to use. + /// The list of keys of TKey. + /// The list of values of TValue[]. + /// The list of columns to apply. + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), + new ValueMappingTransformer(env, keys, values, columns)) + { + _columns = columns; + } + + /// + /// Retrieves the output schema given the input schema + /// + /// Input schema + /// Returns the generated output schema + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + var resultDic = inputSchema.ToDictionary(x => x.Name); + var vectorKind = Transformer.ValueColumnType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar; + var isKey = Transformer.ValueColumnType.IsKey; + var columnType = (isKey) ? PrimitiveType.FromKind(DataKind.U4) : + Transformer.ValueColumnType; + foreach (var (Input, Output) in _columns) + { + if (!inputSchema.TryFindColumn(Input, out var originalColumn)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Input); + + // Get the type from TOutputType + var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, originalColumn.Metadata); + resultDic[Output] = col; + } + return new SchemaShape(resultDic.Values); + } + } + + /// + /// The DataViewHelper provides a set of static functions to create a DataView given a list of keys and values. + /// + internal class DataViewHelper + { + /// + /// Helper function to retrieve the Primitie type given a Type + /// + internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType) + { + Type type = rawType; + isVectorType = false; + if (type.IsArray) + { + type = rawType.GetElementType(); + isVectorType = true; + } + + if (!type.TryGetDataKind(out DataKind kind)) + throw new InvalidOperationException($"Unsupported type {type} used in mapping."); + + return PrimitiveType.FromKind(kind); + } + + /// + /// Helper function for a reverse lookup given value. This is used for generating the metadata of the value column. + /// + + private static ValueGetter>> GetKeyValueGetter(TKey[] keys) + { + return + (ref VBuffer> dst) => + { + var editor = VBufferEditor.Create(ref dst, keys.Length); + for (int i = 0; i < keys.Length; i++) + editor.Values[i] = keys[i].ToString().AsMemory(); + dst = editor.Commit(); + }; + } + + /// + /// Helper function to create an IDataView given a list of key and vector-based values + /// + internal static IDataView CreateDataView(IHostEnvironment env, + IEnumerable keys, + IEnumerable values, + string keyColumnName, + string valueColumnName) + { + var keyType = GetPrimitiveType(typeof(TKey), out bool isKeyVectorType); + var valueType = GetPrimitiveType(typeof(TValue), out bool isValueVectorType); + var dataViewBuilder = new ArrayDataViewBuilder(env); + dataViewBuilder.AddColumn(keyColumnName, keyType, keys.ToArray()); + dataViewBuilder.AddColumn(valueColumnName, valueType, values.ToArray()); + return dataViewBuilder.GetDataView(); + } + + /// + /// Helper function that builds the IDataView given a list of keys and non-vector values + /// + internal static IDataView CreateDataView(IHostEnvironment env, + IEnumerable keys, + IEnumerable values, + string keyColumnName, + string valueColumnName, + bool treatValuesAsKeyTypes) + { + var keyType = GetPrimitiveType(typeof(TKey), out bool isKeyVectorType); + var valueType = GetPrimitiveType(typeof(TValue), out bool isValueVectorType); + + var dataViewBuilder = new ArrayDataViewBuilder(env); + dataViewBuilder.AddColumn(keyColumnName, keyType, keys.ToArray()); + if (treatValuesAsKeyTypes) + { + // When treating the values as KeyTypes, generate the unique + // set of values. This is used for generating the metadata of + // the column. + HashSet valueSet = new HashSet(); + HashSet keySet = new HashSet(); + for (int i = 0; i < values.Count(); ++i) + { + var v = values.ElementAt(i); + if (valueSet.Contains(v)) + continue; + valueSet.Add(v); + + var k = keys.ElementAt(i); + keySet.Add(k); + } + var metaKeys = keySet.ToArray(); + + // Key Values are treated in one of two ways: + // If the values are of type uint or ulong, these values are used directly as the keys types and no new keys are created. + // If the values are not of uint or ulong, then key values are generated as uints starting from 1, since 0 is missing key. + if (valueType.RawKind == DataKind.U4) + { + uint[] indices = values.Select((x) => Convert.ToUInt32(x)).ToArray(); + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); + } + else if (valueType.RawKind == DataKind.U8) + { + ulong[] indices = values.Select((x) => Convert.ToUInt64(x)).ToArray(); + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Length, indices); + } + else + { + // When generating the indices, treat each value as being unique, i.e. two values that are the same will + // be assigned the same index. The dictionary is used to maintain uniqueness, indices will contain + // the full list of indices (equal to the same length of values). + Dictionary keyTypeValueMapping = new Dictionary(); + uint[] indices = new uint[values.Count()]; + // Start the index at 1 + uint index = 1; + for (int i = 0; i < values.Count(); ++i) + { + TValue value = values.ElementAt(i); + if (!keyTypeValueMapping.ContainsKey(value)) + { + keyTypeValueMapping.Add(value, index); + index++; + } + + var keyValue = keyTypeValueMapping[value]; + indices[i] = keyValue; + } + + dataViewBuilder.AddColumn(valueColumnName, GetKeyValueGetter(metaKeys), 0, metaKeys.Count(), indices); + } + } + else + dataViewBuilder.AddColumn(valueColumnName, valueType, values.ToArray()); + + return dataViewBuilder.GetDataView(); + } + } + + /// + /// The ValueMappingTransformer is a 1-1 mapping from a key to value. The key type and value type are specified + /// through TKey and TValue. Arrays are supported for vector types which can be used as either a key or a value + /// or both. The mapping is specified, not trained by providiing a list of keys and a list of values. + /// + /// Specifies the key type + /// Specifies the value type + public sealed class ValueMappingTransformer : ValueMappingTransformer + { + /// + /// Constructs a ValueMappingTransformer with a key type to value type. + /// + /// The environment to use. + /// The list of keys that are TKey. + /// The list of values that are TValue. + /// Specifies to treat the values as a . + /// The specified columns to apply + public ValueMappingTransformer(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyTypes, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), + ConvertToDataView(env, keys, values, treatValuesAsKeyTypes), KeyColumnName, ValueColumnName, columns) + { } + + /// + /// Constructs a ValueMappingTransformer with a key type to value array type. + /// + /// The environment to use. + /// The list of keys that are TKey. + /// The list of values that are TValue[]. + /// The specified columns to apply. + public ValueMappingTransformer(IHostEnvironment env, IEnumerable keys, IEnumerable values, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), + ConvertToDataView(env, keys, values), KeyColumnName, ValueColumnName, columns) + { } + + private static IDataView ConvertToDataView(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyValue) + => DataViewHelper.CreateDataView(env, + keys, + values, + ValueMappingTransformer.KeyColumnName, + ValueMappingTransformer.ValueColumnName, + treatValuesAsKeyValue); + + // Handler for vector value types + private static IDataView ConvertToDataView(IHostEnvironment env, IEnumerable keys, IEnumerable values) + => DataViewHelper.CreateDataView(env, keys, values, ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName); + } + + public class ValueMappingTransformer : OneToOneTransformerBase + { + internal const string Summary = "Maps text values columns to new columns using a map dataset."; + internal const string LoaderSignature = "ValueMappingTransformer"; + internal const string UserName = "Value Mapping Transform"; + internal const string ShortName = "ValueMap"; + + internal const string TermLookupLoaderSignature = "TermLookupTransform"; + + // Stream names for the binary idv streams. + private const string DefaultMapName = "DefaultMap.idv"; + protected static string KeyColumnName = "Key"; + protected static string ValueColumnName = "Value"; + private ValueMap _valueMap; + private Schema.Metadata _valueMetadata; + private byte[] _dataView; + + public ColumnType ValueColumnType => _valueMap.ValueType; + public Schema.Metadata ValueColumnMetadata => _valueMetadata; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "VALUMAPG", + verWrittenCur: 0x00010001, // Initial. + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ValueMappingTransformer).Assembly.FullName); + } + + private static VersionInfo GetTermLookupVersionInfo() + { + return new VersionInfo( + modelSignature: "TXTLOOKT", + // verWrittenCur: 0x00010001, // Initial. + verWrittenCur: 0x00010002, // Dropped sizeof(Float). + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010002, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ValueMappingTransformer).Assembly.FullName); + } + + public sealed class Column : OneToOneColumn + { + public static Column Parse(string str) + { + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + public bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + public sealed class Arguments + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] + public Column[] Column; + + [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file containing the terms", ShortName = "data", SortOrder = 2)] + public string DataFile; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the column containing the keys", ShortName = "keyCol, term, TermColumn")] + public string KeyColumn; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the column containing the values", ShortName = "valueCol, value")] + public string ValueColumn; + + [Argument(ArgumentType.Multiple, HelpText = "The data loader", NullName = "", SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; + + [Argument(ArgumentType.AtMostOnce, + HelpText = "Specifies whether the values are key values or numeric, only valid when loader is not specified and the type of data is not an idv.", + ShortName = "key")] + public bool ValuesAsKeyType = true; + } + + protected ValueMappingTransformer(IHostEnvironment env, IDataView lookupMap, + string keyColumn, string valueColumn, (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), columns) + { + Host.CheckNonEmpty(keyColumn, nameof(keyColumn), "A key column must be specified when passing in an IDataView for the value mapping"); + Host.CheckNonEmpty(valueColumn, nameof(valueColumn), "A value column must be specified when passing in an IDataView for the value mapping"); + _valueMap = CreateValueMapFromDataView(lookupMap, keyColumn, valueColumn); + int valueColumnIdx = 0; + Host.Assert(lookupMap.Schema.TryGetColumnIndex(valueColumn, out valueColumnIdx)); + _valueMetadata = lookupMap.Schema[valueColumnIdx].Metadata; + + // Create the byte array of the original IDataView, this is used for saving out the data. + _dataView = GetBytesFromDataView(Host, lookupMap, keyColumn, valueColumn); + } + + private ValueMap CreateValueMapFromDataView(IDataView dataView, string keyColumn, string valueColumn) + { + // Confirm that the key and value columns exist in the dataView + Host.Check(dataView.Schema.TryGetColumnIndex(keyColumn, out int keyIdx), "Key column " + keyColumn + " does not exist in the given dataview"); + Host.Check(dataView.Schema.TryGetColumnIndex(valueColumn, out int valueIdx), "Value column " + valueColumn + " does not exist in the given dataview"); + var keyType = dataView.Schema[keyIdx].Type; + var valueType = dataView.Schema[valueIdx].Type; + var valueMap = ValueMap.Create(keyType, valueType, _valueMetadata); + using (var cursor = dataView.GetRowCursor(c => c == keyIdx || c == valueIdx)) + valueMap.Train(Host, cursor); + return valueMap; + } + + private static TextLoader.Column GenerateValueColumn(IHostEnvironment env, + IDataView loader, + string valueColumnName, + int keyIdx, + int valueIdx, + string fileName) + { + // Scan the source to determine the min max of the column + ulong keyMin = ulong.MaxValue; + ulong keyMax = ulong.MinValue; + + // scan the input to create convert the values as key types + using (var cursor = loader.GetRowCursor(c => true)) + { + using (var ch = env.Start($"Processing key values from file {fileName}")) + { + var getKey = cursor.GetGetter>(keyIdx); + var getValue = cursor.GetGetter>(valueIdx); + int countNonKeys = 0; + + ReadOnlyMemory key = default; + ReadOnlyMemory value = default; + while (cursor.MoveNext()) + { + getKey(ref key); + getValue(ref value); + + ulong res; + // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0, + // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for + // computing max and min. + if (Microsoft.ML.Runtime.Data.Conversion.Conversions.Instance.TryParseKey(in value, 1, ulong.MaxValue, out res)) + { + if (res < keyMin && res != 0) + keyMin = res; + if (res > keyMax) + keyMax = res; + } + // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds, + // then the value is 0, and we update min accordingly. + else if (Microsoft.ML.Runtime.Data.Conversion.Conversions.Instance.TryParse(in value, out res)) + { + keyMin = 0; + } + //If parsing as a ulong fails, we increment the counter for the non-key values. + else + { + if (countNonKeys < 5) + ch.Warning($"Key '{key}' in mapping file is mapped to non key value '{value}'"); + countNonKeys++; + } + } + + if (countNonKeys > 0) + ch.Warning($"Found {countNonKeys} non key values in the file '{fileName}'"); + if (keyMin > keyMax) + { + keyMin = 0; + keyMax = uint.MaxValue - 1; + ch.Warning($"Did not find any valid key values in the file '{fileName}'"); + } + else + ch.Info($"Found key values in the range {keyMin} to {keyMax} in the file '{fileName}'"); + } + } + + TextLoader.Column valueColumn = new TextLoader.Column(valueColumnName, DataKind.U4, 1); + if (keyMax - keyMin < (ulong)int.MaxValue) + { + valueColumn.KeyRange = new KeyRange(keyMin, keyMax); + } + else if (keyMax - keyMin < (ulong)uint.MaxValue) + { + valueColumn.KeyRange = new KeyRange(keyMin); + } + else + { + valueColumn.Type = DataKind.U8; + valueColumn.KeyRange = new KeyRange(keyMin); + } + + return valueColumn; + } + + private static ValueMappingTransformer CreateTransformInvoke(IHostEnvironment env, + IDataView idv, + string keyColumnName, + string valueColumnName, + bool treatValuesAsKeyTypes, + (string input, string output)[] columns) + { + // Read in the data + // scan the input to create convert the values as key types + List keys = new List(); + List values = new List(); + + idv.Schema.TryGetColumnIndex(keyColumnName, out int keyIdx); + idv.Schema.TryGetColumnIndex(valueColumnName, out int valueIdx); + using (var cursor = idv.GetRowCursor(c => true)) + { + using (var ch = env.Start("Processing key values")) + { + TKey key = default; + TValue value = default; + var getKey = cursor.GetGetter(keyIdx); + var getValue = cursor.GetGetter(valueIdx); + while (cursor.MoveNext()) + { + try + { + getKey(ref key); + } + catch (InvalidOperationException) + { + ch.Warning("Invalid key parsed, row will be skipped."); + continue; + } + + try + { + getValue(ref value); + } + catch (InvalidOperationException) + { + ch.Warning("Invalid value parsed for key {key}, row will be skipped."); + continue; + } + + keys.Add(key); + values.Add(value); + } + } + } + + return new ValueMappingTransformer(env, keys, values, treatValuesAsKeyTypes, columns); + } + + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.Assert(!string.IsNullOrWhiteSpace(args.DataFile)); + env.CheckValueOrNull(args.KeyColumn); + env.CheckValueOrNull(args.ValueColumn); + + var keyColumnName = (string.IsNullOrEmpty(args.KeyColumn)) ? KeyColumnName : args.KeyColumn; + var valueColumnName = (string.IsNullOrEmpty(args.ValueColumn)) ? ValueColumnName : args.ValueColumn; + + IMultiStreamSource fileSource = new MultiFileSource(args.DataFile); + IDataView loader; + if (args.Loader != null) + { + loader = args.Loader.CreateComponent(env, fileSource); + } + else + { + var extension = Path.GetExtension(args.DataFile); + if (extension.Equals(".idv", StringComparison.OrdinalIgnoreCase)) + loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); + else if (extension.Equals(".tdv")) + loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); + else + { + // The user has not specified how to load this file. This will attempt to load the + // data file as two text columns. If the user has also specified ValuesAsKeyTypes, + // this will default to the key column as a text column and the value column as a uint column + + // Set the keyColumnName and valueColumnName to the default values. + keyColumnName = KeyColumnName; + valueColumnName = ValueColumnName; + TextLoader.Column keyColumn = default; + TextLoader.Column valueColumn = default; + + // Default to a text loader. KeyType and ValueType are assumed to be string + // types unless ValueAsKeyType is specified. + if (args.ValuesAsKeyType) + { + keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0); + valueColumn = new TextLoader.Column(valueColumnName, DataKind.TXT, 1); + var txtArgs = new TextLoader.Arguments() + { + Column = new TextLoader.Column[] + { + keyColumn, + valueColumn + } + }; + + try + { + var textLoader = TextLoader.ReadFile(env, txtArgs, fileSource); + valueColumn = GenerateValueColumn(env, textLoader, valueColumnName, 0, 1, args.DataFile); + } + catch (Exception ex) + { + throw env.Except(ex, "Failed to parse the lookup file '{args.DataFile}' in ValueMappingTransformerer"); + } + } + else + { + keyColumn = new TextLoader.Column(keyColumnName, DataKind.TXT, 0); + valueColumn = new TextLoader.Column(valueColumnName, DataKind.R4, 1); + } + + loader = TextLoader.Create( + env, + new TextLoader.Arguments() + { + Column = new TextLoader.Column[] + { + keyColumn, + valueColumn + } + }, + fileSource); + } + } + + env.AssertValue(loader); + env.Assert(loader.Schema.TryGetColumnIndex(keyColumnName, out int keyColumnIndex)); + env.Assert(loader.Schema.TryGetColumnIndex(valueColumnName, out int valueColumnIndex)); + + ValueMappingTransformer transformer = null; + (string Source, string Name)[] columns = args.Column.Select(x => (x.Source, x.Name)).ToArray(); + transformer = new ValueMappingTransformer(env, loader, keyColumnName, valueColumnName, columns); + return transformer.MakeDataTransform(input); + } + + /// + /// Helper function to determine the model version that is being loaded. + /// + private static bool CheckModelVersion(ModelLoadContext ctx, VersionInfo versionInfo) + { + try + { + ctx.CheckVersionInfo(versionInfo); + return true; + } + catch (Exception) + { + //consume + return false; + } + } + + protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + + // Checks for both the TermLookup for backwards compatibility + var termLookupModel = CheckModelVersion(ctx, GetTermLookupVersionInfo()); + env.Check(termLookupModel || CheckModelVersion(ctx, GetVersionInfo())); + + // *** Binary format *** + // int: number of added columns + // for each added column + // string: output column name + // string: input column name + // Binary stream of mapping + + var length = ctx.Reader.ReadInt32(); + var columns = new (string Source, string Name)[length]; + for (int i = 0; i < length; i++) + { + columns[i].Name = ctx.LoadNonEmptyString(); + columns[i].Source = ctx.LoadNonEmptyString(); + } + + byte[] rgb = null; + Action fn = r => rgb = ReadAllBytes(env, r); + + if (!ctx.TryLoadBinaryStream(DefaultMapName, fn)) + throw env.ExceptDecode(); + + var binaryLoader = GetLoader(env, rgb); + var keyColumnName = (termLookupModel) ? "Term" : KeyColumnName; + return new ValueMappingTransformer(env, binaryLoader, keyColumnName, ValueColumnName, columns); + } + + private static byte[] ReadAllBytes(IExceptionContext ectx, BinaryReader rdr) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(rdr); + ectx.Assert(rdr.BaseStream.CanSeek); + + long size = rdr.BaseStream.Length; + ectx.CheckDecode(size <= int.MaxValue); + + var rgb = new byte[(int)size]; + int cb = rdr.Read(rgb, 0, rgb.Length); + ectx.CheckDecode(cb == rgb.Length); + + return rgb; + } + + protected static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType) + { + Type type = rawType; + isVectorType = false; + if (type.IsArray) + { + type = rawType.GetElementType(); + isVectorType = true; + } + + if (!type.TryGetDataKind(out DataKind kind)) + throw Contracts.Except($"Unsupported type {type} used in mapping."); + + return PrimitiveType.FromKind(kind); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.SetVersionInfo(GetVersionInfo()); + SaveColumns(ctx); + + // Save out the byte stream of the IDataView of the data source + ctx.SaveBinaryStream(DefaultMapName, w => w.Write(_dataView)); + } + + /// + /// Base class that contains the mapping of keys to values. + /// + private abstract class ValueMap + { + public readonly ColumnType KeyType; + public readonly ColumnType ValueType; + + public ValueMap(ColumnType keyType, ColumnType valueType) + { + KeyType = keyType; + ValueType = valueType; + } + + public static ValueMap Create(ColumnType keyType, ColumnType valueType, Schema.Metadata valueMetadata) + { + Func del = CreateValueMapInvoke; + var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(keyType.RawType, valueType.RawType); + return (ValueMap)meth.Invoke(null, new object[] { keyType, valueType, valueMetadata }); + } + + private static ValueMap CreateValueMapInvoke(ColumnType keyType, + ColumnType valueType, + Schema.Metadata valueMetadata) + => new ValueMap(keyType, valueType, valueMetadata); + + public abstract void Train(IHostEnvironment env, RowCursor cursor); + + public abstract Delegate GetGetter(Row input, int index); + + public abstract IDataView GetDataView(IHostEnvironment env); + } + + /// + /// Implementation mapping class that maps a key of TKey to a specified value of TValue. + /// + private class ValueMap : ValueMap + { + private Dictionary _mapping; + private TValue _missingValue; + private Schema.Metadata _valueMetadata; + + private Dictionary CreateDictionary() + { + if (typeof(TKey) == typeof(ReadOnlyMemory)) + return new Dictionary, TValue>(new ReadOnlyMemoryUtils.ReadonlyMemoryCharComparer()) as Dictionary; + return new Dictionary(); + } + + public ValueMap(ColumnType keyType, ColumnType valueType, Schema.Metadata valueMetadata) + : base(keyType, valueType) + { + _mapping = CreateDictionary(); + _valueMetadata = valueMetadata; + } + + /// + /// Generates the mapping based on the IDataView + /// + public override void Train(IHostEnvironment env, RowCursor cursor) + { + // Validate that the conversion is supported for non-vector types + bool identity; + ValueMapper, TValue> conv; + + // For keys that are not in the mapping, the missingValue will be returned. + _missingValue = default; + if (!ValueType.IsVector) + { + // For handling missing values, this follows how a missing value is handled when loading from a text source. + // First check if there is a String->ValueType conversion method. If so, call the conversion method with an + // empty string, the returned value will be the new missing value. + // NOTE this will return NA for R4 and R8 types. + if (Microsoft.ML.Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion, TValue>( + TextType.Instance, + ValueType, + out conv, + out identity)) + { + TValue value = default; + conv(string.Empty.AsMemory(), ref value); + _missingValue = value; + } + } + + var keyGetter = cursor.GetGetter(0); + var valueGetter = cursor.GetGetter(1); + while (cursor.MoveNext()) + { + TKey key = default; + TValue value = default; + keyGetter(ref key); + valueGetter(ref value); + if (_mapping.ContainsKey(key)) + throw env.Except($"Duplicate keys in data '{key}'"); + + _mapping.Add(key, value); + } + } + + public override Delegate GetGetter(Row input, int index) + { + var src = default(TKey); + ValueGetter getSrc = input.GetGetter(index); + ValueGetter retVal = + (ref TValue dst) => + { + getSrc(ref src); + if (_mapping.ContainsKey(src)) + { + if (ValueType.IsVector) + dst = Utils.MarshalInvoke(GetVector, ValueType.ItemType.RawType, _mapping[src]); + else + dst = Utils.MarshalInvoke(GetValue, ValueType.RawType, _mapping[src]); + } + else + dst = _missingValue; + }; + return retVal; + } + + public override IDataView GetDataView(IHostEnvironment env) + => DataViewHelper.CreateDataView(env, + _mapping.Keys, + _mapping.Values, + ValueMappingTransformer.KeyColumnName, + ValueMappingTransformer.ValueColumnName, + ValueType.IsKey); + + private static TValue GetVector(TValue value) + { + if (value is VBuffer valueRef) + { + VBuffer dest = default; + valueRef.CopyTo(ref dest); + if (dest is TValue destRef) + return destRef; + } + + return default; + } + + private static TValue GetValue(TValue value) => value; + } + + /// + /// Retrieves the byte array given a dataview and columns + /// + private static byte[] GetBytesFromDataView(IHost host, IDataView lookup, string keyColumn, string valueColumn) + { + Contracts.AssertValue(host); + host.AssertValue(lookup); + host.AssertNonEmpty(keyColumn); + host.AssertNonEmpty(valueColumn); + + var schema = lookup.Schema; + + if (!schema.GetColumnOrNull(keyColumn).HasValue) + throw host.ExceptUserArg(nameof(Arguments.KeyColumn), $"Key column not found: '{keyColumn}'"); + if (!schema.GetColumnOrNull(valueColumn).HasValue) + throw host.ExceptUserArg(nameof(Arguments.ValueColumn), $"Value column not found: '{valueColumn}'"); + + var cols = new List<(string Source, string Name)>() + { + (keyColumn, KeyColumnName), + (valueColumn, ValueColumnName) + }; + + var view = new ColumnCopyingTransformer(host, cols.ToArray()).Transform(lookup); + view = ColumnSelectingTransformer.CreateKeep(host, view, cols.Select(x => x.Name).ToArray()); + + var saver = new BinarySaver(host, new BinarySaver.Arguments()); + using (var strm = new MemoryStream()) + { + saver.SaveData(strm, view, 0, 1); + return strm.ToArray(); + } + } + + private static BinaryLoader GetLoader(IHostEnvironment env, byte[] bytes) + { + env.AssertValue(env); + env.AssertValue(bytes); + + var strm = new MemoryStream(bytes, writable: false); + return new BinaryLoader(env, new BinaryLoader.Arguments(), strm); + } + + private protected override IRowMapper MakeRowMapper(Schema schema) + { + return new Mapper(this, schema, _valueMap, _valueMetadata, ColumnPairs); + } + + private sealed class Mapper : OneToOneMapperBase + { + private readonly Schema _inputSchema; + private readonly ValueMap _valueMap; + private readonly Schema.Metadata _valueMetadata; + private readonly (string Source, string Name)[] _columns; + private readonly ValueMappingTransformer _parent; + + internal Mapper(ValueMappingTransformer transform, + Schema inputSchema, + ValueMap valueMap, + Schema.Metadata valueMetadata, + (string input, string output)[] columns) + : base(transform.Host.Register(nameof(Mapper)), transform, inputSchema) + { + _inputSchema = inputSchema; + _valueMetadata = valueMetadata; + _valueMap = valueMap; + _columns = columns; + _parent = transform; + } + + protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _columns.Length); + disposer = null; + + return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]); + } + + protected override Schema.DetachedColumn[] GetOutputColumnsCore() + { + var result = new Schema.DetachedColumn[_columns.Length]; + for (int i = 0; i < _columns.Length; i++) + { + var srcCol = _inputSchema[_columns[i].Source]; + result[i] = new Schema.DetachedColumn(_columns[i].Name, _valueMap.ValueType, _valueMetadata); + } + return result; + } + } + } +} diff --git a/src/Microsoft.ML.Transforms/TermLookupTransformer.cs b/src/Microsoft.ML.Transforms/TermLookupTransformer.cs deleted file mode 100644 index 0b3b1a2c22..0000000000 --- a/src/Microsoft.ML.Transforms/TermLookupTransformer.cs +++ /dev/null @@ -1,705 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms.Categorical; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Text; - -[assembly: LoadableClass(TermLookupTransformer.Summary, typeof(TermLookupTransformer), typeof(TermLookupTransformer.Arguments), typeof(SignatureDataTransform), - "Term Lookup Transform", "TermLookup", "Lookup", "LookupTransform", "TermLookupTransform")] - -[assembly: LoadableClass(TermLookupTransformer.Summary, typeof(TermLookupTransformer), null, typeof(SignatureLoadDataTransform), - "Term Lookup Transform", TermLookupTransformer.LoaderSignature)] - -namespace Microsoft.ML.Transforms.Categorical -{ - using Conditional = System.Diagnostics.ConditionalAttribute; - - /// - /// This transform maps text values columns to new columns using a map dataset provided through its arguments. - /// - public sealed class TermLookupTransformer : OneToOneTransformBase - { - public sealed class Column : OneToOneColumn - { - public static Column Parse(string str) - { - var res = new Column(); - if (res.TryParse(str)) - return res; - return null; - } - - public bool TryUnparse(StringBuilder sb) - { - Contracts.AssertValue(sb); - return TryUnparseCore(sb); - } - } - - public sealed class Arguments - { - [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] - public Column[] Column; - - [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file containing the terms", ShortName = "data", SortOrder = 2)] - public string DataFile; - - [Argument(ArgumentType.Multiple, HelpText = "The data loader", NullName = "", SignatureType = typeof(SignatureDataLoader))] - public IComponentFactory Loader; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the text column containing the terms", ShortName = "term")] - public string TermColumn; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the column containing the values", ShortName = "value")] - public string ValueColumn; - - [Argument(ArgumentType.AtMostOnce, - HelpText = "If term and value columns are unspecified, specifies whether the values are key values or numeric.", ShortName = "key")] - public bool KeyValues = true; - } - - /// - /// Holds the values that the terms map to. - /// - private abstract class ValueMap - { - public readonly ColumnType Type; - - protected ValueMap(ColumnType type) - { - Type = type; - } - - public static ValueMap Create(ColumnType type) - { - Contracts.AssertValue(type); - - if (!type.IsVector) - { - Func> del = CreatePrimitive; - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType); - return (ValueMap)meth.Invoke(null, new object[] { type }); - } - else - { - Func> del = CreateVector; - var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); - return (ValueMap)meth.Invoke(null, new object[] { type }); - } - } - - public static OneValueMap CreatePrimitive(PrimitiveType type) - { - Contracts.AssertValue(type); - Contracts.Assert(type.RawType == typeof(TVal)); - return new OneValueMap(type); - } - - public static VecValueMap CreateVector(VectorType type) - { - Contracts.AssertValue(type); - Contracts.Assert(type.ItemType.RawType == typeof(TVal)); - return new VecValueMap(type); - } - - public abstract void Train(IExceptionContext ectx, RowCursor cursor, int colTerm, int colValue); - - public abstract Delegate GetGetter(ValueGetter> getSrc); - } - - /// - /// Holds the values that the terms map to - where the destination type is TRes. - /// - private abstract class ValueMap : ValueMap - { - private NormStr.Pool _terms; - private TRes[] _values; - - protected ValueMap(ColumnType type) - : base(type) - { - Contracts.Assert(type.RawType == typeof(TRes)); - } - - /// - /// Bind this value map to the given cursor for "training". - /// - public override void Train(IExceptionContext ectx, RowCursor cursor, int colTerm, int colValue) - { - Contracts.AssertValue(ectx); - ectx.Assert(_terms == null); - ectx.Assert(_values == null); - ectx.AssertValue(cursor); - ectx.Assert(0 <= colTerm && colTerm < cursor.Schema.Count); - ectx.Assert(cursor.Schema[colTerm].Type.IsText); - ectx.Assert(0 <= colValue && colValue < cursor.Schema.Count); - ectx.Assert(cursor.Schema[colValue].Type.Equals(Type)); - - var getTerm = cursor.GetGetter>(colTerm); - var getValue = cursor.GetGetter(colValue); - var terms = new NormStr.Pool(); - var values = new List(); - - ReadOnlyMemory term = default; - while (cursor.MoveNext()) - { - getTerm(ref term); - // REVIEW: Should we trim? - term = ReadOnlyMemoryUtils.TrimSpaces(term); - var nstr = terms.Add(term); - if (nstr.Id != values.Count) - throw ectx.Except("Duplicate term in lookup data: '{0}'", nstr); - - TRes res = default(TRes); - getValue(ref res); - values.Add(res); - ectx.Assert(terms.Count == values.Count); - } - - _terms = terms; - _values = values.ToArray(); - ectx.Assert(_terms.Count == _values.Length); - } - - /// - /// Given the term getter, produce a value getter from this value map. - /// - public override Delegate GetGetter(ValueGetter> getTerm) - { - Contracts.Assert(_terms != null); - Contracts.Assert(_values != null); - Contracts.Assert(_terms.Count == _values.Length); - - return GetGetterCore(getTerm); - } - - private ValueGetter GetGetterCore(ValueGetter> getTerm) - { - var src = default(ReadOnlyMemory); - return - (ref TRes dst) => - { - getTerm(ref src); - src = ReadOnlyMemoryUtils.TrimSpaces(src); - var nstr = _terms.Get(src); - if (nstr == null) - GetMissing(ref dst); - else - { - Contracts.Assert(0 <= nstr.Id && nstr.Id < _values.Length); - CopyValue(in _values[nstr.Id], ref dst); - } - }; - } - - protected abstract void GetMissing(ref TRes dst); - - protected abstract void CopyValue(in TRes src, ref TRes dst); - } - - /// - /// Holds the values that the terms map to when the destination type is a PrimitiveType (non-vector). - /// - private sealed class OneValueMap : ValueMap - { - private readonly TRes _badValue; - - public OneValueMap(PrimitiveType type) - : base(type) - { - // REVIEW: This uses the fact that standard conversions map NA to NA to get the NA for TRes. - // We should probably have a mapping from type to its bad value somewhere, perhaps in Conversions. - bool identity; - ValueMapper, TRes> conv; - if (Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion, TRes>(TextType.Instance, type, - out conv, out identity)) - { - //Empty string will map to NA for R4 and R8, the only two types that can - //handle missing values. - var bad = String.Empty.AsMemory(); - conv(in bad, ref _badValue); - } - } - - protected override void GetMissing(ref TRes dst) - { - dst = _badValue; - } - - protected override void CopyValue(in TRes src, ref TRes dst) - { - dst = src; - } - } - - /// - /// Holds the values that the terms map to when the destination type is a VectorType. - /// TItem is the represtation type for the vector's ItemType. - /// - private sealed class VecValueMap : ValueMap> - { - public VecValueMap(VectorType type) - : base(type) - { - } - - protected override void GetMissing(ref VBuffer dst) - { - VBufferUtils.Resize(ref dst, Type.VectorSize, 0); - } - - protected override void CopyValue(in VBuffer src, ref VBuffer dst) - { - src.CopyTo(ref dst); - } - } - - public const string LoaderSignature = "TermLookupTransform"; - - internal const string Summary = "Maps text values columns to new columns using a map dataset."; - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "TXTLOOKT", - // verWrittenCur: 0x00010001, // Initial. - verWrittenCur: 0x00010002, // Dropped sizeof(Float). - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(TermLookupTransformer).Assembly.FullName); - } - - // This is the byte array containing the binary .idv file contents for the lookup data. - // This is persisted; the _termMap and _valueMap are constructed from it. - private readonly byte[] _bytes; - - // The BinaryLoader over the byte array above. We keep this - // active simply for metadata requests. - private readonly BinaryLoader _ldr; - - // The value map. - private readonly ValueMap _valueMap; - - // Stream names for the binary idv streams. - private const string DefaultMapName = "DefaultMap.idv"; - - private const string RegistrationName = "TextLookup"; - - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public TermLookupTransformer(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, - input, TestIsText) - { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - - Host.CheckUserArg(!string.IsNullOrWhiteSpace(args.DataFile), nameof(args.DataFile), "must specify dataFile"); - Host.CheckUserArg(string.IsNullOrEmpty(args.TermColumn) == string.IsNullOrEmpty(args.ValueColumn), nameof(args.TermColumn), - "Either both term and value column should be specified, or neither."); - - using (var ch = Host.Start("Training")) - { - _bytes = GetBytes(Host, Infos, args); - _ldr = GetLoader(Host, _bytes); - _valueMap = Train(ch, _ldr); - SetMetadata(); - } - } - - public TermLookupTransformer(IHostEnvironment env, IDataView input, IDataView lookup, string sourceTerm, string sourceValue, string targetTerm, string targetValue) - : base(env, RegistrationName, new[] { new Column { Name = sourceValue, Source = sourceTerm } }, input, TestIsText) - { - Host.AssertNonEmpty(Infos); - Host.CheckValue(input, nameof(input)); - Host.CheckValue(lookup, nameof(lookup)); - Host.Assert(Infos.Length == 1); - Host.CheckNonEmpty(targetTerm, nameof(targetTerm), "Term column must be specified when passing in a data view as lookup table."); - Host.CheckNonEmpty(targetValue, nameof(targetValue), "Value column must be specified when passing in a data view as lookup table."); - - using (var ch = Host.Start("Training")) - { - _bytes = GetBytesFromDataView(Host, lookup, targetTerm, targetValue); - _ldr = GetLoader(Host, _bytes); - _valueMap = Train(ch, _ldr); - SetMetadata(); - } - } - - // This method is called if only a datafile is specified, without a loader/term and value columns. - // It determines the type of the Value column and returns the appropriate TextLoader component factory. - private static IComponentFactory GetLoaderFactory(string filename, bool keyValues, IHost host) - { - Contracts.AssertValue(host); - - // If the user specified non-key values, we define the value column to be numeric. - if (!keyValues) - return ComponentFactoryUtils.CreateFromFunction( - (env, files) => new TextLoader( - env, new[] - { - new TextLoader.Column("Term", DataKind.TX, 0), - new TextLoader.Column("Value", DataKind.Num, 1) - }, dataSample: files).Read(files) as IDataLoader); - - // If the user specified key values, we scan the values to determine the range of the key type. - ulong min = ulong.MaxValue; - ulong max = ulong.MinValue; - try - { - var file = new MultiFileSource(filename); - var data = new TextLoader(host, new[] - { - new TextLoader.Column("Term", DataKind.TX, 0), - new TextLoader.Column("Value", DataKind.TX, 1) - }, - dataSample: file - ).Read(file); - - using (var cursor = data.GetRowCursor(c => true)) - { - var getTerm = cursor.GetGetter>(0); - var getVal = cursor.GetGetter>(1); - ReadOnlyMemory txt = default; - - using (var ch = host.Start("Creating Text Lookup Loader")) - { - long countNonKeys = 0; - while (cursor.MoveNext()) - { - getVal(ref txt); - ulong res; - // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0, - // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for - // computing max and min. - if (Runtime.Data.Conversion.Conversions.Instance.TryParseKey(in txt, 1, ulong.MaxValue, out res)) - { - if (res < min && res != 0) - min = res; - if (res > max) - max = res; - } - // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds, - // then the value is 0, and we update min accordingly. - else if (Runtime.Data.Conversion.Conversions.Instance.TryParse(in txt, out res)) - { - ch.Assert(res == 0); - min = 0; - } - //If parsing as a ulong fails, we increment the counter for the non-key values. - else - { - var term = default(ReadOnlyMemory); - getTerm(ref term); - if (countNonKeys < 5) - ch.Warning("Term '{0}' in mapping file is mapped to non key value '{1}'", term, txt); - countNonKeys++; - } - } - if (countNonKeys > 0) - ch.Warning("Found {0} non key values in the file '{1}'", countNonKeys, filename); - if (min > max) - { - min = 0; - max = uint.MaxValue - 1; - ch.Warning("did not find any valid key values in the file '{0}'", filename); - } - else - ch.Info("Found key values in the range {0} to {1} in the file '{2}'", min, max, filename); - } - } - } - catch (Exception e) - { - throw host.Except(e, "Failed to parse the lookup file '{0}' in TermLookupTransform", filename); - } - - TextLoader.Column valueColumn = new TextLoader.Column("Value", DataKind.U4, 1); - if (max - min < (ulong)int.MaxValue) - { - valueColumn.KeyRange = new KeyRange(min, max); - } - else if (max - min < (ulong)uint.MaxValue) - { - valueColumn.KeyRange = new KeyRange(min); - } - else - { - valueColumn.Type = DataKind.U8; - valueColumn.KeyRange = new KeyRange(min); - } - - return ComponentFactoryUtils.CreateFromFunction( - (env, files) => new TextLoader( - env, - columns: new[] - { - new TextLoader.Column("Term", DataKind.TX, 0), - valueColumn - }, - dataSample: files).Read(files) as IDataLoader); - } - - // This saves the lookup data as a byte array encoded as a binary .idv file. - private static byte[] GetBytes(IHost host, ColInfo[] infos, Arguments args) - { - Contracts.AssertValue(host); - host.AssertNonEmpty(infos); - host.AssertValue(args); - - string dataFile = args.DataFile; - IComponentFactory loaderFactory = args.Loader; - string termColumn; - string valueColumn; - if (!string.IsNullOrEmpty(args.TermColumn)) - { - host.Assert(!string.IsNullOrEmpty(args.ValueColumn)); - termColumn = args.TermColumn; - valueColumn = args.ValueColumn; - } - else - { - var ext = Path.GetExtension(dataFile); - if (loaderFactory != null || string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase)) - throw host.ExceptUserArg(nameof(args.TermColumn), "Term and value columns needed."); - loaderFactory = GetLoaderFactory(args.DataFile, args.KeyValues, host); - termColumn = "Term"; - valueColumn = "Value"; - } - return GetBytesOne(host, dataFile, loaderFactory, termColumn, valueColumn); - } - - private static byte[] GetBytesFromDataView(IHost host, IDataView lookup, string termColumn, string valueColumn) - { - Contracts.AssertValue(host); - host.AssertValue(lookup); - host.AssertNonEmpty(termColumn); - host.AssertNonEmpty(valueColumn); - - int colTerm; - int colValue; - var schema = lookup.Schema; - - if (!schema.TryGetColumnIndex(termColumn, out colTerm)) - throw host.ExceptUserArg(nameof(Arguments.TermColumn), "column not found: '{0}'", termColumn); - if (!schema.TryGetColumnIndex(valueColumn, out colValue)) - throw host.ExceptUserArg(nameof(Arguments.ValueColumn), "column not found: '{0}'", valueColumn); - - // REVIEW: Should we allow term to be a vector of text (each term in the vector - // would map to the same value)? - var typeTerm = schema[colTerm].Type; - host.CheckUserArg(typeTerm.IsText, nameof(Arguments.TermColumn), "term column must contain text"); - var typeValue = schema[colValue].Type; - var cols = new List<(string Source, string Name)>() - { - (termColumn, "Term"), - (valueColumn, "Value") - }; - - var view = new ColumnCopyingTransformer(host, cols.ToArray()).Transform(lookup); - view = ColumnSelectingTransformer.CreateKeep(host, view, cols.Select(x=>x.Name).ToArray()); - - var saver = new BinarySaver(host, new BinarySaver.Arguments()); - using (var strm = new MemoryStream()) - { - saver.SaveData(strm, view, 0, 1); - return strm.ToArray(); - } - } - - private static byte[] GetBytesOne(IHost host, string dataFile, IComponentFactory loaderFactory, - string termColumn, string valueColumn) - { - Contracts.AssertValue(host); - host.Assert(!string.IsNullOrWhiteSpace(dataFile)); - host.AssertNonEmpty(termColumn); - host.AssertNonEmpty(valueColumn); - - IMultiStreamSource fileSource = new MultiFileSource(dataFile); - IDataLoader loader; - if (loaderFactory == null) - { - // REVIEW: Should there be defaults for loading from text? - var ext = Path.GetExtension(dataFile); - bool isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); - bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); - if (!isBinary && !isTranspose) - throw host.ExceptUserArg(nameof(Arguments.Loader), "must specify the loader"); - host.Assert(isBinary != isTranspose); // One or the other must be true. - if (isBinary) - { - loader = new BinaryLoader(host, new BinaryLoader.Arguments(), fileSource); - } - else - { - loader = new TransposeLoader(host, new TransposeLoader.Arguments(), fileSource); - } - } - else - { - loader = loaderFactory.CreateComponent(host, fileSource); - } - - return GetBytesFromDataView(host, loader, termColumn, valueColumn); - } - - private static BinaryLoader GetLoader(IHostEnvironment env, byte[] bytes) - { - env.AssertValue(env); - env.AssertValue(bytes); - - var strm = new MemoryStream(bytes, writable: false); - return new BinaryLoader(env, new BinaryLoader.Arguments(), strm); - } - - private static ValueMap Train(IExceptionContext ectx, BinaryLoader ldr) - { - Contracts.AssertValue(ectx); - ectx.AssertValue(ldr); - ectx.Assert(ldr.Schema.Count == 2); - - // REVIEW: Should we allow term to be a vector of text (each term in the vector - // would map to the same value)? - ectx.Assert(ldr.Schema[0].Type.IsText); - - var schema = ldr.Schema; - var typeValue = schema[1].Type; - - // REVIEW: We should know the number of rows - use that info to set initial capacity. - var values = ValueMap.Create(typeValue); - using (var cursor = ldr.GetRowCursor(c => true)) - values.Train(ectx, cursor, 0, 1); - return values; - } - - private TermLookupTransformer(IChannel ch, ModelLoadContext ctx, IHost host, IDataView input) - : base(host, ctx, input, TestIsText) - { - Host.AssertValue(ch); - - // *** Binary format *** - // - ch.AssertNonEmpty(Infos); - - // Extra streams: - // DefaultMap.idv - byte[] rgb = null; - Action fn = r => rgb = ReadAllBytes(ch, r); - - if (!ctx.TryLoadBinaryStream(DefaultMapName, fn)) - throw ch.ExceptDecode(); - _bytes = rgb; - - // Process the bytes into the loader and map. - _ldr = GetLoader(Host, _bytes); - ValidateLoader(ch, _ldr); - _valueMap = Train(ch, _ldr); - SetMetadata(); - } - - private static byte[] ReadAllBytes(IExceptionContext ectx, BinaryReader rdr) - { - Contracts.AssertValue(ectx); - ectx.AssertValue(rdr); - ectx.Assert(rdr.BaseStream.CanSeek); - - long size = rdr.BaseStream.Length; - ectx.CheckDecode(size <= int.MaxValue); - - var rgb = new byte[(int)size]; - int cb = rdr.Read(rgb, 0, rgb.Length); - ectx.CheckDecode(cb == rgb.Length); - - return rgb; - } - - public static TermLookupTransformer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - h.CheckValue(input, nameof(input)); - return h.Apply("Loading Model", ch => new TermLookupTransformer(ch, ctx, h, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // - SaveBase(ctx); - - // Extra streams: - // DefaultMap.idv - Host.Assert(_ldr != null); - Host.AssertValue(_bytes); - DebugValidateLoader(_ldr); - ctx.SaveBinaryStream(DefaultMapName, w => w.Write(_bytes)); - } - - [Conditional("DEBUG")] - private static void DebugValidateLoader(BinaryLoader ldr) - { - Contracts.Assert(ldr != null); - Contracts.Assert(ldr.Schema.Count == 2); - Contracts.Assert(ldr.Schema[0].Type.IsText); - } - - private static void ValidateLoader(IExceptionContext ectx, BinaryLoader ldr) - { - if (ldr == null) - return; - ectx.CheckDecode(ldr.Schema.Count == 2); - ectx.CheckDecode(ldr.Schema[0].Type.IsText); - } - - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Contracts.Assert(0 <= iinfo & iinfo < Infos.Length); - return _valueMap.Type; - } - - private void SetMetadata() - { - // Metadata is passed through from the Value column of the map data view. - var md = Metadata; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - using (var bldr = md.BuildMetadata(iinfo, _ldr.Schema, 1)) - { - // No additional metadata. - } - } - md.Seal(); - } - - protected override Delegate GetGetterCore(IChannel ch, Row input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; - - var getSrc = GetSrcGetter>(input, iinfo); - return _valueMap.GetGetter(getSrc); - } - } -} diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt index 6402981be7..f45353c79a 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt @@ -34,7 +34,7 @@ StringLabel: Key Metadata 'KeyValues': Vec: Length=7, Count=7 [0] 'Wirtschaft', [1] 'Gesundheit', [2] 'Deutschland', [3] 'Ausland', [4] 'Unterhaltung', [5] 'Sport', [6] 'Technik & Wissen' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 6 columns: RawLabel: Text Names: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt index 1e8446cc3e..17c269f1d8 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt @@ -7,7 +7,7 @@ Features: Vec Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 4 columns: RawLabel: Text Names: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt index f40e727ef0..2ad6cfab86 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt @@ -7,7 +7,7 @@ Features: Vec Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 4 columns: RawLabel: Text Names: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt index eb6fccd5db..64ac99b379 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt @@ -7,7 +7,7 @@ Features: Vec Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 4 columns: RawLabel: Text Names: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt index 750b267e78..f35b76301f 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt @@ -7,7 +7,7 @@ Features: Vec Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 4 columns: RawLabel: Text Names: Vec @@ -17,7 +17,7 @@ Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabelNum: R4 ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 5 columns: RawLabel: Text Names: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-out.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-out.txt index bf71ea5899..4d443c426b 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-out.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-out.txt @@ -1,15 +1,25 @@ Bad value at line 5 in column Value Processed 7 rows with 1 bad values and 0 format errors Bad value at line 5 in column Value +Processed 7 rows with 1 bad values and 0 format errors + Bad value at line 5 in column Value Processed 7 rows with 1 bad values and 0 format errors Wrote 7 rows across 2 columns in %Time% -Warning: Term 'Wirtschaft' in mapping file is mapped to non key value '3.14' -Warning: Term 'Gesundheit' in mapping file is mapped to non key value '0.1' -Warning: Term 'Deutschland' in mapping file is mapped to non key value '1.5' -Warning: Term 'Ausland' in mapping file is mapped to non key value '0.5' -Warning: Term 'Unterhaltung' in mapping file is mapped to non key value '1a' +Warning: Key 'Wirtschaft' in mapping file is mapped to non key value '3.14' +Warning: Key 'Gesundheit' in mapping file is mapped to non key value '0.1' +Warning: Key 'Deutschland' in mapping file is mapped to non key value '1.5' +Warning: Key 'Ausland' in mapping file is mapped to non key value '0.5' +Warning: Key 'Unterhaltung' in mapping file is mapped to non key value '1a' Warning: Found 7 non key values in the file '%Output% -Warning: did not find any valid key values in the file '%Output% +Warning: Did not find any valid key values in the file '%Output% + Bad value at line 1 in column Value + Bad value at line 2 in column Value + Bad value at line 3 in column Value + Bad value at line 4 in column Value + Bad value at line 5 in column Value + Bad value at line 6 in column Value + Bad value at line 7 in column Value +Processed 7 rows with 7 bad values and 0 format errors Bad value at line 1 in column Value Bad value at line 2 in column Value Bad value at line 3 in column Value @@ -27,6 +37,8 @@ Processed 7 rows with 7 bad values and 0 format errors Bad value at line 7 in column Value Processed 7 rows with 7 bad values and 0 format errors Wrote 7 rows across 2 columns in %Time% +Wrote 7 rows across 2 columns in %Time% +Wrote 7 rows across 2 columns in %Time% Wrote 119 rows of length 3 Wrote 119 rows across 3 columns in %Time% --- Progress log --- @@ -36,48 +48,66 @@ Wrote 119 rows across 3 columns in %Time% [2] 'BinarySaver' started. [2] (%Time%) 7 rows [2] 'BinarySaver' finished in %Time%. -[3] 'TextSaver: saving data' started. -[3] (%Time%) 119 rows -[3] 'TextSaver: saving data' finished in %Time%. -[4] 'BinarySaver #2' started. +[3] 'BinarySaver #2' started. +[3] (%Time%) 7 rows +[3] 'BinarySaver #2' finished in %Time%. +[4] 'TextSaver: saving data' started. [4] (%Time%) 119 rows -[4] 'BinarySaver #2' finished in %Time%. +[4] 'TextSaver: saving data' finished in %Time%. [5] 'BinarySaver #3' started. -[5] (%Time%) 7 rows +[5] (%Time%) 119 rows [5] 'BinarySaver #3' finished in %Time%. -[6] 'TextSaver: saving data #2' started. -[6] (%Time%) 119 rows -[6] 'TextSaver: saving data #2' finished in %Time%. -[7] 'BinarySaver #4' started. -[7] (%Time%) 119 rows -[7] 'BinarySaver #4' finished in %Time%. -[8] 'BinarySaver #5' started. -[8] (%Time%) 7 rows -[8] 'BinarySaver #5' finished in %Time%. -[9] 'TextSaver: saving data #3' started. +[6] 'BinarySaver #4' started. +[6] (%Time%) 7 rows +[6] 'BinarySaver #4' finished in %Time%. +[7] 'BinarySaver #5' started. +[7] (%Time%) 7 rows +[7] 'BinarySaver #5' finished in %Time%. +[8] 'TextSaver: saving data #2' started. +[8] (%Time%) 119 rows +[8] 'TextSaver: saving data #2' finished in %Time%. +[9] 'BinarySaver #6' started. [9] (%Time%) 119 rows -[9] 'TextSaver: saving data #3' finished in %Time%. -[10] 'BinarySaver #6' started. -[10] (%Time%) 119 rows -[10] 'BinarySaver #6' finished in %Time%. -[11] 'BinarySaver #7' started. +[9] 'BinarySaver #6' finished in %Time%. +[10] 'BinarySaver #7' started. +[10] (%Time%) 7 rows +[10] 'BinarySaver #7' finished in %Time%. +[11] 'BinarySaver #8' started. [11] (%Time%) 7 rows -[11] 'BinarySaver #7' finished in %Time%. -[12] 'TextSaver: saving data #4' started. +[11] 'BinarySaver #8' finished in %Time%. +[12] 'TextSaver: saving data #3' started. [12] (%Time%) 119 rows -[12] 'TextSaver: saving data #4' finished in %Time%. -[13] 'BinarySaver #8' started. +[12] 'TextSaver: saving data #3' finished in %Time%. +[13] 'BinarySaver #9' started. [13] (%Time%) 119 rows -[13] 'BinarySaver #8' finished in %Time%. -[14] 'BinarySaver #9' started. +[13] 'BinarySaver #9' finished in %Time%. +[14] 'BinarySaver #10' started. [14] (%Time%) 7 rows -[14] 'BinarySaver #9' finished in %Time%. -[15] 'BinarySaver #10' started. +[14] 'BinarySaver #10' finished in %Time%. +[15] 'BinarySaver #11' started. [15] (%Time%) 7 rows -[15] 'BinarySaver #10' finished in %Time%. -[16] 'TextSaver: saving data #5' started. +[15] 'BinarySaver #11' finished in %Time%. +[16] 'TextSaver: saving data #4' started. [16] (%Time%) 119 rows -[16] 'TextSaver: saving data #5' finished in %Time%. -[17] 'BinarySaver #11' started. +[16] 'TextSaver: saving data #4' finished in %Time%. +[17] 'BinarySaver #12' started. [17] (%Time%) 119 rows -[17] 'BinarySaver #11' finished in %Time%. +[17] 'BinarySaver #12' finished in %Time%. +[18] 'BinarySaver #13' started. +[18] (%Time%) 7 rows +[18] 'BinarySaver #13' finished in %Time%. +[19] 'BinarySaver #14' started. +[19] (%Time%) 7 rows +[19] 'BinarySaver #14' finished in %Time%. +[20] 'BinarySaver #15' started. +[20] (%Time%) 7 rows +[20] 'BinarySaver #15' finished in %Time%. +[21] 'BinarySaver #16' started. +[21] (%Time%) 7 rows +[21] 'BinarySaver #16' finished in %Time%. +[22] 'TextSaver: saving data #5' started. +[22] (%Time%) 119 rows +[22] 'TextSaver: saving data #5' finished in %Time%. +[23] 'BinarySaver #17' started. +[23] (%Time%) 119 rows +[23] 'BinarySaver #17' finished in %Time%. diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt index 68614d1599..f46173577b 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt @@ -7,7 +7,7 @@ Features: Vec Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' ----- TermLookupTransformer ---- +---- RowToRowMapperTransform ---- 4 columns: RawLabel: Text Names: Vec diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs new file mode 100644 index 0000000000..99fcfa0020 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -0,0 +1,440 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms.Conversions; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class ValueMappingTests : TestDataPipeBase + { + public ValueMappingTests(ITestOutputHelper output) : base(output) + { + } + + class TestClass + { + public string A; + public string B; + public string C; + } + + class TestWrong + { + public string A; + public float B; + } + + public class TestTermLookup + { + public string Label; + public int GroupId; + + [VectorType(2107)] + public float[] Features; + }; + + + [Fact] + public void ValueMapOneValueTest() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + IEnumerable values = new List() { 1, 2, 3, 4 }; + + var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter(3); + var getterE = cursor.GetGetter(4); + var getterF = cursor.GetGetter(5); + cursor.MoveNext(); + + int dValue = 0; + getterD(ref dValue); + Assert.Equal(2, dValue); + int eValue = 0; + getterE(ref eValue); + Assert.Equal(3, eValue); + int fValue = 0; + getterF(ref fValue); + Assert.Equal(1, fValue); + } + + [Fact] + public void ValueMapVectorValueTest() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory() }; + List values = new List() { + new int[] {2, 3, 4 }, + new int[] {100, 200 }, + new int[] {400, 500, 600, 700 }}; + + var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter>(3); + var getterE = cursor.GetGetter>(4); + var getterF = cursor.GetGetter>(5); + cursor.MoveNext(); + + var valuesArray = values.ToArray(); + VBuffer dValue = default; + getterD(ref dValue); + Assert.Equal(values[1].Length, dValue.Length); + VBuffer eValue = default; + getterE(ref eValue); + Assert.Equal(values[2].Length, eValue.Length); + VBuffer fValue = default; + getterF(ref fValue); + Assert.Equal(values[0].Length, fValue.Length); + } + + [Fact] + public void ValueMappingMissingKey() + { + var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + IEnumerable values = new List() { 1, 2, 3, 4 }; + + var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter(3); + var getterE = cursor.GetGetter(4); + var getterF = cursor.GetGetter(5); + cursor.MoveNext(); + + int dValue = 1; + getterD(ref dValue); + Assert.Equal(0, dValue); + int eValue = 0; + getterE(ref eValue); + Assert.Equal(3, eValue); + int fValue = 0; + getterF(ref fValue); + Assert.Equal(1, fValue); + } + + [Fact] + void TestDuplicateKeys() + { + var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "foo".AsMemory() }; + IEnumerable values = new List() { 1, 2 }; + + Assert.Throws(() => new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") })); + } + + [Fact] + public void ValueMappingOutputSchema() + { + var data = new[] { new TestClass() { A = "barTest", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + IEnumerable values = new List() { 1, 2, 3, 4 }; + + var estimator = new ValueMappingEstimator, int>(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + Assert.Equal(6, outputSchema.Count()); + Assert.True(outputSchema.TryFindColumn("D", out SchemaShape.Column dColumn)); + Assert.True(outputSchema.TryFindColumn("E", out SchemaShape.Column eColumn)); + Assert.True(outputSchema.TryFindColumn("F", out SchemaShape.Column fColumn)); + + Assert.Equal(typeof(int), dColumn.ItemType.RawType); + Assert.False(dColumn.IsKey); + + Assert.Equal(typeof(int), eColumn.ItemType.RawType); + Assert.False(eColumn.IsKey); + + Assert.Equal(typeof(int), fColumn.ItemType.RawType); + Assert.False(fColumn.IsKey); + } + + [Fact] + public void ValueMappingWithValuesAsKeyTypesOutputSchema() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + IEnumerable> values = new List>() { "t".AsMemory(), "s".AsMemory(), "u".AsMemory(), "v".AsMemory() }; + + var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + Assert.Equal(6, outputSchema.Count()); + Assert.True(outputSchema.TryFindColumn("D", out SchemaShape.Column dColumn)); + Assert.True(outputSchema.TryFindColumn("E", out SchemaShape.Column eColumn)); + Assert.True(outputSchema.TryFindColumn("F", out SchemaShape.Column fColumn)); + + Assert.Equal(typeof(uint), dColumn.ItemType.RawType); + Assert.True(dColumn.IsKey); + + Assert.Equal(typeof(uint), eColumn.ItemType.RawType); + Assert.True(eColumn.IsKey); + + Assert.Equal(typeof(uint), fColumn.ItemType.RawType); + Assert.True(fColumn.IsKey); + + var t = estimator.Fit(dataView); + } + + [Fact] + public void ValueMappingValuesAsUintKeyTypes() + { + var data = new[] { new TestClass() { A = "bar", B = "test2", C = "wahoo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + + // These are the expected key type values + IEnumerable values = new List() { 51, 25, 42, 61 }; + + var estimator = new ValueMappingEstimator, uint>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter(3); + var getterE = cursor.GetGetter(4); + var getterF = cursor.GetGetter(5); + cursor.MoveNext(); + + // The expected values will contain the actual uints and are not generated. + uint dValue = 1; + getterD(ref dValue); + Assert.Equal(25, dValue); + + // Should be 0 as test2 is a missing key + uint eValue = 0; + getterE(ref eValue); + Assert.Equal(0, eValue); + + // Testing the last key + uint fValue = 0; + getterF(ref fValue); + Assert.Equal(61, fValue); + } + + + [Fact] + public void ValueMappingValuesAsUlongKeyTypes() + { + var data = new[] { new TestClass() { A = "bar", B = "test2", C = "wahoo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + + // These are the expected key type values + IEnumerable values = new List() { 51, Int32.MaxValue, 42, 61 }; + + var estimator = new ValueMappingEstimator, ulong>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter(3); + var getterE = cursor.GetGetter(4); + var getterF = cursor.GetGetter(5); + cursor.MoveNext(); + + // The expected values will contain the actual uints and are not generated. + ulong dValue = 1; + getterD(ref dValue); + Assert.Equal(Int32.MaxValue, dValue); + + // Should be 0 as test2 is a missing key + ulong eValue = 0; + getterE(ref eValue); + Assert.Equal(0, eValue); + + // Testing the last key + ulong fValue = 0; + getterF(ref fValue); + Assert.Equal(61, fValue); + } + + [Fact] + public void ValueMappingValuesAsStringKeyTypes() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "notfound" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + + // Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned + IEnumerable> values = new List>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() }; + + var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var t = estimator.Fit(dataView); + + var result = t.Transform(dataView); + var cursor = result.GetRowCursor((col) => true); + var getterD = cursor.GetGetter(3); + var getterE = cursor.GetGetter(4); + var getterF = cursor.GetGetter(5); + cursor.MoveNext(); + + // The expected values will contain the generated key type values starting from 1. + uint dValue = 1; + getterD(ref dValue); + Assert.Equal(2, dValue); + + // eValue will equal 1 since foo1 occurs first. + uint eValue = 0; + getterE(ref eValue); + Assert.Equal(1, eValue); + + // fValue will be 0 since its missing + uint fValue = 0; + getterF(ref fValue); + Assert.Equal(0, fValue); + } + + [Fact] + public void ValueMappingWorkout() + { + var data = new[] { new TestClass() { A = "bar", B = "test", C = "foo" } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var badData = new[] { new TestWrong() { A = "bar", B = 1.2f } }; + var badDataView = ComponentCreation.CreateDataView(Env, badData); + + IEnumerable> keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; + IEnumerable values = new List() { 1, 2, 3, 4 }; + + // Workout on value mapping + var est = ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView); + } + + [Fact] + void TestCommandLine() + { + var dataFile = GetDataPath("QuotingData.csv"); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=valuemap{keyCol=ID valueCol=Text data=" + + dataFile + + @" col=A:B loader=Text{col=ID:U8:0 col=Text:TX:1 sep=, header=+} } in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineNoLoader() + { + var dataFile = GetDataPath("lm.labels.txt"); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=valuemap{data=" + + dataFile + + @" col=A:B } in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineNoLoaderWithColumnNames() + { + var dataFile = GetDataPath("lm.labels.txt"); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=valuemap{data=" + + dataFile + + @" col=A:B keyCol=foo valueCol=bar} in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineNoLoaderWithoutTreatValuesAsKeys() + { + var dataFile = GetDataPath("lm.labels.txt"); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=valuemap{data=" + + dataFile + + @" col=A:B valuesAsKeyType=-} in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestSavingAndLoading() + { + var data = new[] { new TestClass() { A = "bar", B = "foo", C = "test", } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new ValueMappingEstimator, int>(Env, + new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory() }, + new List() { 2, 43, 56 }, + new [] {("A","D"), ("B", "E")}); + var transformer = est.Fit(dataView); + using (var ms = new MemoryStream()) + { + transformer.SaveTo(Env, ms); + ms.Position = 0; + var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var result = loadedTransformer.Transform(dataView); + Assert.Equal(5, result.Schema.Count); + Assert.True(result.Schema.TryGetColumnIndex("D", out int col)); + Assert.True(result.Schema.TryGetColumnIndex("E", out col)); + } + } + + + [Fact] + void TestValueMapBackCompatTermLookup() + { + // Model generated with: xf=drop{col=A} + // Expected output: Features Label B C + var data = new[] { new TestTermLookup() { Label = "good", GroupId=1 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string termLookupModelPath = GetDataPath("backcompat/termlookup.zip"); + using (FileStream fs = File.OpenRead(termLookupModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + Assert.True(result.Schema.TryGetColumnIndex("Features", out int featureIdx)); + Assert.True(result.Schema.TryGetColumnIndex("Label", out int labelIdx)); + Assert.True(result.Schema.TryGetColumnIndex("GroupId", out int groupIdx)); + } + } + + [Fact] + void TestValueMapBackCompatTermLookupKeyTypeValue() + { + // Model generated with: xf=drop{col=A} + // Expected output: Features Label B C + var data = new[] { new TestTermLookup() { Label = "Good", GroupId=1 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string termLookupModelPath = GetDataPath("backcompat/termlookup_with_key.zip"); + using (FileStream fs = File.OpenRead(termLookupModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + Assert.True(result.Schema.TryGetColumnIndex("Features", out int featureIdx)); + Assert.True(result.Schema.TryGetColumnIndex("Label", out int labelIdx)); + Assert.True(result.Schema.TryGetColumnIndex("GroupId", out int groupIdx)); + + Assert.True(result.Schema[labelIdx].Type.IsKey); + Assert.Equal(5, result.Schema[labelIdx].Type.ItemType.KeyCount); + + var t = result.GetColumn(Env, "Label"); + uint s = t.First(); + Assert.Equal((uint)3, s); + } + } + } +} diff --git a/test/data/backcompat/termlookup.zip b/test/data/backcompat/termlookup.zip new file mode 100644 index 0000000000..242d5782ed Binary files /dev/null and b/test/data/backcompat/termlookup.zip differ diff --git a/test/data/backcompat/termlookup_with_key.zip b/test/data/backcompat/termlookup_with_key.zip new file mode 100644 index 0000000000..a5ec53b396 Binary files /dev/null and b/test/data/backcompat/termlookup_with_key.zip differ