Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 1 addition & 31 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2985,41 +2985,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
{
Host.CheckValue(writer, nameof(writer));
Host.CheckValue(schema, nameof(schema));
Host.CheckValueOrNull(calibrator);
string ensembleIni = TrainedEnsemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator,
InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false);
ensembleIni = AddCalibrationToIni(ensembleIni, calibrator);
writer.WriteLine(ensembleIni);
}

/// <summary>
/// Get the calibration summary in INI format
/// </summary>
private string AddCalibrationToIni(string ini, ICalibrator calibrator)
{
Host.AssertValue(ini);
Host.AssertValueOrNull(calibrator);

if (calibrator == null)
return ini;

if (calibrator is PlattCalibrator)
{
string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
}
else
{
StringBuilder newSection = new StringBuilder();
newSection.AppendLine();
newSection.AppendLine();
newSection.AppendLine("[TLCCalibration]");
newSection.AppendLine("Type=" + calibrator.GetType().Name);
return ini + newSection;
}
}

JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
{
Host.CheckValue(ctx, nameof(ctx));
Expand Down
112 changes: 111 additions & 1 deletion src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Linq;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.Calibrator;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Core.Data;
Expand Down Expand Up @@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue)
}

public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution,
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat
{
private readonly double[][] _binUpperBounds;
private readonly double[][] _binEffects;
Expand Down Expand Up @@ -833,6 +834,7 @@ private void Map(in VBuffer<float> features, ref float response)

double value = Intercept;
var featuresValues = features.GetValues();

if (features.IsDense)
{
for (int i = 0; i < featuresValues.Length; ++i)
Expand Down Expand Up @@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<flo
Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize);
}

void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
{
Host.CheckValue(writer, nameof(writer));
var ensemble = new TreeEnsemble();

for (int featureIndex = 0; featureIndex < _numFeatures; featureIndex++)
{
var effects = _binEffects[featureIndex];
var binThresholds = _binUpperBounds[featureIndex];

Host.Assert(effects.Length == binThresholds.Length);
var numLeaves = effects.Length;
var numInternalNodes = numLeaves - 1;

var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray();
var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds);
var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects);
ensemble.AddTree(tree);
}

// Adding the intercept as a dummy tree with the output values being the model intercept,
Copy link
Contributor

@rogancarr rogancarr Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good comment for review, but for the final check-in, just keep this line. #Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the top 2 lines of comments.


In reply to: 243141367 [](ancestors = 243141367)

// works for reaching parity.
var interceptTree = CreateRegressionTree(
numLeaves: 2,
splitFeatures: new[] { 0 },
rawThresholds: new[] { 0f },
lteChild: new[] { ~0 },
gtChild: new[] { ~1 },
leafValues: new[] { Intercept, Intercept });
ensemble.AddTree(interceptTree);

var ini = FastTreeIniFileUtils.TreeEnsembleToIni(
Host, ensemble, schema, calibrator, string.Empty, false, false);

// Remove the SplitGain values which are all 0.
// It's eaiser to remove them here, than to modify the FastTree code.
var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain="));
ini = string.Join("\n", goodLines);
writer.WriteLine(ini);
}

// GAM bins should be converted to balanced trees / binary search trees
// so that scoring takes O(log(n)) instead of O(n). The following utility
// creates a balanced tree.
private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds)
Copy link
Contributor

@rogancarr rogancarr Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does returning a list limit what C# or .NET versions we can target? If so, we can switch this to an out pattern. #ByDesign

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuples are heavily used as return type in Static API, so using it here doesn't create more version restrictions than what we already have in the rest of the code. I think this way it fine.


In reply to: 243140927 [](ancestors = 243140927)

