Skip to content

Commit

Permalink
Make ScoreMapperSchema and its relatives not ISchema (#2107)
Browse files Browse the repository at this point in the history
* Make ScoreMapperSchema and its relatives not ISchema
1. Remove ScoreMapperSchema and its relatives entirely
2. Create static functions to generate commonly-used Schema

* Rename file

* Add helper function for creating sequence predictor's schema and its tests
  • Loading branch information
wschin committed Jan 11, 2019
1 parent 77410ea commit d26510f
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 423 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ public Bound(IHostEnvironment env, SchemaBindableCalibratedPredictor parent, Rol
throw env.Except("Predictor does not output a score");
var scoreType = _predictor.OutputSchema[_scoreCol].Type;
env.Check(!scoreType.IsVector && scoreType is NumberType);
OutputSchema = Schema.Create(new BinaryClassifierSchema());
OutputSchema = ScoreSchemaFactory.CreateBinaryClassificationSchema();
}

public Func<int, bool> GetDependencies(Func<int, bool> predicate)
Expand Down
98 changes: 9 additions & 89 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,8 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName));
}

private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
{
var outputSchema = Schema.Create(new ScoreMapperSchema(ScoreType, _scoreColumnKind));
return new SingleValueRowMapper(schema, this, outputSchema);
}
private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>
new SingleValueRowMapper(schema, this, ScoreSchemaFactory.Create(ScoreType, _scoreColumnKind));

private static string GetScoreColumnKind(IPredictor predictor)
{
Expand Down Expand Up @@ -480,7 +477,7 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto

_parent = parent;
InputRoleMappedSchema = schema;
OutputSchema = Schema.Create(new BinaryClassifierSchema());
OutputSchema = ScoreSchemaFactory.CreateBinaryClassificationSchema();

if (schema.Feature?.Type is ColumnType typeSrc)
{
Expand Down Expand Up @@ -591,9 +588,9 @@ private static VersionInfo GetVersionInfo()
}

private readonly IQuantileValueMapper _qpred;
private readonly Double[] _quantiles;
private readonly double[] _quantiles;

public SchemaBindableQuantileRegressionPredictor(IPredictor predictor, Double[] quantiles)
public SchemaBindableQuantileRegressionPredictor(IPredictor predictor, double[] quantiles)
: base(predictor)
{
var qpred = Predictor as IQuantileValueMapper;
Expand All @@ -613,7 +610,7 @@ private SchemaBindableQuantileRegressionPredictor(IHostEnvironment env, ModelLoa
// *** Binary format ***
// <base info>
// int: the number of quantiles
// Double[]: the quantiles
// double[]: the quantiles

var qpred = Predictor as IQuantileValueMapper;
Contracts.CheckDecode(qpred != null);
Expand All @@ -633,7 +630,7 @@ public override void Save(ModelSaveContext ctx)
// *** Binary format ***
// <base info>
// int: the number of quantiles
// Double[]: the quantiles
// double[]: the quantiles

base.Save(ctx);
ctx.Writer.WriteDoubleArray(_quantiles);
Expand All @@ -646,10 +643,8 @@ public static SchemaBindableQuantileRegressionPredictor Create(IHostEnvironment
return new SchemaBindableQuantileRegressionPredictor(env, ctx);
}

private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
{
return new SingleValueRowMapper(schema, this, Schema.Create(new SchemaImpl(ScoreType, _quantiles)));
}
private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>
new SingleValueRowMapper(schema, this, ScoreSchemaFactory.CreateQuantileRegressionSchema(ScoreType, _quantiles));

protected override Delegate GetPredictionGetter(Row input, int colSrc)
{
Expand Down Expand Up @@ -680,80 +675,5 @@ protected override Delegate GetPredictionGetter(Row input, int colSrc)
};
return del;
}

private sealed class SchemaImpl : ScoreMapperSchemaBase
{
private readonly string[] _slotNames;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _getSlotNames;

public SchemaImpl(ColumnType scoreType, Double[] quantiles)
: base(scoreType, MetadataUtils.Const.ScoreColumnKind.QuantileRegression)
{
Contracts.Assert(Utils.Size(quantiles) > 0);
_slotNames = new string[quantiles.Length];
for (int i = 0; i < _slotNames.Length; i++)
_slotNames[i] = string.Format("Quantile-{0}", quantiles[i]);
_getSlotNames = GetSlotNames;
}

public override int ColumnCount { get { return 1; } }

public override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Contracts.Assert(Utils.Size(_slotNames) > 0);
Contracts.Assert(col == 0);

var items = base.GetMetadataTypes(col);
items = items.Prepend(MetadataUtils.GetSlotNamesPair(_slotNames.Length));
return items;
}

public override ColumnType GetMetadataTypeOrNull(string kind, int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.Assert(Utils.Size(_slotNames) > 0);
Contracts.Assert(col == 0);

if (kind == MetadataUtils.Kinds.SlotNames)
return MetadataUtils.GetNamesType(_slotNames.Length);
return base.GetMetadataTypeOrNull(kind, col);
}

public override void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Contracts.CheckNonEmpty(kind, nameof(kind));
Contracts.Assert(Utils.Size(_slotNames) > 0);
Contracts.Assert(col == 0);
Contracts.Assert(_getSlotNames != null);

if (kind == MetadataUtils.Kinds.SlotNames)
_getSlotNames.Marshal(col, ref value);
else
base.GetMetadata(kind, col, ref value);
}

public override ColumnType GetColumnType(int col)
{
Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
Contracts.Assert(col == 0);
Contracts.Assert(Utils.Size(_slotNames) > 0);
return new VectorType(NumberType.Float, _slotNames.Length);
}

private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
{
Contracts.Assert(iinfo == 0);
Contracts.Assert(Utils.Size(_slotNames) > 0);

int size = Utils.Size(_slotNames);
var editor = VBufferEditor.Create(ref dst, size);
for (int i = 0; i < _slotNames.Length; i++)
editor.Values[i] = _slotNames[i].AsMemory();
dst = editor.Commit();
}
}
}
}
Loading

0 comments on commit d26510f

Please sign in to comment.