Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 0 additions & 109 deletions test/Microsoft.ML.Benchmarks/LegacyPredictionEngineBench.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Engines;
using Microsoft.ML.Data;
using Microsoft.ML.Legacy.Models;
using Microsoft.ML.Legacy.Trainers;
using Microsoft.ML.Legacy.Transforms;
using Microsoft.ML.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Text;

namespace Microsoft.ML.Benchmarks
{
#pragma warning disable 612, 618
public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics
{
private readonly string _dataPath = Program.GetInvariantCultureDataPath("iris.txt");
private readonly string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");
private readonly Consumer _consumer = new Consumer(); // BenchmarkDotNet utility type used to prevent dead code elimination

private readonly MLContext _env = new MLContext(seed: 1);

private readonly int[] _batchSizes = new int[] { 1, 2, 5 };

private readonly IrisData _example = new IrisData()
{
SepalLength = 3.3f,
Expand All @@ -31,37 +32,47 @@ public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics
PetalWidth = 5.1f,
};

private Legacy.PredictionModel<IrisData, IrisPrediction> _trainedModel;
private TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> _trainedModel;
private PredictionEngine<IrisData, IrisPrediction> _predictionEngine;
private IrisData[][] _batches;
private ClassificationMetrics _metrics;
private MultiClassClassifierMetrics _metrics;

protected override IEnumerable<Metric> GetMetrics()
{
if (_metrics != null)
yield return new Metric(
nameof(ClassificationMetrics.AccuracyMacro),
nameof(MultiClassClassifierMetrics.AccuracyMacro),
_metrics.AccuracyMacro.ToString("0.##", CultureInfo.InvariantCulture));
}

[Benchmark]
public Legacy.PredictionModel<IrisData, IrisPrediction> TrainIris() => Train(_dataPath);
public TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> TrainIris() => Train(_dataPath);

private Legacy.PredictionModel<IrisData, IrisPrediction> Train(string dataPath)
private TransformerChain<MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>> Train(string dataPath)
{
var pipeline = new Legacy.LearningPipeline();
var reader = new TextLoader(_env,
columns: new[]
{
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("SepalLength", DataKind.R4, 1),
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
new TextLoader.Column("PetalLength", DataKind.R4, 3),
new TextLoader.Column("PetalWidth", DataKind.R4, 4),
},
hasHeader: true
);

pipeline.Add(new Legacy.Data.TextLoader(dataPath).CreateFrom<IrisData>(useHeader: true));
pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
IDataView data = reader.Read(dataPath);

pipeline.Add(new StochasticDualCoordinateAscentClassifier());
var pipeline = new ColumnConcatenatingEstimator(_env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
.Append(new SdcaMultiClassTrainer(_env, "Label", "Features"));

return pipeline.Train<IrisData, IrisPrediction>();
return pipeline.Fit(data);
}

[Benchmark]
public void TrainSentiment()
{
var env = new MLContext(seed: 1);
// Pipeline
var arguments = new TextLoader.Arguments()
{
Expand All @@ -85,9 +96,9 @@ public void TrainSentiment()
AllowQuoting = false,
AllowSparse = false
};
var loader = env.Data.ReadFromTextFile(_sentimentDataPath, arguments);
var loader = _env.Data.ReadFromTextFile(_sentimentDataPath, arguments);

var text = TextFeaturizingEstimator.Create(env,
var text = TextFeaturizingEstimator.Create(_env,
new TextFeaturizingEstimator.Arguments()
{
Column = new TextFeaturizingEstimator.Column
Expand All @@ -103,7 +114,7 @@ public void TrainSentiment()
WordFeatureExtractor = null,
}, loader);

var trans = WordEmbeddingsExtractingTransformer.Create(env,
var trans = WordEmbeddingsExtractingTransformer.Create(_env,
new WordEmbeddingsExtractingTransformer.Arguments()
{
Column = new WordEmbeddingsExtractingTransformer.Column[1]
Expand All @@ -118,7 +129,7 @@ public void TrainSentiment()
}, text);

// Train
var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20);
var trainer = new SdcaMultiClassTrainer(_env, "Label", "Features", maxIterations: 20);
var predicted = trainer.Fit(trans);
_consumer.Consume(predicted);
}
Expand All @@ -127,41 +138,49 @@ public void TrainSentiment()
public void SetupPredictBenchmarks()
{
_trainedModel = Train(_dataPath);
_consumer.Consume(_trainedModel.Predict(_example));
_predictionEngine = _trainedModel.CreatePredictionEngine<IrisData, IrisPrediction>(_env);
_consumer.Consume(_predictionEngine.Predict(_example));

var reader = new TextLoader(_env,
columns: new[]
{
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("SepalLength", DataKind.R4, 1),
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
new TextLoader.Column("PetalLength", DataKind.R4, 3),
new TextLoader.Column("PetalWidth", DataKind.R4, 4),
},
hasHeader: true
);

var testData = new Legacy.Data.TextLoader(_dataPath).CreateFrom<IrisData>(useHeader: true);
var evaluator = new ClassificationEvaluator();
_metrics = evaluator.Evaluate(_trainedModel, testData);
IDataView testData = reader.Read(_dataPath);
IDataView scoredTestData = _trainedModel.Transform(testData);
var evaluator = new MultiClassClassifierEvaluator(_env, new MultiClassClassifierEvaluator.Arguments());
_metrics = evaluator.Evaluate(scoredTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel);

_batches = new IrisData[_batchSizes.Length][];
for (int i = 0; i < _batches.Length; i++)
{
var batch = new IrisData[_batchSizes[i]];
_batches[i] = batch;
for (int bi = 0; bi < batch.Length; bi++)
{
batch[bi] = _example;
}
_batches[i] = batch;
}
}

[Benchmark]
public float[] PredictIris() => _trainedModel.Predict(_example).PredictedLabels;
public float[] PredictIris() => _predictionEngine.Predict(_example).PredictedLabels;

[Benchmark]
public void PredictIrisBatchOf1() => Consume(_trainedModel.Predict(_batches[0]));
public void PredictIrisBatchOf1() => _trainedModel.Transform(_env.CreateStreamingDataView(_batches[0]));

[Benchmark]
public void PredictIrisBatchOf2() => Consume(_trainedModel.Predict(_batches[1]));
public void PredictIrisBatchOf2() => _trainedModel.Transform(_env.CreateStreamingDataView(_batches[1]));

[Benchmark]
public void PredictIrisBatchOf5() => Consume(_trainedModel.Predict(_batches[2]));

private void Consume(IEnumerable<IrisPrediction> predictions)
{
foreach (var prediction in predictions)
_consumer.Consume(prediction);
}
public void PredictIrisBatchOf5() => _trainedModel.Transform(_env.CreateStreamingDataView(_batches[2]));
}

public class IrisData
Expand All @@ -187,5 +206,4 @@ public class IrisPrediction
[ColumnName("Score")]
public float[] PredictedLabels;
}
#pragma warning restore 612, 618
}