{
var binIndices = Enumerable.Range(0, numInternalNodes).ToArray();
var internalNodeIndices = new List<int>();
var lteChild = new List<int>();
var gtChild = new List<int>();
var internalNodeId = numInternalNodes;

CreateBalancedTreeRecursive(
0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
// internalNodeId should have been counted all the way down to 0 (root node)
Host.Assert(internalNodeId == 0);

var tree = (
thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(),
lteChild: lteChild.ToArray(),
gtChild: gtChild.ToArray());
return tree;
}

private int CreateBalancedTreeRecursive(int lower, int upper,
List<int> internalNodeIndices, List<int> lteChild, List<int> gtChild, ref int internalNodeId)
{
if (lower > upper)
{
// Base case: we've reached a leaf node
Host.Assert(lower == upper + 1);
return ~lower;
Copy link
Member

@singlis singlis Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why negate lower? #Resolved

Copy link
Author

@shmoradims shmoradims Dec 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The leaf values are negated in the ini format (to distinguish from the internal nodes). The base case happens when we reach leaf nodes.

I didn't want to add too much comments about the ini format details, because frankly the rest of the code is the only documentation we have for that format.


In reply to: 243443992 [](ancestors = 243443992)

}
else
{
// This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse.
// Preorder is the only option, because we need the results of both left/right recursions for populating the lists.
// As a result, lists are populated in reverse, because the root node should be the first item on the lists.
// Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree.
var mid = (lower + upper) / 2;
Copy link
Member

@singlis singlis Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think this is good information - but for me its missing a key point, which is you are building a balanced tree by dividing the array in half, then taking each half and dividing that half where each midway point becomes a node in the tree.
#Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I had mentioned binary-search-tree in the function above, which is exactly the key point you mentioned. I'll re-mention it here.


In reply to: 243444717 [](ancestors = 243444717)

var left = CreateBalancedTreeRecursive(
lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
var right = CreateBalancedTreeRecursive(
mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
internalNodeIndices.Insert(0, mid);
lteChild.Insert(0, left);
Copy link
Contributor

@rogancarr rogancarr Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like we have one depth too many. Won't the final Right and Left children be the same? #Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RegressionTree has internal nodes and leaf nodes. That extra depth ends up in the leaf nodes which ends the recursion. The final left and right nodes will be two leaf nodes, OR 1 leaf node and 1 internal node (which will subsequently end up with 2 leaf nodes).


In reply to: 243455772 [](ancestors = 243455772)

gtChild.Insert(0, right);
return --internalNodeId;
}
}
private static RegressionTree CreateRegressionTree(
int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues)
{
var numInternalNodes = numLeaves - 1;
return RegressionTree.Create(
numLeaves: numLeaves,
splitFeatures: splitFeatures,
rawThresholds: rawThresholds,
lteChild: lteChild,
gtChild: gtChild.ToArray(),
leafValues: leafValues,
// Ignored arguments
splitGain: new double[numInternalNodes],
defaultValueForMissing: new float[numInternalNodes],
categoricalSplitFeatures: new int[numInternalNodes][],
categoricalSplit: new bool[numInternalNodes]);
}

/// <summary>
/// The GAM model visualization command. Because the data access commands must access private members of
/// <see cref="GamModelParametersBase"/>, it is convenient to have the command itself nested within the base
Expand Down
55 changes: 55 additions & 0 deletions src/Microsoft.ML.FastTree/Utils/FastTreeIniFileUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Text;
using Microsoft.ML.Calibrator;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Utilities;

namespace Microsoft.ML.Trainers.FastTree.Internal
{
internal static class FastTreeIniFileUtils
{
public static string TreeEnsembleToIni(
IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator,
string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures)
{
host.CheckValue(ensemble, nameof(ensemble));
host.CheckValue(schema, nameof(schema));

string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
trainingParams, appendFeatureGain, includeZeroGainFeatures);
ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator);
return ensembleIni;
}

/// <summary>
/// Get the calibration summary in INI format
/// </summary>
private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator)
{
host.AssertValue(ini);
host.AssertValueOrNull(calibrator);

if (calibrator == null)
return ini;

if (calibrator is PlattCalibrator)
{
string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
}
else
{
StringBuilder newSection = new StringBuilder();
newSection.AppendLine();
newSection.AppendLine();
newSection.AppendLine("[TLCCalibration]");
newSection.AppendLine("Type=" + calibrator.GetType().Name);
return ini + newSection;
}
}
}
}
104 changes: 104 additions & 0 deletions test/Microsoft.ML.Predictor.Tests/TestIniModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Calibration;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Tools;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.RunTests
{
Expand Down Expand Up @@ -497,4 +510,95 @@ public ProcessDebugInformation RunCommandLine(string commandLine, string dir)
}
}
#endif

