Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ public enum EvaluateMetricType
MeanSquaredError
};

/// <summary>
/// The type of regression objective to use.
/// </summary>
public enum RegressionObjective
{
/// <summary>
/// Standard L2 (least squares) regression.
/// </summary>
Regression,
/// <summary>
/// Quantile regression. Use <see cref="Alpha"/> to set the target quantile.
/// </summary>
Quantile
};

/// <summary>
/// Determines what evaluation metric to use.
/// </summary>
Expand All @@ -137,6 +152,25 @@ public enum EvaluateMetricType
ShortName = "em")]
public EvaluateMetricType EvaluationMetric = EvaluateMetricType.RootMeanSquaredError;

/// <summary>
/// The regression objective type. Use <see cref="RegressionObjective.Quantile"/> with
/// <see cref="Alpha"/> for quantile regression.
/// </summary>
[Argument(ArgumentType.AtMostOnce,
HelpText = "Regression objective type. Use 'Quantile' for quantile regression.",
ShortName = "obj")]
public RegressionObjective Objective = RegressionObjective.Regression;

/// <summary>
/// The quantile to predict when <see cref="Objective"/> is <see cref="RegressionObjective.Quantile"/>.
/// Must be in the open interval (0, 1). For example, 0.05 for the 5th percentile or
/// 0.95 for the 95th percentile.
/// </summary>
[Argument(ArgumentType.AtMostOnce,
HelpText = "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
ShortName = "qa")]
public double Alpha = 0.5;

static Options()
{
NameMapping.Add(nameof(EvaluateMetricType), "metric");
Expand All @@ -145,6 +179,7 @@ static Options()
NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae");
NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse");
NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse");
NameMapping.Add(nameof(Objective), "_regression_objective");
}

internal override Dictionary<string, object> ToDictionary(IHost host)
Expand Down Expand Up @@ -240,7 +275,19 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)

private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups)
{
GbmOptions["objective"] = "regression";
var regressionOptions = (Options)LightGbmTrainerOptions;

if (regressionOptions.Objective == Options.RegressionObjective.Quantile)
{
Contracts.CheckUserArg(regressionOptions.Alpha > 0 && regressionOptions.Alpha < 1,
nameof(Options.Alpha), "Alpha for quantile regression must be in the open interval (0, 1).");
GbmOptions["objective"] = "quantile";
GbmOptions["alpha"] = regressionOptions.Alpha;
}
else
{
GbmOptions["objective"] = "regression";
}
}

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand Down
30 changes: 30 additions & 0 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -13266,6 +13266,36 @@
"IsNullable": false,
"Default": "RootMeanSquaredError"
},
{
"Name": "Objective",
"Type": {
"Kind": "Enum",
"Values": [
"Regression",
"Quantile"
]
},
"Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
"Aliases": [
"obj"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": "Regression"
},
{
"Name": "Alpha",
"Type": "Float",
"Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
"Aliases": [
"qa"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": 0.5
},
{
"Name": "MaximumBinCountPerFeature",
"Type": "Int",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13266,6 +13266,36 @@
"IsNullable": false,
"Default": "RootMeanSquaredError"
},
{
"Name": "Objective",
"Type": {
"Kind": "Enum",
"Values": [
"Regression",
"Quantile"
]
},
"Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
"Aliases": [
"obj"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": "Regression"
},
{
"Name": "Alpha",
"Type": "Float",
"Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
"Aliases": [
"qa"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": 0.5
},
{
"Name": "MaximumBinCountPerFeature",
"Type": "Int",
Expand Down
121 changes: 121 additions & 0 deletions test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,127 @@ public void LightGBMRegressorEstimator()
Done();
}

/// <summary>
/// LightGbmRegressionTrainer with quantile objective TrainerEstimator test
/// </summary>
[LightGBMFact]
public void LightGBMQuantileRegressorEstimator()
{
var dataView = GetRegressionPipeline();

var trainer = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = 0.5,
NumberOfIterations = 10,
NumberOfLeaves = 5,
});

TestEstimatorCore(trainer, dataView);
var model = trainer.Fit(dataView, dataView);

var gbmParameters = trainer.GetGbmParameters();
Assert.True(gbmParameters.ContainsKey("objective"));
Assert.Equal("quantile", gbmParameters["objective"]);
Assert.True(gbmParameters.ContainsKey("alpha"));
Assert.Equal(0.5, gbmParameters["alpha"]);

Done();
}

/// <summary>
/// Verify that quantile regression predictions with different alpha values
/// produce appropriately ordered results (lower quantile less than upper quantile).
/// </summary>
[LightGBMFact]
public void LightGBMQuantileRegressorPredictionOrdering()
{
var dataView = GetRegressionPipeline();

// Train model for the 5th percentile
var trainerLow = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = 0.05,
NumberOfIterations = 50,
NumberOfLeaves = 10,
Seed = 42,
Deterministic = true,
});

// Train model for the 95th percentile
var trainerHigh = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = 0.95,
NumberOfIterations = 50,
NumberOfLeaves = 10,
Seed = 42,
Deterministic = true,
});

var modelLow = trainerLow.Fit(dataView);
var modelHigh = trainerHigh.Fit(dataView);

var predictionsLow = modelLow.Transform(dataView);
var predictionsHigh = modelHigh.Transform(dataView);

var scoresLow = predictionsLow.GetColumn<float>(predictionsLow.Schema["Score"]).ToArray();
var scoresHigh = predictionsHigh.GetColumn<float>(predictionsHigh.Schema["Score"]).ToArray();

Assert.Equal(scoresLow.Length, scoresHigh.Length);
Assert.True(scoresLow.Length > 0);

// The 95th percentile predictions should generally be at least as large as the
// 5th percentile predictions. Allow a small numerical tolerance and a limited
// number of crossings since the models are trained independently.
const float tolerance = 1e-4f;
var orderedCount = Enumerable.Range(0, scoresLow.Length)
.Count(i => scoresHigh[i] + tolerance >= scoresLow[i]);
var orderedRatio = (float)orderedCount / scoresLow.Length;

Assert.True(orderedRatio >= 0.90f,
$"Expected the 95th percentile prediction to be >= the 5th percentile prediction for most rows, " +
$"but only {orderedCount} of {scoresLow.Length} rows satisfied the condition " +
$"({orderedRatio:P2}, tolerance={tolerance}).");

Done();
}

/// <summary>
/// Verify that invalid Alpha values are rejected for quantile regression.
/// </summary>
[LightGBMFact]
public void LightGBMQuantileRegressorInvalidAlpha()
{
var dataView = GetRegressionPipeline();

// Alpha = 0 should fail
var trainerZero = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = 0.0,
});
Assert.Throws<ArgumentOutOfRangeException>(() => trainerZero.Fit(dataView));

// Alpha = 1 should fail
var trainerOne = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = 1.0,
});
Assert.Throws<ArgumentOutOfRangeException>(() => trainerOne.Fit(dataView));

// Alpha = -0.1 should fail
var trainerNeg = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
{
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
Alpha = -0.1,
});
Assert.Throws<ArgumentOutOfRangeException>(() => trainerNeg.Fit(dataView));

Done();
}

/// <summary>
/// RegressionGamTrainer TrainerEstimator test
Expand Down
Loading