diff --git a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs index d16ced2de7..62c20bcdd1 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs @@ -129,6 +129,21 @@ public enum EvaluateMetricType MeanSquaredError }; + /// + /// The type of regression objective to use. + /// + public enum RegressionObjective + { + /// + /// Standard L2 (least squares) regression. + /// + Regression, + /// + /// Quantile regression. Use to set the target quantile. + /// + Quantile + }; + /// /// Determines what evaluation metric to use. /// @@ -137,6 +152,25 @@ public enum EvaluateMetricType ShortName = "em")] public EvaluateMetricType EvaluationMetric = EvaluateMetricType.RootMeanSquaredError; + /// + /// The regression objective type. Use with + /// for quantile regression. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Regression objective type. Use 'Quantile' for quantile regression.", + ShortName = "obj")] + public RegressionObjective Objective = RegressionObjective.Regression; + + /// + /// The quantile to predict when is . + /// Must be in the open interval (0, 1). For example, 0.05 for the 5th percentile or + /// 0.95 for the 95th percentile. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "The alpha (quantile) value for quantile regression. Must be in (0, 1).", + ShortName = "qa")] + public double Alpha = 0.5; + static Options() { NameMapping.Add(nameof(EvaluateMetricType), "metric"); @@ -145,6 +179,7 @@ static Options() NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae"); NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse"); NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse"); + NameMapping.Add(nameof(Objective), "_regression_objective"); } internal override Dictionary ToDictionary(IHost host) @@ -240,7 +275,19 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups) { - GbmOptions["objective"] = "regression"; + var regressionOptions = (Options)LightGbmTrainerOptions; + + if (regressionOptions.Objective == Options.RegressionObjective.Quantile) + { + Contracts.CheckUserArg(regressionOptions.Alpha > 0 && regressionOptions.Alpha < 1, + nameof(Options.Alpha), "Alpha for quantile regression must be in the open interval (0, 1)."); + GbmOptions["objective"] = "quantile"; + GbmOptions["alpha"] = regressionOptions.Alpha; + } + else + { + GbmOptions["objective"] = "regression"; + } } private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 55e9439cf4..b30265cea2 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -13266,6 +13266,36 @@ "IsNullable": false, "Default": "RootMeanSquaredError" }, + { + "Name": "Objective", + "Type": { + "Kind": "Enum", + "Values": [ + "Regression", + "Quantile" + ] + }, + "Desc": "Regression objective type. Use 'Quantile' for quantile regression.", + "Aliases": [ + "obj" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "Regression" + }, + { + "Name": "Alpha", + "Type": "Float", + "Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).", + "Aliases": [ + "qa" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0.5 + }, { "Name": "MaximumBinCountPerFeature", "Type": "Int", diff --git a/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json index a2cc2b6836..88e80589b7 100644 --- a/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json @@ -13266,6 +13266,36 @@ "IsNullable": false, "Default": "RootMeanSquaredError" }, + { + "Name": "Objective", + "Type": { + "Kind": "Enum", + "Values": [ + "Regression", + "Quantile" + ] + }, + "Desc": "Regression objective type. Use 'Quantile' for quantile regression.", + "Aliases": [ + "obj" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "Regression" + }, + { + "Name": "Alpha", + "Type": "Float", + "Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).", + "Aliases": [ + "qa" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0.5 + }, { "Name": "MaximumBinCountPerFeature", "Type": "Int", diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index b36dfa574a..978c18afdf 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -244,6 +244,127 @@ public void LightGBMRegressorEstimator() Done(); } + /// + /// LightGbmRegressionTrainer with quantile objective TrainerEstimator test + /// + [LightGBMFact] + public void LightGBMQuantileRegressorEstimator() + { + var dataView = GetRegressionPipeline(); + + var trainer = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = 0.5, + NumberOfIterations = 10, + NumberOfLeaves = 5, + }); + + TestEstimatorCore(trainer, dataView); + var model = trainer.Fit(dataView, dataView); + + var gbmParameters = trainer.GetGbmParameters(); + Assert.True(gbmParameters.ContainsKey("objective")); + Assert.Equal("quantile", gbmParameters["objective"]); + Assert.True(gbmParameters.ContainsKey("alpha")); + Assert.Equal(0.5, gbmParameters["alpha"]); + + Done(); + } + + /// + /// Verify that quantile regression predictions with different alpha values + /// produce appropriately ordered results (lower quantile less than upper quantile). + /// + [LightGBMFact] + public void LightGBMQuantileRegressorPredictionOrdering() + { + var dataView = GetRegressionPipeline(); + + // Train model for the 5th percentile + var trainerLow = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = 0.05, + NumberOfIterations = 50, + NumberOfLeaves = 10, + Seed = 42, + Deterministic = true, + }); + + // Train model for the 95th percentile + var trainerHigh = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = 0.95, + NumberOfIterations = 50, + NumberOfLeaves = 10, + Seed = 42, + Deterministic = true, + }); + + var modelLow = trainerLow.Fit(dataView); + var modelHigh = trainerHigh.Fit(dataView); + + var predictionsLow = modelLow.Transform(dataView); + var predictionsHigh = modelHigh.Transform(dataView); + + var scoresLow = predictionsLow.GetColumn(predictionsLow.Schema["Score"]).ToArray(); + var scoresHigh = predictionsHigh.GetColumn(predictionsHigh.Schema["Score"]).ToArray(); + + Assert.Equal(scoresLow.Length, scoresHigh.Length); + Assert.True(scoresLow.Length > 0); + + // The 95th percentile predictions should generally be at least as large as the + // 5th percentile predictions. Allow a small numerical tolerance and a limited + // number of crossings since the models are trained independently. + const float tolerance = 1e-4f; + var orderedCount = Enumerable.Range(0, scoresLow.Length) + .Count(i => scoresHigh[i] + tolerance >= scoresLow[i]); + var orderedRatio = (float)orderedCount / scoresLow.Length; + + Assert.True(orderedRatio >= 0.90f, + $"Expected the 95th percentile prediction to be >= the 5th percentile prediction for most rows, " + + $"but only {orderedCount} of {scoresLow.Length} rows satisfied the condition " + + $"({orderedRatio:P2}, tolerance={tolerance})."); + + Done(); + } + + /// + /// Verify that invalid Alpha values are rejected for quantile regression. + /// + [LightGBMFact] + public void LightGBMQuantileRegressorInvalidAlpha() + { + var dataView = GetRegressionPipeline(); + + // Alpha = 0 should fail + var trainerZero = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = 0.0, + }); + Assert.Throws(() => trainerZero.Fit(dataView)); + + // Alpha = 1 should fail + var trainerOne = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = 1.0, + }); + Assert.Throws(() => trainerOne.Fit(dataView)); + + // Alpha = -0.1 should fail + var trainerNeg = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options + { + Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile, + Alpha = -0.1, + }); + Assert.Throws(() => trainerNeg.Fit(dataView)); + + Done(); + } /// /// RegressionGamTrainer TrainerEstimator test