-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Implement ICanSaveInIniFormat interface for GamPredictor #1929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3d1c9e
49250c2
e6e918f
dbe8695
0d49c9d
e22a82b
cd783c2
195a08b
344a15a
9101a32
363bd09
d94a279
e08ec6c
30e20be
81fcdec
74668ba
39eb429
37fa779
1e0e3fa
345e6b9
8022e33
2728c79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| using System.Linq; | ||
| using System.Threading; | ||
| using Microsoft.ML; | ||
| using Microsoft.ML.Calibrator; | ||
| using Microsoft.ML.Command; | ||
| using Microsoft.ML.CommandLine; | ||
| using Microsoft.ML.Core.Data; | ||
|
|
@@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue) | |
| } | ||
|
|
||
| public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution, | ||
| IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary | ||
| IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat | ||
| { | ||
| private readonly double[][] _binUpperBounds; | ||
| private readonly double[][] _binEffects; | ||
|
|
@@ -833,6 +834,7 @@ private void Map(in VBuffer<float> features, ref float response) | |
|
|
||
| double value = Intercept; | ||
| var featuresValues = features.GetValues(); | ||
|
|
||
| if (features.IsDense) | ||
| { | ||
| for (int i = 0; i < featuresValues.Length; ++i) | ||
|
|
@@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<flo | |
| Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize); | ||
| } | ||
|
|
||
| void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) | ||
| { | ||
| Host.CheckValue(writer, nameof(writer)); | ||
| var ensemble = new TreeEnsemble(); | ||
|
|
||
| for (int featureIndex = 0; featureIndex < _numFeatures; featureIndex++) | ||
| { | ||
| var effects = _binEffects[featureIndex]; | ||
| var binThresholds = _binUpperBounds[featureIndex]; | ||
|
|
||
| Host.Assert(effects.Length == binThresholds.Length); | ||
| var numLeaves = effects.Length; | ||
| var numInternalNodes = numLeaves - 1; | ||
|
|
||
| var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray(); | ||
| var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds); | ||
| var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects); | ||
| ensemble.AddTree(tree); | ||
| } | ||
|
|
||
| // Adding the intercept as a dummy tree with the output values being the model intercept, | ||
| // works for reaching parity. | ||
| var interceptTree = CreateRegressionTree( | ||
| numLeaves: 2, | ||
| splitFeatures: new[] { 0 }, | ||
| rawThresholds: new[] { 0f }, | ||
| lteChild: new[] { ~0 }, | ||
| gtChild: new[] { ~1 }, | ||
| leafValues: new[] { Intercept, Intercept }); | ||
| ensemble.AddTree(interceptTree); | ||
|
|
||
| var ini = FastTreeIniFileUtils.TreeEnsembleToIni( | ||
| Host, ensemble, schema, calibrator, string.Empty, false, false); | ||
|
|
||
| // Remove the SplitGain values which are all 0. | ||
| // It's eaiser to remove them here, than to modify the FastTree code. | ||
| var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain=")); | ||
| ini = string.Join("\n", goodLines); | ||
| writer.WriteLine(ini); | ||
| } | ||
|
|
||
| // GAM bins should be converted to balanced trees / binary search trees | ||
| // so that scoring takes O(log(n)) instead of O(n). The following utility | ||
| // creates a balanced tree. | ||
| private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does returning a list limit what C# or .NET versions we can target? If so, we can switch this to an
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tuples are heavily used as return type in Static API, so using it here doesn't create more version restrictions than what we already have in the rest of the code. I think this way it fine. In reply to: 243140927 [](ancestors = 243140927) |
||
| { | ||
| var binIndices = Enumerable.Range(0, numInternalNodes).ToArray(); | ||
| var internalNodeIndices = new List<int>(); | ||
| var lteChild = new List<int>(); | ||
| var gtChild = new List<int>(); | ||
| var internalNodeId = numInternalNodes; | ||
|
|
||
| CreateBalancedTreeRecursive( | ||
| 0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); | ||
| // internalNodeId should have been counted all the way down to 0 (root node) | ||
| Host.Assert(internalNodeId == 0); | ||
|
|
||
| var tree = ( | ||
| thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(), | ||
| lteChild: lteChild.ToArray(), | ||
| gtChild: gtChild.ToArray()); | ||
| return tree; | ||
| } | ||
|
|
||
| private int CreateBalancedTreeRecursive(int lower, int upper, | ||
| List<int> internalNodeIndices, List<int> lteChild, List<int> gtChild, ref int internalNodeId) | ||
| { | ||
| if (lower > upper) | ||
| { | ||
| // Base case: we've reached a leaf node | ||
| Host.Assert(lower == upper + 1); | ||
| return ~lower; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why negate lower? #Resolved
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The leaf values are negated in the ini format (to distinguish from the internal nodes). The base case happens when we reach leaf nodes. I didn't want to add too much comments about the ini format details, because frankly the rest of the code is the only documentation we have for that format. In reply to: 243443992 [](ancestors = 243443992) |
||
| } | ||
| else | ||
| { | ||
| // This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse. | ||
| // Preorder is the only option, because we need the results of both left/right recursions for populating the lists. | ||
| // As a result, lists are populated in reverse, because the root node should be the first item on the lists. | ||
| // Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree. | ||
| var mid = (lower + upper) / 2; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think this is good information - but for me its missing a key point, which is you are building a balanced tree by dividing the array in half, then taking each half and dividing that half where each midway point becomes a node in the tree.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. I had mentioned binary-search-tree in the function above, which is exactly the key point you mentioned. I'll re-mention it here. In reply to: 243444717 [](ancestors = 243444717) |
||
| var left = CreateBalancedTreeRecursive( | ||
| lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); | ||
| var right = CreateBalancedTreeRecursive( | ||
| mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId); | ||
| internalNodeIndices.Insert(0, mid); | ||
| lteChild.Insert(0, left); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like we have one depth too many. Won't the final Right and Left children be the same? #Resolved
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RegressionTree has internal nodes and leaf nodes. That extra depth ends up in the leaf nodes which ends the recursion. The final left and right nodes will be two leaf nodes, OR 1 leaf node and 1 internal node (which will subsequently end up with 2 leaf nodes). In reply to: 243455772 [](ancestors = 243455772) |
||
| gtChild.Insert(0, right); | ||
| return --internalNodeId; | ||
| } | ||
| } | ||
| private static RegressionTree CreateRegressionTree( | ||
| int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues) | ||
| { | ||
| var numInternalNodes = numLeaves - 1; | ||
| return RegressionTree.Create( | ||
| numLeaves: numLeaves, | ||
| splitFeatures: splitFeatures, | ||
| rawThresholds: rawThresholds, | ||
| lteChild: lteChild, | ||
| gtChild: gtChild.ToArray(), | ||
| leafValues: leafValues, | ||
| // Ignored arguments | ||
| splitGain: new double[numInternalNodes], | ||
| defaultValueForMissing: new float[numInternalNodes], | ||
| categoricalSplitFeatures: new int[numInternalNodes][], | ||
| categoricalSplit: new bool[numInternalNodes]); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// The GAM model visualization command. Because the data access commands must access private members of | ||
| /// <see cref="GamModelParametersBase"/>, it is convenient to have the command itself nested within the base | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System.Text; | ||
| using Microsoft.ML.Calibrator; | ||
| using Microsoft.ML.Data; | ||
| using Microsoft.ML.Internal.Calibration; | ||
| using Microsoft.ML.Internal.Utilities; | ||
|
|
||
| namespace Microsoft.ML.Trainers.FastTree.Internal | ||
| { | ||
| internal static class FastTreeIniFileUtils | ||
| { | ||
| public static string TreeEnsembleToIni( | ||
| IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator, | ||
| string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures) | ||
| { | ||
| host.CheckValue(ensemble, nameof(ensemble)); | ||
| host.CheckValue(schema, nameof(schema)); | ||
|
|
||
| string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema), | ||
| trainingParams, appendFeatureGain, includeZeroGainFeatures); | ||
| ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator); | ||
| return ensembleIni; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Get the calibration summary in INI format | ||
| /// </summary> | ||
| private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator) | ||
| { | ||
| host.AssertValue(ini); | ||
| host.AssertValueOrNull(calibrator); | ||
|
|
||
| if (calibrator == null) | ||
| return ini; | ||
|
|
||
| if (calibrator is PlattCalibrator) | ||
| { | ||
| string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator); | ||
| return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni); | ||
| } | ||
| else | ||
| { | ||
| StringBuilder newSection = new StringBuilder(); | ||
| newSection.AppendLine(); | ||
| newSection.AppendLine(); | ||
| newSection.AppendLine("[TLCCalibration]"); | ||
| newSection.AppendLine("Type=" + calibrator.GetType().Name); | ||
| return ini + newSection; | ||
| } | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,19 @@ | |
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.IO; | ||
| using System.Threading; | ||
| using Microsoft.ML; | ||
| using Microsoft.ML.Data; | ||
| using Microsoft.ML.Internal.Calibration; | ||
| using Microsoft.ML.Internal.Utilities; | ||
| using Microsoft.ML.Internal.Internallearn; | ||
| using Microsoft.ML.Trainers.FastTree; | ||
| using Microsoft.ML.Tools; | ||
| using Xunit; | ||
| using Xunit.Abstractions; | ||
|
|
||
| namespace Microsoft.ML.RunTests | ||
| { | ||
|
|
@@ -497,4 +510,95 @@ public ProcessDebugInformation RunCommandLine(string commandLine, string dir) | |
| } | ||
| } | ||
| #endif | ||
|
|
||
| public sealed class TestIniModels : TestDataPipeBase | ||
| { | ||
| public TestIniModels(ITestOutputHelper output) : base(output) | ||
| { | ||
| } | ||
|
|
||
| [Fact] | ||
| public void TestGamRegressionIni() | ||
| { | ||
| var mlContext = new MLContext(seed: 0); | ||
| var idv = mlContext.Data.CreateTextReader( | ||
| new TextLoader.Arguments() | ||
| { | ||
| HasHeader = false, | ||
| Column = new[] | ||
| { | ||
| new TextLoader.Column("Label", DataKind.R4, 0), | ||
| new TextLoader.Column("Features", DataKind.R4, 1, 9) | ||
| } | ||
| }).Read(GetDataPath("breast-cancer.txt")); | ||
|
|
||
| var pipeline = mlContext.Transforms.ReplaceMissingValues("Features") | ||
| .Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels()); | ||
| var model = pipeline.Fit(idv); | ||
| var data = model.Transform(idv); | ||
|
|
||
| var roleMappedSchema = new RoleMappedSchema(data.Schema, false, | ||
| new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"), | ||
| new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label")); | ||
|
|
||
| string modelIniPath = GetOutputPath(FullTestName + "-model.ini"); | ||
| using (Stream iniStream = File.Create(modelIniPath)) | ||
| using (StreamWriter iniWriter = Utils.OpenWriter(iniStream)) | ||
| ((ICanSaveInIniFormat)model.LastTransformer.Model).SaveAsIni(iniWriter, roleMappedSchema); | ||
|
|
||
| var results = mlContext.Regression.Evaluate(data); | ||
|
|
||
| // Getting parity results from maml.exe: | ||
| // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Regression | ||
| Assert.Equal(0.093256807643323947, results.L1); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about this test. On one hand, it's good to make sure that it can produce the score. On the other hand, it's a new baseline test for GAMs. I'd say remove it from ML.NET. #ByDesign
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, do you want me to remove both TestGamRegressionIni and TestGamBinaryClassificationIni tests? In reply to: 243141824 [](ancestors = 243141824)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We definitely need a test that checks that the new Ini file produces the same results as the GAM. It would be ideal if you could compare the the outputs with a full file compare instead of using Assert.Equal. Is it possible to do that? In reply to: 244387146 [](ancestors = 244387146,243141824)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's better to keep these parity tests and also prevent regressions cause by refactoring in fasttree code. The problem with ini file comparison is the same as testing baseline comparison, which needs to extract numbers in the file and compare them with some tolerance, because verbatim comparison will fail due to changes in the higher decimal places. We already have those facilities but they're fitted for train-test / CV type tasks that are common in ML. I didn't want to modify those files for this one-off ini-file comparison. Also half of this testing is done in TLC, b/c the ini parsing code is only in TLC. The baseline metrics actually come from TLC. More importantly, the ini-format is not an open-source format and is Microsoft IP. If we want to check-in an ini file for baseline, we should have LCA clear that. For those reasons, I think this simple test is good enough for now. In reply to: 244822621 [](ancestors = 244822621,244387146,243141824) |
||
| Assert.Equal(0.025707474358979077, results.L2); | ||
| Assert.Equal(0.16033550560926635, results.Rms); | ||
| Assert.Equal(0.88620288753853549, results.RSquared); | ||
| } | ||
|
|
||
| [Fact] | ||
| public void TestGamBinaryClassificationIni() | ||
| { | ||
| var mlContext = new MLContext(seed: 0); | ||
| var idv = mlContext.Data.CreateTextReader( | ||
| new TextLoader.Arguments() | ||
| { | ||
| HasHeader = false, | ||
| Column = new[] | ||
| { | ||
| new TextLoader.Column("Label", DataKind.BL, 0), | ||
| new TextLoader.Column("Features", DataKind.R4, 1, 9) | ||
| } | ||
| }).Read(GetDataPath("breast-cancer.txt")); | ||
|
|
||
| var pipeline = mlContext.Transforms.ReplaceMissingValues("Features") | ||
| .Append(mlContext.BinaryClassification.Trainers.GeneralizedAdditiveModels()); | ||
| var model = pipeline.Fit(idv); | ||
| var data = model.Transform(idv); | ||
|
|
||
| var roleMappedSchema = new RoleMappedSchema(data.Schema, false, | ||
| new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"), | ||
| new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label")); | ||
|
|
||
| var calibratedPredictor = model.LastTransformer.Model as CalibratedPredictor; | ||
| var predictor = calibratedPredictor.SubPredictor as ICanSaveInIniFormat; | ||
| string modelIniPath = GetOutputPath(FullTestName + "-model.ini"); | ||
|
|
||
| using (Stream iniStream = File.Create(modelIniPath)) | ||
| using (StreamWriter iniWriter = Utils.OpenWriter(iniStream)) | ||
| predictor.SaveAsIni(iniWriter, roleMappedSchema, calibratedPredictor.Calibrator); | ||
|
|
||
| var results = mlContext.BinaryClassification.Evaluate(data); | ||
|
|
||
| // Getting parity results from maml.exe: | ||
| // maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Binary | ||
| Assert.Equal(0.99545199224483139, results.Auc); | ||
| Assert.Equal(0.96995708154506433, results.Accuracy); | ||
| Assert.Equal(0.95081967213114749, results.PositivePrecision); | ||
| Assert.Equal(0.96265560165975106, results.PositiveRecall); | ||
| Assert.Equal(0.95670103092783509, results.F1Score); | ||
| Assert.Equal(0.11594021906091197, results.LogLoss); | ||
| } | ||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good comment for review, but for the final check-in, just keep this line. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the top 2 lines of comments.
In reply to: 243141367 [](ancestors = 243141367)