/
Ensemble.cs
54 lines (46 loc) · 3.02 KB
/
Ensemble.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(void), typeof(Ensemble), null, typeof(SignatureEntryPointModule), "TrainEnsemble")]
namespace Microsoft.ML.Trainers.Ensemble
{
internal static class Ensemble
{
[TlcModule.EntryPoint(Name = "Trainers.EnsembleBinaryClassifier", Desc = "Train binary ensemble.", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainBinaryEnsemble");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return LearnerEntryPointsUtils.Train<EnsembleTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new EnsembleTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
[TlcModule.EntryPoint(Name = "Trainers.EnsembleClassification", Desc = "Train multiclass ensemble.", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassEnsemble(IHostEnvironment env, MulticlassDataPartitionEnsembleTrainer.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainMultiClassEnsemble");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return LearnerEntryPointsUtils.Train<MulticlassDataPartitionEnsembleTrainer.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
() => new MulticlassDataPartitionEnsembleTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
[TlcModule.EntryPoint(Name = "Trainers.EnsembleRegression", Desc = "Train regression ensemble.", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionEnsembleTrainer.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainRegressionEnsemble");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return LearnerEntryPointsUtils.Train<RegressionEnsembleTrainer.Arguments, CommonOutputs.RegressionOutput>(host, input,
() => new RegressionEnsembleTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
}
}