Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Light GBM sample #2493

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 87 additions & 0 deletions docs/samples/Microsoft.ML.Samples/Dynamic/LightGbm.cs
@@ -0,0 +1,87 @@
using System;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Samples.Dynamic
{
public class LightBgmExample
{
public static void LightBgm()
{
var ml = new MLContext();

// Downloading a classification dataset from github.com/dotnet/machinelearning.
// It will be stored in the same path as the executable
var dataFilePath = SamplesUtils.DatasetUtils.DownloadAdultDataset();

// Data Preview
// 1. Column: age (numeric)
// 2. Column: workclass (text/categorical)
// 3. Column: fnlwgt (numeric)
// 4. Column: education (text/categorical)
// 5. Column: education-num (numeric)
// 6. Column: marital-status (text/categorical)
// 7. Column: occupation (text/categorical)
// 8. Column: relationship (text/categorical)
// 9. Column: ethnicity (text/categorical)
// 10. Column: sex (text/categorical)
// 11. Column: capital-gain (numeric)
// 12. Column: capital-loss (numeric)
// 13. Column: hours-per-week (numeric)
// 14. Column: native-country (text/categorical)
// 15. Column: Column [Label]: IsOver50K (boolean)

var reader = ml.Data.CreateTextLoader(new TextLoader.Arguments
{
Separators = new[] { ',' },
HasHeader = true,
Columns = new[]
{
new TextLoader.Column("age", DataKind.R4, 0),
new TextLoader.Column("workclass", DataKind.Text, 1),
new TextLoader.Column("fnlwgt", DataKind.R4, 2),
new TextLoader.Column("education", DataKind.Text, 3),
new TextLoader.Column("education-num", DataKind.R4, 4),
new TextLoader.Column("marital-status", DataKind.Text, 5),
new TextLoader.Column("occupation", DataKind.Text, 6),
new TextLoader.Column("relationship", DataKind.Text, 7),
new TextLoader.Column("ethnicity", DataKind.Text, 8),
new TextLoader.Column("sex", DataKind.Text, 9),
new TextLoader.Column("capital-gain", DataKind.R4, 10),
new TextLoader.Column("capital-loss", DataKind.R4, 11),
new TextLoader.Column("hours-per-week", DataKind.R4, 12),
new TextLoader.Column("native-country", DataKind.Text, 13),
new TextLoader.Column("Label", DataKind.Bool, 14)
}
});

IDataView data = reader.Read(dataFilePath);

var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);

var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status",
"relationship", "ethnicity", "sex", "native-country")
.Append(ml.Transforms.Text.FeaturizeText("TextFeatures", "Text"))
.Append(ml.Transforms.Concatenate("Features", "TextFeatures", "age", "fnlwgt",
"education-num", "capital-gain", "capital-loss", "hours-per-week"))
.Append(ml.BinaryClassification.Trainers.LightGbm());

var model = pipeline.Fit(trainData);

var dataWithPredictions = model.Transform(testData);

var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions);

Console.WriteLine($"Accuracy: {metrics.Accuracy}"); // 0.87
Console.WriteLine($"AUC: {metrics.Auc}"); // 0.92
Console.WriteLine($"F1 Score: {metrics.F1Score}"); // 0.70

Console.WriteLine($"Negative Precision: {metrics.NegativePrecision}"); // 0.90
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall}"); // 0.93
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision}"); // 0.75
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.66

Console.ReadLine();
}
}
}
7 changes: 7 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGbmCatalog.cs
Expand Up @@ -24,6 +24,13 @@ public static class LightGbmExtensions
/// <param name="numBoostRound">Number of iterations.</param>
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
/// <param name="learningRate">The learning rate.</param>
/// /// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[Light GBM](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/LightGbm.cs)]
/// ]]>
/// </format>
/// </example>
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Expand Down