diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 46aac0f066..9c5e7ec4cd 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -2985,41 +2985,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
{
Host.CheckValue(writer, nameof(writer));
- Host.CheckValue(schema, nameof(schema));
- Host.CheckValueOrNull(calibrator);
- string ensembleIni = TrainedEnsemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
+ var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator,
InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false);
- ensembleIni = AddCalibrationToIni(ensembleIni, calibrator);
writer.WriteLine(ensembleIni);
}
- ///
- /// Get the calibration summary in INI format
- ///
- private string AddCalibrationToIni(string ini, ICalibrator calibrator)
- {
- Host.AssertValue(ini);
- Host.AssertValueOrNull(calibrator);
-
- if (calibrator == null)
- return ini;
-
- if (calibrator is PlattCalibrator)
- {
- string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
- return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
- }
- else
- {
- StringBuilder newSection = new StringBuilder();
- newSection.AppendLine();
- newSection.AppendLine();
- newSection.AppendLine("[TLCCalibration]");
- newSection.AppendLine("Type=" + calibrator.GetType().Name);
- return ini + newSection;
- }
- }
-
JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
{
Host.CheckValue(ctx, nameof(ctx));
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index c5d213671c..3c6edf0acc 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -8,6 +8,7 @@
using System.Linq;
using System.Threading;
using Microsoft.ML;
+using Microsoft.ML.Calibrator;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Core.Data;
@@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue)
}
public abstract class GamModelParametersBase : ModelParametersBase, IValueMapper, ICalculateFeatureContribution,
- IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary
+ IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat
{
private readonly double[][] _binUpperBounds;
private readonly double[][] _binEffects;
@@ -833,6 +834,7 @@ private void Map(in VBuffer features, ref float response)
double value = Intercept;
var featuresValues = features.GetValues();
+
if (features.IsDense)
{
for (int i = 0; i < featuresValues.Length; ++i)
@@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer features, ref VBuffer !line.StartsWith("SplitGain="));
+ ini = string.Join("\n", goodLines);
+ writer.WriteLine(ini);
+ }
+
+ // GAM bins should be converted to balanced trees / binary search trees
+ // so that scoring takes O(log(n)) instead of O(n). The following utility
+ // creates a balanced tree.
+ private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds)
+ {
+ var binIndices = Enumerable.Range(0, numInternalNodes).ToArray();
+ var internalNodeIndices = new List();
+ var lteChild = new List();
+ var gtChild = new List();
+ var internalNodeId = numInternalNodes;
+
+ CreateBalancedTreeRecursive(
+ 0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
+ // internalNodeId should have been counted all the way down to 0 (root node)
+ Host.Assert(internalNodeId == 0);
+
+ var tree = (
+ thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(),
+ lteChild: lteChild.ToArray(),
+ gtChild: gtChild.ToArray());
+ return tree;
+ }
+
+ private int CreateBalancedTreeRecursive(int lower, int upper,
+ List internalNodeIndices, List lteChild, List gtChild, ref int internalNodeId)
+ {
+ if (lower > upper)
+ {
+ // Base case: we've reached a leaf node
+ Host.Assert(lower == upper + 1);
+ return ~lower;
+ }
+ else
+ {
+ // This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse.
+ // Preorder is the only option, because we need the results of both left/right recursions for populating the lists.
+ // As a result, lists are populated in reverse, because the root node should be the first item on the lists.
+ // Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree.
+ var mid = (lower + upper) / 2;
+ var left = CreateBalancedTreeRecursive(
+ lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
+ var right = CreateBalancedTreeRecursive(
+ mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
+ internalNodeIndices.Insert(0, mid);
+ lteChild.Insert(0, left);
+ gtChild.Insert(0, right);
+ return --internalNodeId;
+ }
+ }
+ private static RegressionTree CreateRegressionTree(
+ int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues)
+ {
+ var numInternalNodes = numLeaves - 1;
+ return RegressionTree.Create(
+ numLeaves: numLeaves,
+ splitFeatures: splitFeatures,
+ rawThresholds: rawThresholds,
+ lteChild: lteChild,
+ gtChild: gtChild.ToArray(),
+ leafValues: leafValues,
+ // Ignored arguments
+ splitGain: new double[numInternalNodes],
+ defaultValueForMissing: new float[numInternalNodes],
+ categoricalSplitFeatures: new int[numInternalNodes][],
+ categoricalSplit: new bool[numInternalNodes]);
+ }
+
///
/// The GAM model visualization command. Because the data access commands must access private members of
/// , it is convenient to have the command itself nested within the base
diff --git a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
new file mode 100644
index 0000000000..1a2a35136b
--- /dev/null
+++ b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+using Microsoft.ML.Calibrator;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Calibration;
+using Microsoft.ML.Internal.Utilities;
+
+namespace Microsoft.ML.Trainers.FastTree.Internal
+{
+ internal static class FastTreeIniFileUtils
+ {
+ public static string TreeEnsembleToIni(
+ IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator,
+ string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures)
+ {
+ host.CheckValue(ensemble, nameof(ensemble));
+ host.CheckValue(schema, nameof(schema));
+
+ string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
+ trainingParams, appendFeatureGain, includeZeroGainFeatures);
+ ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator);
+ return ensembleIni;
+ }
+
+ ///
+ /// Get the calibration summary in INI format
+ ///
+ private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator)
+ {
+ host.AssertValue(ini);
+ host.AssertValueOrNull(calibrator);
+
+ if (calibrator == null)
+ return ini;
+
+ if (calibrator is PlattCalibrator)
+ {
+ string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
+ return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
+ }
+ else
+ {
+ StringBuilder newSection = new StringBuilder();
+ newSection.AppendLine();
+ newSection.AppendLine();
+ newSection.AppendLine("[TLCCalibration]");
+ newSection.AppendLine("Type=" + calibrator.GetType().Name);
+ return ini + newSection;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
index 58332208a3..edb6bd5e86 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
@@ -2,6 +2,19 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Calibration;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Internal.Internallearn;
+using Microsoft.ML.Trainers.FastTree;
+using Microsoft.ML.Tools;
+using Xunit;
+using Xunit.Abstractions;
namespace Microsoft.ML.RunTests
{
@@ -497,4 +510,95 @@ public ProcessDebugInformation RunCommandLine(string commandLine, string dir)
}
}
#endif
+
+ public sealed class TestIniModels : TestDataPipeBase
+ {
+ public TestIniModels(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [Fact]
+ public void TestGamRegressionIni()
+ {
+ var mlContext = new MLContext(seed: 0);
+ var idv = mlContext.Data.CreateTextReader(
+ new TextLoader.Arguments()
+ {
+ HasHeader = false,
+ Column = new[]
+ {
+ new TextLoader.Column("Label", DataKind.R4, 0),
+ new TextLoader.Column("Features", DataKind.R4, 1, 9)
+ }
+ }).Read(GetDataPath("breast-cancer.txt"));
+
+ var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
+ .Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels());
+ var model = pipeline.Fit(idv);
+ var data = model.Transform(idv);
+
+ var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
+ new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, "Features"),
+ new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label"));
+
+ string modelIniPath = GetOutputPath(FullTestName + "-model.ini");
+ using (Stream iniStream = File.Create(modelIniPath))
+ using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
+ ((ICanSaveInIniFormat)model.LastTransformer.Model).SaveAsIni(iniWriter, roleMappedSchema);
+
+ var results = mlContext.Regression.Evaluate(data);
+
+ // Getting parity results from maml.exe:
+ // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Regression
+ Assert.Equal(0.093256807643323947, results.L1);
+ Assert.Equal(0.025707474358979077, results.L2);
+ Assert.Equal(0.16033550560926635, results.Rms);
+ Assert.Equal(0.88620288753853549, results.RSquared);
+ }
+
+ [Fact]
+ public void TestGamBinaryClassificationIni()
+ {
+ var mlContext = new MLContext(seed: 0);
+ var idv = mlContext.Data.CreateTextReader(
+ new TextLoader.Arguments()
+ {
+ HasHeader = false,
+ Column = new[]
+ {
+ new TextLoader.Column("Label", DataKind.BL, 0),
+ new TextLoader.Column("Features", DataKind.R4, 1, 9)
+ }
+ }).Read(GetDataPath("breast-cancer.txt"));
+
+ var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
+ .Append(mlContext.BinaryClassification.Trainers.GeneralizedAdditiveModels());
+ var model = pipeline.Fit(idv);
+ var data = model.Transform(idv);
+
+ var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
+ new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, "Features"),
+ new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label"));
+
+ var calibratedPredictor = model.LastTransformer.Model as CalibratedPredictor;
+ var predictor = calibratedPredictor.SubPredictor as ICanSaveInIniFormat;
+ string modelIniPath = GetOutputPath(FullTestName + "-model.ini");
+
+ using (Stream iniStream = File.Create(modelIniPath))
+ using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
+ predictor.SaveAsIni(iniWriter, roleMappedSchema, calibratedPredictor.Calibrator);
+
+ var results = mlContext.BinaryClassification.Evaluate(data);
+
+ // Getting parity results from maml.exe:
+ // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Binary
+ Assert.Equal(0.99545199224483139, results.Auc);
+ Assert.Equal(0.96995708154506433, results.Accuracy);
+ Assert.Equal(0.95081967213114749, results.PositivePrecision);
+ Assert.Equal(0.96265560165975106, results.PositiveRecall);
+ Assert.Equal(0.95670103092783509, results.F1Score);
+ Assert.Equal(0.11594021906091197, results.LogLoss);
+ }
+ }
+
}