Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a LoadColumnNameAttribute #4308

Merged
merged 4 commits into from Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 89 additions & 25 deletions src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs
Expand Up @@ -125,12 +125,21 @@ internal static DatabaseLoader CreateDatabaseLoader<TInput>(IHostEnvironment hos
var column = new Column();
column.Name = mappingAttrName?.Name ?? memberInfo.Name;

var mappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
var indexMappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
var nameMappingAttr = memberInfo.GetCustomAttribute<LoadColumnNameAttribute>();

if (mappingAttr is object)
if (indexMappingAttr is object)
{
var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
column.Source = sources;
if (nameMappingAttr is object)
{
throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}");
}

column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
}
else if (nameMappingAttr is object)
{
column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray();
}

InternalDataKind dk;
Expand Down Expand Up @@ -228,7 +237,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
public DbType Type = DbType.Single;

/// <summary>
/// Source index range(s) of the column.
/// Source index or name range(s) of the column.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Source index range(s) of the column", ShortName = "src")]
public Range[] Source;
Expand All @@ -241,7 +250,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
}

/// <summary>
/// Specifies the range of indices of input columns that should be mapped to an output column.
/// Specifies the range of indices or names of input columns that should be mapped to an output column.
/// </summary>
public sealed class Range
{
Expand All @@ -256,6 +265,19 @@ public Range(int index)
Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative");
Min = index;
Max = index;
Name = null;
}

/// <summary>
/// A range representing a single value. Will result in a scalar column.
/// </summary>
/// <param name="name">The name of the field of the table to read.</param>
public Range(string name)
{
Contracts.CheckValue(name, nameof(name));
Min = -1;
Max = -1;
Name = name;
}

/// <summary>
Expand All @@ -278,15 +300,30 @@ public Range(int min, int max)
/// <summary>
/// The minimum index of the column, inclusive.
/// </summary>
/// <remarks>
/// This is <c>-1</c> if the range represents a column name.
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
/// </remarks>
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
public int Min;

/// <summary>
/// The maximum index of the column, inclusive.
/// </summary>
/// <remarks>
/// This is <c>-1</c> if the range represents a column name.
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")]
public int Max;

/// <summary>
/// The name of the input column.
/// </summary>
/// <remarks>
/// This is <c>null</c> if the range represents an index.
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
public string Name;
tannergooding marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Force scalar columns to be treated as vectors of length one.
/// </summary>
Expand Down Expand Up @@ -318,17 +355,28 @@ public sealed class Options
/// </summary>
internal readonly struct Segment
{
public readonly string Name;
public readonly int Min;
public readonly int Lim;
public readonly bool ForceVector;

public Segment(int min, int lim, bool forceVector)
{
Contracts.Assert(0 <= min & min < lim);
Name = null;
Min = min;
Lim = lim;
ForceVector = forceVector;
}

public Segment(string name, bool forceVector)
{
Contracts.Assert(name != null);
Name = name;
Min = -1;
Lim = -1;
ForceVector = forceVector;
}
}

/// <summary>
Expand Down Expand Up @@ -368,19 +416,23 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
if (segs != null)
{
var order = Utils.GetIdentityPermutation(segs.Length);
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));

// Check that the segments are disjoint.
for (int i = 1; i < order.Length; i++)
if (segs[0].Name is null)
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
{
int a = order[i - 1];
int b = order[i];
Contracts.Assert(segs[a].Min <= segs[b].Min);
if (segs[a].Lim > segs[b].Min)
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));

// Check that the segments are disjoint.
for (int i = 1; i < order.Length; i++)
{
throw user ?
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
int a = order[i - 1];
int b = order[i];
Contracts.Assert(segs[a].Min <= segs[b].Min);
if (segs[a].Lim > segs[b].Min)
{
throw user ?
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
}
}
}

Expand All @@ -389,7 +441,7 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
for (int i = 0; i < segs.Length; i++)
{
var seg = segs[i];
size += seg.Lim - seg.Min;
size += (seg.Name is null) ? seg.Lim - seg.Min : 1;
}
Contracts.Assert(size >= segs.Length);

Expand Down Expand Up @@ -454,15 +506,23 @@ public Bindings(DatabaseLoader parent, Column[] cols)
for (int i = 0; i < segs.Length; i++)
{
var range = col.Source[i];

int min = range.Min;
ch.CheckUserArg(0 <= min, nameof(range.Min));

Segment seg;

int max = range.Max;
ch.CheckUserArg(min <= max, nameof(range.Max));
seg = new Segment(min, max + 1, range.ForceVector);
if (range.Name is null)
{
int min = range.Min;
ch.CheckUserArg(0 <= min, nameof(range.Min));

int max = range.Max;
ch.CheckUserArg(min <= max, nameof(range.Max));
seg = new Segment(min, max + 1, range.ForceVector);
}
else
{
string columnName = range.Name;
ch.CheckUserArg(columnName != null, nameof(range.Name));
seg = new Segment(columnName, range.ForceVector);
}

segs[i] = seg;
}
Expand Down Expand Up @@ -490,6 +550,7 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
// ulong: count for key range
// int: number of segments
// foreach segment:
// string id: name
// int: min
// int: lim
// byte: force vector (verWrittenCur: verIsVectorSupported)
Expand Down Expand Up @@ -532,11 +593,12 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
segs = new Segment[cseg];
for (int iseg = 0; iseg < cseg; iseg++)
{
string columnName = ctx.LoadStringOrNull();
int min = ctx.Reader.ReadInt32();
int lim = ctx.Reader.ReadInt32();
Contracts.CheckDecode(0 <= min && min < lim);
bool forceVector = ctx.Reader.ReadBoolByte();
segs[iseg] = new Segment(min, lim, forceVector);
segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector);
}
}

Expand All @@ -563,6 +625,7 @@ internal void Save(ModelSaveContext ctx)
// ulong: count for key range
// int: number of segments
// foreach segment:
// string id: name
// int: min
// int: lim
// byte: force vector (verWrittenCur: verIsVectorSupported)
Expand All @@ -588,6 +651,7 @@ internal void Save(ModelSaveContext ctx)
ctx.Writer.Write(info.Segments.Length);
foreach (var seg in info.Segments)
{
ctx.SaveStringOrNull(seg.Name);
ctx.Writer.Write(seg.Min);
ctx.Writer.Write(seg.Lim);
ctx.Writer.WriteBoolByte(seg.ForceVector);
Expand Down