diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 46aac0f066..9c5e7ec4cd 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2985,41 +2985,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { Host.CheckValue(writer, nameof(writer)); - Host.CheckValue(schema, nameof(schema)); - Host.CheckValueOrNull(calibrator); - string ensembleIni = TrainedEnsemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema), + var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator, InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false); - ensembleIni = AddCalibrationToIni(ensembleIni, calibrator); writer.WriteLine(ensembleIni); } - /// - /// Get the calibration summary in INI format - /// - private string AddCalibrationToIni(string ini, ICalibrator calibrator) - { - Host.AssertValue(ini); - Host.AssertValueOrNull(calibrator); - - if (calibrator == null) - return ini; - - if (calibrator is PlattCalibrator) - { - string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator); - return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni); - } - else - { - StringBuilder newSection = new StringBuilder(); - newSection.AppendLine(); - newSection.AppendLine(); - newSection.AppendLine("[TLCCalibration]"); - newSection.AppendLine("Type=" + calibrator.GetType().Name); - return ini + newSection; - } - } - JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index c5d213671c..3c6edf0acc 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Threading; using Microsoft.ML; +using Microsoft.ML.Calibrator; using Microsoft.ML.Command; using Microsoft.ML.CommandLine; using Microsoft.ML.Core.Data; @@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue) } public abstract class GamModelParametersBase : ModelParametersBase, IValueMapper, ICalculateFeatureContribution, - IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary + IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat { private readonly double[][] _binUpperBounds; private readonly double[][] _binEffects; @@ -833,6 +834,7 @@ private void Map(in VBuffer features, ref float response) double value = Intercept; var featuresValues = features.GetValues(); + if (features.IsDense) { for (int i = 0; i < featuresValues.Length; ++i) @@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer features, ref VBuffer !line.StartsWith("SplitGain=")); + ini = string.Join("\n", goodLines); + writer.WriteLine(ini); + } + + // GAM bins should be converted to balanced trees / binary search trees + // so that scoring takes O(log(n)) instead of O(n). The following utility + // creates a balanced tree. + private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds) + { + var binIndices = Enumerable.Range(0, numInternalNodes).ToArray(); + var internalNodeIndices = new List(); + var lteChild = new List(); + var gtChild = new List(); + var internalNodeId = numInternalNodes; + + CreateBalancedTreeRecursive( + 0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + // internalNodeId should have been counted all the way down to 0 (root node) + Host.Assert(internalNodeId == 0); + + var tree = ( + thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(), + lteChild: lteChild.ToArray(), + gtChild: gtChild.ToArray()); + return tree; + } + + private int CreateBalancedTreeRecursive(int lower, int upper, + List internalNodeIndices, List lteChild, List gtChild, ref int internalNodeId) + { + if (lower > upper) + { + // Base case: we've reached a leaf node + Host.Assert(lower == upper + 1); + return ~lower; + } + else + { + // This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse. + // Preorder is the only option, because we need the results of both left/right recursions for populating the lists. + // As a result, lists are populated in reverse, because the root node should be the first item on the lists. + // Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree. + var mid = (lower + upper) / 2; + var left = CreateBalancedTreeRecursive( + lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + var right = CreateBalancedTreeRecursive( + mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId); + internalNodeIndices.Insert(0, mid); + lteChild.Insert(0, left); + gtChild.Insert(0, right); + return --internalNodeId; + } + } + private static RegressionTree CreateRegressionTree( + int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues) + { + var numInternalNodes = numLeaves - 1; + return RegressionTree.Create( + numLeaves: numLeaves, + splitFeatures: splitFeatures, + rawThresholds: rawThresholds, + lteChild: lteChild, + gtChild: gtChild.ToArray(), + leafValues: leafValues, + // Ignored arguments + splitGain: new double[numInternalNodes], + defaultValueForMissing: new float[numInternalNodes], + categoricalSplitFeatures: new int[numInternalNodes][], + categoricalSplit: new bool[numInternalNodes]); + } + /// /// The GAM model visualization command. Because the data access commands must access private members of /// , it is convenient to have the command itself nested within the base diff --git a/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs new file mode 100644 index 0000000000..1a2a35136b --- /dev/null +++ b/src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Text; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; + +namespace Microsoft.ML.Trainers.FastTree.Internal +{ + internal static class FastTreeIniFileUtils + { + public static string TreeEnsembleToIni( + IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator, + string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures) + { + host.CheckValue(ensemble, nameof(ensemble)); + host.CheckValue(schema, nameof(schema)); + + string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema), + trainingParams, appendFeatureGain, includeZeroGainFeatures); + ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator); + return ensembleIni; + } + + /// + /// Get the calibration summary in INI format + /// + private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator) + { + host.AssertValue(ini); + host.AssertValueOrNull(calibrator); + + if (calibrator == null) + return ini; + + if (calibrator is PlattCalibrator) + { + string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator); + return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni); + } + else + { + StringBuilder newSection = new StringBuilder(); + newSection.AppendLine(); + newSection.AppendLine(); + newSection.AppendLine("[TLCCalibration]"); + newSection.AppendLine("Type=" + calibrator.GetType().Name); + return ini + newSection; + } + } + } +} diff --git a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs index 58332208a3..edb6bd5e86 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestIniModels.cs @@ -2,6 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Calibration; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Internal.Internallearn; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Tools; +using Xunit; +using Xunit.Abstractions; namespace Microsoft.ML.RunTests { @@ -497,4 +510,95 @@ public ProcessDebugInformation RunCommandLine(string commandLine, string dir) } } #endif + + public sealed class TestIniModels : TestDataPipeBase + { + public TestIniModels(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void TestGamRegressionIni() + { + var mlContext = new MLContext(seed: 0); + var idv = mlContext.Data.CreateTextReader( + new TextLoader.Arguments() + { + HasHeader = false, + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("Features", DataKind.R4, 1, 9) + } + }).Read(GetDataPath("breast-cancer.txt")); + + var pipeline = mlContext.Transforms.ReplaceMissingValues("Features") + .Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels()); + var model = pipeline.Fit(idv); + var data = model.Transform(idv); + + var roleMappedSchema = new RoleMappedSchema(data.Schema, false, + new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, "Features"), + new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label")); + + string modelIniPath = GetOutputPath(FullTestName + "-model.ini"); + using (Stream iniStream = File.Create(modelIniPath)) + using (StreamWriter iniWriter = Utils.OpenWriter(iniStream)) + ((ICanSaveInIniFormat)model.LastTransformer.Model).SaveAsIni(iniWriter, roleMappedSchema); + + var results = mlContext.Regression.Evaluate(data); + + // Getting parity results from maml.exe: + // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Regression + Assert.Equal(0.093256807643323947, results.L1); + Assert.Equal(0.025707474358979077, results.L2); + Assert.Equal(0.16033550560926635, results.Rms); + Assert.Equal(0.88620288753853549, results.RSquared); + } + + [Fact] + public void TestGamBinaryClassificationIni() + { + var mlContext = new MLContext(seed: 0); + var idv = mlContext.Data.CreateTextReader( + new TextLoader.Arguments() + { + HasHeader = false, + Column = new[] + { + new TextLoader.Column("Label", DataKind.BL, 0), + new TextLoader.Column("Features", DataKind.R4, 1, 9) + } + }).Read(GetDataPath("breast-cancer.txt")); + + var pipeline = mlContext.Transforms.ReplaceMissingValues("Features") + .Append(mlContext.BinaryClassification.Trainers.GeneralizedAdditiveModels()); + var model = pipeline.Fit(idv); + var data = model.Transform(idv); + + var roleMappedSchema = new RoleMappedSchema(data.Schema, false, + new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, "Features"), + new KeyValuePair(RoleMappedSchema.ColumnRole.Label, "Label")); + + var calibratedPredictor = model.LastTransformer.Model as CalibratedPredictor; + var predictor = calibratedPredictor.SubPredictor as ICanSaveInIniFormat; + string modelIniPath = GetOutputPath(FullTestName + "-model.ini"); + + using (Stream iniStream = File.Create(modelIniPath)) + using (StreamWriter iniWriter = Utils.OpenWriter(iniStream)) + predictor.SaveAsIni(iniWriter, roleMappedSchema, calibratedPredictor.Calibrator); + + var results = mlContext.BinaryClassification.Evaluate(data); + + // Getting parity results from maml.exe: + // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Binary + Assert.Equal(0.99545199224483139, results.Auc); + Assert.Equal(0.96995708154506433, results.Accuracy); + Assert.Equal(0.95081967213114749, results.PositivePrecision); + Assert.Equal(0.96265560165975106, results.PositiveRecall); + Assert.Equal(0.95670103092783509, results.F1Score); + Assert.Equal(0.11594021906091197, results.LogLoss); + } + } + }