diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index f7c88d9b0e..61d6858cc1 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -885,7 +885,7 @@ IEstimator dynamicPipe = learningPipeline.AsDynamic; var binaryTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "Features"); // Append the OVA learner to the pipeline. -dynamicPipe = dynamicPipe.Append(new Ova(mlContext, binaryTrainer)); +dynamicPipe = dynamicPipe.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer)); // At this point, we have a choice. We could continue working with the dynamically-typed pipeline, and // ultimately call dynamicPipe.Fit(data.AsDynamic) to get the model, or we could go back into the static world. diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index 682cec417d..eccd0423be 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -44,27 +44,12 @@ public interface IPredictorProducing : IPredictor { } - /// - /// Strongly typed generic predictor that takes data instances (feature containers) - /// and produces predictions for them. - /// - /// Type of features container (instance) on which to make predictions - /// Type of prediction result - public interface IPredictor : IPredictorProducing - { - /// - /// Produce a prediction for provided features - /// - /// Data instance - /// Prediction - TResult Predict(TFeatures features); - } - /// /// A predictor that produces values and distributions of the indicated types. - /// REVIEW: Determine whether this is just a temporary shim or long term solution. + /// Note that from a public API perspective this is bad. /// - public interface IDistPredictorProducing : IPredictorProducing + [BestFriend] + internal interface IDistPredictorProducing : IPredictorProducing { } } diff --git a/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs index 0366772768..b86df7d35b 100644 --- a/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs +++ b/src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs @@ -69,7 +69,8 @@ private protected virtual void SaveCore(ModelSaveContext ctx) /// /// This emits a warning if there is Normalizer sub-model. /// - public static bool WarnOnOldNormalizer(ModelLoadContext ctx, Type typePredictor, IChannelProvider provider) + [BestFriend] + private protected static bool WarnOnOldNormalizer(ModelLoadContext ctx, Type typePredictor, IChannelProvider provider) { Contracts.CheckValue(provider, nameof(provider)); provider.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index 10115f2565..435a1ba1a8 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -45,7 +45,7 @@ public abstract class LearnerInputBase { /// /// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is - /// that the user iwll use the or some other train + /// that the user will use the or some other train /// method. /// [BestFriend] diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index ae023af9f3..6aa21bd53b 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -116,7 +116,6 @@ internal interface ISelfCalibratingPredictor IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows); } - [BestFriend] public abstract class CalibratedPredictorBase : IDistPredictorProducing, ICanSaveInIniFormat, @@ -131,7 +130,7 @@ public abstract class CalibratedPredictorBase : public ICalibrator Calibrator { get; } public PredictionKind PredictionKind => SubPredictor.PredictionKind; - protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) + private protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -185,20 +184,20 @@ IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKe return null; } - protected void SaveCore(ModelSaveContext ctx) + private protected void SaveCore(ModelSaveContext ctx) { ctx.SaveModel(SubPredictor, ModelFileUtils.DirPredictor); ctx.SaveModel(Calibrator, @"Calibrator"); } - protected static IPredictorProducing GetPredictor(IHostEnvironment env, ModelLoadContext ctx) + private protected static IPredictorProducing GetPredictor(IHostEnvironment env, ModelLoadContext ctx) { IPredictorProducing predictor; ctx.LoadModel, SignatureLoadModel>(env, out predictor, ModelFileUtils.DirPredictor); return predictor; } - protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContext ctx) + private protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContext ctx) { ICalibrator calibrator; ctx.LoadModel(env, out calibrator, @"Calibrator"); @@ -221,7 +220,7 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) + private protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) : base(env, name, predictor, calibrator) { Contracts.AssertValue(Host); diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index fbc801410e..12d21b516c 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -37,13 +37,15 @@ namespace Microsoft.ML.Data /// The Scorer used by this public abstract class PredictionTransformerBase : IPredictionTransformer where TScorer : RowToRowScorerBase - where TModel : class, IPredictor + where TModel : class { /// /// The model. /// public TModel Model { get; } + private protected IPredictor ModelAsPredictor => (IPredictor)Model; + protected const string DirModel = "Model"; protected const string DirTransSchema = "TrainSchema"; protected readonly IHost Host; @@ -55,19 +57,22 @@ public abstract class PredictionTransformerBase : IPredictionTr protected abstract TScorer Scorer { get; set; } - protected PredictionTransformerBase(IHost host, TModel model, Schema trainSchema) + [BestFriend] + private protected PredictionTransformerBase(IHost host, TModel model, Schema trainSchema) { Contracts.CheckValue(host, nameof(host)); Host = host; - Host.CheckValue(trainSchema, nameof(trainSchema)); + Host.CheckValue(model, nameof(model)); + Host.CheckParam(model is IPredictor, nameof(model)); Model = model; Host.CheckValue(trainSchema, nameof(trainSchema)); TrainSchema = trainSchema; } - protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) + [BestFriend] + private protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; @@ -146,7 +151,7 @@ protected void SaveModel(ModelSaveContext ctx) /// The model used to transform the data. /// The scorer used on this PredictionTransformer. public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel - where TModel : class, IPredictor + where TModel : class where TScorer : RowToRowScorerBase { /// @@ -168,7 +173,7 @@ public abstract class SingleFeaturePredictionTransformerBase : /// The model used for scoring. /// The schema of the training data. /// The feature column name. - public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn) + private protected SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn) : base(host, model, trainSchema) { FeatureColumn = featureColumn; @@ -179,10 +184,10 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema t else FeatureColumnType = trainSchema[col].Type; - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor); } - internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) + private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) : base(host, ctx) { FeatureColumn = ctx.LoadStringOrNull(); @@ -194,7 +199,7 @@ internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx else FeatureColumnType = TrainSchema[col].Type; - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor); } public override Schema GetOutputSchema(Schema inputSchema) @@ -237,12 +242,13 @@ protected virtual GenericScorer GetGenericScorer() /// /// An implementation of the public sealed class AnomalyPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing + where TModel : class { public readonly string ThresholdColumn; public readonly float Threshold; - public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, + [BestFriend] + internal AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer)),model, inputSchema, featureColumn) { @@ -253,8 +259,8 @@ public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, Schema i SetScorer(); } - public AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), ctx) + internal AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer)), ctx) { // *** Binary format *** // @@ -305,12 +311,13 @@ private static VersionInfo GetVersionInfo() /// /// An implementation of the public sealed class BinaryPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing + where TModel : class { public readonly string ThresholdColumn; public readonly float Threshold; - public BinaryPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, + [BestFriend] + internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) { @@ -373,11 +380,12 @@ private static VersionInfo GetVersionInfo() /// /// An implementation of the public sealed class MulticlassPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing> + where TModel : class { private readonly string _trainLabelColumn; - public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, string labelColumn) + [BestFriend] + internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, string labelColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckValueOrNull(labelColumn); @@ -434,9 +442,10 @@ private static VersionInfo GetVersionInfo() /// /// An implementation of the public sealed class RegressionPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing + where TModel : class { - public RegressionPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn) + [BestFriend] + internal RegressionPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, featureColumn) { Scorer = GetGenericScorer(); @@ -475,9 +484,10 @@ private static VersionInfo GetVersionInfo() /// /// An implementation of the public sealed class RankingPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing + where TModel : class { - public RankingPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn) + [BestFriend] + internal RankingPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer)), model, inputSchema, featureColumn) { Scorer = GetGenericScorer(); @@ -516,9 +526,10 @@ private static VersionInfo GetVersionInfo() /// /// An implementation of the public sealed class ClusteringPredictionTransformer : SingleFeaturePredictionTransformerBase - where TModel : class, IPredictorProducing> + where TModel : class { - public ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, + [BestFriend] + internal ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ClusteringPredictionTransformer)), model, inputSchema, featureColumn) { @@ -529,7 +540,7 @@ public ClusteringPredictionTransformer(IHostEnvironment env, TModel model, Schem Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); } - public ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) + internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ClusteringPredictionTransformer)), ctx) { // *** Binary format *** diff --git a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs index 0384203a9e..f02d8afa9e 100644 --- a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs +++ b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs @@ -8,7 +8,7 @@ namespace Microsoft.ML.Training { public interface ITrainerEstimator : IEstimator where TTransformer : ISingleFeaturePredictionTransformer - where TPredictor : IPredictor + where TPredictor : class { TrainerInfo Info { get; } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 7db8d2e47d..4100987977 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -14,9 +14,9 @@ namespace Microsoft.ML.Training /// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column. /// It produces a 'prediction transformer'. /// - public abstract class TrainerEstimatorBase : ITrainerEstimator, ITrainer + public abstract class TrainerEstimatorBase : ITrainerEstimator, ITrainer where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { /// /// A standard string to use in errors or warnings by subclasses, to communicate the idea that no valid @@ -85,10 +85,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// protected abstract SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema); - TModel ITrainer.Train(TrainContext context) + IPredictor ITrainer.Train(TrainContext context) { Host.CheckValue(context, nameof(context)); - return TrainModelCore(context); + var pred = TrainModelCore(context) as IPredictor; + Host.Check(pred != null, "Training did not return a predictor."); + return pred; } private void CheckInputSchema(SchemaShape inputSchema) @@ -147,7 +149,7 @@ protected TTransformer TrainTransformer(IDataView trainSet, private protected virtual RoleMappedData MakeRoles(IDataView data) => new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name); - IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); + IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); } /// @@ -157,7 +159,8 @@ private protected virtual RoleMappedData MakeRoles(IDataView data) => /// public abstract class TrainerEstimatorBaseWithGroupId : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class + { /// /// The optional groupID column that the ranking trainers expects. diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 77c20c46d7..eb2d653658 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Transforms; namespace Microsoft.ML.Training { @@ -398,6 +399,93 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(Optional weigh return default; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } + + /// + /// This is a shim class to translate the more contemporaneous + /// style transformers into the older now disfavored idiom, for components that still + /// need to operate via that older mechanism. (Mostly command line invocations, and so on.). + /// + /// The type of the new model parameters. + /// The type corresponding to the legacy predictor. + private sealed class TrainerEstimatorToTrainerShim : ITrainer + where TModel : class, TPredictor + where TPredictor : IPredictor + { + public TrainerInfo Info { get; } + public PredictionKind PredictionKind { get; } + + private readonly ITrainerEstimator, TModel> _trainer; + private readonly IHostEnvironment _env; + + public TrainerEstimatorToTrainerShim(IHostEnvironment env, ITrainerEstimator, TModel> trainer) + { + Contracts.AssertValue(env); + _env = env; + _env.AssertValue(trainer); + _env.Assert(trainer is ITrainer); + + var oldTrainer = (ITrainer)trainer; + Info = oldTrainer.Info; + PredictionKind = oldTrainer.PredictionKind; + + _trainer = trainer; + } + + public TPredictor Train(TrainContext context) + { + _env.CheckValue(context, nameof(context)); + // For the purpose of mapping into the estimator, we assume that the input estimator does not have + // any custom overrides for the column names defined. + var tschema = context.TrainingSet.Schema; + var nameMap = new List<(string outName, string inName)>(); + if (tschema.Feature?.Name is string fname && fname != DefaultColumnNames.Features) + nameMap.Add((DefaultColumnNames.Features, fname)); + if (tschema.Label?.Name is string lname && lname != DefaultColumnNames.Label) + nameMap.Add((DefaultColumnNames.Label, lname)); + if (tschema.Weight?.Name is string wname && wname != DefaultColumnNames.Weight) + nameMap.Add((DefaultColumnNames.Weight, wname)); + if (tschema.Group?.Name is string gname && gname != DefaultColumnNames.GroupId) + nameMap.Add((DefaultColumnNames.GroupId, gname)); + if (tschema.Group?.Name is string iname && iname != DefaultColumnNames.Item) + nameMap.Add((DefaultColumnNames.Item, iname)); + if (tschema.Group?.Name is string uname && uname != DefaultColumnNames.User) + nameMap.Add((DefaultColumnNames.User, uname)); + + var data = context.TrainingSet.Data; + if (nameMap.Count > 0) + { + var estimator = new ColumnCopyingEstimator(_env, nameMap.ToArray()); + data = estimator.Fit(data).Transform(data); + } + var predictionTransformer = _trainer.Fit(data); + var model = predictionTransformer.Model; + if (model is TPredictor pred) + return pred; + throw _env.Except($"Training resulted in a model of type {model.GetType().Name}."); + } + + IPredictor ITrainer.Train(TrainContext context) => Train(context); + } + + /// + /// This is a shim for legacy code that takes the more modern + /// interface, and maps it to the legacy code that wants an . The goal should be to + /// remove reliance on that interface if possible, but this may not be practical in the immediate term, so for the benefit + /// of scenarios like this we have this convenience function. + /// + /// The trainer estimator type. + /// The type of the model produced by the estimator. + /// The type of the predictor to be produced by the predictor. + /// The host environment. + /// The trainer estimator. + /// An implementation of the legacy trainer interface. + public static ITrainer MapTrainerEstimatorToTrainer(IHostEnvironment env, T trainer) + where T : ITrainerEstimator, TModel>, ITrainer + where TModel : class, TPredictor + where TPredictor : IPredictor + { + return new TrainerEstimatorToTrainerShim(env, trainer); + } } /// diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs index f295697f6e..83d8fd0a96 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs @@ -14,19 +14,19 @@ namespace Microsoft.ML.Ensemble.EntryPoints { [TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)] - public sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory + internal sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory { public IBinaryDiversityMeasure CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure(); } [TlcModule.Component(Name = RegressionDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)] - public sealed class RegressionDisagreementDiversityFactory : ISupportRegressionDiversityMeasureFactory + internal sealed class RegressionDisagreementDiversityFactory : ISupportRegressionDiversityMeasureFactory { public IRegressionDiversityMeasure CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure(); } [TlcModule.Component(Name = MultiDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)] - public sealed class MultiDisagreementDiversityFactory : ISupportMulticlassDiversityMeasureFactory + internal sealed class MultiDisagreementDiversityFactory : ISupportMulticlassDiversityMeasureFactory { public IMulticlassDiversityMeasure CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure(); } diff --git a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs index fadd096810..733573d39f 100644 --- a/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs +++ b/src/Microsoft.ML.Ensemble/FeatureSubsetModel.cs @@ -8,18 +8,23 @@ namespace Microsoft.ML.Ensemble { - public sealed class FeatureSubsetModel where TPredictor : IPredictor + internal sealed class FeatureSubsetModel { - public readonly TPredictor Predictor; + public readonly IPredictorProducing Predictor; public readonly BitArray SelectedFeatures; public readonly int Cardinality; public KeyValuePair[] Metrics { get; set; } - public FeatureSubsetModel(TPredictor predictor, BitArray features = null, + public FeatureSubsetModel(IPredictorProducing predictor, BitArray features = null, KeyValuePair[] metrics = null) { - Predictor = predictor; + if (!(predictor is IPredictorProducing predictorProducing)) + { + throw Contracts.ExceptParam(nameof(predictor), + $"Input predictor did not have the expected output type {typeof(TOutput).Name}."); + } + Predictor = predictorProducing; int card; if (features != null && (card = Utils.GetCardinality(features)) < features.Count) { diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index 6dbbb96235..81d0c545f9 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -122,7 +122,7 @@ private void CheckMeta() throw Contracts.Except("Stacking predictor output type is unsupported: {0}", ivm.OutputType); } - public void Train(List>> models, RoleMappedData data, IHostEnvironment env) + public void Train(List> models, RoleMappedData data, IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(Stacking.LoadName); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs index 53aec1f8d8..bbbbecf217 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs @@ -30,7 +30,7 @@ public interface IOutputCombiner : IOutputCombiner internal interface IStackingTrainer { - void Train(List>> models, RoleMappedData data, IHostEnvironment env); + void Train(List> models, RoleMappedData data, IHostEnvironment env); Single ValidationDatasetProportion { get; } } diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs index f6ac5da19e..b4106a7b1d 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/BaseDisagreementDiversityMeasure.cs @@ -8,10 +8,10 @@ namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { - public abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure + internal abstract class BaseDisagreementDiversityMeasure : IDiversityMeasure { - public List> CalculateDiversityMeasure(IList>> models, - ConcurrentDictionary>, TOutput[]> predictions) + public List> CalculateDiversityMeasure(IList> models, + ConcurrentDictionary, TOutput[]> predictions) { Contracts.Assert(models.Count > 1); Contracts.Assert(predictions.Count == models.Count); diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs index c5af6268c3..bb0127003f 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { - public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure + internal sealed class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IBinaryDiversityMeasure { public const string UserName = "Disagreement Diversity Measure"; public const string LoadName = "DisagreementDiversityMeasure"; diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs index 621671c2f0..b182ad9963 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/ModelDiversityMetric.cs @@ -6,10 +6,10 @@ namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { - public class ModelDiversityMetric + internal sealed class ModelDiversityMetric { - public FeatureSubsetModel> ModelX { get; set; } - public FeatureSubsetModel> ModelY { get; set; } + public FeatureSubsetModel ModelX { get; set; } + public FeatureSubsetModel ModelY { get; set; } public Single DiversityNumber { get; set; } } } diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs index 1645ae9a8a..ffb6839818 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { - public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure + internal sealed class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure>, IMulticlassDiversityMeasure { public const string LoadName = "MultiDisagreementDiversityMeasure"; diff --git a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs index 67b9eb0d21..2b60e44f1c 100644 --- a/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Ensemble.Selector.DiversityMeasure { - public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure + internal sealed class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure, IRegressionDiversityMeasure { public const string LoadName = "RegressionDisagreementDiversityMeasure"; diff --git a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs index 4209761c94..8ac30f3818 100644 --- a/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs +++ b/src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs @@ -11,33 +11,33 @@ namespace Microsoft.ML.Ensemble.Selector { - public interface IDiversityMeasure + internal interface IDiversityMeasure { - List> CalculateDiversityMeasure(IList>> models, - ConcurrentDictionary>, TOutput[]> predictions); + List> CalculateDiversityMeasure(IList> models, + ConcurrentDictionary, TOutput[]> predictions); } - public delegate void SignatureEnsembleDiversityMeasure(); + internal delegate void SignatureEnsembleDiversityMeasure(); - public interface IBinaryDiversityMeasure : IDiversityMeasure + internal interface IBinaryDiversityMeasure : IDiversityMeasure { } - public interface IRegressionDiversityMeasure : IDiversityMeasure + internal interface IRegressionDiversityMeasure : IDiversityMeasure { } - public interface IMulticlassDiversityMeasure : IDiversityMeasure> + internal interface IMulticlassDiversityMeasure : IDiversityMeasure> { } [TlcModule.ComponentKind("EnsembleBinaryDiversityMeasure")] - public interface ISupportBinaryDiversityMeasureFactory : IComponentFactory + internal interface ISupportBinaryDiversityMeasureFactory : IComponentFactory { } [TlcModule.ComponentKind("EnsembleRegressionDiversityMeasure")] - public interface ISupportRegressionDiversityMeasureFactory : IComponentFactory + internal interface ISupportRegressionDiversityMeasureFactory : IComponentFactory { } [TlcModule.ComponentKind("EnsembleMulticlassDiversityMeasure")] - public interface ISupportMulticlassDiversityMeasureFactory : IComponentFactory + internal interface ISupportMulticlassDiversityMeasureFactory : IComponentFactory { } } diff --git a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs index 64ecf45bce..e5b35082ee 100644 --- a/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/ISubModelSelector.cs @@ -11,9 +11,9 @@ namespace Microsoft.ML.Ensemble.Selector { internal interface ISubModelSelector { - IList>> Prune(IList>> models); + IList> Prune(IList> models); - void CalculateMetrics(FeatureSubsetModel> model, ISubsetSelector subsetSelector, Subset subset, + void CalculateMetrics(FeatureSubsetModel model, ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics); Single ValidationDatasetProportion { get; } diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs index cc6292e9b2..55add475f5 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseBestPerformanceSelector.cs @@ -21,13 +21,13 @@ protected BaseBestPerformanceSelector(ArgumentsBase args, IHostEnvironment env, { } - public override void CalculateMetrics(FeatureSubsetModel> model, + public override void CalculateMetrics(FeatureSubsetModel model, ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics) { base.CalculateMetrics(model, subsetSelector, subset, batch, true); } - public override IList>> Prune(IList>> models) + public override IList> Prune(IList> models) { using (var ch = Host.Start("Pruning")) { @@ -65,7 +65,7 @@ protected static string FindMetricName(Type type, object value) return null; } - private sealed class ModelPerformanceComparer : IComparer>> + private sealed class ModelPerformanceComparer : IComparer> { private readonly string _metricName; private readonly bool _isAscMetric; @@ -78,7 +78,7 @@ public ModelPerformanceComparer(string metricName, bool isAscMetric) _isAscMetric = isAscMetric; } - public int Compare(FeatureSubsetModel> x, FeatureSubsetModel> y) + public int Compare(FeatureSubsetModel x, FeatureSubsetModel y) { if (x == null || y == null) return (x == null ? 0 : 1) - (y == null ? 0 : 1); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs index 98d5c47c02..e1edba726e 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs @@ -20,14 +20,14 @@ public abstract class DiverseSelectorArguments : ArgumentsBase } private readonly IComponentFactory> _diversityMetricType; - private ConcurrentDictionary>, TOutput[]> _predictions; + private ConcurrentDictionary, TOutput[]> _predictions; private protected BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name, IComponentFactory> diversityMetricType) : base(args, env, name) { _diversityMetricType = diversityMetricType; - _predictions = new ConcurrentDictionary>, TOutput[]>(); + _predictions = new ConcurrentDictionary, TOutput[]>(); } protected IDiversityMeasure CreateDiversityMetric() @@ -35,7 +35,7 @@ protected IDiversityMeasure CreateDiversityMetric() return _diversityMetricType.CreateComponent(Host); } - public override void CalculateMetrics(FeatureSubsetModel> model, + public override void CalculateMetrics(FeatureSubsetModel model, ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics) { base.CalculateMetrics(model, subsetSelector, subset, batch, needMetrics); @@ -67,7 +67,7 @@ public override void CalculateMetrics(FeatureSubsetModel /// /// - public override IList>> Prune(IList>> models) + public override IList> Prune(IList> models) { if (models.Count <= 1) return models; @@ -85,7 +85,7 @@ public override IList>> Prune(IL modelCountToBeSelected++; // 3. Take the most diverse classifiers - var selectedModels = new List>>(); + var selectedModels = new List>(); foreach (var item in sortedModels) { if (selectedModels.Count < modelCountToBeSelected) @@ -113,8 +113,8 @@ public override IList>> Prune(IL return selectedModels; } - public abstract List> CalculateDiversityMeasure(IList>> models, - ConcurrentDictionary>, TOutput[]> predictions); + public abstract List> CalculateDiversityMeasure(IList> models, + ConcurrentDictionary, TOutput[]> predictions); public class ModelDiversityComparer : IComparer> { diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 78995456c8..8d1189d049 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -25,7 +25,7 @@ protected BaseSubModelSelector(IHostEnvironment env, string name) Host = env.Register(name); } - protected void Print(IChannel ch, IList>> models, string metricName) + protected void Print(IChannel ch, IList> models, string metricName) { // REVIEW: The output format was faithfully reproduced from the original format, but it's unclear // to me that this is right. Why have two bars in the header line, but only one bar in the results? @@ -49,7 +49,7 @@ protected void Print(IChannel ch, IList>> Prune(IList>> models) + public virtual IList> Prune(IList> models) { return models; } @@ -69,7 +69,7 @@ private IEvaluator GetEvaluator(IHostEnvironment env) } } - public virtual void CalculateMetrics(FeatureSubsetModel> model, + public virtual void CalculateMetrics(FeatureSubsetModel model, ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics) { if (!needMetrics || model == null || model.Metrics != null) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs index cb56769b79..e2a0147f01 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs @@ -42,8 +42,8 @@ public BestDiverseSelectorBinary(IHostEnvironment env, Arguments args) } - public override List> CalculateDiversityMeasure(IList> models, - ConcurrentDictionary, Single[]> predictions) + public override List> CalculateDiversityMeasure(IList> models, + ConcurrentDictionary, Single[]> predictions) { var diversityMetric = CreateDiversityMetric(); return diversityMetric.CalculateDiversityMeasure(models, predictions); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs index d5e6b9ae91..a25fe4d8d1 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs @@ -43,8 +43,8 @@ public BestDiverseSelectorMultiClass(IHostEnvironment env, Arguments args) protected override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public override List>> CalculateDiversityMeasure(IList> models, - ConcurrentDictionary, VBuffer[]> predictions) + public override List>> CalculateDiversityMeasure(IList>> models, + ConcurrentDictionary>, VBuffer[]> predictions) { Host.Assert(models.Count > 1); Host.Assert(predictions.Count == models.Count); diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs index a276995a9a..ef310eab7e 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorRegression.cs @@ -40,8 +40,8 @@ public BestDiverseSelectorRegression(IHostEnvironment env, Arguments args) { } - public override List> CalculateDiversityMeasure(IList> models, - ConcurrentDictionary, Single[]> predictions) + public override List> CalculateDiversityMeasure(IList> models, + ConcurrentDictionary, Single[]> predictions) { var diversityMetric = CreateDiversityMetric(); return diversityMetric.CalculateDiversityMeasure(models, predictions); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index bd74ae8b27..83637c333f 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -12,7 +12,9 @@ using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Learners; using Microsoft.ML.Trainers.Online; +using Microsoft.ML.Training; [assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, @@ -58,7 +60,11 @@ public Arguments() BasePredictors = new[] { ComponentFactoryUtils.CreateFromFunction( - env => new LinearSvmTrainer(env)) + env => { + var trainerEstimator = new LinearSvmTrainer(env); + return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator); + }) }; } } @@ -81,7 +87,7 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - private protected override TScalarPredictor CreatePredictor(List> models) + private protected override TScalarPredictor CreatePredictor(List> models) { if (models.All(m => m.Predictor is TDistPredictor)) return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels(models), Combiner); @@ -98,12 +104,12 @@ public IPredictor CombineModels(IEnumerable models) { Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models)); return new EnsembleDistributionModelParameters(Host, p.PredictionKind, - models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner); + models.Select(k => new FeatureSubsetModel((TDistPredictor)k)).ToArray(), combiner); } Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models)); return new EnsembleModelParameters(Host, p.PredictionKind, - models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); + models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs index b573795e48..e63f2059e5 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs @@ -22,7 +22,7 @@ namespace Microsoft.ML.Ensemble { using TDistPredictor = IDistPredictorProducing; - public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase, + public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase, TDistPredictor, IValueMapperDist { internal const string UserName = "Ensemble Distribution Executor"; @@ -62,8 +62,8 @@ private static VersionInfo GetVersionInfo() /// Array of sub-models that you want to ensemble together. /// The combiner class to use to ensemble the models. /// The weights assigned to each model to be ensembled. - public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind, - FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) + internal EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind, + FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { PredictionKind = kind; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs index 298dbf163e..032ff39695 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs @@ -19,12 +19,10 @@ namespace Microsoft.ML.Ensemble { - using TScalarPredictor = IPredictorProducing; - /// /// A class for artifacts of ensembled models. /// - public sealed class EnsembleModelParameters : EnsembleModelParametersBase, IValueMapper + public sealed class EnsembleModelParameters : EnsembleModelParametersBase, IValueMapper { internal const string UserName = "Ensemble Executor"; internal const string LoaderSignature = "EnsembleFloatExec"; @@ -58,8 +56,8 @@ private static VersionInfo GetVersionInfo() /// Array of sub-models that you want to ensemble together. /// The combiner class to use to ensemble the models. /// The weights assigned to each model to be ensembled. - public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind, - FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) + internal EnsembleModelParameters(IHostEnvironment env, PredictionKind kind, + FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights = null) : base(env, LoaderSignature, models, combiner, weights) { PredictionKind = kind; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs index f4a1d6d72d..5e79bbe41e 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParametersBase.cs @@ -13,19 +13,18 @@ namespace Microsoft.ML.Ensemble { - public abstract class EnsembleModelParametersBase : ModelParametersBase, + public abstract class EnsembleModelParametersBase : ModelParametersBase, IPredictorProducing, ICanSaveInTextFormat, ICanSaveSummary - where TPredictor : class, IPredictorProducing { private const string SubPredictorFmt = "SubPredictor_{0:000}"; - protected readonly FeatureSubsetModel[] Models; - protected readonly IOutputCombiner Combiner; - protected readonly Single[] Weights; + private protected readonly FeatureSubsetModel[] Models; + private protected readonly IOutputCombiner Combiner; + private protected readonly Single[] Weights; private const uint VerOld = 0x00010002; - internal EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, + private protected EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel[] models, IOutputCombiner combiner, Single[] weights) : base(env, name) { @@ -38,7 +37,7 @@ internal EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureS Weights = weights; } - protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) + private protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name, ctx) { // *** Binary format *** @@ -62,13 +61,13 @@ protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLo Host.CheckDecode(weightCount == 0 || weightCount == count); Weights = ctx.Reader.ReadFloatArray(weightCount); - Models = new FeatureSubsetModel[count]; + Models = new FeatureSubsetModel[count]; var ver = ctx.Header.ModelVerWritten; for (int i = 0; i < count; i++) { ctx.LoadModel(Host, out IPredictor p, string.Format(SubPredictorFmt, i)); - var predictor = p as TPredictor; - Host.Check(p != null, "Inner predictor type not compatible with the ensemble type."); + var predictor = p as IPredictorProducing; + Host.Check(predictor != null, "Inner predictor type not compatible with the ensemble type."); var features = ctx.Reader.ReadBitArray(); int numMetrics = ctx.Reader.ReadInt32(); Host.CheckDecode(numMetrics >= 0); @@ -81,7 +80,7 @@ protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLo ctx.Reader.ReadBoolByte(); metrics[j] = new KeyValuePair(metricName, metricValue); } - Models[i] = new FeatureSubsetModel(predictor, features, metrics); + Models[i] = new FeatureSubsetModel(predictor, features, metrics); } ctx.LoadModel, SignatureLoadModel>(Host, out Combiner, @"Combiner"); } diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 305e05c788..ef7ebc6d4d 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -129,7 +129,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion); var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager; - var models = new List>>(); + var models = new List>(); _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion); int batchNumber = 1; @@ -137,7 +137,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) { // 2. Core train ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++); - var batchModels = new FeatureSubsetModel>[Trainers.Length]; + var batchModels = new FeatureSubsetModel[Trainers.Length]; Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand), new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 }, @@ -149,7 +149,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) { if (EnsureMinimumFeaturesSelected(subset)) { - var model = new FeatureSubsetModel>( + var model = new FeatureSubsetModel( Trainers[(int)index].Train(subset.Data), subset.SelectedFeatures, null); @@ -184,7 +184,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data) return CreatePredictor(models); } - private protected abstract TPredictor CreatePredictor(List>> models); + private protected abstract TPredictor CreatePredictor(List> models); private bool EnsureMinimumFeaturesSelected(Subset subset) { @@ -199,7 +199,7 @@ private bool EnsureMinimumFeaturesSelected(Subset subset) return false; } - private protected virtual void PrintMetrics(IChannel ch, List>> models) + private protected virtual void PrintMetrics(IChannel ch, List> models) { // REVIEW: The formatting of this method is bizarre and seemingly not even self-consistent // w.r.t. its usage of |. Is this intentional? @@ -212,12 +212,12 @@ private protected virtual void PrintMetrics(IChannel ch, List string.Format("| {0} |", m.Value))), model.Predictor.GetType().Name); } - private protected static FeatureSubsetModel[] CreateModels(List>> models) where T : IPredictor + private protected static FeatureSubsetModel[] CreateModels(List> models) where T : IPredictorProducing { - var subsetModels = new FeatureSubsetModel[models.Count]; + var subsetModels = new FeatureSubsetModel[models.Count]; for (int i = 0; i < models.Count; i++) { - subsetModels[i] = new FeatureSubsetModel( + subsetModels[i] = new FeatureSubsetModel( (T)models[i].Predictor, models[i].SelectedFeatures, models[i].Metrics); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs index 98b0857da1..134e58d51d 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassModelParameters.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Ensemble { using TVectorPredictor = IPredictorProducing>; - public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper + public sealed class EnsembleMultiClassModelParameters : EnsembleModelParametersBase>, IValueMapper { internal const string UserName = "Ensemble Multiclass Executor"; internal const string LoaderSignature = "EnsemMcExec"; @@ -51,7 +51,7 @@ private static VersionInfo GetVersionInfo() /// Array of sub-models that you want to ensemble together. /// The combiner class to use to ensemble the models. /// The weights assigned to each model to be ensembled. - public EnsembleMultiClassModelParameters(IHostEnvironment env, FeatureSubsetModel[] models, + internal EnsembleMultiClassModelParameters(IHostEnvironment env, FeatureSubsetModel>[] models, IMultiClassOutputCombiner combiner, Single[] weights = null) : base(env, RegistrationName, models, combiner, weights) { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 7d13483f4a..d194dfc947 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Learners; +using Microsoft.ML.Training; [assembly: LoadableClass(MulticlassDataPartitionEnsembleTrainer.Summary, typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments), @@ -51,7 +52,7 @@ public sealed class Arguments : ArgumentsBase // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] - public IComponentFactory>[] BasePredictors; + internal IComponentFactory>[] BasePredictors; internal override IComponentFactory>[] GetPredictorFactories() => BasePredictors; @@ -60,7 +61,17 @@ public Arguments() BasePredictors = new[] { ComponentFactoryUtils.CreateFromFunction( - env => new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn)) + env => { + // Note that this illustrates a fundamnetal problem with the mixture of `ITrainer` and `ITrainerEstimator` + // present in this class. The options to the estimator have no way of being communicated to the `ITrainer` + // implementation, so there is a fundamnetal disconnect if someone chooses to ever use the *estimator* with + // non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any + // less laborious than the proper fix, which is that this "meta" component should itself be a trainer + // estimator, as opposed to a regular trainer. + var trainerEstimator = new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn); + return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator); + }) }; } } @@ -83,7 +94,7 @@ private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments a public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - private protected override EnsembleMultiClassModelParameters CreatePredictor(List> models) + private protected override EnsembleMultiClassModelParameters CreatePredictor(List>> models) { return new EnsembleMultiClassModelParameters(Host, CreateModels(models), Combiner as IMultiClassOutputCombiner); } @@ -95,7 +106,7 @@ public IPredictor CombineModels(IEnumerable models) var combiner = _outputCombiner.CreateComponent(Host); var predictor = new EnsembleMultiClassModelParameters(Host, - models.Select(k => new FeatureSubsetModel((TVectorPredictor)k)).ToArray(), + models.Select(k => new FeatureSubsetModel>((TVectorPredictor)k)).ToArray(), combiner); return predictor; } diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index 4545ce3c34..569d4cd829 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -13,7 +13,9 @@ using Microsoft.ML.Ensemble.OutputCombiners; using Microsoft.ML.Ensemble.Selector; using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Learners; using Microsoft.ML.Trainers.Online; +using Microsoft.ML.Training; [assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) }, @@ -54,7 +56,11 @@ public Arguments() BasePredictors = new[] { ComponentFactoryUtils.CreateFromFunction( - env => new OnlineGradientDescentTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features)) + env => { + var trainerEstimator = new OnlineGradientDescentTrainer(env); + return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator); + }) }; } } @@ -77,7 +83,7 @@ private RegressionEnsembleTrainer(IHostEnvironment env, Arguments args, Predicti public override PredictionKind PredictionKind => PredictionKind.Regression; - private protected override TScalarPredictor CreatePredictor(List> models) + private protected override TScalarPredictor CreatePredictor(List> models) { return new EnsembleModelParameters(Host, PredictionKind, CreateModels(models), Combiner); } @@ -91,7 +97,7 @@ public IPredictor CombineModels(IEnumerable models) var p = models.First(); var predictor = new EnsembleModelParameters(Host, p.PredictionKind, - models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); + models.Select(k => new FeatureSubsetModel((TScalarPredictor)k)).ToArray(), combiner); return predictor; } diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 111b7150e7..e84ec0117b 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Trainers.FastTree public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase where TTransformer : ISingleFeaturePredictionTransformer where TArgs : BoostedTreeArgs, new() - where TModel : IPredictorProducing + where TModel : class { protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(env, args, label) { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index df61a37420..eb840cc068 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -52,7 +52,7 @@ public abstract class FastTreeTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() - where TModel : IPredictorProducing + where TModel : class { protected readonly TArgs Args; protected readonly bool AllowGC; diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index f12679033f..e8a85550ad 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -31,7 +31,7 @@ namespace Microsoft.ML.Trainers.FastTree { public sealed class BinaryClassificationGamTrainer : - GamTrainerBase>, IPredictorProducing> + GamTrainerBase, CalibratedPredictorBase> { public sealed class Options : ArgumentsBase { @@ -104,7 +104,7 @@ private static bool[] ConvertTargetsToBool(double[] targets) return boolArray; } - private protected override IPredictorProducing TrainModelCore(TrainContext context) + private protected override CalibratedPredictorBase TrainModelCore(TrainContext context) { TrainBase(context); var predictor = new BinaryClassificationGamModelParameters(Host, @@ -139,10 +139,10 @@ protected override void DefinePruningTest() PruningTest = new TestHistory(validTest, PruningLossIndex); } - protected override BinaryPredictionTransformer> MakeTransformer(IPredictorProducing model, Schema trainSchema) - => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); + protected override BinaryPredictionTransformer MakeTransformer(CalibratedPredictorBase model, Schema trainSchema) + => new BinaryPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public BinaryPredictionTransformer> Train(IDataView trainData, IDataView validationData = null) + public BinaryPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index cdfc4a8b5b..87bcc370ba 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -56,7 +56,7 @@ namespace Microsoft.ML.Trainers.FastTree public abstract partial class GamTrainerBase : TrainerEstimatorBase where TTransformer: ISingleFeaturePredictionTransformer where TArgs : GamTrainerBase.ArgumentsBase, new() - where TPredictor : IPredictorProducing + where TPredictor : class { public abstract class ArgumentsBase : LearnerInputBaseWithWeight { diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 400e8241b0..943b10f621 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML.Trainers.FastTree { public abstract class RandomForestTrainerBase : FastTreeTrainerBase where TArgs : FastForestArgumentsBase, new() - where TModel : IPredictorProducing + where TModel : class where TTransformer: ISingleFeaturePredictionTransformer { private readonly bool _quantileEnabled; diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 44de6c9f9f..58cc2fb684 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -41,7 +41,7 @@ namespace Microsoft.ML.Data /// 2. An indicator vector for the leaves that the feature vector falls on in the tree ensemble. /// 3. An indicator vector for the internal nodes on the paths that the feature vector falls on in the tree ensemble. /// - public sealed class TreeEnsembleFeaturizerBindableMapper : ISchemaBindableMapper, ICanSaveModel + internal sealed class TreeEnsembleFeaturizerBindableMapper : ISchemaBindableMapper, ICanSaveModel { public static class OutputColumnNames { diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index c9c78d06c9..18a5d97b97 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -35,7 +35,6 @@ namespace Microsoft.ML.Trainers.HalLearners { /// - [BestFriend] public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase, OlsLinearRegressionModelParameters> { public sealed class Options : LearnerInputBaseWithWeight diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index 45d732048b..af5ecd4a11 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -34,7 +34,6 @@ namespace Microsoft.ML.Trainers.SymSgd using TPredictor = IPredictorWithFeatureWeights; /// - [BestFriend] public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase, TPredictor> { internal const string LoadNameValue = "SymbolicSGD"; diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index a8e84e3372..1ce9b0ccf4 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -29,7 +29,7 @@ internal static class LightGbmShared /// public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictorProducing + where TModel : class { private sealed class CategoricalMetaData { diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index e40740e2e0..438380ae98 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -398,7 +398,7 @@ public sealed class MatrixFactorizationPredictionTransformer : PredictionTransfo /// A string attached to the output column name of this transformer public MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema, string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "") - :base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema) { Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName)); Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName)); @@ -434,7 +434,7 @@ private RoleMappedSchema GetSchema() /// the original transform is saved. /// public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) - :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx) + : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx) { // *** Binary format *** // diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 8a4238a378..c6c6602c52 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Learners { public abstract class LbfgsTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class where TArgs : LbfgsTrainerBase.ArgumentsBase, new () { public abstract class ArgumentsBase : LearnerInputBaseWithWeight diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index eac950e68a..fdd32f2f20 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -18,15 +18,15 @@ namespace Microsoft.ML.Learners { using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - public abstract class MetaMulticlassTrainer : ITrainerEstimator, ITrainer + public abstract class MetaMulticlassTrainer : ITrainerEstimator, ITrainer where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { public abstract class ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))] [TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")] - public IComponentFactory PredictorType; + internal IComponentFactory PredictorType; [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "", SignatureType = typeof(SignatureCalibrator))] public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); @@ -43,19 +43,19 @@ public abstract class ArgumentsBase /// public readonly SchemaShape.Column LabelColumn; - protected readonly ArgumentsBase Args; - protected readonly IHost Host; - protected readonly ICalibratorTrainer Calibrator; - protected readonly TScalarTrainer Trainer; + private protected readonly ArgumentsBase Args; + private protected readonly IHost Host; + private protected readonly ICalibratorTrainer Calibrator; + private protected readonly TScalarTrainer Trainer; public PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - protected SchemaShape.Column[] OutputColumns; + private protected SchemaShape.Column[] OutputColumns; public TrainerInfo Info { get; } /// - /// Initializes the from the Arguments class. + /// Initializes the from the class. /// /// The private instance of the . /// The legacy arguments class. @@ -122,7 +122,7 @@ private protected IDataView MapLabelsCore(ColumnType type, InPredicate equ /// /// The trainig context for this learner. /// The trained model. - TModel ITrainer.Train(TrainContext context) + IPredictor ITrainer.Train(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; @@ -135,7 +135,7 @@ TModel ITrainer.Train(TrainContext context) using (var ch = Host.Start("Training")) { - var pred = TrainCore(ch, data, count); + var pred = TrainCore(ch, data, count) as IPredictor; ch.Check(pred != null, "Training did not result in a predictor"); return pred; } @@ -203,7 +203,7 @@ private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) return cols; } - IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); + IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); /// /// Fits the data to the trainer. diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index f1b5543142..cf7a083985 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -84,7 +84,7 @@ internal Ova(IHostEnvironment env, Arguments args) /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - public Ova(IHostEnvironment env, + internal Ova(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -163,7 +163,7 @@ private IDataView MapLabels(RoleMappedData data, int cls) } if (lab.Type == NumberType.R8) { - Double key = cls; + double key = cls; return MapLabelsCore(NumberType.R8, (in double val) => key == val, data); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index c2e999322e..5e3dd33f8b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -88,7 +88,7 @@ internal Pkpd(IHostEnvironment env, Arguments args) /// The name of the label colum. /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Number of instances to train the calibrator. - public Pkpd(IHostEnvironment env, + internal Pkpd(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index abbfe31549..bf7a88d27d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -64,7 +64,7 @@ internal class AveragedDefaultArgs : OnlineDefaultArgs public abstract class AveragedLinearTrainer : OnlineLinearTrainer where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { protected readonly new AveragedLinearArguments Args; protected IScalarOutputLoss LossFunction; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 5b7e4cbf94..2128c8e008 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -50,7 +50,7 @@ internal class OnlineDefaultArgs public abstract class OnlineLinearTrainer : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { protected readonly OnlineLinearArguments Args; protected readonly string Name; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 73c6191b9c..9237f79abe 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -49,7 +49,7 @@ namespace Microsoft.ML.Trainers public abstract class LinearTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { private const string RegisterName = nameof(LinearTrainerBase); @@ -153,7 +153,7 @@ private protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorF public abstract class SdcaTrainerBase : StochasticTrainerBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class where TArgs : SdcaTrainerBase.ArgumentsBase, new() { // REVIEW: Making it even faster and more accurate: diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index 309dd2a433..550939892f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Learners { public abstract class StochasticTrainerBase : TrainerEstimatorBase where TTransformer : ISingleFeaturePredictionTransformer - where TModel : IPredictor + where TModel : class { public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default) : base(host, feature, label, weight) diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs index 5cc4d87846..6442e4eb01 100644 --- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs @@ -14,6 +14,7 @@ namespace Microsoft.ML { using LROptions = LogisticRegression.Options; using SgdOptions = StochasticGradientDescentClassificationTrainer.Options; + using TLegacyPredictor = IPredictorProducing; /// /// TrainerEstimator extension methods. @@ -454,16 +455,21 @@ public static MultiClassNaiveBayesTrainer NaiveBayes(this MulticlassClassificati /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - public static Ova OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - ITrainerEstimator>, IPredictorProducing> binaryEstimator, + /// The type of the model. This type parameter will usually be inferred automatically from . + public static Ova OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModel> binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000, bool useProbabilities = true) + where TModel : class { Contracts.CheckValue(catalog, nameof(catalog)); - return new Ova(CatalogUtils.GetEnvironment(catalog), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples, useProbabilities); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new Ova(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples, useProbabilities); } /// @@ -482,15 +488,20 @@ public static Ova OneVersusAll(this MulticlassClassificationCatalog.MulticlassCl /// The name of the label colum. /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Number of instances to train the calibrator. - public static Pkpd PairwiseCoupling(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - ITrainerEstimator>, IPredictorProducing> binaryEstimator, + /// The type of the model. This type parameter will usually be inferred automatically from . + public static Pkpd PairwiseCoupling(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModel> binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, - int maxCalibrationExamples = 1000000000) + int maxCalibrationExamples = 1_000_000_000) + where TModel : class { Contracts.CheckValue(catalog, nameof(catalog)); - return new Pkpd(CatalogUtils.GetEnvironment(catalog), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new Pkpd(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples); } /// diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index 628112ccff..6c1738b297 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -130,7 +130,6 @@ private FastForestRegressionModelParameters FitModel(IEnumerable pre IDataView view = dvBuilder.GetDataView(); _host.Assert(view.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations"); - RoleMappedData data = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features); using (IChannel ch = _host.Start("Single training")) { @@ -142,11 +141,13 @@ private FastForestRegressionModelParameters FitModel(IEnumerable pre FeatureFraction = _args.SplitRatio, NumTrees = _args.NumOfTrees, MinDocumentsInLeafs = _args.NMinForSplit, + LabelColumn = DefaultColumnNames.Label, + FeatureColumn = DefaultColumnNames.Features, }); - var predictor = trainer.Train(data); + var predictor = trainer.Train(view); // Return random forest predictor. - return predictor; + return predictor.Model; } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 2e70de1441..0960e32da4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -662,7 +662,7 @@ private void MixMatch(string dataPath) var binaryTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "Features"); // Append the OVA learner to the pipeline. - dynamicPipe = dynamicPipe.Append(new Ova(mlContext, binaryTrainer)); + dynamicPipe = dynamicPipe.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer)); // At this point, we have a choice. We could continue working with the dynamically-typed pipeline, and // ultimately call dynamicPipe.Fit(data.AsDynamic) to get the model, or we could go back into the static world. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index b3c33095d0..1782706b6c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -28,9 +28,9 @@ public void Metacomponents() new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, }); var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) - .Append(new Ova(ml, sdcaTrainer)) - .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); + .Append(ml.Transforms.Conversion.MapValueToKey("Label"), TransformerScope.TrainTest) + .Append(ml.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer)) + .Append(ml.Transforms.Conversion.MapKeyToValue(("PredictedLabel"))); var model = pipeline.Fit(data); } diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 47cab13ce5..ecea5411a6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -34,10 +34,8 @@ public void OvaLogisticRegression() var data = reader.Read(GetDataPath(dataPath)); // Pipeline - var pipeline = new Ova( - mlContext, - mlContext.BinaryClassification.Trainers.LogisticRegression(), - useProbabilities: false); + var logReg = mlContext.BinaryClassification.Trainers.LogisticRegression(); + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(logReg, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); @@ -68,11 +66,9 @@ public void OvaAveragedPerceptron() var data = mlContext.Data.Cache(reader.Read(GetDataPath(dataPath))); // Pipeline - var pipeline = new Ova( - mlContext, - mlContext.BinaryClassification.Trainers.AveragedPerceptron( - new AveragedPerceptronTrainer.Options { Shuffle = true, Calibrator = null }), - useProbabilities: false); + var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true, Calibrator = null }); + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); @@ -103,8 +99,7 @@ public void OvaFastTree() var data = reader.Read(GetDataPath(dataPath)); // Pipeline - var pipeline = new Ova( - mlContext, + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }), useProbabilities: false); @@ -137,7 +132,7 @@ public void OvaLinearSvm() var data = mlContext.Data.Cache(reader.Read(GetDataPath(dataPath))); // Pipeline - var pipeline = new Ova(mlContext, + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.LinearSupportVectorMachines(new LinearSvmTrainer.Options { NumIterations = 100 }), useProbabilities: false); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 86b1ce026c..7e0a34b1b6 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -26,7 +26,10 @@ public void OVAWithAllConstructorArgs() var averagePerceptron = ML.BinaryClassification.Trainers.AveragedPerceptron( new AveragedPerceptronTrainer.Options { Shuffle = true, Calibrator = null }); - pipeline = pipeline.Append(new Ova(Env, averagePerceptron, "Label", true, calibrator: calibrator, 10000, true)) + var ova = ML.MulticlassClassification.Trainers.OneVersusAll(averagePerceptron, imputeMissingLabelsAsNegative: true, + calibrator: calibrator, maxCalibrationExamples: 10000, useProbabilities: true); + + pipeline = pipeline.Append(ova) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); TestEstimatorCore(pipeline, data); @@ -43,7 +46,7 @@ public void OVAUncalibrated() var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, Calibrator = null }); - pipeline = pipeline.Append(new Ova(Env, sdcaTrainer, useProbabilities: false)) + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer, useProbabilities: false)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); TestEstimatorCore(pipeline, data); @@ -61,8 +64,8 @@ public void Pkpd() var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); - pipeline = pipeline.Append(new Pkpd(Env, sdcaTrainer)) - .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.PairwiseCoupling(sdcaTrainer)) + .Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel")); TestEstimatorCore(pipeline, data); Done(); @@ -84,7 +87,7 @@ public void MetacomponentsFeaturesRenamed() var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - .Append(new Ova(Env, sdcaTrainer)) + .Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data);