Skip to content

Commit 6266a48

Browse files
committed
add sample
1 parent 4bc1ca0 commit 6266a48

File tree

6 files changed

+398
-1
lines changed

6 files changed

+398
-1
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic.Trainers.BinaryClassification
8+
{
9+
public static class LdSvm
10+
{
11+
public static void Example()
12+
{
13+
// Create a new context for ML.NET operations. It can be used for
14+
// exception tracking and logging, as a catalog of available operations
15+
// and as the source of randomness. Setting the seed to a fixed number
16+
// in this example to make outputs deterministic.
17+
var mlContext = new MLContext(seed: 0);
18+
19+
// Create a list of training data points.
20+
var dataPoints = GenerateRandomDataPoints(1000);
21+
22+
// Convert the list of data points to an IDataView object, which is
23+
// consumable by ML.NET API.
24+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
25+
26+
// Define the trainer.
27+
var pipeline = mlContext.BinaryClassification.Trainers
28+
.LdSvm();
29+
30+
// Train the model.
31+
var model = pipeline.Fit(trainingData);
32+
33+
// Create testing data. Use different random seed to make it different
34+
// from training data.
35+
var testData = mlContext.Data
36+
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
37+
38+
// Run the model on test data set.
39+
var transformedTestData = model.Transform(testData);
40+
41+
// Convert IDataView object to a list.
42+
var predictions = mlContext.Data
43+
.CreateEnumerable<Prediction>(transformedTestData,
44+
reuseRowObject: false).ToList();
45+
46+
// Print 5 predictions.
47+
foreach (var p in predictions.Take(5))
48+
Console.WriteLine($"Label: {p.Label}, "
49+
+ $"Prediction: {p.PredictedLabel}");
50+
51+
// Expected output:
52+
// Label: True, Prediction: True
53+
// Label: False, Prediction: True
54+
// Label: True, Prediction: True
55+
// Label: True, Prediction: True
56+
// Label: False, Prediction: False
57+
58+
// Evaluate the overall metrics.
59+
var metrics = mlContext.BinaryClassification
60+
.EvaluateNonCalibrated(transformedTestData);
61+
62+
PrintMetrics(metrics);
63+
64+
// Expected output:
65+
// Accuracy: 0.82
66+
// AUC: 0.85
67+
// F1 Score: 0.81
68+
// Negative Precision: 0.82
69+
// Negative Recall: 0.82
70+
// Positive Precision: 0.81
71+
// Positive Recall: 0.81
72+
73+
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
74+
// Confusion table
75+
// ||======================
76+
// PREDICTED || positive | negative | Recall
77+
// TRUTH ||======================
78+
// positive || 192 | 46 | 0.8067
79+
// negative || 46 | 216 | 0.8244
80+
// ||======================
81+
// Precision || 0.8067 | 0.8244 |
82+
}
83+
84+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
85+
int seed=0)
86+
87+
{
88+
var random = new Random(seed);
89+
float randomFloat() => (float)random.NextDouble();
90+
for (int i = 0; i < count; i++)
91+
{
92+
var label = randomFloat() > 0.5f;
93+
yield return new DataPoint
94+
{
95+
Label = label,
96+
// Create random features that are correlated with the label.
97+
// For data points with false label, the feature values are
98+
// slightly increased by adding a constant.
99+
Features = Enumerable.Repeat(label, 50)
100+
.Select(x => x ? randomFloat() : randomFloat() +
101+
0.1f).ToArray()
102+
103+
};
104+
}
105+
}
106+
107+
// Example with label and 50 feature values. A data set is a collection of
108+
// such examples.
109+
private class DataPoint
110+
{
111+
public bool Label { get; set; }
112+
[VectorType(50)]
113+
public float[] Features { get; set; }
114+
}
115+
116+
// Class used to capture predictions.
117+
private class Prediction
118+
{
119+
// Original label.
120+
public bool Label { get; set; }
121+
// Predicted label from the trainer.
122+
public bool PredictedLabel { get; set; }
123+
}
124+
125+
// Pretty-print BinaryClassificationMetrics objects.
126+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
127+
{
128+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
129+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
130+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
131+
Console.WriteLine($"Negative Precision: " +
132+
$"{metrics.NegativePrecision:F2}");
133+
134+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
135+
Console.WriteLine($"Positive Precision: " +
136+
$"{metrics.PositivePrecision:F2}");
137+
138+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
139+
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
140+
}
141+
}
142+
}
143+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "LdSvm";
4+
string Trainer = "LdSvm";
5+
string TrainerOptions = null;
6+
bool IsCalibrated = false;
7+
bool CacheData = false;
8+
9+
string LabelThreshold = "0.5f";
10+
string DataSepValue = "0.1f";
11+
string OptionsInclude = "";
12+
string Comments = "";
13+
14+
string ExpectedOutputPerInstance = @"// Expected output:
15+
// Label: True, Prediction: True
16+
// Label: False, Prediction: True
17+
// Label: True, Prediction: True
18+
// Label: True, Prediction: True
19+
// Label: False, Prediction: False";
20+
21+
string ExpectedOutput = @"// Expected output:
22+
// Accuracy: 0.82
23+
// AUC: 0.85
24+
// F1 Score: 0.81
25+
// Negative Precision: 0.82
26+
// Negative Recall: 0.82
27+
// Positive Precision: 0.81
28+
// Positive Recall: 0.81
29+
30+
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
31+
// Confusion table
32+
// ||======================
33+
// PREDICTED || positive | negative | Recall
34+
// TRUTH ||======================
35+
// positive || 192 | 46 | 0.8067
36+
// negative || 46 | 216 | 0.8244
37+
// ||======================
38+
// Precision || 0.8067 | 0.8244 |";
39+
#>
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Trainers;
7+
8+
namespace Samples.Dynamic.Trainers.BinaryClassification
9+
{
10+
public static class LdSvmWithOptions
11+
{
12+
public static void Example()
13+
{
14+
// Create a new context for ML.NET operations. It can be used for
15+
// exception tracking and logging, as a catalog of available operations
16+
// and as the source of randomness. Setting the seed to a fixed number
17+
// in this example to make outputs deterministic.
18+
var mlContext = new MLContext(seed: 0);
19+
20+
// Create a list of training data points.
21+
var dataPoints = GenerateRandomDataPoints(1000);
22+
23+
// Convert the list of data points to an IDataView object, which is
24+
// consumable by ML.NET API.
25+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
26+
27+
// Define trainer options.
28+
var options = new LdSvmTrainer.Options
29+
{
30+
TreeDepth = 5,
31+
NumberOfIterations = 10000,
32+
Sigma = 0.1f,
33+
};
34+
35+
// Define the trainer.
36+
var pipeline = mlContext.BinaryClassification.Trainers
37+
.LdSvm(options);
38+
39+
// Train the model.
40+
var model = pipeline.Fit(trainingData);
41+
42+
// Create testing data. Use different random seed to make it different
43+
// from training data.
44+
var testData = mlContext.Data
45+
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
46+
47+
// Run the model on test data set.
48+
var transformedTestData = model.Transform(testData);
49+
50+
// Convert IDataView object to a list.
51+
var predictions = mlContext.Data
52+
.CreateEnumerable<Prediction>(transformedTestData,
53+
reuseRowObject: false).ToList();
54+
55+
// Print 5 predictions.
56+
foreach (var p in predictions.Take(5))
57+
Console.WriteLine($"Label: {p.Label}, "
58+
+ $"Prediction: {p.PredictedLabel}");
59+
60+
// Expected output:
61+
// Label: True, Prediction: True
62+
// Label: False, Prediction: True
63+
// Label: True, Prediction: True
64+
// Label: True, Prediction: True
65+
// Label: False, Prediction: False
66+
67+
// Evaluate the overall metrics.
68+
var metrics = mlContext.BinaryClassification
69+
.EvaluateNonCalibrated(transformedTestData);
70+
71+
PrintMetrics(metrics);
72+
73+
// Expected output:
74+
// Accuracy: 0.80
75+
// AUC: 0.89
76+
// F1 Score: 0.79
77+
// Negative Precision: 0.81
78+
// Negative Recall: 0.81
79+
// Positive Precision: 0.79
80+
// Positive Recall: 0.79
81+
82+
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
83+
// Confusion table
84+
// ||======================
85+
// PREDICTED || positive | negative | Recall
86+
// TRUTH ||======================
87+
// positive || 189 | 49 | 0.7941
88+
// negative || 50 | 212 | 0.8092
89+
// ||======================
90+
// Precision || 0.7908 | 0.8123 |
91+
}
92+
93+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
94+
int seed=0)
95+
96+
{
97+
var random = new Random(seed);
98+
float randomFloat() => (float)random.NextDouble();
99+
for (int i = 0; i < count; i++)
100+
{
101+
var label = randomFloat() > 0.5f;
102+
yield return new DataPoint
103+
{
104+
Label = label,
105+
// Create random features that are correlated with the label.
106+
// For data points with false label, the feature values are
107+
// slightly increased by adding a constant.
108+
Features = Enumerable.Repeat(label, 50)
109+
.Select(x => x ? randomFloat() : randomFloat() +
110+
0.1f).ToArray()
111+
112+
};
113+
}
114+
}
115+
116+
// Example with label and 50 feature values. A data set is a collection of
117+
// such examples.
118+
private class DataPoint
119+
{
120+
public bool Label { get; set; }
121+
[VectorType(50)]
122+
public float[] Features { get; set; }
123+
}
124+
125+
// Class used to capture predictions.
126+
private class Prediction
127+
{
128+
// Original label.
129+
public bool Label { get; set; }
130+
// Predicted label from the trainer.
131+
public bool PredictedLabel { get; set; }
132+
}
133+
134+
// Pretty-print BinaryClassificationMetrics objects.
135+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
136+
{
137+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
138+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
139+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
140+
Console.WriteLine($"Negative Precision: " +
141+
$"{metrics.NegativePrecision:F2}");
142+
143+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
144+
Console.WriteLine($"Positive Precision: " +
145+
$"{metrics.PositivePrecision:F2}");
146+
147+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
148+
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
149+
}
150+
}
151+
}
152+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName="LdSvmWithOptions";
4+
string Trainer = "LdSvm";
5+
bool IsCalibrated = false;
6+
7+
string LabelThreshold = "0.5f";
8+
string DataSepValue = "0.1f";
9+
string OptionsInclude = "using Microsoft.ML.Trainers;";
10+
string Comments= "";
11+
bool CacheData = false;
12+
13+
string TrainerOptions = @"LdSvmTrainer.Options
14+
{
15+
TreeDepth = 5,
16+
NumberOfIterations = 10000,
17+
Sigma = 0.1f,
18+
}";
19+
20+
string ExpectedOutputPerInstance= @"// Expected output:
21+
// Label: True, Prediction: True
22+
// Label: False, Prediction: True
23+
// Label: True, Prediction: True
24+
// Label: True, Prediction: True
25+
// Label: False, Prediction: False";
26+
27+
string ExpectedOutput = @"// Expected output:
28+
// Accuracy: 0.80
29+
// AUC: 0.89
30+
// F1 Score: 0.79
31+
// Negative Precision: 0.81
32+
// Negative Recall: 0.81
33+
// Positive Precision: 0.79
34+
// Positive Recall: 0.79
35+
36+
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
37+
// Confusion table
38+
// ||======================
39+
// PREDICTED || positive | negative | Recall
40+
// TRUTH ||======================
41+
// positive || 189 | 49 | 0.7941
42+
// negative || 50 | 212 | 0.8092
43+
// ||======================
44+
// Precision || 0.7908 | 0.8123 |";
45+
#>

0 commit comments

Comments
 (0)