/
CreateEnsemble.cs
403 lines (359 loc) · 19.1 KB
/
CreateEnsemble.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")]
namespace Microsoft.ML.Trainers.Ensemble
{
/// <summary>
/// A component to combine given models into an ensemble model.
/// </summary>
internal static class EnsembleCreator
{
/// <summary>
/// These are the combiner options for binary and multi class classifiers.
/// </summary>
public enum ClassifierCombiner
{
Median,
Average,
Vote,
}
/// <summary>
/// These are the combiner options for regression and anomaly detection.
/// </summary>
public enum ScoreCombiner
{
Median,
Average,
}
public abstract class PipelineInputBase
{
[Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
public PredictorModel[] Models;
}
public abstract class InputBase
{
[Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
public PredictorModel[] Models;
[Argument(ArgumentType.AtMostOnce, ShortName = "validate", HelpText = "Whether to validate that all the pipelines are identical", SortOrder = 5)]
public bool ValidatePipelines = true;
}
public sealed class ClassifierInput : InputBase
{
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
}
public sealed class PipelineClassifierInput : PipelineInputBase
{
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
}
public sealed class RegressionInput : InputBase
{
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
}
public sealed class PipelineRegressionInput : PipelineInputBase
{
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
}
public sealed class PipelineAnomalyInput : PipelineInputBase
{
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Average;
}
private static void GetPipeline(IHostEnvironment env, InputBase input, out IDataView startingData, out RoleMappedData transformedData)
{
Contracts.AssertValue(env);
env.AssertValue(input);
env.AssertNonEmpty(input.Models);
Schema inputSchema = null;
startingData = null;
transformedData = null;
byte[][] transformedDataSerialized = null;
string[] transformedDataZipEntryNames = null;
for (int i = 0; i < input.Models.Length; i++)
{
var model = input.Models[i];
var inputData = new EmptyDataView(env, model.TransformModel.InputSchema);
model.PrepareData(env, inputData, out RoleMappedData transformedDataCur, out IPredictor pred);
if (inputSchema == null)
{
env.Assert(i == 0);
inputSchema = model.TransformModel.InputSchema;
startingData = inputData;
transformedData = transformedDataCur;
}
else if (input.ValidatePipelines)
{
using (var ch = env.Start("Validating pipeline"))
{
if (transformedDataSerialized == null)
{
ch.Assert(transformedDataZipEntryNames == null);
SerializeRoleMappedData(env, ch, transformedData, out transformedDataSerialized,
out transformedDataZipEntryNames);
}
CheckSamePipeline(env, ch, transformedDataCur, transformedDataSerialized, transformedDataZipEntryNames);
}
}
}
}
[TlcModule.EntryPoint(Name = "Models.BinaryEnsemble", Desc = "Combine binary classifiers into an ensemble", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
var args = new EnsembleTrainer.Arguments();
switch (input.ModelCombiner)
{
case ClassifierCombiner.Median:
args.OutputCombiner = new MedianFactory();
break;
case ClassifierCombiner.Average:
args.OutputCombiner = new AverageFactory();
break;
case ClassifierCombiner.Vote:
args.OutputCombiner = new VotingFactory();
break;
default:
throw host.Except("Unknown combiner kind");
}
var trainer = new EnsembleTrainer(host, args);
var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing<float>));
var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);
var output = new CommonOutputs.BinaryClassificationOutput { PredictorModel = predictorModel };
return output;
}
[TlcModule.EntryPoint(Name = "Models.RegressionEnsemble", Desc = "Combine regression models into an ensemble", UserName = RegressionEnsembleTrainer.UserNameValue)]
public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
var args = new RegressionEnsembleTrainer.Arguments();
switch (input.ModelCombiner)
{
case ScoreCombiner.Median:
args.OutputCombiner = new MedianFactory();
break;
case ScoreCombiner.Average:
args.OutputCombiner = new AverageFactory();
break;
default:
throw host.Except("Unknown combiner kind");
}
var trainer = new RegressionEnsembleTrainer(host, args);
var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing<float>));
var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);
var output = new CommonOutputs.RegressionOutput { PredictorModel = predictorModel };
return output;
}
[TlcModule.EntryPoint(Name = "Models.BinaryPipelineEnsemble", Desc = "Combine binary classification models into an ensemble")]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IBinaryOutputCombiner combiner;
switch (input.ModelCombiner)
{
case ClassifierCombiner.Median:
combiner = new Median(host);
break;
case ClassifierCombiner.Average:
combiner = new Average(host);
break;
case ClassifierCombiner.Vote:
combiner = new Voting(host);
break;
default:
throw host.Except("Unknown combiner kind");
}
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
return CreatePipelineEnsemble<CommonOutputs.BinaryClassificationOutput>(host, input.Models, ensemble);
}
[TlcModule.EntryPoint(Name = "Models.RegressionPipelineEnsemble", Desc = "Combine regression models into an ensemble")]
public static CommonOutputs.RegressionOutput CreateRegressionPipelineEnsemble(IHostEnvironment env, PipelineRegressionInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IRegressionOutputCombiner combiner;
switch (input.ModelCombiner)
{
case ScoreCombiner.Median:
combiner = new Median(host);
break;
case ScoreCombiner.Average:
combiner = new Average(host);
break;
default:
throw host.Except("Unknown combiner kind");
}
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.Regression);
return CreatePipelineEnsemble<CommonOutputs.RegressionOutput>(host, input.Models, ensemble);
}
[TlcModule.EntryPoint(Name = "Models.MultiClassPipelineEnsemble", Desc = "Combine multiclass classifiers into an ensemble")]
public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IOutputCombiner<VBuffer<Single>> combiner;
switch (input.ModelCombiner)
{
case ClassifierCombiner.Median:
combiner = new MultiMedian(host, new MultiMedian.Arguments() { Normalize = true });
break;
case ClassifierCombiner.Average:
combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true });
break;
case ClassifierCombiner.Vote:
combiner = new MultiVoting(host);
break;
default:
throw host.Except("Unknown combiner kind");
}
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification);
return CreatePipelineEnsemble<CommonOutputs.MulticlassClassificationOutput>(host, input.Models, ensemble);
}
[TlcModule.EntryPoint(Name = "Models.AnomalyPipelineEnsemble", Desc = "Combine anomaly detection models into an ensemble")]
public static CommonOutputs.AnomalyDetectionOutput CreateAnomalyPipelineEnsemble(IHostEnvironment env, PipelineAnomalyInput input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IRegressionOutputCombiner combiner;
switch (input.ModelCombiner)
{
case ScoreCombiner.Median:
combiner = new Median(host);
break;
case ScoreCombiner.Average:
combiner = new Average(host);
break;
default:
throw host.Except("Unknown combiner kind");
}
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, MetadataUtils.Const.ScoreColumnKind.AnomalyDetection);
return CreatePipelineEnsemble<CommonOutputs.AnomalyDetectionOutput>(host, input.Models, ensemble);
}
private static TOut CreatePipelineEnsemble<TOut>(IHostEnvironment env, PredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble)
where TOut : CommonOutputs.TrainerOutput, new()
{
var inputSchema = predictors[0].TransformModel.InputSchema;
var dv = new EmptyDataView(env, inputSchema);
// The role mappings are specific to the individual predictors.
var rmd = new RoleMappedData(dv);
var predictorModel = new PredictorModelImpl(env, rmd, dv, ensemble);
var output = new TOut { PredictorModel = predictorModel };
return output;
}
/// <summary>
/// This method takes a <see cref="RoleMappedData"/> as input, saves it as an in-memory <see cref="ZipArchive"/>
/// and returns two arrays indexed by the entries in the zip:
/// 1. An array of byte arrays, containing the byte sequences of each entry.
/// 2. An array of strings, containing the name of each entry.
///
/// This method is used for comparing pipelines. Its outputs can be passed to <see cref="CheckSamePipeline"/>
/// to check if this pipeline is identical to another pipeline.
/// </summary>
internal static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, RoleMappedData data,
out byte[][] dataSerialized, out string[] dataZipEntryNames)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(data, nameof(data));
using (var ms = new MemoryStream())
{
TrainUtils.SaveModel(env, ch, ms, null, data);
var zip = new ZipArchive(ms);
var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
dataSerialized = new byte[Utils.Size(entries)][];
dataZipEntryNames = new string[Utils.Size(entries)];
for (int i = 0; i < Utils.Size(entries); i++)
{
dataZipEntryNames[i] = entries[i].FullName;
dataSerialized[i] = new byte[entries[i].Length];
using (var s = entries[i].Open())
s.Read(dataSerialized[i], 0, (int)entries[i].Length);
}
}
}
/// <summary>
/// This method compares two pipelines to make sure they are identical. The first pipeline is passed
/// as a <see cref="RoleMappedData"/>, and the second as a double byte array and a string array. The double
/// byte array and the string array are obtained by calling <see cref="SerializeRoleMappedData"/> on the
/// second pipeline.
/// The comparison is done by saving <see ref="dataToCompare"/> as an in-memory <see cref="ZipArchive"/>,
/// and for each entry in it, comparing its name, and the byte sequence to the corresponding entries in
/// <see ref="dataZipEntryNames"/> and <see ref="dataSerialized"/>.
/// This method throws if for any of the entries the name/byte sequence are not identical.
/// </summary>
internal static void CheckSamePipeline(IHostEnvironment env, IChannel ch,
RoleMappedData dataToCompare, byte[][] dataSerialized, string[] dataZipEntryNames)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(dataToCompare, nameof(dataToCompare));
ch.CheckValue(dataSerialized, nameof(dataSerialized));
ch.CheckValue(dataZipEntryNames, nameof(dataZipEntryNames));
if (dataZipEntryNames.Length != dataSerialized.Length)
{
throw ch.ExceptParam(nameof(dataSerialized),
$"The length of {nameof(dataSerialized)} must be equal to the length of {nameof(dataZipEntryNames)}");
}
using (var ms = new MemoryStream())
{
// REVIEW: This can be done more efficiently by adding a custom type of repository that
// doesn't actually save the data, but upon stream closure compares the results to the given repository
// and then discards it. Currently, however, this cannot be done because ModelSaveContext does not use
// an abstract class/interface, but rather the RepositoryWriter class.
TrainUtils.SaveModel(env, ch, ms, null, dataToCompare);
string errorMsg = "Models contain different pipelines, cannot ensemble them.";
var zip = new ZipArchive(ms);
var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
ch.Check(dataSerialized.Length == Utils.Size(entries));
byte[] buffer = null;
for (int i = 0; i < dataSerialized.Length; i++)
{
ch.Check(dataZipEntryNames[i] == entries[i].FullName, errorMsg);
int len = dataSerialized[i].Length;
if (Utils.Size(buffer) < len)
buffer = new byte[len];
using (var s = entries[i].Open())
{
int bytesRead = s.Read(buffer, 0, len);
ch.Check(bytesRead == len, errorMsg);
for (int j = 0; j < len; j++)
ch.Check(buffer[j] == dataSerialized[i][j], errorMsg);
if (s.Read(buffer, 0, 1) > 0)
throw env.Except(errorMsg);
}
}
}
}
}
}