public sealed class TestIniModels : TestDataPipeBase
{
public TestIniModels(ITestOutputHelper output) : base(output)
{
}

[Fact]
public void TestGamRegressionIni()
{
var mlContext = new MLContext(seed: 0);
var idv = mlContext.Data.CreateTextReader(
new TextLoader.Arguments()
{
HasHeader = false,
Column = new[]
{
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("Features", DataKind.R4, 1, 9)
}
}).Read(GetDataPath("breast-cancer.txt"));

var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
.Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels());
var model = pipeline.Fit(idv);
var data = model.Transform(idv);

var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"),
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label"));

string modelIniPath = GetOutputPath(FullTestName + "-model.ini");
using (Stream iniStream = File.Create(modelIniPath))
using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
((ICanSaveInIniFormat)model.LastTransformer.Model).SaveAsIni(iniWriter, roleMappedSchema);

var results = mlContext.Regression.Evaluate(data);

// Getting parity results from maml.exe:
// maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Regression
Assert.Equal(0.093256807643323947, results.L1);
Copy link
Contributor

@rogancarr rogancarr Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this test. On one hand, it's good to make sure that it can produce the score. On the other hand, it's a new baseline test for GAMs. I'd say remove it from ML.NET. #ByDesign

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, do you want me to remove both TestGamRegressionIni and TestGamBinaryClassificationIni tests?


In reply to: 243141824 [](ancestors = 243141824)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need a test that checks that the new Ini file produces the same results as the GAM. It would be ideal if you could compare the the outputs with a full file compare instead of using Assert.Equal. Is it possible to do that?


In reply to: 244387146 [](ancestors = 244387146,243141824)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's better to keep these parity tests and also prevent regressions cause by refactoring in fasttree code.

The problem with ini file comparison is the same as testing baseline comparison, which needs to extract numbers in the file and compare them with some tolerance, because verbatim comparison will fail due to changes in the higher decimal places. We already have those facilities but they're fitted for train-test / CV type tasks that are common in ML. I didn't want to modify those files for this one-off ini-file comparison. Also half of this testing is done in TLC, b/c the ini parsing code is only in TLC. The baseline metrics actually come from TLC.

More importantly, the ini-format is not an open-source format and is Microsoft IP. If we want to check-in an ini file for baseline, we should have LCA clear that.

For those reasons, I think this simple test is good enough for now.


In reply to: 244822621 [](ancestors = 244822621,244387146,243141824)

Assert.Equal(0.025707474358979077, results.L2);
Assert.Equal(0.16033550560926635, results.Rms);
Assert.Equal(0.88620288753853549, results.RSquared);
}

[Fact]
public void TestGamBinaryClassificationIni()
{
var mlContext = new MLContext(seed: 0);
var idv = mlContext.Data.CreateTextReader(
new TextLoader.Arguments()
{
HasHeader = false,
Column = new[]
{
new TextLoader.Column("Label", DataKind.BL, 0),
new TextLoader.Column("Features", DataKind.R4, 1, 9)
}
}).Read(GetDataPath("breast-cancer.txt"));

var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
.Append(mlContext.BinaryClassification.Trainers.GeneralizedAdditiveModels());
var model = pipeline.Fit(idv);
var data = model.Transform(idv);

var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"),
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label"));

var calibratedPredictor = model.LastTransformer.Model as CalibratedPredictor;
var predictor = calibratedPredictor.SubPredictor as ICanSaveInIniFormat;
string modelIniPath = GetOutputPath(FullTestName + "-model.ini");

using (Stream iniStream = File.Create(modelIniPath))
using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
predictor.SaveAsIni(iniWriter, roleMappedSchema, calibratedPredictor.Calibrator);

var results = mlContext.BinaryClassification.Evaluate(data);

// Getting parity results from maml.exe:
// maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Binary
Assert.Equal(0.99545199224483139, results.Auc);
Assert.Equal(0.96995708154506433, results.Accuracy);
Assert.Equal(0.95081967213114749, results.PositivePrecision);
Assert.Equal(0.96265560165975106, results.PositiveRecall);
Assert.Equal(0.95670103092783509, results.F1Score);
Assert.Equal(0.11594021906091197, results.LogLoss);
}
}

}