Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ IEstimator<ITransformer> dynamicPipe = learningPipeline.AsDynamic;
var binaryTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "Features");

// Append the OVA learner to the pipeline.
dynamicPipe = dynamicPipe.Append(new Ova(mlContext, binaryTrainer));
dynamicPipe = dynamicPipe.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer));

// At this point, we have a choice. We could continue working with the dynamically-typed pipeline, and
// ultimately call dynamicPipe.Fit(data.AsDynamic) to get the model, or we could go back into the static world.
Expand Down
21 changes: 3 additions & 18 deletions src/Microsoft.ML.Core/Prediction/IPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,12 @@ public interface IPredictorProducing<out TResult> : IPredictor
{
}

/// <summary>
/// Strongly typed generic predictor that takes data instances (feature containers)
/// and produces predictions for them.
/// </summary>
/// <typeparam name="TFeatures"> Type of features container (instance) on which to make predictions</typeparam>
/// <typeparam name="TResult"> Type of prediction result</typeparam>
public interface IPredictor<in TFeatures, out TResult> : IPredictorProducing<TResult>
{
/// <summary>
/// Produce a prediction for provided features
/// </summary>
/// <param name="features"> Data instance </param>
/// <returns> Prediction </returns>
TResult Predict(TFeatures features);
}

/// <summary>
/// A predictor that produces values and distributions of the indicated types.
/// REVIEW: Determine whether this is just a temporary shim or long term solution.
/// Note that from a public API perspective this is bad.
/// </summary>
public interface IDistPredictorProducing<out TResult, out TResultDistribution> : IPredictorProducing<TResult>
[BestFriend]
internal interface IDistPredictorProducing<out TResult, out TResultDistribution> : IPredictorProducing<TResult>
{
}
}
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ private protected virtual void SaveCore(ModelSaveContext ctx)
/// <summary>
/// This emits a warning if there is Normalizer sub-model.
/// </summary>
public static bool WarnOnOldNormalizer(ModelLoadContext ctx, Type typePredictor, IChannelProvider provider)
[BestFriend]
private protected static bool WarnOnOldNormalizer(ModelLoadContext ctx, Type typePredictor, IChannelProvider provider)
{
Contracts.CheckValue(provider, nameof(provider));
provider.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public abstract class LearnerInputBase
{
/// <summary>
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
/// that the user iwll use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
/// method.
/// </summary>
[BestFriend]
Expand Down
11 changes: 5 additions & 6 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ internal interface ISelfCalibratingPredictor
IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows);
}

[BestFriend]
public abstract class CalibratedPredictorBase :
IDistPredictorProducing<float, float>,
ICanSaveInIniFormat,
Expand All @@ -131,7 +130,7 @@ public abstract class CalibratedPredictorBase :
public ICalibrator Calibrator { get; }
public PredictionKind PredictionKind => SubPredictor.PredictionKind;

protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<float> predictor, ICalibrator calibrator)
private protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<float> predictor, ICalibrator calibrator)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Expand Down Expand Up @@ -185,20 +184,20 @@ IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKe
return null;
}

protected void SaveCore(ModelSaveContext ctx)
private protected void SaveCore(ModelSaveContext ctx)
{
ctx.SaveModel(SubPredictor, ModelFileUtils.DirPredictor);
ctx.SaveModel(Calibrator, @"Calibrator");
}

protected static IPredictorProducing<float> GetPredictor(IHostEnvironment env, ModelLoadContext ctx)
private protected static IPredictorProducing<float> GetPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
IPredictorProducing<float> predictor;
ctx.LoadModel<IPredictorProducing<float>, SignatureLoadModel>(env, out predictor, ModelFileUtils.DirPredictor);
return predictor;
}

protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContext ctx)
private protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContext ctx)
{
ICalibrator calibrator;
ctx.LoadModel<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
Expand All @@ -221,7 +220,7 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<float> predictor, ICalibrator calibrator)
private protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<float> predictor, ICalibrator calibrator)
: base(env, name, predictor, calibrator)
{
Contracts.AssertValue(Host);
Expand Down
59 changes: 35 additions & 24 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ namespace Microsoft.ML.Data
/// <typeparam name="TScorer">The Scorer used by this <see cref="IPredictionTransformer{TModel}"/></typeparam>
public abstract class PredictionTransformerBase<TModel, TScorer> : IPredictionTransformer<TModel>
where TScorer : RowToRowScorerBase
where TModel : class, IPredictor
where TModel : class
{
/// <summary>
/// The model.
/// </summary>
public TModel Model { get; }

private protected IPredictor ModelAsPredictor => (IPredictor)Model;

protected const string DirModel = "Model";
protected const string DirTransSchema = "TrainSchema";
protected readonly IHost Host;
Expand All @@ -55,19 +57,22 @@ public abstract class PredictionTransformerBase<TModel, TScorer> : IPredictionTr

protected abstract TScorer Scorer { get; set; }

protected PredictionTransformerBase(IHost host, TModel model, Schema trainSchema)
[BestFriend]
private protected PredictionTransformerBase(IHost host, TModel model, Schema trainSchema)
{
Contracts.CheckValue(host, nameof(host));
Host = host;

Host.CheckValue(trainSchema, nameof(trainSchema));
Host.CheckValue(model, nameof(model));
Host.CheckParam(model is IPredictor, nameof(model));
Model = model;

Host.CheckValue(trainSchema, nameof(trainSchema));
TrainSchema = trainSchema;
}

protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)
[BestFriend]
private protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)
{
Host = host;

Expand Down Expand Up @@ -146,7 +151,7 @@ protected void SaveModel(ModelSaveContext ctx)
/// <typeparam name="TModel">The model used to transform the data.</typeparam>
/// <typeparam name="TScorer">The scorer used on this PredictionTransformer.</typeparam>
public abstract class SingleFeaturePredictionTransformerBase<TModel, TScorer> : PredictionTransformerBase<TModel, TScorer>, ISingleFeaturePredictionTransformer<TModel>, ICanSaveModel
where TModel : class, IPredictor
where TModel : class
where TScorer : RowToRowScorerBase
{
/// <summary>
Expand All @@ -168,7 +173,7 @@ public abstract class SingleFeaturePredictionTransformerBase<TModel, TScorer> :
/// <param name="model">The model used for scoring.</param>
/// <param name="trainSchema">The schema of the training data.</param>
/// <param name="featureColumn">The feature column name.</param>
public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn)
private protected SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn)
: base(host, model, trainSchema)
{
FeatureColumn = featureColumn;
Expand All @@ -179,10 +184,10 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema t
else
FeatureColumnType = trainSchema[col].Type;

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor);
}

internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
: base(host, ctx)
{
FeatureColumn = ctx.LoadStringOrNull();
Expand All @@ -194,7 +199,7 @@ internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx
else
FeatureColumnType = TrainSchema[col].Type;

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor);
}

public override Schema GetOutputSchema(Schema inputSchema)
Expand Down Expand Up @@ -237,12 +242,13 @@ protected virtual GenericScorer GetGenericScorer()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class AnomalyPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, BinaryClassifierScorer>
where TModel : class, IPredictorProducing<float>
where TModel : class
{
public readonly string ThresholdColumn;
public readonly float Threshold;

public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
[BestFriend]
internal AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer<TModel>)),model, inputSchema, featureColumn)
{
Expand All @@ -253,8 +259,8 @@ public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema i
SetScorer();
}

public AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), ctx)
internal AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer<TModel>)), ctx)
{
// *** Binary format ***
// <base info>
Expand Down Expand Up @@ -305,12 +311,13 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class BinaryPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, BinaryClassifierScorer>
where TModel : class, IPredictorProducing<float>
where TModel : class
{
public readonly string ThresholdColumn;
public readonly float Threshold;

public BinaryPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
[BestFriend]
internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Expand Down Expand Up @@ -373,11 +380,12 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class MulticlassPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, MultiClassClassifierScorer>
where TModel : class, IPredictorProducing<VBuffer<float>>
where TModel : class
{
private readonly string _trainLabelColumn;

public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, string labelColumn)
[BestFriend]
internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, string labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Host.CheckValueOrNull(labelColumn);
Expand Down Expand Up @@ -434,9 +442,10 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class RegressionPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, GenericScorer>
where TModel : class, IPredictorProducing<float>
where TModel : class
{
public RegressionPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn)
[BestFriend]
internal RegressionPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Scorer = GetGenericScorer();
Expand Down Expand Up @@ -475,9 +484,10 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class RankingPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, GenericScorer>
where TModel : class, IPredictorProducing<float>
where TModel : class
{
public RankingPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn)
[BestFriend]
internal RankingPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Scorer = GetGenericScorer();
Expand Down Expand Up @@ -516,9 +526,10 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class ClusteringPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, ClusteringScorer>
where TModel : class, IPredictorProducing<VBuffer<float>>
where TModel : class
{
public ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
[BestFriend]
internal ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ClusteringPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Expand All @@ -529,7 +540,7 @@ public ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schem
Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
}

public ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ClusteringPredictionTransformer<TModel>)), ctx)
{
// *** Binary format ***
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/ITrainerEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.ML.Training
{
public interface ITrainerEstimator<out TTransformer, out TPredictor> : IEstimator<TTransformer>
where TTransformer : ISingleFeaturePredictionTransformer<TPredictor>
where TPredictor : IPredictor
where TPredictor : class
{
TrainerInfo Info { get; }

Expand Down
15 changes: 9 additions & 6 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ namespace Microsoft.ML.Training
/// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column.
/// It produces a 'prediction transformer'.
/// </summary>
public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<IPredictor>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : IPredictor
where TModel : class
{
/// <summary>
/// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid
Expand Down Expand Up @@ -85,10 +85,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
/// </summary>
protected abstract SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema);

TModel ITrainer<TModel>.Train(TrainContext context)
IPredictor ITrainer<IPredictor>.Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
return TrainModelCore(context);
var pred = TrainModelCore(context) as IPredictor;
Host.Check(pred != null, "Training did not return a predictor.");
return pred;
}

private void CheckInputSchema(SchemaShape inputSchema)
Expand Down Expand Up @@ -147,7 +149,7 @@ protected TTransformer TrainTransformer(IDataView trainSet,
private protected virtual RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name);

IPredictor ITrainer.Train(TrainContext context) => ((ITrainer<TModel>)this).Train(context);
IPredictor ITrainer.Train(TrainContext context) => ((ITrainer<IPredictor>)this).Train(context);
}

/// <summary>
Expand All @@ -157,7 +159,8 @@ private protected virtual RoleMappedData MakeRoles(IDataView data) =>
/// </summary>
public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : IPredictor
where TModel : class

{
/// <summary>
/// The optional groupID column that the ranking trainers expects.
Expand Down
Loading